2023-12-07 21:13:06 +01:00
use std ::ffi ::{ OsStr , OsString } ;
2023-12-07 18:02:35 +01:00
use std ::fs ::File ;
use std ::net ::{ TcpListener , TcpStream , UdpSocket } ;
2023-12-07 20:59:11 +01:00
use std ::os ::fd ::RawFd ;
2023-12-07 21:13:06 +01:00
use std ::os ::unix ::net ::{ UnixDatagram , UnixListener , UnixStream } ;
use std ::{ env , process } ;
2023-12-07 18:02:35 +01:00
2023-12-08 10:03:55 +01:00
use error ::{ DupError , GetFdsError , ReceiveError , ReceiveNameError } ;
use rustix ::fd ::{ AsRawFd , BorrowedFd , FromRawFd , OwnedFd } ;
2023-12-07 18:02:35 +01:00
use rustix ::fs ::FileType ;
2023-12-08 10:03:55 +01:00
use rustix ::io ::FdFlags ;
2023-12-07 18:02:35 +01:00
use rustix ::net ::SocketType ;
mod error ;
const SD_LISTEN_FDS_START : RawFd = 3 ;
const PID_VAR : & str = " LISTEN_PID " ;
const FD_NUMBER_VAR : & str = " LISTEN_FDS " ;
const FD_NAMES_VAR : & str = " LISTEN_FDNAMES " ;
2023-12-07 20:59:11 +01:00
/// File Descriptor passed by systemd.
2023-12-08 08:53:19 +01:00
///
/// They are duplicated from the actual passed file descriptor so as to be safe to use from rust code.
2023-12-07 20:59:11 +01:00
#[ derive(Debug) ]
pub enum FileDescriptor {
/// The file descriptor is a [`File`](std::fs::File).
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenSpecial=`.
2023-12-07 18:02:35 +01:00
File ( File ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a directory.
Directory ( OwnedFd ) ,
/// The file descriptor is a FIFO
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenFIFO=`
Fifo ( OwnedFd ) ,
/// The file descriptor is a TCP socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with an IP address.
2023-12-07 18:02:35 +01:00
TcpListener ( TcpListener ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a TCP socket that doesn't listen for incoming connexions.
2023-12-07 18:02:35 +01:00
TcpStream ( TcpStream ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is an UDP .
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with an IP address.
2023-12-07 18:02:35 +01:00
UdpSocket ( UdpSocket ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is another inet socket
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// You should figure out what exactly it is before you use it.
InetOther ( OwnedFd ) ,
/// The file descriptor is a Unix Stream Socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixListener ( UnixListener ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix Stream Socket.
2023-12-07 18:02:35 +01:00
UnixStream ( UnixStream ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix Datagram Socket.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixDatagram ( UnixDatagram ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix SEQPACKET Socket.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenSequentialPacket=`
UnixSequentialPaquet ( OwnedFd ) ,
/// The file descriptor is another type of Unix Socket.
UnixOther ( OwnedFd ) ,
/// The file descriptor is some other socket type.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// Probably a Netlink family socket if from a .socket unit but you should check yourself.
SocketOther ( OwnedFd ) ,
/// The file descriptor is a message queue.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenMessageQueue=`
MessageQueue ( OwnedFd ) ,
/// The file descriptor is something else.
Other ( OwnedFd ) ,
2023-12-07 18:02:35 +01:00
}
2023-12-07 20:59:11 +01:00
impl FileDescriptor {
/// Get any file descriptor passed by systemd or anything implementing the `LISTEN_FD` protocol.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
2023-12-08 10:03:55 +01:00
///
/// The file descriptors are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptor colud be received which might not actually be an error. See [`ReceiveError`] for details.
2023-12-08 10:03:55 +01:00
pub fn receive (
unset_env : bool ,
) -> Result < impl IntoIterator < Item = Result < Self , DupError > > , ReceiveError > {
let fds = Self ::inner_receive ( unset_env ) ? ;
match Self ::from_fds ( fds ) {
Ok ( fds ) = > Ok ( fds ) ,
Err ( error ) = > Err ( ReceiveError ::GetFds ( error ) ) ,
}
}
/// Get any file descriptor passed by systemd or anything implementing the `LISTEN_FD` protocol.
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
///
/// The file descriptors are taken directly as [`OwnedFd`]s instead of being duplicated. In order to limit unsoundness this function therefore always unset the environment.
///
/// # Safety
///
/// This function is safe if (and only if) the received file descriptors weren't already taken into a owned rust struct.
/// (In short it needs to follow the safety constraints of [`OwnedFd::from_raw_fd`](std::os::fd::FromRawFd)).
/// The simplest way to insure that it is so is to only use functions from this crate to get these file descriptors.
///
/// # Errors
///
/// This function will fail if no file descriptor colud be received which might not actually be an error. See [`ReceiveError`] for details.
pub unsafe fn receive_no_dup ( ) -> Result < impl IntoIterator < Item = Self > , ReceiveError > {
let fds = Self ::inner_receive ( true ) ? ;
Self ::from_fds_no_dup ( fds ) . map_err ( ReceiveError ::GetFds )
}
fn inner_receive ( unset_env : bool ) -> Result < usize , ReceiveError > {
2023-12-07 18:02:35 +01:00
let pid = env ::var_os ( PID_VAR ) . ok_or ( ReceiveError ::NoListenPID ) ? ;
let fds = env ::var_os ( FD_NUMBER_VAR ) . ok_or ( ReceiveError ::NoListenFD ) ? ;
tracing ::trace! ( " {PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?} " ) ;
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
if unset_env {
env ::remove_var ( PID_VAR ) ;
env ::remove_var ( FD_NUMBER_VAR ) ;
env ::remove_var ( FD_NAMES_VAR ) ;
}
2023-12-07 21:13:06 +01:00
let pid = pid
. into_string ( )
2023-12-07 18:02:35 +01:00
. map_err ( ReceiveError ::NotUnicodeListenPID ) ?
. parse ::< u32 > ( )
. map_err ( ReceiveError ::ListenPIDParse ) ? ;
2023-12-07 21:13:06 +01:00
let fds = fds
. into_string ( )
2023-12-07 18:02:35 +01:00
. map_err ( ReceiveError ::NotUnicodeListenFD ) ?
. parse ::< usize > ( )
. map_err ( ReceiveError ::ListenFDParse ) ? ;
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
let current_pid = process ::id ( ) ;
if current_pid ! = pid {
2023-12-07 21:13:06 +01:00
return Err ( ReceiveError ::PidMismatch {
expected : pid ,
found : current_pid ,
} ) ;
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
2023-12-08 10:03:55 +01:00
Ok ( fds )
2023-12-07 18:02:35 +01:00
}
2023-12-07 20:59:11 +01:00
/// Get any file descriptor passed using `LISTEN_FD` and their names.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
2023-12-08 10:03:55 +01:00
///
/// The file descriptors are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptors could be obtained or the names associated with them couldn't be obtained. See [`ReceiveNameError`] for details.
2023-12-07 21:13:06 +01:00
pub fn receive_with_names (
unset_env : bool ,
2023-12-08 10:03:55 +01:00
) -> Result < impl IntoIterator < Item = Result < ( Self , OsString ) , DupError > > , ReceiveNameError >
{
let fd_names = Self ::get_names ( ) ? ;
Ok ( Self ::receive ( unset_env ) ?
. into_iter ( )
. zip ( fd_names )
. map ( | ( res , name ) | res . map ( | fd | ( fd , name ) ) ) )
}
/// Get any file descriptor passed using `LISTEN_FD` and their names.
///
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
///
/// The file descriptors are taken directly as [`OwnedFd`]s instead of being duplicated. In order to limit unsoundness this function therefore always unset the environment.
///
/// # Safety
///
/// This function is safe if (and only if) the received file descriptors weren't already taken into a owned rust struct.
/// (In short it needs to follow the safety constraints of [`OwnedFd::from_raw_fd`](std::os::fd::FromRawFd)).
/// The simplest way to insure that it is so is to only use functions from this crate to get these file descriptors.
///
/// # Errors
///
/// This function will fail if no file descriptors could be obtained or the names associated with them couldn't be obtained. See [`ReceiveNameError`] for details.
///
///
pub unsafe fn receive_with_names_no_dup (
2023-12-07 21:13:06 +01:00
) -> Result < impl IntoIterator < Item = ( Self , OsString ) > , ReceiveNameError > {
2023-12-08 10:03:55 +01:00
let fd_names = Self ::get_names ( ) ? ;
Ok ( Self ::receive_no_dup ( ) ? . into_iter ( ) . zip ( fd_names ) )
}
fn get_names ( ) -> Result < impl IntoIterator < Item = OsString > , ReceiveNameError > {
2023-12-07 18:02:35 +01:00
let fd_names = env ::var_os ( FD_NAMES_VAR ) . ok_or ( ReceiveNameError ::NoListenFDName ) ? ;
tracing ::trace! ( " {FD_NAMES_VAR} = {fd_names:?} " ) ;
2023-12-08 10:03:55 +01:00
Ok ( fd_names
2023-12-07 21:13:06 +01:00
. as_encoded_bytes ( )
2023-12-07 18:02:35 +01:00
. split ( | b | * b = = b ':' )
. map ( | name | {
// SAFETY:
// - Each `word` only contains content that originated from `OsStr::as_encoded_bytes`
// - Only split with ASCII colon which is a non-empty UTF-8 substring
unsafe { OsStr ::from_encoded_bytes_unchecked ( name ) }
} )
2023-12-07 22:01:17 +01:00
. map ( OsStr ::to_os_string )
2023-12-08 10:03:55 +01:00
. collect ::< Vec < _ > > ( ) )
}
fn from_fds (
num_fds : usize ,
) -> Result < impl IntoIterator < Item = Result < Self , DupError > > , GetFdsError > {
if SD_LISTEN_FDS_START . checked_add ( num_fds as RawFd ) . is_none ( ) {
return Err ( GetFdsError ::TooManyFDs ( num_fds ) ) ;
}
Ok ( ( 0 .. num_fds ) . map ( | fd_offset | {
SD_LISTEN_FDS_START
. checked_add ( fd_offset as RawFd )
. map ( | fd | {
// SAFETY: The file descriptor won't be closed by the time we duplicate it.
let fd = unsafe { BorrowedFd ::borrow_raw ( fd ) } ;
rustix ::fs ::fcntl_dupfd_cloexec ( fd , 0 ) . map_err ( DupError ::from )
} )
. expect ( " Already checked against overflow. " )
. map ( Self ::from_fd )
} ) )
2023-12-07 18:02:35 +01:00
}
2023-12-08 10:03:55 +01:00
unsafe fn from_fds_no_dup (
num_fds : usize ,
) -> Result < impl IntoIterator < Item = Self > , GetFdsError > {
2023-12-07 20:59:11 +01:00
if SD_LISTEN_FDS_START . checked_add ( num_fds as RawFd ) . is_none ( ) {
return Err ( GetFdsError ::TooManyFDs ( num_fds ) ) ;
}
2023-12-08 08:53:19 +01:00
Ok ( ( 0 .. num_fds ) . map ( | fd_offset | {
2023-12-07 21:13:06 +01:00
SD_LISTEN_FDS_START
. checked_add ( fd_offset as RawFd )
2023-12-08 10:03:55 +01:00
. map ( | fd | {
let fd = unsafe { OwnedFd ::from_raw_fd ( fd ) } ;
let flags = rustix ::fs ::fcntl_getfd ( & fd ) . unwrap ( ) ;
rustix ::fs ::fcntl_setfd ( & fd , flags . union ( FdFlags ::CLOEXEC ) ) . unwrap ( ) ;
fd
} )
. map ( Self ::from_fd )
2023-12-07 21:13:06 +01:00
. expect ( " Already checked against overflow. " )
} ) )
2023-12-07 20:59:11 +01:00
}
2023-12-08 10:03:55 +01:00
fn from_fd ( fd : OwnedFd ) -> Self {
2023-12-07 21:13:06 +01:00
let stat = rustix ::fs ::fstat ( & fd )
. expect ( " This should only ever be called on a valid file descriptor. " ) ;
2023-12-07 20:59:11 +01:00
let file_type = FileType ::from_raw_mode ( stat . st_mode ) ;
match file_type {
FileType ::RegularFile = > Self ::File ( fd . into ( ) ) ,
FileType ::Directory = > Self ::Directory ( fd ) ,
FileType ::Fifo = > Self ::Fifo ( fd ) ,
2023-12-07 21:13:06 +01:00
FileType ::Socket = > match rustix ::net ::getsockname ( & fd ) . unwrap ( ) {
rustix ::net ::SocketAddrAny ::V4 ( _ ) | rustix ::net ::SocketAddrAny ::V6 ( _ ) = > {
match rustix ::net ::sockopt ::get_socket_type ( & fd ) . unwrap ( ) {
SocketType ::DGRAM = > Self ::UdpSocket ( fd . into ( ) ) ,
SocketType ::STREAM = > {
if rustix ::net ::sockopt ::get_socket_acceptconn ( & fd ) . unwrap_or_default ( )
{
Self ::TcpListener ( fd . into ( ) )
} else {
Self ::TcpStream ( fd . into ( ) )
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
_ = > Self ::InetOther ( fd ) ,
}
}
rustix ::net ::SocketAddrAny ::Unix ( _ ) = > {
match rustix ::net ::sockopt ::get_socket_type ( & fd ) . unwrap ( ) {
SocketType ::DGRAM = > Self ::UnixDatagram ( fd . into ( ) ) ,
SocketType ::STREAM = > {
if rustix ::net ::sockopt ::get_socket_acceptconn ( & fd ) . unwrap_or_default ( )
{
Self ::UnixListener ( fd . into ( ) )
} else {
Self ::UnixStream ( fd . into ( ) )
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
SocketType ::SEQPACKET = > Self ::UnixSequentialPaquet ( fd ) ,
_ = > Self ::UnixOther ( fd ) ,
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
_ = > Self ::SocketOther ( fd ) ,
2023-12-07 20:59:11 +01:00
} ,
2023-12-07 21:10:44 +01:00
_ = > {
// `rustix` does not enable us to test if a raw fd is a mq, so we must drop to libc here.
// SAFETY: `mq_getattr` is specified to return -1 when passed a fd which is not a mq.
// Furthermore, we ignore `attr` and rely only on the return value.
let mut attr = std ::mem ::MaybeUninit ::< libc ::mq_attr > ::uninit ( ) ;
let res = unsafe { libc ::mq_getattr ( fd . as_raw_fd ( ) , attr . as_mut_ptr ( ) ) } ;
if res = = 0 {
Self ::MessageQueue ( fd )
} else {
Self ::Other ( fd )
}
2023-12-07 21:13:06 +01:00
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
}