142 lines
5.6 KiB
Rust
142 lines
5.6 KiB
Rust
|
use std::fs::File;
|
||
|
use std::net::{TcpListener, TcpStream, UdpSocket};
|
||
|
use std::os::unix::net::{UnixListener, UnixStream, UnixDatagram};
|
||
|
use std::{env, process};
|
||
|
use std::ffi::{OsString, OsStr};
|
||
|
use std::os::fd::{RawFd, FromRawFd};
|
||
|
|
||
|
use error::{ReceiveError, ReceiveNameError, GetFdsError};
|
||
|
use rustix::fd::BorrowedFd;
|
||
|
use rustix::fs::FileType;
|
||
|
use rustix::net::SocketType;
|
||
|
|
||
|
mod error;
|
||
|
|
||
|
const SD_LISTEN_FDS_START: RawFd = 3;
|
||
|
const PID_VAR: &str = "LISTEN_PID";
|
||
|
const FD_NUMBER_VAR: &str = "LISTEN_FDS";
|
||
|
const FD_NAMES_VAR: &str = "LISTEN_FDNAMES";
|
||
|
|
||
|
pub enum StoredFd {
|
||
|
File(File),
|
||
|
Directory(RawFd),
|
||
|
Fifo(RawFd),
|
||
|
Special(RawFd),
|
||
|
TcpListener(TcpListener),
|
||
|
TcpStream(TcpStream),
|
||
|
UdpSocket(UdpSocket),
|
||
|
InetOther(RawFd),
|
||
|
UnixListener(UnixListener),
|
||
|
UnixStream(UnixStream),
|
||
|
UnixDatagram(UnixDatagram),
|
||
|
UnixOther(RawFd),
|
||
|
SocketOther(RawFd),
|
||
|
MessageQueue(RawFd),
|
||
|
Other(RawFd),
|
||
|
}
|
||
|
|
||
|
impl FromRawFd for StoredFd {
|
||
|
unsafe fn from_raw_fd(fd: RawFd) -> Self {
|
||
|
let stat = rustix::fs::fstat(BorrowedFd::borrow_raw(fd)).expect("This should only ever be called on a valid file descriptor.");
|
||
|
let file_type = FileType::from_raw_mode(stat.st_mode);
|
||
|
match file_type {
|
||
|
FileType::RegularFile => Self::File(File::from_raw_fd(fd)),
|
||
|
FileType::Directory => Self::Directory(fd),
|
||
|
FileType::Fifo => Self::Fifo(fd),
|
||
|
FileType::Socket => {
|
||
|
let borrowed_fd = BorrowedFd::borrow_raw(fd);
|
||
|
match rustix::net::getsockname(borrowed_fd).unwrap() {
|
||
|
rustix::net::SocketAddrAny::V4(_) |
|
||
|
rustix::net::SocketAddrAny::V6(_) => {
|
||
|
match rustix::net::sockopt::get_socket_type(borrowed_fd).unwrap() {
|
||
|
SocketType::DGRAM => Self::UdpSocket(UdpSocket::from_raw_fd(fd)),
|
||
|
SocketType::STREAM => {
|
||
|
if rustix::net::sockopt::get_socket_acceptconn(borrowed_fd).unwrap_or_default() {
|
||
|
Self::TcpListener(TcpListener::from_raw_fd(fd))
|
||
|
} else {
|
||
|
Self::TcpStream(TcpStream::from_raw_fd(fd))
|
||
|
}
|
||
|
},
|
||
|
_ => Self::InetOther(fd),
|
||
|
}
|
||
|
},
|
||
|
rustix::net::SocketAddrAny::Unix(_) => {
|
||
|
match rustix::net::sockopt::get_socket_type(borrowed_fd).unwrap() {
|
||
|
SocketType::DGRAM => Self::UnixDatagram(UnixDatagram::from_raw_fd(fd)),
|
||
|
SocketType::STREAM => {
|
||
|
if rustix::net::sockopt::get_socket_acceptconn(borrowed_fd).unwrap_or_default() {
|
||
|
Self::UnixListener(UnixListener::from_raw_fd(fd))
|
||
|
} else {
|
||
|
Self::UnixStream(UnixStream::from_raw_fd(fd))
|
||
|
}
|
||
|
},
|
||
|
_ => Self::UnixOther(fd),
|
||
|
}
|
||
|
},
|
||
|
_ => Self::SocketOther(fd),
|
||
|
}
|
||
|
},
|
||
|
_ => Self::Other(fd)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl StoredFd {
|
||
|
pub fn receive(unset_env: bool) -> Result<impl IntoIterator<Item = Self>, ReceiveError> {
|
||
|
let pid = env::var_os(PID_VAR).ok_or(ReceiveError::NoListenPID)?;
|
||
|
let fds = env::var_os(FD_NUMBER_VAR).ok_or(ReceiveError::NoListenFD)?;
|
||
|
tracing::trace!("{PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?}");
|
||
|
|
||
|
if unset_env {
|
||
|
env::remove_var(PID_VAR);
|
||
|
env::remove_var(FD_NUMBER_VAR);
|
||
|
env::remove_var(FD_NAMES_VAR);
|
||
|
}
|
||
|
|
||
|
let pid = pid.into_string()
|
||
|
.map_err(ReceiveError::NotUnicodeListenPID)?
|
||
|
.parse::<u32>()
|
||
|
.map_err(ReceiveError::ListenPIDParse)?;
|
||
|
let fds = fds.into_string()
|
||
|
.map_err(ReceiveError::NotUnicodeListenFD)?
|
||
|
.parse::<usize>()
|
||
|
.map_err(ReceiveError::ListenFDParse)?;
|
||
|
|
||
|
let current_pid = process::id();
|
||
|
if current_pid != pid {
|
||
|
return Err(ReceiveError::PidMismatch{expected: pid, found: current_pid});
|
||
|
}
|
||
|
|
||
|
match Self::from_fds(fds) {
|
||
|
Ok(fds) => Ok(fds),
|
||
|
Err(error) => Err(ReceiveError::GetFds(error)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub fn receive_with_names(unset_env: bool) -> Result<Vec<(Self, OsString)>, ReceiveNameError> {
|
||
|
let fd_names = env::var_os(FD_NAMES_VAR).ok_or(ReceiveNameError::NoListenFDName)?;
|
||
|
tracing::trace!("{FD_NAMES_VAR} = {fd_names:?}");
|
||
|
let fd_names = fd_names.as_encoded_bytes()
|
||
|
.split(|b| *b == b':')
|
||
|
.map(|name| {
|
||
|
// SAFETY:
|
||
|
// - Each `word` only contains content that originated from `OsStr::as_encoded_bytes`
|
||
|
// - Only split with ASCII colon which is a non-empty UTF-8 substring
|
||
|
unsafe { OsStr::from_encoded_bytes_unchecked(name) }
|
||
|
})
|
||
|
.map(OsStr::to_os_string);
|
||
|
Ok(Self::receive(unset_env)?.into_iter().zip(fd_names).collect::<Vec<_>>())
|
||
|
}
|
||
|
|
||
|
fn from_fds(num_fds: usize) -> Result<Vec<StoredFd>, GetFdsError> {
|
||
|
(0..num_fds)
|
||
|
.map(|fd_offset| {
|
||
|
SD_LISTEN_FDS_START
|
||
|
.checked_add(fd_offset as RawFd)
|
||
|
.ok_or(GetFdsError::TooManyFDs(num_fds))
|
||
|
// SAFETY: We are receiving the fd so it should be safe
|
||
|
.map(|fd| unsafe { StoredFd::from_raw_fd(fd) })
|
||
|
})
|
||
|
.collect()
|
||
|
}
|
||
|
}
|