committed by
B. Blechschmidt
11 changed files with 615 additions and 48 deletions
@ -0,0 +1,230 @@ |
|||
#![cfg(target_os = "linux")] |
|||
|
|||
use crate::{error, SocketDomain, SocketProtocol}; |
|||
use nix::{ |
|||
errno::Errno, |
|||
fcntl::{self, FdFlag}, |
|||
sys::socket::{cmsg_space, getsockopt, recvmsg, sendmsg, sockopt, ControlMessage, ControlMessageOwned, MsgFlags, SockType}, |
|||
}; |
|||
use serde::{Deserialize, Serialize}; |
|||
use std::{ |
|||
io::{ErrorKind, IoSlice, IoSliceMut, Result}, |
|||
ops::DerefMut, |
|||
os::fd::{AsFd, AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}, |
|||
}; |
|||
use tokio::net::{TcpSocket, UdpSocket, UnixDatagram}; |
|||
|
|||
const REQUEST_BUFFER_SIZE: usize = 64; |
|||
|
|||
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] |
|||
struct Request { |
|||
protocol: SocketProtocol, |
|||
domain: SocketDomain, |
|||
number: u32, |
|||
} |
|||
|
|||
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] |
|||
enum Response { |
|||
Ok, |
|||
} |
|||
|
|||
/// Reconstruct socket from raw `fd`
|
|||
pub fn reconstruct_socket(fd: RawFd) -> Result<OwnedFd> { |
|||
// Check if `fd` is valid
|
|||
let fd_flags = fcntl::fcntl(fd, fcntl::F_GETFD)?; |
|||
|
|||
// `fd` is confirmed to be valid so it should be closed
|
|||
let socket = unsafe { OwnedFd::from_raw_fd(fd) }; |
|||
|
|||
// Insert CLOEXEC flag to the `fd` to prevent further propagation across `execve(2)` calls
|
|||
let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?; |
|||
if !fd_flags.contains(FdFlag::FD_CLOEXEC) { |
|||
fd_flags.insert(FdFlag::FD_CLOEXEC); |
|||
fcntl::fcntl(fd, fcntl::F_SETFD(fd_flags))?; |
|||
} |
|||
|
|||
Ok(socket) |
|||
} |
|||
|
|||
/// Reconstruct transfer socket from `fd`
|
|||
///
|
|||
/// Panics if called outside of tokio runtime
|
|||
pub fn reconstruct_transfer_socket(fd: OwnedFd) -> Result<UnixDatagram> { |
|||
// Check if socket of type DATAGRAM
|
|||
let sock_type = getsockopt(&fd, sockopt::SockType)?; |
|||
if !matches!(sock_type, SockType::Datagram) { |
|||
return Err(ErrorKind::InvalidInput.into()); |
|||
} |
|||
|
|||
let std_socket: std::os::unix::net::UnixDatagram = fd.into(); |
|||
std_socket.set_nonblocking(true)?; |
|||
|
|||
// Fails if tokio context is absent
|
|||
Ok(UnixDatagram::from_std(std_socket).unwrap()) |
|||
} |
|||
|
|||
/// Create pair of interconnected sockets one of which is set to stay open across `execve(2)` calls.
|
|||
pub async fn create_transfer_socket_pair() -> std::io::Result<(UnixDatagram, OwnedFd)> { |
|||
let (local, remote) = tokio::net::UnixDatagram::pair()?; |
|||
|
|||
let remote_fd: OwnedFd = remote.into_std().unwrap().into(); |
|||
|
|||
// Get `remote_fd` flags
|
|||
let fd_flags = fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_GETFD)?; |
|||
|
|||
// Remove CLOEXEC flag from the `remote_fd` to allow propagating across `execve(2)`
|
|||
let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?; |
|||
fd_flags.remove(FdFlag::FD_CLOEXEC); |
|||
fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_SETFD(fd_flags))?; |
|||
|
|||
Ok((local, remote_fd)) |
|||
} |
|||
|
|||
pub trait TransferableSocket: Sized { |
|||
fn from_fd(fd: OwnedFd) -> Result<Self>; |
|||
fn domain() -> SocketProtocol; |
|||
} |
|||
|
|||
impl TransferableSocket for TcpSocket { |
|||
fn from_fd(fd: OwnedFd) -> Result<Self> { |
|||
// Check if socket is of type STREAM
|
|||
let sock_type = getsockopt(&fd, sockopt::SockType)?; |
|||
if !matches!(sock_type, SockType::Stream) { |
|||
return Err(ErrorKind::InvalidInput.into()); |
|||
} |
|||
|
|||
let std_stream: std::net::TcpStream = fd.into(); |
|||
std_stream.set_nonblocking(true)?; |
|||
|
|||
Ok(TcpSocket::from_std_stream(std_stream)) |
|||
} |
|||
|
|||
fn domain() -> SocketProtocol { |
|||
SocketProtocol::Tcp |
|||
} |
|||
} |
|||
|
|||
impl TransferableSocket for UdpSocket { |
|||
/// Panics if called outside of tokio runtime
|
|||
fn from_fd(fd: OwnedFd) -> Result<Self> { |
|||
// Check if socket is of type DATAGRAM
|
|||
let sock_type = getsockopt(&fd, sockopt::SockType)?; |
|||
if !matches!(sock_type, SockType::Datagram) { |
|||
return Err(ErrorKind::InvalidInput.into()); |
|||
} |
|||
|
|||
let std_socket: std::net::UdpSocket = fd.into(); |
|||
std_socket.set_nonblocking(true)?; |
|||
|
|||
Ok(UdpSocket::try_from(std_socket).unwrap()) |
|||
} |
|||
|
|||
fn domain() -> SocketProtocol { |
|||
SocketProtocol::Udp |
|||
} |
|||
} |
|||
|
|||
/// Send [`Request`] to `socket` and return received [`TransferableSocket`]s
|
|||
///
|
|||
/// Panics if called outside of tokio runtime
|
|||
pub async fn request_sockets<S, T>(mut socket: S, domain: SocketDomain, number: u32) -> error::Result<Vec<T>> |
|||
where |
|||
S: DerefMut<Target = UnixDatagram>, |
|||
T: TransferableSocket, |
|||
{ |
|||
// Borrow socket as mut to prevent multiple simultaneous requests
|
|||
let socket = socket.deref_mut(); |
|||
|
|||
// Send request
|
|||
let request = bincode::serialize(&Request { |
|||
protocol: T::domain(), |
|||
domain, |
|||
number, |
|||
})?; |
|||
|
|||
socket.send(&request[..]).await?; |
|||
|
|||
// Receive response
|
|||
loop { |
|||
socket.readable().await?; |
|||
|
|||
let mut buf = [0_u8; REQUEST_BUFFER_SIZE]; |
|||
let mut iov = [IoSliceMut::new(&mut buf[..])]; |
|||
let mut cmsg = Vec::with_capacity(cmsg_space::<RawFd>() * number as usize); |
|||
|
|||
let msg = recvmsg::<()>(socket.as_fd().as_raw_fd(), &mut iov, Some(&mut cmsg), MsgFlags::empty()); |
|||
|
|||
let msg = match msg { |
|||
Err(Errno::EAGAIN) => continue, |
|||
msg => msg?, |
|||
}; |
|||
|
|||
// Parse response
|
|||
let response = &msg.iovs().next().unwrap()[..msg.bytes]; |
|||
let response: Response = bincode::deserialize(response)?; |
|||
if !matches!(response, Response::Ok) { |
|||
return Err("Request for new sockets failed".into()); |
|||
} |
|||
|
|||
// Process received file descriptors
|
|||
let mut sockets = Vec::<T>::with_capacity(number as usize); |
|||
for cmsg in msg.cmsgs() { |
|||
if let ControlMessageOwned::ScmRights(fds) = cmsg { |
|||
for fd in fds { |
|||
if fd < 0 { |
|||
return Err("Received socket is invalid".into()); |
|||
} |
|||
|
|||
let owned_fd = reconstruct_socket(fd)?; |
|||
sockets.push(T::from_fd(owned_fd)?); |
|||
} |
|||
} |
|||
} |
|||
|
|||
return Ok(sockets); |
|||
} |
|||
} |
|||
|
|||
/// Process [`Request`]s received from `socket`
|
|||
///
|
|||
/// Panics if called outside of tokio runtime
|
|||
pub async fn process_socket_requests(socket: &UnixDatagram) -> error::Result<()> { |
|||
loop { |
|||
let mut buf = [0_u8; REQUEST_BUFFER_SIZE]; |
|||
|
|||
let len = socket.recv(&mut buf[..]).await?; |
|||
|
|||
let request: Request = bincode::deserialize(&buf[..len])?; |
|||
|
|||
let response = Response::Ok; |
|||
let buf = bincode::serialize(&response)?; |
|||
|
|||
let mut owned_fd_buf: Vec<OwnedFd> = Vec::with_capacity(request.number as usize); |
|||
for _ in 0..request.number { |
|||
let fd = match request.protocol { |
|||
SocketProtocol::Tcp => match request.domain { |
|||
SocketDomain::IpV4 => tokio::net::TcpSocket::new_v4(), |
|||
SocketDomain::IpV6 => tokio::net::TcpSocket::new_v6(), |
|||
} |
|||
.map(|s| unsafe { OwnedFd::from_raw_fd(s.into_raw_fd()) }), |
|||
SocketProtocol::Udp => match request.domain { |
|||
SocketDomain::IpV4 => tokio::net::UdpSocket::bind("0.0.0.0:0").await, |
|||
SocketDomain::IpV6 => tokio::net::UdpSocket::bind("[::]:0").await, |
|||
} |
|||
.map(|s| s.into_std().unwrap().into()), |
|||
}; |
|||
match fd { |
|||
Err(err) => log::warn!("Failed to allocate socket: {err}"), |
|||
Ok(fd) => owned_fd_buf.push(fd), |
|||
}; |
|||
} |
|||
|
|||
socket.writable().await?; |
|||
|
|||
let raw_fd_buf: Vec<RawFd> = owned_fd_buf.iter().map(|fd| fd.as_raw_fd()).collect(); |
|||
let cmsg = ControlMessage::ScmRights(&raw_fd_buf[..]); |
|||
let iov = [IoSlice::new(&buf[..])]; |
|||
|
|||
sendmsg::<()>(socket.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None)?; |
|||
} |
|||
} |
|||
Loading…
Reference in new issue