Skip to content

Commit

Permalink
Rewrite suspicious code around ancillary data
Browse files Browse the repository at this point in the history
253 seems to be a hardcoded value without any particular reasoning. Move it to a constant and use a
rounder value of 128.

There were also troubles regarding alignment: the API asks us to pass an arbitrary byte buffer and
then performs unaligned reads/writes. Workaround that by aligning the buffer manually.

For more information, see rust-lang/rust#76915 (comment)
  • Loading branch information
purplesyringa committed Jan 3, 2024
1 parent 6b950bd commit 8b8234a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
37 changes: 30 additions & 7 deletions src/platform/unix/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,39 @@
//! ```

use crate::{Deserializer, Object, Serializer};
use nix::libc::{AF_UNIX, SOCK_CLOEXEC, SOCK_SEQPACKET};
use nix::libc::{cmsghdr, AF_UNIX, SOCK_CLOEXEC, SOCK_SEQPACKET};
use std::io::{Error, ErrorKind, IoSlice, IoSliceMut, Result};
use std::marker::PhantomData;
use std::os::unix::{
io::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd},
net::{AncillaryData, SocketAncillary, UnixStream},
};

const fn round_to_usize(n: usize) -> usize {
const ALIGNMENT: usize = std::mem::size_of::<usize>();
(n + ALIGNMENT - 1) / ALIGNMENT * ALIGNMENT
}

pub(crate) const MAX_PACKET_SIZE: usize = 16 * 1024;
pub(crate) const MAX_PACKET_FDS: usize = 128;
pub(crate) const ANCILLARY_BUFFER_SIZE: usize =
round_to_usize(MAX_PACKET_FDS * std::mem::size_of::<i32>())
+ round_to_usize(std::mem::size_of::<cmsghdr>());

// https://github.com/rust-lang/rust/issues/76915#issuecomment-1875845773
pub(crate) struct AncillaryBuffer {
_alignment: [usize; 0],
pub(crate) data: [u8; ANCILLARY_BUFFER_SIZE],
}

impl AncillaryBuffer {
pub(crate) fn new() -> Self {
Self {
_alignment: [],
data: [0u8; ANCILLARY_BUFFER_SIZE],
}
}
}

/// The transmitting side of a unidirectional channel.
///
Expand Down Expand Up @@ -91,19 +115,18 @@ fn send_on_fd<T: Object>(fd: &UnixStream, value: &T) -> Result<()> {
let fds = s.drain_handles();
let serialized = s.into_vec();

let mut ancillary_buffer = [0; 253];

// Send the data and pass file descriptors
let mut buffer_pos: usize = 0;
let mut fds_pos: usize = 0;

loop {
let buffer_end = serialized.len().min(buffer_pos + MAX_PACKET_SIZE - 1);
let fds_end = fds.len().min(fds_pos + 253);
let fds_end = fds.len().min(fds_pos + MAX_PACKET_FDS);

let is_last = buffer_end == serialized.len() && fds_end == fds.len();

let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
let mut ancillary_buffer = AncillaryBuffer::new();
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer.data);
if !ancillary.add_fds(&fds[fds_pos..fds_end]) {
return Err(Error::new(ErrorKind::Other, "Too many fds to pass"));
}
Expand Down Expand Up @@ -131,14 +154,14 @@ unsafe fn recv_on_fd<T: Object>(fd: &UnixStream) -> Result<Option<T>> {
let mut serialized: Vec<u8> = Vec::new();
let mut buffer_pos: usize = 0;

let mut ancillary_buffer = [0; 253];
let mut received_fds: Vec<OwnedFd> = Vec::new();

loop {
serialized.resize(buffer_pos + MAX_PACKET_SIZE - 1, 0);

let mut marker = [0];
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
let mut ancillary_buffer = AncillaryBuffer::new();
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer.data);

let n_read = fd.recv_vectored_with_ancillary(
&mut [
Expand Down
15 changes: 8 additions & 7 deletions src/platform/unix/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
//! ```

use crate::{
entry, imp, ipc::MAX_PACKET_SIZE, subprocess, Deserializer, FnOnceObject, Object, Serializer,
entry, imp,
ipc::{AncillaryBuffer, MAX_PACKET_FDS, MAX_PACKET_SIZE},
subprocess, Deserializer, FnOnceObject, Object, Serializer,
};
use nix::libc::pid_t;
use std::io::{Error, ErrorKind, IoSlice, IoSliceMut, Result};
Expand Down Expand Up @@ -105,19 +107,18 @@ async fn send_on_fd<T: Object>(fd: &UnixSeqpacket, value: &T) -> Result<()> {
(s.drain_handles(), s.into_vec())
};

let mut ancillary_buffer = [0; 253];

// Send the data and pass file descriptors
let mut buffer_pos: usize = 0;
let mut fds_pos: usize = 0;

loop {
let buffer_end = serialized.len().min(buffer_pos + MAX_PACKET_SIZE - 1);
let fds_end = fds.len().min(fds_pos + 253);
let fds_end = fds.len().min(fds_pos + MAX_PACKET_FDS);

let is_last = buffer_end == serialized.len() && fds_end == fds.len();

let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
let mut ancillary_buffer = AncillaryBuffer::new();
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer.data);
if !ancillary.add_fds(&fds[fds_pos..fds_end]) {
return Err(Error::new(ErrorKind::Other, "Too many fds to pass"));
}
Expand Down Expand Up @@ -147,14 +148,14 @@ async unsafe fn recv_on_fd<T: Object>(fd: &UnixSeqpacket) -> Result<Option<T>> {
let mut serialized: Vec<u8> = Vec::new();
let mut buffer_pos: usize = 0;

let mut ancillary_buffer = [0; 253];
let mut received_fds: Vec<OwnedFd> = Vec::new();

loop {
serialized.resize(buffer_pos + MAX_PACKET_SIZE - 1, 0);

let mut marker = [0];
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
let mut ancillary_buffer = AncillaryBuffer::new();
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer.data);

let n_read = fd
.recv_vectored_with_ancillary(
Expand Down

0 comments on commit 8b8234a

Please sign in to comment.