storefd/src/lib.rs

230 lines
10 KiB
Rust
Raw Normal View History

2023-12-07 21:13:06 +01:00
use std::ffi::{OsStr, OsString};
2023-12-07 18:02:35 +01:00
use std::fs::File;
use std::net::{TcpListener, TcpStream, UdpSocket};
use std::os::fd::RawFd;
2023-12-07 21:13:06 +01:00
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
use std::{env, process};
2023-12-07 18:02:35 +01:00
2023-12-07 21:13:06 +01:00
use error::{GetFdsError, ReceiveError, ReceiveNameError};
use rustix::fd::{AsRawFd, BorrowedFd, OwnedFd, FromRawFd};
2023-12-07 18:02:35 +01:00
use rustix::fs::FileType;
use rustix::net::SocketType;
mod error;
const SD_LISTEN_FDS_START: RawFd = 3;
const PID_VAR: &str = "LISTEN_PID";
const FD_NUMBER_VAR: &str = "LISTEN_FDS";
const FD_NAMES_VAR: &str = "LISTEN_FDNAMES";
/// File Descriptor passed by systemd.
#[derive(Debug)]
pub enum FileDescriptor {
/// The file descriptor is a [`File`](std::fs::File).
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenSpecial=`.
2023-12-07 18:02:35 +01:00
File(File),
/// The file descriptor is a directory.
Directory(OwnedFd),
/// The file descriptor is a FIFO
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenFIFO=`
Fifo(OwnedFd),
/// The file descriptor is a TCP socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with an IP address.
2023-12-07 18:02:35 +01:00
TcpListener(TcpListener),
/// The file descriptor is a TCP socket that doesn't listen for incoming connexions.
2023-12-07 18:02:35 +01:00
TcpStream(TcpStream),
/// The file descriptor is an UDP .
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with an IP address.
2023-12-07 18:02:35 +01:00
UdpSocket(UdpSocket),
/// The file descriptor is another inet socket
2023-12-07 21:13:06 +01:00
///
/// You should figure out what exactly it is before you use it.
InetOther(OwnedFd),
/// The file descriptor is a Unix Stream Socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixListener(UnixListener),
/// The file descriptor is a Unix Stream Socket.
2023-12-07 18:02:35 +01:00
UnixStream(UnixStream),
/// The file descriptor is a Unix Datagram Socket.
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixDatagram(UnixDatagram),
/// The file descriptor is a Unix SEQPACKET Socket.
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenSequentialPacket=`
UnixSequentialPaquet(OwnedFd),
/// The file descriptor is another type of Unix Socket.
UnixOther(OwnedFd),
/// The file descriptor is some other socket type.
2023-12-07 21:13:06 +01:00
///
/// Probably a Netlink family socket if from a .socket unit but you should check yourself.
SocketOther(OwnedFd),
/// The file descriptor is a message queue.
2023-12-07 21:13:06 +01:00
///
/// If this is an FD provided by a .socket unit it corresponds to `ListenMessageQueue=`
MessageQueue(OwnedFd),
/// The file descriptor is something else.
Other(OwnedFd),
2023-12-07 18:02:35 +01:00
}
impl FileDescriptor {
/// Get any file descriptor passed by systemd or anything implementing the `LISTEN_FD` protocol.
2023-12-07 21:13:06 +01:00
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
///
/// If `unset_env` is `true` then the file descriptor are directly taken as if they were owned. This is only safe if this library is the only place taking these file descriptor but avoid unnecessary duplication of file descriptors.
///
/// If `unset_env` is `false` the file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptor colud be received which might not actually be an error. See [`ReceiveError`] for details.
2023-12-07 18:02:35 +01:00
pub fn receive(unset_env: bool) -> Result<impl IntoIterator<Item = Self>, ReceiveError> {
let pid = env::var_os(PID_VAR).ok_or(ReceiveError::NoListenPID)?;
let fds = env::var_os(FD_NUMBER_VAR).ok_or(ReceiveError::NoListenFD)?;
tracing::trace!("{PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?}");
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
if unset_env {
env::remove_var(PID_VAR);
env::remove_var(FD_NUMBER_VAR);
env::remove_var(FD_NAMES_VAR);
}
2023-12-07 21:13:06 +01:00
let pid = pid
.into_string()
2023-12-07 18:02:35 +01:00
.map_err(ReceiveError::NotUnicodeListenPID)?
.parse::<u32>()
.map_err(ReceiveError::ListenPIDParse)?;
2023-12-07 21:13:06 +01:00
let fds = fds
.into_string()
2023-12-07 18:02:35 +01:00
.map_err(ReceiveError::NotUnicodeListenFD)?
.parse::<usize>()
.map_err(ReceiveError::ListenFDParse)?;
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
let current_pid = process::id();
if current_pid != pid {
2023-12-07 21:13:06 +01:00
return Err(ReceiveError::PidMismatch {
expected: pid,
found: current_pid,
});
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
match Self::from_fds(fds, unset_env) {
2023-12-07 18:02:35 +01:00
Ok(fds) => Ok(fds),
Err(error) => Err(ReceiveError::GetFds(error)),
}
}
/// Get any file descriptor passed using `LISTEN_FD` and their names.
2023-12-07 21:13:06 +01:00
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
/// The file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptors could be obtained or the names associated with them couldn't be obtained. See [`ReceiveNameError`] for details.
2023-12-07 21:13:06 +01:00
pub fn receive_with_names(
unset_env: bool,
) -> Result<impl IntoIterator<Item = (Self, OsString)>, ReceiveNameError> {
2023-12-07 18:02:35 +01:00
let fd_names = env::var_os(FD_NAMES_VAR).ok_or(ReceiveNameError::NoListenFDName)?;
tracing::trace!("{FD_NAMES_VAR} = {fd_names:?}");
let fd_names: Vec<_> = fd_names
2023-12-07 21:13:06 +01:00
.as_encoded_bytes()
2023-12-07 18:02:35 +01:00
.split(|b| *b == b':')
.map(|name| {
// SAFETY:
// - Each `word` only contains content that originated from `OsStr::as_encoded_bytes`
// - Only split with ASCII colon which is a non-empty UTF-8 substring
unsafe { OsStr::from_encoded_bytes_unchecked(name) }
})
.map(OsStr::to_os_string)
.collect();
2023-12-07 21:13:06 +01:00
Ok(Self::receive(unset_env)?
.into_iter()
.zip(fd_names))
2023-12-07 18:02:35 +01:00
}
fn from_fds(num_fds: usize, unset_env: bool) -> Result<impl IntoIterator<Item = FileDescriptor>, GetFdsError> {
if SD_LISTEN_FDS_START.checked_add(num_fds as RawFd).is_none() {
return Err(GetFdsError::TooManyFDs(num_fds));
}
Ok((0..num_fds).map(move |fd_offset| {
2023-12-07 21:13:06 +01:00
SD_LISTEN_FDS_START
.checked_add(fd_offset as RawFd)
// SAFETY: We are receiving the fd so it should be safe
.map(|fd| FileDescriptor::from_fd(fd, unset_env))
2023-12-07 21:13:06 +01:00
.expect("Already checked against overflow.")
}))
}
fn from_fd(fd: RawFd, unset_env: bool) -> Self {
let fd = if unset_env {
// SAFETY: The environement is removed so there shouldn't be anything new that might take it and close it.
unsafe { OwnedFd::from_raw_fd(fd) }
} else {
// SAFETY: The file descriptor won't be closed by the time we duplicate it.
let fd = unsafe { BorrowedFd::borrow_raw(fd) };
rustix::fs::fcntl_dupfd_cloexec(fd, 0)
2023-12-07 21:13:06 +01:00
.expect("Couldn't duplicate the file descriptor")
};
2023-12-07 21:13:06 +01:00
let stat = rustix::fs::fstat(&fd)
.expect("This should only ever be called on a valid file descriptor.");
let file_type = FileType::from_raw_mode(stat.st_mode);
match file_type {
FileType::RegularFile => Self::File(fd.into()),
FileType::Directory => Self::Directory(fd),
FileType::Fifo => Self::Fifo(fd),
2023-12-07 21:13:06 +01:00
FileType::Socket => match rustix::net::getsockname(&fd).unwrap() {
rustix::net::SocketAddrAny::V4(_) | rustix::net::SocketAddrAny::V6(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UdpSocket(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default()
{
Self::TcpListener(fd.into())
} else {
Self::TcpStream(fd.into())
}
}
2023-12-07 21:13:06 +01:00
_ => Self::InetOther(fd),
}
}
rustix::net::SocketAddrAny::Unix(_) => {
match rustix::net::sockopt::get_socket_type(&fd).unwrap() {
SocketType::DGRAM => Self::UnixDatagram(fd.into()),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(&fd).unwrap_or_default()
{
Self::UnixListener(fd.into())
} else {
Self::UnixStream(fd.into())
}
}
2023-12-07 21:13:06 +01:00
SocketType::SEQPACKET => Self::UnixSequentialPaquet(fd),
_ => Self::UnixOther(fd),
}
}
2023-12-07 21:13:06 +01:00
_ => Self::SocketOther(fd),
},
_ => {
// `rustix` does not enable us to test if a raw fd is a mq, so we must drop to libc here.
// SAFETY: `mq_getattr` is specified to return -1 when passed a fd which is not a mq.
// Furthermore, we ignore `attr` and rely only on the return value.
let mut attr = std::mem::MaybeUninit::<libc::mq_attr>::uninit();
let res = unsafe { libc::mq_getattr(fd.as_raw_fd(), attr.as_mut_ptr()) };
if res == 0 {
Self::MessageQueue(fd)
} else {
Self::Other(fd)
}
2023-12-07 21:13:06 +01:00
}
}
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
}