179 lines
6 KiB
Rust
179 lines
6 KiB
Rust
|
use core::fmt::Display;
|
||
|
use std::env;
|
||
|
use std::ffi::OsString;
|
||
|
use std::io::IoSlice;
|
||
|
use std::os::unix::net::UnixDatagram;
|
||
|
|
||
|
use rustix::fd::{AsFd, BorrowedFd};
|
||
|
|
||
|
use self::error::{FdNameError, NotifyError, SanityCheckError};
|
||
|
|
||
|
pub mod error;
|
||
|
|
||
|
pub struct Sent(bool);
|
||
|
|
||
|
pub fn notify_with_fds(
|
||
|
unset_env: bool,
|
||
|
state: &[NotifyState<'_>],
|
||
|
fds: &[BorrowedFd],
|
||
|
) -> Result<Sent, NotifyError> {
|
||
|
let env_sock = match env::var_os("NOTIFY_SOCKET") {
|
||
|
None => return Ok(Sent(false)),
|
||
|
Some(v) => v,
|
||
|
};
|
||
|
|
||
|
if unset_env {
|
||
|
env::remove_var("NOTIFY_SOCKET");
|
||
|
}
|
||
|
|
||
|
state.iter().try_for_each(NotifyState::sanity_check)?;
|
||
|
|
||
|
// If the first character of `$NOTIFY_SOCKET` is '@', the string
|
||
|
// is understood as Linux abstract namespace socket.
|
||
|
let socket_addr = match env_sock.as_encoded_bytes().strip_prefix(b"@").map(|v| {
|
||
|
// SAFETY:
|
||
|
// - Only strip ASCII '@' which is a non-empty UTF-8 substring
|
||
|
unsafe { OsString::from_encoded_bytes_unchecked(v.to_vec()) }
|
||
|
}) {
|
||
|
Some(stripped_addr) => {
|
||
|
rustix::net::SocketAddrUnix::new_abstract_name(stripped_addr.as_encoded_bytes())
|
||
|
.map_err(NotifyError::InvalidAbstractSocket)?
|
||
|
}
|
||
|
None => {
|
||
|
rustix::net::SocketAddrUnix::new(env_sock).map_err(NotifyError::InvalidSocketPath)?
|
||
|
}
|
||
|
};
|
||
|
|
||
|
let socket = UnixDatagram::unbound().map_err(NotifyError::CouldntOpenSocket)?;
|
||
|
let msg = state
|
||
|
.iter()
|
||
|
.fold(String::new(), |acc, state| format!("{acc}{state}\n"))
|
||
|
.into_bytes();
|
||
|
let msg_len = msg.len();
|
||
|
let msg_iov = IoSlice::new(&msg);
|
||
|
|
||
|
let mut ancillary = if !fds.is_empty() {
|
||
|
let mut ancillary = rustix::net::SendAncillaryBuffer::default();
|
||
|
let tmp = rustix::net::SendAncillaryMessage::ScmRights(fds);
|
||
|
if !ancillary.push(tmp) {
|
||
|
return Err(NotifyError::AncillarySize);
|
||
|
}
|
||
|
ancillary
|
||
|
} else {
|
||
|
rustix::net::SendAncillaryBuffer::default()
|
||
|
};
|
||
|
|
||
|
let sent_len = rustix::net::sendmsg_unix(
|
||
|
socket.as_fd(),
|
||
|
&socket_addr,
|
||
|
&[msg_iov],
|
||
|
&mut ancillary,
|
||
|
rustix::net::SendFlags::empty(),
|
||
|
)
|
||
|
.map_err(NotifyError::SendMsg)?;
|
||
|
|
||
|
if sent_len != msg_len {
|
||
|
return Err(NotifyError::PartialSend);
|
||
|
}
|
||
|
|
||
|
Ok(Sent(true))
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||
|
pub enum NotifyState<'a> {
|
||
|
/// D-Bus error-style error code.
|
||
|
BusError(&'a str),
|
||
|
/// errno-style error code.
|
||
|
Errno(u8),
|
||
|
/// A name for the submitted file descriptors.
|
||
|
FdName(&'a str),
|
||
|
/// Stores additional file descriptors in the service manager. Use [`notify_with_fds`] with this.
|
||
|
Fdstore,
|
||
|
/// Remove stored file descriptors. Must be used together with [`NotifyState::Fdname`].
|
||
|
FdstoreRemove,
|
||
|
/// Tell the service manager to not poll the filedescriptors for errors. This causes
|
||
|
/// systemd to hold on to broken file descriptors which must be removed manually.
|
||
|
/// Must be used together with [`NotifyState::Fdstore`].
|
||
|
FdpollDisable,
|
||
|
/// The main process ID of the service, in case of forking applications.
|
||
|
Mainpid(libc::pid_t),
|
||
|
/// Custom state change, as a `KEY=VALUE` string.
|
||
|
Other(&'a str),
|
||
|
/// Service startup is finished.
|
||
|
Ready,
|
||
|
/// Service is reloading.
|
||
|
Reloading,
|
||
|
/// Custom status change.
|
||
|
Status(&'a str),
|
||
|
/// Service is beginning to shutdown.
|
||
|
Stopping,
|
||
|
/// Tell the service manager to update the watchdog timestamp.
|
||
|
Watchdog,
|
||
|
/// Tell the service manager to execute the configured watchdog option.
|
||
|
WatchdogTrigger,
|
||
|
/// Reset watchdog timeout value during runtime.
|
||
|
/// The value is in milliseconds.
|
||
|
WatchdogUsec(u64),
|
||
|
/// Tells the service manager to extend the startup, runtime or shutdown service timeout corresponding the current state.
|
||
|
/// The value is in milliseconds.
|
||
|
ExtendTimeoutUsec(u64),
|
||
|
}
|
||
|
|
||
|
impl<'a> Display for NotifyState<'a> {
|
||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
match *self {
|
||
|
NotifyState::BusError(s) => write!(f, "BUSERROR={s}"),
|
||
|
NotifyState::Errno(e) => write!(f, "ERRNO={e}"),
|
||
|
NotifyState::FdName(name) => write!(f, "FDNAME={name}"),
|
||
|
NotifyState::Fdstore => f.write_str("FDSTORE=1"),
|
||
|
NotifyState::FdstoreRemove => f.write_str("FDSTOREREMOVE=1"),
|
||
|
NotifyState::FdpollDisable => f.write_str("FDPOLL=0"),
|
||
|
NotifyState::Mainpid(pid) => write!(f, "MAINPID={pid}"),
|
||
|
NotifyState::Other(message) => Display::fmt(message, f),
|
||
|
NotifyState::Ready => f.write_str("READY=1"),
|
||
|
NotifyState::Reloading => f.write_str("READY=1"),
|
||
|
NotifyState::Status(status) => write!(f, "STATUS={status}"),
|
||
|
NotifyState::Stopping => f.write_str("STOPPING=1"),
|
||
|
NotifyState::Watchdog => f.write_str("WATCHDOG=1"),
|
||
|
NotifyState::WatchdogTrigger => f.write_str("WATCHDOG=trigger"),
|
||
|
NotifyState::WatchdogUsec(milliseconds) => write!(f, "WATCHDOG_USEC={milliseconds}"),
|
||
|
NotifyState::ExtendTimeoutUsec(milliseconds) => {
|
||
|
write!(f, "EXTEND_TIMEOUT_USEC={milliseconds}")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<'a> NotifyState<'a> {
|
||
|
fn sanity_check(&self) -> Result<(), SanityCheckError> {
|
||
|
match self {
|
||
|
NotifyState::FdName(name) => {
|
||
|
validate_fd_name(name).map_err(SanityCheckError::InvalidFdName)
|
||
|
}
|
||
|
_ => Ok(()),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn validate_fd_name(name: &str) -> Result<(), FdNameError> {
|
||
|
if name.len() > 255 {
|
||
|
return Err(FdNameError::TooLong {
|
||
|
length: name.len(),
|
||
|
name: name.into(),
|
||
|
});
|
||
|
}
|
||
|
|
||
|
for c in name.chars() {
|
||
|
if !c.is_ascii() || c.is_ascii_control() {
|
||
|
return Err(FdNameError::NotAsciiNonControl {
|
||
|
disallowed_char: c,
|
||
|
name: name.into(),
|
||
|
});
|
||
|
}
|
||
|
if c == ':' {
|
||
|
return Err(FdNameError::ContainColon(name.into()));
|
||
|
}
|
||
|
}
|
||
|
Ok(())
|
||
|
}
|