chore(fmt)

This commit is contained in:
Mathieu Trossevin 2023-12-07 21:13:06 +01:00
parent 6984d1a0ea
commit a0728e086f
2 changed files with 91 additions and 78 deletions

View file

@ -13,10 +13,7 @@ pub enum ReceiveError {
NoListenFD,
NotUnicodeListenFD(OsString),
ListenFDParse(ParseIntError),
PidMismatch{
expected: u32,
found: u32,
},
PidMismatch { expected: u32, found: u32 },
GetFds(GetFdsError),
}
@ -45,7 +42,10 @@ impl Error for ReceiveError {
ReceiveError::NoListenFD => None,
ReceiveError::NotUnicodeListenFD(_) => None,
ReceiveError::ListenFDParse(error) => Some(error),
ReceiveError::PidMismatch { expected: _, found: _ } => None,
ReceiveError::PidMismatch {
expected: _,
found: _,
} => None,
ReceiveError::GetFds(error) => Some(error),
}
}
@ -66,7 +66,10 @@ impl From<ReceiveError> for ReceiveNameError {
impl Display for ReceiveNameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReceiveNameError::NoListenFDName => write!(f, "Couldn't find FDs name : the variable {FD_NAMES_VAR} doesn't exists."),
ReceiveNameError::NoListenFDName => write!(
f,
"Couldn't find FDs name : the variable {FD_NAMES_VAR} doesn't exists."
),
ReceiveNameError::Receive(error) => Display::fmt(error, f),
}
}
@ -94,4 +97,4 @@ impl Display for GetFdsError {
}
}
impl Error for GetFdsError {}
impl Error for GetFdsError {}

View file

@ -1,12 +1,12 @@
use std::ffi::{OsStr, OsString};
use std::fs::File;
use std::net::{TcpListener, TcpStream, UdpSocket};
use std::os::unix::net::{UnixListener, UnixStream, UnixDatagram};
use std::{env, process};
use std::ffi::{OsString, OsStr};
use std::os::fd::RawFd;
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
use std::{env, process};
use error::{ReceiveError, ReceiveNameError, GetFdsError};
use rustix::fd::{BorrowedFd, OwnedFd, AsRawFd};
use error::{GetFdsError, ReceiveError, ReceiveNameError};
use rustix::fd::{AsRawFd, BorrowedFd, OwnedFd};
use rustix::fs::FileType;
use rustix::net::SocketType;
@ -18,56 +18,56 @@ const FD_NUMBER_VAR: &str = "LISTEN_FDS";
const FD_NAMES_VAR: &str = "LISTEN_FDNAMES";
/// File Descriptor passed by systemd.
///
///
/// They are duplicated from the actual passed file descriptor so as to be safe to use from rust code.
#[derive(Debug)]
pub enum FileDescriptor {
/// The file descriptor is a [`File`](std::fs::File).
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenSpecial=`.
File(File),
/// The file descriptor is a directory.
Directory(OwnedFd),
/// The file descriptor is a FIFO
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenFIFO=`
Fifo(OwnedFd),
/// The file descriptor is a TCP socket listening for incoming connexions.
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with an IP address.
TcpListener(TcpListener),
/// The file descriptor is a TCP socket that doesn't listen for incoming connexions.
TcpStream(TcpStream),
/// The file descriptor is an UDP .
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with an IP address.
UdpSocket(UdpSocket),
/// The file descriptor is another inet socket
///
///
/// You should figure out what exactly it is before you use it.
InetOther(OwnedFd),
/// The file descriptor is a Unix Stream Socket listening for incoming connexions.
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with a path or starting with `@` (abstract socket).
UnixListener(UnixListener),
/// The file descriptor is a Unix Stream Socket.
UnixStream(UnixStream),
/// The file descriptor is a Unix Datagram Socket.
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with a path or starting with `@` (abstract socket).
UnixDatagram(UnixDatagram),
/// The file descriptor is a Unix SEQPACKET Socket.
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenSequentialPacket=`
UnixSequentialPaquet(OwnedFd),
/// The file descriptor is another type of Unix Socket.
UnixOther(OwnedFd),
/// The file descriptor is some other socket type.
///
///
/// Probably a Netlink family socket if from a .socket unit but you should check yourself.
SocketOther(OwnedFd),
/// The file descriptor is a message queue.
///
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenMessageQueue=`
MessageQueue(OwnedFd),
/// The file descriptor is something else.
@ -76,34 +76,39 @@ pub enum FileDescriptor {
impl FileDescriptor {
/// Get any file descriptor passed by systemd or anything implementing the `LISTEN_FD` protocol.
///
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
/// The file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
pub fn receive(unset_env: bool) -> Result<impl IntoIterator<Item = Self>, ReceiveError> {
let pid = env::var_os(PID_VAR).ok_or(ReceiveError::NoListenPID)?;
let fds = env::var_os(FD_NUMBER_VAR).ok_or(ReceiveError::NoListenFD)?;
tracing::trace!("{PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?}");
if unset_env {
env::remove_var(PID_VAR);
env::remove_var(FD_NUMBER_VAR);
env::remove_var(FD_NAMES_VAR);
}
let pid = pid.into_string()
let pid = pid
.into_string()
.map_err(ReceiveError::NotUnicodeListenPID)?
.parse::<u32>()
.map_err(ReceiveError::ListenPIDParse)?;
let fds = fds.into_string()
let fds = fds
.into_string()
.map_err(ReceiveError::NotUnicodeListenFD)?
.parse::<usize>()
.map_err(ReceiveError::ListenFDParse)?;
let current_pid = process::id();
if current_pid != pid {
return Err(ReceiveError::PidMismatch{expected: pid, found: current_pid});
return Err(ReceiveError::PidMismatch {
expected: pid,
found: current_pid,
});
}
match Self::from_fds(fds) {
Ok(fds) => Ok(fds),
Err(error) => Err(ReceiveError::GetFds(error)),
@ -111,13 +116,16 @@ impl FileDescriptor {
}
/// Get any file descriptor passed using `LISTEN_FD` and their names.
///
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
/// The file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
pub fn receive_with_names(unset_env: bool) -> Result<impl IntoIterator<Item = (Self, OsString)>, ReceiveNameError> {
pub fn receive_with_names(
unset_env: bool,
) -> Result<impl IntoIterator<Item = (Self, OsString)>, ReceiveNameError> {
let fd_names = env::var_os(FD_NAMES_VAR).ok_or(ReceiveNameError::NoListenFDName)?;
tracing::trace!("{FD_NAMES_VAR} = {fd_names:?}");
let fd_names = fd_names.as_encoded_bytes()
let fd_names = fd_names
.as_encoded_bytes()
.split(|b| *b == b':')
.map(|name| {
// SAFETY:
@ -126,7 +134,10 @@ impl FileDescriptor {
unsafe { OsStr::from_encoded_bytes_unchecked(name) }
})
.map(OsStr::to_os_string);
Ok(Self::receive(unset_env)?.into_iter().zip(fd_names).collect::<Vec<_>>())
Ok(Self::receive(unset_env)?
.into_iter()
.zip(fd_names)
.collect::<Vec<_>>())
}
fn from_fds(num_fds: usize) -> Result<impl IntoIterator<Item = FileDescriptor>, GetFdsError> {
@ -134,61 +145,60 @@ impl FileDescriptor {
return Err(GetFdsError::TooManyFDs(num_fds));
}
Ok((0..num_fds)
.map(|fd_offset| {
SD_LISTEN_FDS_START
.checked_add(fd_offset as RawFd)
// SAFETY: We are receiving the fd so it should be safe
.map(|fd| FileDescriptor::from_fd(fd, 0))
.expect("Already checked against overflow.")
})
)
Ok((0..num_fds).map(|fd_offset| {
SD_LISTEN_FDS_START
.checked_add(fd_offset as RawFd)
// SAFETY: We are receiving the fd so it should be safe
.map(|fd| FileDescriptor::from_fd(fd, 0))
.expect("Already checked against overflow.")
}))
}
fn from_fd(fd: RawFd, min_new: RawFd) -> Self {
let fd = {
// SAFETY: The file descriptor won't be closed by the time we duplicate it.
let fd = unsafe { BorrowedFd::borrow_raw(fd) };
rustix::fs::fcntl_dupfd_cloexec(fd, min_new).expect("Couldn't duplicate the file descriptor")
rustix::fs::fcntl_dupfd_cloexec(fd, min_new)
.expect("Couldn't duplicate the file descriptor")
};
let stat = rustix::fs::fstat(&fd).expect("This should only ever be called on a valid file descriptor.");
let stat = rustix::fs::fstat(&fd)
.expect("This should only ever be called on a valid file descriptor.");
let file_type = FileType::from_raw_mode(stat.st_mode);
match file_type {
FileType::RegularFile => Self::File(fd.into()),
FileType::Directory => Self::Directory(fd),
FileType::Fifo => Self::Fifo(fd),
FileType::Socket => {
match rustix::net::getsockname(&fd).unwrap() {
rustix::net::SocketAddrAny::V4(_) |
rustix::net::SocketAddrAny::V6(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UdpSocket(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default() {
Self::TcpListener(fd.into())
} else {
Self::TcpStream(fd.into())
}
},
_ => Self::InetOther(fd),
FileType::Socket => match rustix::net::getsockname(&fd).unwrap() {
rustix::net::SocketAddrAny::V4(_) | rustix::net::SocketAddrAny::V6(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UdpSocket(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default()
{
Self::TcpListener(fd.into())
} else {
Self::TcpStream(fd.into())
}
}
},
rustix::net::SocketAddrAny::Unix(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UnixDatagram(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default() {
Self::UnixListener(fd.into())
} else {
Self::UnixStream(fd.into())
}
},
SocketType::SEQPACKET => Self::UnixSequentialPaquet(fd),
_ => Self::UnixOther(fd),
}
},
_ => Self::SocketOther(fd),
_ => Self::InetOther(fd),
}
}
rustix::net::SocketAddrAny::Unix(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UnixDatagram(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default()
{
Self::UnixListener(fd.into())
} else {
Self::UnixStream(fd.into())
}
}
SocketType::SEQPACKET => Self::UnixSequentialPaquet(fd),
_ => Self::UnixOther(fd),
}
}
_ => Self::SocketOther(fd),
},
_ => {
// `rustix` does not enable us to test if a raw fd is a mq, so we must drop to libc here.
@ -201,7 +211,7 @@ impl FileDescriptor {
} else {
Self::Other(fd)
}
},
}
}
}
}
}