notify: Add an equivalent to sd_notify_barrier

This also add an RAII guard for that can be used for the same purpose.
(Mostly for use in a closure.)
This commit is contained in:
Mathieu Trossevin 2024-01-05 21:27:09 +01:00
parent 7d930c3e42
commit d60d906483
4 changed files with 202 additions and 17 deletions

View file

@ -13,7 +13,7 @@ rust-version = "1.74.0"
[features]
default = []
listenfd = ["dep:rustix", "rustix/fs", "dep:libc"]
notify = ["dep:libc", "dep:rustix"]
notify = ["dep:libc", "dep:rustix", "rustix/pipe", "rustix/event"]
[dependencies]
libc = { version = "0.2.150", optional = true }

View file

@ -121,3 +121,42 @@ impl Display for OtherStateError {
}
impl Error for OtherStateError {}
#[derive(Debug)]
pub enum BarrierError {
Notify(NotifyError),
FailedPipeCreation(Errno),
FailedPolling(Errno),
Timedout,
}
impl From<NotifyError> for BarrierError {
fn from(value: NotifyError) -> Self {
Self::Notify(value)
}
}
impl Display for BarrierError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Notify(_) => write!(f, "Couldn't notify of barrier."),
Self::FailedPipeCreation(errno) => write!(
f,
"Couldn't create pipe to serve as a notify syncronisation barrier : error {errno}."
),
Self::FailedPolling(errno) => write!(f, "poll failed with error : {errno}."),
Self::Timedout => write!(f, "Notification synchronisation timedout."),
}
}
}
impl Error for BarrierError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Notify(source) => Some(source),
Self::FailedPipeCreation(source) => Some(source),
Self::FailedPolling(source) => Some(source),
Self::Timedout => None,
}
}
}

View file

@ -20,9 +20,15 @@ use std::ffi::OsString;
use std::io::IoSlice;
use std::os::unix::net::UnixDatagram;
use rustix::fd::{AsFd, BorrowedFd};
use rustix::{
fd::{AsFd, BorrowedFd},
pipe::PipeFlags,
};
use self::error::{NewNotifierError, NotifyError};
use self::{
error::{BarrierError, NewNotifierError, NotifyError},
types::PollTimeout,
};
pub mod error;
pub mod types;
@ -159,6 +165,92 @@ impl Notifier {
pub fn notify(&self, state: &[NotifyState<'_>]) -> Result<(), NotifyError> {
self.notify_with_fds(state, &[])
}
pub fn barrier(&self, timeout: PollTimeout) -> Result<(), BarrierError> {
let (to_poll, sent) = rustix::pipe::pipe_with(PipeFlags::CLOEXEC)
.map_err(BarrierError::FailedPipeCreation)?;
self.notify_with_fds(
&[NotifyState::Other(types::OtherState::barrier())],
&[sent.as_fd()],
)?;
core::mem::drop(sent);
let to_poll = rustix::event::PollFd::new(&to_poll, rustix::event::PollFlags::HUP);
rustix::event::poll(&mut [to_poll], timeout.into())
.map_err(BarrierError::FailedPolling)
.and_then(|events| {
if events == 0_usize {
return Err(BarrierError::Timedout);
}
Ok(())
})
}
pub fn guard(&self, timeout: PollTimeout) -> NotifyBarrierGuard<'_> {
NotifyBarrierGuard {
notifier: self,
timeout,
}
}
pub fn with_guard<F, T>(&self, timeout: PollTimeout, f: F) -> T
where
F: FnOnce(NotifyBarrierGuard) -> T,
{
f(self.guard(timeout))
}
}
pub struct NotifyBarrierGuard<'a> {
notifier: &'a Notifier,
timeout: PollTimeout,
}
impl<'a> NotifyBarrierGuard<'a> {
/// Notify service manager about status changes.
///
/// Send a notification to the manager about service status changes. Also see [`notify_with_fds()`](Self::notify_with_fds) to send file descriptors.
///
/// # Errors
///
/// This function will error out if the passed [`NotifyState`] do not follow the rules set by systemd or if they couldn't be fully sent.
pub fn notify(&self, state: &[NotifyState<'_>]) -> Result<(), NotifyError> {
self.notifier.notify(state)
}
/// Notify service manager about status change and send file descriptors.
///
/// Use this together with [`NotifyState::FdStore`]. Otherwise works like [`notify()`](Self::notify).
///
/// # Errors
///
/// This function will error out if the passed [`NotifyState`] do not follow the rules set by systemd or if they couldn't be fully sent.
pub fn notify_with_fds(
&self,
state: &[NotifyState<'_>],
fds: &[BorrowedFd],
) -> Result<(), NotifyError> {
self.notifier.notify_with_fds(state, fds)
}
pub fn barrier(&self, timeout: PollTimeout) -> Result<(), BarrierError> {
self.notifier.barrier(timeout)
}
pub fn with_guard<F, T>(&self, timeout: PollTimeout, f: F) -> T
where
F: FnOnce(Self) -> T,
{
f(self.notifier.guard(timeout))
}
}
impl Drop for NotifyBarrierGuard<'_> {
fn drop(&mut self) {
self.barrier(self.timeout).unwrap_or_default();
}
}
/// Check for watchdog support at runtime
@ -240,24 +332,24 @@ pub enum NotifyState<'a> {
impl<'a> Display for NotifyState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
NotifyState::BusError(types::BusError(s)) => write!(f, "BUSERROR={s}"),
NotifyState::BusError(s) => write!(f, "BUSERROR={s}"),
NotifyState::Errno(e) => write!(f, "ERRNO={e}"),
NotifyState::FdName(types::FdName(name)) => write!(f, "FDNAME={name}"),
NotifyState::FdName(name) => write!(f, "FDNAME={name}"),
NotifyState::FdStore => f.write_str("FDSTORE=1"),
NotifyState::FdStoreRemove => f.write_str("FDSTOREREMOVE=1"),
NotifyState::FdpollDisable => f.write_str("FDPOLL=0"),
NotifyState::Mainpid(pid) => write!(f, "MAINPID={pid}"),
NotifyState::Other(types::OtherState(message)) => f.write_str(message),
NotifyState::Other(message) => f.write_str(message.as_ref()),
NotifyState::Ready => f.write_str("READY=1"),
NotifyState::Reloading => f.write_str("RELOADING=1"),
NotifyState::Status(types::StatusLine(status)) => write!(f, "STATUS={status}"),
NotifyState::Status(status) => write!(f, "STATUS={status}"),
NotifyState::Stopping => f.write_str("STOPPING=1"),
NotifyState::Watchdog => f.write_str("WATCHDOG=1"),
NotifyState::WatchdogTrigger => f.write_str("WATCHDOG=trigger"),
NotifyState::WatchdogUsec(types::Milliseconds(milliseconds)) => {
NotifyState::WatchdogUsec(milliseconds) => {
write!(f, "WATCHDOG_USEC={milliseconds}")
}
NotifyState::ExtendTimeoutUsec(types::Milliseconds(milliseconds)) => {
NotifyState::ExtendTimeoutUsec(milliseconds) => {
write!(f, "EXTEND_TIMEOUT_USEC={milliseconds}")
}
}

View file

@ -1,5 +1,7 @@
//! Newtypes used by [`NotifyState`](super::NotifyState)
use std::fmt::Display;
use super::error;
/// Allowed File descriptor name.
@ -11,7 +13,7 @@ use super::error;
/// * Doesn't contains control characters.
/// * Doesn't contains a colon (`:`).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FdName<'a>(pub(super) &'a str);
pub struct FdName<'a>(&'a str);
impl<'a> TryFrom<&'a str> for FdName<'a> {
type Error = error::FdNameError;
@ -46,11 +48,18 @@ impl AsRef<str> for FdName<'_> {
}
}
impl Display for FdName<'_> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self.0, f)
}
}
/// A status line for [`NotifyState::Status`](super::NotifyState::Status).
///
/// As the name explains it needs to be a single line.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StatusLine<'a>(pub(super) &'a str);
pub struct StatusLine<'a>(&'a str);
impl<'a> TryFrom<&'a str> for StatusLine<'a> {
type Error = error::StatusLineError;
@ -70,8 +79,15 @@ impl AsRef<str> for StatusLine<'_> {
}
}
impl Display for StatusLine<'_> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self.0, f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Milliseconds(pub(super) u64);
pub struct Milliseconds(u64);
impl From<u64> for Milliseconds {
fn from(value: u64) -> Self {
@ -79,9 +95,32 @@ impl From<u64> for Milliseconds {
}
}
impl AsRef<u64> for Milliseconds {
fn as_ref(&self) -> &u64 {
&self.0
impl From<Milliseconds> for u64 {
fn from(value: Milliseconds) -> Self {
value.0
}
}
impl Display for Milliseconds {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PollTimeout(i32);
impl From<i32> for PollTimeout {
fn from(value: i32) -> Self {
Self(value)
}
}
impl From<PollTimeout> for i32 {
#[inline]
fn from(value: PollTimeout) -> Self {
value.0
}
}
@ -89,7 +128,7 @@ impl AsRef<u64> for Milliseconds {
///
/// Right now it doesn't impose any additional constraint on [`str`]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BusError<'a>(pub(super) &'a str);
pub struct BusError<'a>(&'a str);
impl<'a> From<&'a str> for BusError<'a> {
fn from(value: &'a str) -> Self {
@ -98,13 +137,21 @@ impl<'a> From<&'a str> for BusError<'a> {
}
impl AsRef<str> for BusError<'_> {
#[inline]
fn as_ref(&self) -> &str {
self.0
}
}
impl Display for BusError<'_> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self.0, f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct OtherState<'a>(pub(super) &'a str);
pub struct OtherState<'a>(&'a str);
impl<'a> TryFrom<&'a str> for OtherState<'a> {
type Error = error::OtherStateError;
@ -127,7 +174,14 @@ impl<'a> TryFrom<&'a str> for OtherState<'a> {
}
impl AsRef<str> for OtherState<'_> {
#[inline]
fn as_ref(&self) -> &str {
self.0
}
}
impl OtherState<'_> {
pub(super) const fn barrier() -> Self {
Self("BARRIER=1")
}
}