Initial commit

This commit is contained in:
Mathieu Trossevin 2023-12-07 18:02:35 +01:00
commit d670bad1e9
Signed by: mtrossevin
GPG key ID: D1DBB7EA828374E9
5 changed files with 406 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

156
Cargo.lock generated Normal file
View file

@ -0,0 +1,156 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "bitflags"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
[[package]]
name = "errno"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "libc"
version = "0.2.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
[[package]]
name = "linux-raw-sys"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
[[package]]
name = "log"
version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "rustix"
version = "0.38.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a"
dependencies = [
"bitflags",
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
]
[[package]]
name = "storefd"
version = "0.1.0"
dependencies = [
"rustix",
"tracing",
]
[[package]]
name = "tracing"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"log",
"pin-project-lite",
"tracing-core",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
[[package]]
name = "windows_i686_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
[[package]]
name = "windows_i686_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"

10
Cargo.toml Normal file
View file

@ -0,0 +1,10 @@
[package]
name = "storefd"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rustix = { version = "0.38.26", features = ["fs", "net"] }
tracing = { version = "0.1.40", default-features = false, features = ["std", "log"] }

97
src/error.rs Normal file
View file

@ -0,0 +1,97 @@
use core::fmt::Display;
use core::num::ParseIntError;
use std::error::Error;
use std::ffi::OsString;
use crate::{FD_NAMES_VAR, FD_NUMBER_VAR, PID_VAR};
#[derive(Debug)]
pub enum ReceiveError {
NoListenPID,
NotUnicodeListenPID(OsString),
ListenPIDParse(ParseIntError),
NoListenFD,
NotUnicodeListenFD(OsString),
ListenFDParse(ParseIntError),
PidMismatch{
expected: u32,
found: u32,
},
GetFds(GetFdsError),
}
impl Display for ReceiveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Couldn't receive file descriptors :")?;
match self {
ReceiveError::NoListenPID => write!(f, "The variable {PID_VAR} doesn't exists."),
ReceiveError::NotUnicodeListenPID(var) => write!(f, "The variable {PID_VAR} isn't unicode (this should never happen): it is {var:?}"),
ReceiveError::ListenPIDParse(error) => write!(f, "Couldn't parse {PID_VAR} as a `u32`: {error}"),
ReceiveError::NoListenFD => write!(f, "The variable {FD_NUMBER_VAR} doesn't exists."),
ReceiveError::NotUnicodeListenFD(var) => write!(f, "The variable {FD_NUMBER_VAR} isn't unicode (this should never happen): it is {var:?}"),
ReceiveError::ListenFDParse(error) => write!(f, "Couldn't parse {FD_NUMBER_VAR} as a `u32`: {error}"),
ReceiveError::PidMismatch{expected, found} => write!(f, "PID mismatch! Was {found} but should have been {expected}."),
ReceiveError::GetFds(error) => Display::fmt(error, f),
}
}
}
impl Error for ReceiveError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReceiveError::NoListenPID => None,
ReceiveError::NotUnicodeListenPID(_) => None,
ReceiveError::ListenPIDParse(error) => Some(error),
ReceiveError::NoListenFD => None,
ReceiveError::NotUnicodeListenFD(_) => None,
ReceiveError::ListenFDParse(error) => Some(error),
ReceiveError::PidMismatch { expected: _, found: _ } => None,
ReceiveError::GetFds(error) => Some(error),
}
}
}
#[derive(Debug)]
pub enum ReceiveNameError {
NoListenFDName,
Receive(ReceiveError),
}
impl From<ReceiveError> for ReceiveNameError {
fn from(value: ReceiveError) -> Self {
Self::Receive(value)
}
}
impl Display for ReceiveNameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReceiveNameError::NoListenFDName => write!(f, "Couldn't find FDs name : the variable {FD_NAMES_VAR} doesn't exists."),
ReceiveNameError::Receive(error) => Display::fmt(error, f),
}
}
}
impl Error for ReceiveNameError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReceiveNameError::NoListenFDName => None,
ReceiveNameError::Receive(error) => Some(error),
}
}
}
#[derive(Debug)]
pub enum GetFdsError {
TooManyFDs(usize),
}
impl Display for GetFdsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GetFdsError::TooManyFDs(size) => write!(f, "Too many file descriptors ({size})"),
}
}
}
impl Error for GetFdsError {}

142
src/lib.rs Normal file
View file

@ -0,0 +1,142 @@
use std::fs::File;
use std::net::{TcpListener, TcpStream, UdpSocket};
use std::os::unix::net::{UnixListener, UnixStream, UnixDatagram};
use std::{env, process};
use std::ffi::{OsString, OsStr};
use std::os::fd::{RawFd, FromRawFd};
use error::{ReceiveError, ReceiveNameError, GetFdsError};
use rustix::fd::BorrowedFd;
use rustix::fs::FileType;
use rustix::net::SocketType;
mod error;
const SD_LISTEN_FDS_START: RawFd = 3;
const PID_VAR: &str = "LISTEN_PID";
const FD_NUMBER_VAR: &str = "LISTEN_FDS";
const FD_NAMES_VAR: &str = "LISTEN_FDNAMES";
pub enum StoredFd {
File(File),
Directory(RawFd),
Fifo(RawFd),
Special(RawFd),
TcpListener(TcpListener),
TcpStream(TcpStream),
UdpSocket(UdpSocket),
InetOther(RawFd),
UnixListener(UnixListener),
UnixStream(UnixStream),
UnixDatagram(UnixDatagram),
UnixOther(RawFd),
SocketOther(RawFd),
MessageQueue(RawFd),
Other(RawFd),
}
impl FromRawFd for StoredFd {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
let stat = rustix::fs::fstat(BorrowedFd::borrow_raw(fd)).expect("This should only ever be called on a valid file descriptor.");
let file_type = FileType::from_raw_mode(stat.st_mode);
match file_type {
FileType::RegularFile => Self::File(File::from_raw_fd(fd)),
FileType::Directory => Self::Directory(fd),
FileType::Fifo => Self::Fifo(fd),
FileType::Socket => {
let borrowed_fd = BorrowedFd::borrow_raw(fd);
match rustix::net::getsockname(borrowed_fd).unwrap() {
rustix::net::SocketAddrAny::V4(_) |
rustix::net::SocketAddrAny::V6(_) => {
match rustix::net::sockopt::get_socket_type(borrowed_fd).unwrap() {
SocketType::DGRAM => Self::UdpSocket(UdpSocket::from_raw_fd(fd)),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(borrowed_fd).unwrap_or_default() {
Self::TcpListener(TcpListener::from_raw_fd(fd))
} else {
Self::TcpStream(TcpStream::from_raw_fd(fd))
}
},
_ => Self::InetOther(fd),
}
},
rustix::net::SocketAddrAny::Unix(_) => {
match rustix::net::sockopt::get_socket_type(borrowed_fd).unwrap() {
SocketType::DGRAM => Self::UnixDatagram(UnixDatagram::from_raw_fd(fd)),
SocketType::STREAM => {
if rustix::net::sockopt::get_socket_acceptconn(borrowed_fd).unwrap_or_default() {
Self::UnixListener(UnixListener::from_raw_fd(fd))
} else {
Self::UnixStream(UnixStream::from_raw_fd(fd))
}
},
_ => Self::UnixOther(fd),
}
},
_ => Self::SocketOther(fd),
}
},
_ => Self::Other(fd)
}
}
}
impl StoredFd {
pub fn receive(unset_env: bool) -> Result<impl IntoIterator<Item = Self>, ReceiveError> {
let pid = env::var_os(PID_VAR).ok_or(ReceiveError::NoListenPID)?;
let fds = env::var_os(FD_NUMBER_VAR).ok_or(ReceiveError::NoListenFD)?;
tracing::trace!("{PID_VAR} = {pid:?}; {FD_NUMBER_VAR} = {fds:?}");
if unset_env {
env::remove_var(PID_VAR);
env::remove_var(FD_NUMBER_VAR);
env::remove_var(FD_NAMES_VAR);
}
let pid = pid.into_string()
.map_err(ReceiveError::NotUnicodeListenPID)?
.parse::<u32>()
.map_err(ReceiveError::ListenPIDParse)?;
let fds = fds.into_string()
.map_err(ReceiveError::NotUnicodeListenFD)?
.parse::<usize>()
.map_err(ReceiveError::ListenFDParse)?;
let current_pid = process::id();
if current_pid != pid {
return Err(ReceiveError::PidMismatch{expected: pid, found: current_pid});
}
match Self::from_fds(fds) {
Ok(fds) => Ok(fds),
Err(error) => Err(ReceiveError::GetFds(error)),
}
}
pub fn receive_with_names(unset_env: bool) -> Result<Vec<(Self, OsString)>, ReceiveNameError> {
let fd_names = env::var_os(FD_NAMES_VAR).ok_or(ReceiveNameError::NoListenFDName)?;
tracing::trace!("{FD_NAMES_VAR} = {fd_names:?}");
let fd_names = fd_names.as_encoded_bytes()
.split(|b| *b == b':')
.map(|name| {
// SAFETY:
// - Each `word` only contains content that originated from `OsStr::as_encoded_bytes`
// - Only split with ASCII colon which is a non-empty UTF-8 substring
unsafe { OsStr::from_encoded_bytes_unchecked(name) }
})
.map(OsStr::to_os_string);
Ok(Self::receive(unset_env)?.into_iter().zip(fd_names).collect::<Vec<_>>())
}
fn from_fds(num_fds: usize) -> Result<Vec<StoredFd>, GetFdsError> {
(0..num_fds)
.map(|fd_offset| {
SD_LISTEN_FDS_START
.checked_add(fd_offset as RawFd)
.ok_or(GetFdsError::TooManyFDs(num_fds))
// SAFETY: We are receiving the fd so it should be safe
.map(|fd| unsafe { StoredFd::from_raw_fd(fd) })
})
.collect()
}
}