2023-12-07 21:13:06 +01:00
use std ::ffi ::{ OsStr , OsString } ;
2023-12-07 18:02:35 +01:00
use std ::fs ::File ;
use std ::net ::{ TcpListener , TcpStream , UdpSocket } ;
2023-12-07 20:59:11 +01:00
use std ::os ::fd ::RawFd ;
2023-12-07 21:13:06 +01:00
use std ::os ::unix ::net ::{ UnixDatagram , UnixListener , UnixStream } ;
use std ::{ env , process } ;
2023-12-07 18:02:35 +01:00
2023-12-07 21:13:06 +01:00
use error ::{ GetFdsError , ReceiveError , ReceiveNameError } ;
use rustix ::fd ::{ AsRawFd , BorrowedFd , OwnedFd } ;
2023-12-07 18:02:35 +01:00
use rustix ::fs ::FileType ;
use rustix ::net ::SocketType ;
mod error ;
const SD_LISTEN_FDS_START : RawFd = 3 ;
const PID_VAR : & str = " LISTEN_PID " ;
const FD_NUMBER_VAR : & str = " LISTEN_FDS " ;
const FD_NAMES_VAR : & str = " LISTEN_FDNAMES " ;
2023-12-07 20:59:11 +01:00
/// File Descriptor passed by systemd.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// They are duplicated from the actual passed file descriptor so as to be safe to use from rust code.
#[ derive(Debug) ]
pub enum FileDescriptor {
/// The file descriptor is a [`File`](std::fs::File).
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenSpecial=`.
2023-12-07 18:02:35 +01:00
File ( File ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a directory.
Directory ( OwnedFd ) ,
/// The file descriptor is a FIFO
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenFIFO=`
Fifo ( OwnedFd ) ,
/// The file descriptor is a TCP socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with an IP address.
2023-12-07 18:02:35 +01:00
TcpListener ( TcpListener ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a TCP socket that doesn't listen for incoming connexions.
2023-12-07 18:02:35 +01:00
TcpStream ( TcpStream ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is an UDP .
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with an IP address.
2023-12-07 18:02:35 +01:00
UdpSocket ( UdpSocket ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is another inet socket
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// You should figure out what exactly it is before you use it.
InetOther ( OwnedFd ) ,
/// The file descriptor is a Unix Stream Socket listening for incoming connexions.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenStream=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixListener ( UnixListener ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix Stream Socket.
2023-12-07 18:02:35 +01:00
UnixStream ( UnixStream ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix Datagram Socket.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenDatagram=` with a path or starting with `@` (abstract socket).
2023-12-07 18:02:35 +01:00
UnixDatagram ( UnixDatagram ) ,
2023-12-07 20:59:11 +01:00
/// The file descriptor is a Unix SEQPACKET Socket.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenSequentialPacket=`
UnixSequentialPaquet ( OwnedFd ) ,
/// The file descriptor is another type of Unix Socket.
UnixOther ( OwnedFd ) ,
/// The file descriptor is some other socket type.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// Probably a Netlink family socket if from a .socket unit but you should check yourself.
SocketOther ( OwnedFd ) ,
/// The file descriptor is a message queue.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// If this is an FD provided by a .socket unit it corresponds to `ListenMessageQueue=`
MessageQueue ( OwnedFd ) ,
/// The file descriptor is something else.
Other ( OwnedFd ) ,
2023-12-07 18:02:35 +01:00
}
2023-12-07 20:59:11 +01:00
impl FileDescriptor {
/// Get any file descriptor passed by systemd or anything implementing the `LISTEN_FD` protocol.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
/// The file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptor colud be received which might not actually be an error. See [`ReceiveError`] for details.
2023-12-07 18:02:35 +01:00
pub fn receive ( unset_env : bool ) -> Result < impl IntoIterator < Item = Self > , ReceiveError > {
let pid = env ::var_os ( PID_VAR ) . ok_or ( ReceiveError ::NoListenPID ) ? ;
let fds = env ::var_os ( FD_NUMBER_VAR ) . ok_or ( ReceiveError ::NoListenFD ) ? ;
tracing ::trace! ( " {PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?} " ) ;
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
if unset_env {
env ::remove_var ( PID_VAR ) ;
env ::remove_var ( FD_NUMBER_VAR ) ;
env ::remove_var ( FD_NAMES_VAR ) ;
}
2023-12-07 21:13:06 +01:00
let pid = pid
. into_string ( )
2023-12-07 18:02:35 +01:00
. map_err ( ReceiveError ::NotUnicodeListenPID ) ?
. parse ::< u32 > ( )
. map_err ( ReceiveError ::ListenPIDParse ) ? ;
2023-12-07 21:13:06 +01:00
let fds = fds
. into_string ( )
2023-12-07 18:02:35 +01:00
. map_err ( ReceiveError ::NotUnicodeListenFD ) ?
. parse ::< usize > ( )
. map_err ( ReceiveError ::ListenFDParse ) ? ;
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
let current_pid = process ::id ( ) ;
if current_pid ! = pid {
2023-12-07 21:13:06 +01:00
return Err ( ReceiveError ::PidMismatch {
expected : pid ,
found : current_pid ,
} ) ;
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
2023-12-07 18:02:35 +01:00
match Self ::from_fds ( fds ) {
Ok ( fds ) = > Ok ( fds ) ,
Err ( error ) = > Err ( ReceiveError ::GetFds ( error ) ) ,
}
}
2023-12-07 20:59:11 +01:00
/// Get any file descriptor passed using `LISTEN_FD` and their names.
2023-12-07 21:13:06 +01:00
///
2023-12-07 20:59:11 +01:00
/// This isn't necessarily limited to File descriptor of listening sockets, IPCs or FIFOs but also anything that is in the file descriptor store.
/// The file descriptores are duplicated using [`fcntl_dupfd_cloexec`](rustix::fs::fcntl_dupfd_cloexec) so they can safely be used from rust and will not be propagated to children process automatically.
2023-12-07 21:39:33 +01:00
///
/// # Errors
///
/// This function will fail if no file descriptors could be obtained or the names associated with them couldn't be obtained. See [`ReceiveNameError`] for details.
2023-12-07 21:13:06 +01:00
pub fn receive_with_names (
unset_env : bool ,
) -> Result < impl IntoIterator < Item = ( Self , OsString ) > , ReceiveNameError > {
2023-12-07 18:02:35 +01:00
let fd_names = env ::var_os ( FD_NAMES_VAR ) . ok_or ( ReceiveNameError ::NoListenFDName ) ? ;
tracing ::trace! ( " {FD_NAMES_VAR} = {fd_names:?} " ) ;
2023-12-07 21:13:06 +01:00
let fd_names = fd_names
. as_encoded_bytes ( )
2023-12-07 18:02:35 +01:00
. split ( | b | * b = = b ':' )
. map ( | name | {
// SAFETY:
// - Each `word` only contains content that originated from `OsStr::as_encoded_bytes`
// - Only split with ASCII colon which is a non-empty UTF-8 substring
unsafe { OsStr ::from_encoded_bytes_unchecked ( name ) }
} )
. map ( OsStr ::to_os_string ) ;
2023-12-07 21:13:06 +01:00
Ok ( Self ::receive ( unset_env ) ?
. into_iter ( )
. zip ( fd_names )
. collect ::< Vec < _ > > ( ) )
2023-12-07 18:02:35 +01:00
}
2023-12-07 20:59:11 +01:00
fn from_fds ( num_fds : usize ) -> Result < impl IntoIterator < Item = FileDescriptor > , GetFdsError > {
if SD_LISTEN_FDS_START . checked_add ( num_fds as RawFd ) . is_none ( ) {
return Err ( GetFdsError ::TooManyFDs ( num_fds ) ) ;
}
2023-12-07 21:13:06 +01:00
Ok ( ( 0 .. num_fds ) . map ( | fd_offset | {
SD_LISTEN_FDS_START
. checked_add ( fd_offset as RawFd )
// SAFETY: We are receiving the fd so it should be safe
. map ( | fd | FileDescriptor ::from_fd ( fd , 0 ) )
. expect ( " Already checked against overflow. " )
} ) )
2023-12-07 20:59:11 +01:00
}
fn from_fd ( fd : RawFd , min_new : RawFd ) -> Self {
let fd = {
// SAFETY: The file descriptor won't be closed by the time we duplicate it.
let fd = unsafe { BorrowedFd ::borrow_raw ( fd ) } ;
2023-12-07 21:13:06 +01:00
rustix ::fs ::fcntl_dupfd_cloexec ( fd , min_new )
. expect ( " Couldn't duplicate the file descriptor " )
2023-12-07 20:59:11 +01:00
} ;
2023-12-07 21:13:06 +01:00
let stat = rustix ::fs ::fstat ( & fd )
. expect ( " This should only ever be called on a valid file descriptor. " ) ;
2023-12-07 20:59:11 +01:00
let file_type = FileType ::from_raw_mode ( stat . st_mode ) ;
match file_type {
FileType ::RegularFile = > Self ::File ( fd . into ( ) ) ,
FileType ::Directory = > Self ::Directory ( fd ) ,
FileType ::Fifo = > Self ::Fifo ( fd ) ,
2023-12-07 21:13:06 +01:00
FileType ::Socket = > match rustix ::net ::getsockname ( & fd ) . unwrap ( ) {
rustix ::net ::SocketAddrAny ::V4 ( _ ) | rustix ::net ::SocketAddrAny ::V6 ( _ ) = > {
match rustix ::net ::sockopt ::get_socket_type ( & fd ) . unwrap ( ) {
SocketType ::DGRAM = > Self ::UdpSocket ( fd . into ( ) ) ,
SocketType ::STREAM = > {
if rustix ::net ::sockopt ::get_socket_acceptconn ( & fd ) . unwrap_or_default ( )
{
Self ::TcpListener ( fd . into ( ) )
} else {
Self ::TcpStream ( fd . into ( ) )
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
_ = > Self ::InetOther ( fd ) ,
}
}
rustix ::net ::SocketAddrAny ::Unix ( _ ) = > {
match rustix ::net ::sockopt ::get_socket_type ( & fd ) . unwrap ( ) {
SocketType ::DGRAM = > Self ::UnixDatagram ( fd . into ( ) ) ,
SocketType ::STREAM = > {
if rustix ::net ::sockopt ::get_socket_acceptconn ( & fd ) . unwrap_or_default ( )
{
Self ::UnixListener ( fd . into ( ) )
} else {
Self ::UnixStream ( fd . into ( ) )
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
SocketType ::SEQPACKET = > Self ::UnixSequentialPaquet ( fd ) ,
_ = > Self ::UnixOther ( fd ) ,
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 21:13:06 +01:00
_ = > Self ::SocketOther ( fd ) ,
2023-12-07 20:59:11 +01:00
} ,
2023-12-07 21:10:44 +01:00
_ = > {
// `rustix` does not enable us to test if a raw fd is a mq, so we must drop to libc here.
// SAFETY: `mq_getattr` is specified to return -1 when passed a fd which is not a mq.
// Furthermore, we ignore `attr` and rely only on the return value.
let mut attr = std ::mem ::MaybeUninit ::< libc ::mq_attr > ::uninit ( ) ;
let res = unsafe { libc ::mq_getattr ( fd . as_raw_fd ( ) , attr . as_mut_ptr ( ) ) } ;
if res = = 0 {
Self ::MessageQueue ( fd )
} else {
Self ::Other ( fd )
}
2023-12-07 21:13:06 +01:00
}
2023-12-07 20:59:11 +01:00
}
2023-12-07 18:02:35 +01:00
}
2023-12-07 21:13:06 +01:00
}