Start adding support for NOTIFY_SOCKET

This commit is contained in:
Mathieu Trossevin 2023-12-08 14:46:54 +01:00
parent bf2091bb63
commit c142afbc2c
Signed by: mtrossevin
GPG key ID: D1DBB7EA828374E9
4 changed files with 236 additions and 0 deletions

View file

@ -8,6 +8,7 @@ edition = "2021"
[features]
default = []
listenfd = ["dep:rustix", "rustix/fs", "rustix/net", "dep:libc"]
notify = ["dep:libc", "dep:rustix", "rustix/net"]
[dependencies]
libc = { version = "0.2.150", optional = true }

View file

@ -1,2 +1,4 @@
#[cfg(feature = "listenfd")]
pub mod listen;
#[cfg(feature = "notify")]
pub mod notify;

55
src/notify/error.rs Normal file
View file

@ -0,0 +1,55 @@
use core::fmt::Display;
use std::error::Error;
use rustix::io::Errno;
#[derive(Debug)]
pub enum NotifyError {
SanityCheck(SanityCheckError),
InvalidAbstractSocket(Errno),
InvalidSocketPath(Errno),
CouldntOpenSocket(std::io::Error),
SendMsg(Errno),
AncillarySize,
PartialSend,
}
impl From<SanityCheckError> for NotifyError {
fn from(value: SanityCheckError) -> Self {
Self::SanityCheck(value)
}
}
#[derive(Debug)]
pub enum SanityCheckError {
InvalidFdName(FdNameError),
}
impl Display for SanityCheckError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SanityCheckError::InvalidFdName(error) => {
write!(f, "The value of FDNAME was invalid : {error}.")
}
}
}
}
impl Error for SanityCheckError {}
#[derive(Debug)]
pub enum FdNameError {
TooLong { length: usize, name: String },
NotAsciiNonControl { disallowed_char: char, name: String },
ContainColon(String),
}
impl Display for FdNameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FdNameError::TooLong { length, name } => write!(f, "The file descriptor name {name:?} is too long (is {length} characters and should be less than 255)."),
FdNameError::NotAsciiNonControl { disallowed_char, name } => write!(f, "The file descriptor name {name:?} contains invalid character '{disallowed_char}' (only ASCII allowed)."),
FdNameError::ContainColon(name) => write!(f, "The file descriptor name {name:?} contains a colon (':') which isn't allowed."),
}
}
}

178
src/notify/mod.rs Normal file
View file

@ -0,0 +1,178 @@
use core::fmt::Display;
use std::env;
use std::ffi::OsString;
use std::io::IoSlice;
use std::os::unix::net::UnixDatagram;
use rustix::fd::{AsFd, BorrowedFd};
use self::error::{FdNameError, NotifyError, SanityCheckError};
pub mod error;
pub struct Sent(bool);
pub fn notify_with_fds(
unset_env: bool,
state: &[NotifyState<'_>],
fds: &[BorrowedFd],
) -> Result<Sent, NotifyError> {
let env_sock = match env::var_os("NOTIFY_SOCKET") {
None => return Ok(Sent(false)),
Some(v) => v,
};
if unset_env {
env::remove_var("NOTIFY_SOCKET");
}
state.iter().try_for_each(NotifyState::sanity_check)?;
// If the first character of `$NOTIFY_SOCKET` is '@', the string
// is understood as Linux abstract namespace socket.
let socket_addr = match env_sock.as_encoded_bytes().strip_prefix(b"@").map(|v| {
// SAFETY:
// - Only strip ASCII '@' which is a non-empty UTF-8 substring
unsafe { OsString::from_encoded_bytes_unchecked(v.to_vec()) }
}) {
Some(stripped_addr) => {
rustix::net::SocketAddrUnix::new_abstract_name(stripped_addr.as_encoded_bytes())
.map_err(NotifyError::InvalidAbstractSocket)?
}
None => {
rustix::net::SocketAddrUnix::new(env_sock).map_err(NotifyError::InvalidSocketPath)?
}
};
let socket = UnixDatagram::unbound().map_err(NotifyError::CouldntOpenSocket)?;
let msg = state
.iter()
.fold(String::new(), |acc, state| format!("{acc}{state}\n"))
.into_bytes();
let msg_len = msg.len();
let msg_iov = IoSlice::new(&msg);
let mut ancillary = if !fds.is_empty() {
let mut ancillary = rustix::net::SendAncillaryBuffer::default();
let tmp = rustix::net::SendAncillaryMessage::ScmRights(fds);
if !ancillary.push(tmp) {
return Err(NotifyError::AncillarySize);
}
ancillary
} else {
rustix::net::SendAncillaryBuffer::default()
};
let sent_len = rustix::net::sendmsg_unix(
socket.as_fd(),
&socket_addr,
&[msg_iov],
&mut ancillary,
rustix::net::SendFlags::empty(),
)
.map_err(NotifyError::SendMsg)?;
if sent_len != msg_len {
return Err(NotifyError::PartialSend);
}
Ok(Sent(true))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NotifyState<'a> {
/// D-Bus error-style error code.
BusError(&'a str),
/// errno-style error code.
Errno(u8),
/// A name for the submitted file descriptors.
FdName(&'a str),
/// Stores additional file descriptors in the service manager. Use [`notify_with_fds`] with this.
Fdstore,
/// Remove stored file descriptors. Must be used together with [`NotifyState::Fdname`].
FdstoreRemove,
/// Tell the service manager to not poll the filedescriptors for errors. This causes
/// systemd to hold on to broken file descriptors which must be removed manually.
/// Must be used together with [`NotifyState::Fdstore`].
FdpollDisable,
/// The main process ID of the service, in case of forking applications.
Mainpid(libc::pid_t),
/// Custom state change, as a `KEY=VALUE` string.
Other(&'a str),
/// Service startup is finished.
Ready,
/// Service is reloading.
Reloading,
/// Custom status change.
Status(&'a str),
/// Service is beginning to shutdown.
Stopping,
/// Tell the service manager to update the watchdog timestamp.
Watchdog,
/// Tell the service manager to execute the configured watchdog option.
WatchdogTrigger,
/// Reset watchdog timeout value during runtime.
/// The value is in milliseconds.
WatchdogUsec(u64),
/// Tells the service manager to extend the startup, runtime or shutdown service timeout corresponding the current state.
/// The value is in milliseconds.
ExtendTimeoutUsec(u64),
}
impl<'a> Display for NotifyState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
NotifyState::BusError(s) => write!(f, "BUSERROR={s}"),
NotifyState::Errno(e) => write!(f, "ERRNO={e}"),
NotifyState::FdName(name) => write!(f, "FDNAME={name}"),
NotifyState::Fdstore => f.write_str("FDSTORE=1"),
NotifyState::FdstoreRemove => f.write_str("FDSTOREREMOVE=1"),
NotifyState::FdpollDisable => f.write_str("FDPOLL=0"),
NotifyState::Mainpid(pid) => write!(f, "MAINPID={pid}"),
NotifyState::Other(message) => Display::fmt(message, f),
NotifyState::Ready => f.write_str("READY=1"),
NotifyState::Reloading => f.write_str("READY=1"),
NotifyState::Status(status) => write!(f, "STATUS={status}"),
NotifyState::Stopping => f.write_str("STOPPING=1"),
NotifyState::Watchdog => f.write_str("WATCHDOG=1"),
NotifyState::WatchdogTrigger => f.write_str("WATCHDOG=trigger"),
NotifyState::WatchdogUsec(milliseconds) => write!(f, "WATCHDOG_USEC={milliseconds}"),
NotifyState::ExtendTimeoutUsec(milliseconds) => {
write!(f, "EXTEND_TIMEOUT_USEC={milliseconds}")
}
}
}
}
impl<'a> NotifyState<'a> {
fn sanity_check(&self) -> Result<(), SanityCheckError> {
match self {
NotifyState::FdName(name) => {
validate_fd_name(name).map_err(SanityCheckError::InvalidFdName)
}
_ => Ok(()),
}
}
}
fn validate_fd_name(name: &str) -> Result<(), FdNameError> {
if name.len() > 255 {
return Err(FdNameError::TooLong {
length: name.len(),
name: name.into(),
});
}
for c in name.chars() {
if !c.is_ascii() || c.is_ascii_control() {
return Err(FdNameError::NotAsciiNonControl {
disallowed_char: c,
name: name.into(),
});
}
if c == ':' {
return Err(FdNameError::ContainColon(name.into()));
}
}
Ok(())
}