diff --git a/compio-io/Cargo.toml b/compio-io/Cargo.toml index 89d9366a..e4fa6a37 100644 --- a/compio-io/Cargo.toml +++ b/compio-io/Cargo.toml @@ -12,6 +12,7 @@ repository = { workspace = true } [dependencies] compio-buf = { workspace = true, features = ["arrayvec"] } +futures-util = { workspace = true } paste = { workspace = true } [dev-dependencies] diff --git a/compio-io/src/lib.rs b/compio-io/src/lib.rs index ffe28fe1..92ba98df 100644 --- a/compio-io/src/lib.rs +++ b/compio-io/src/lib.rs @@ -107,11 +107,13 @@ mod buffer; #[cfg(feature = "compat")] pub mod compat; mod read; +mod split; pub mod util; mod write; pub(crate) type IoResult = std::io::Result; pub use read::*; +pub use split::*; pub use util::{copy, null, repeat}; pub use write::*; diff --git a/compio-io/src/split.rs b/compio-io/src/split.rs new file mode 100644 index 00000000..cbbbe725 --- /dev/null +++ b/compio-io/src/split.rs @@ -0,0 +1,75 @@ +use std::sync::Arc; + +use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use futures_util::lock::Mutex; + +use crate::{AsyncRead, AsyncWrite, IoResult}; + +/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate +/// [`AsyncRead`] and [`AsyncWrite`] handles. +pub fn split(stream: T) -> (ReadHalf, WriteHalf) { + let stream = Arc::new(Mutex::new(stream)); + (ReadHalf(stream.clone()), WriteHalf(stream)) +} + +/// The readable half of a value returned from [`split`]. +#[derive(Debug)] +pub struct ReadHalf(Arc>); + +impl ReadHalf { + /// Reunites with a previously split [`WriteHalf`]. + /// + /// # Panics + /// + /// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from + /// the same [`split`] operation this method will panic. + /// This can be checked ahead of time by comparing the stored pointer + /// of the two halves. + #[track_caller] + pub fn unsplit(self, w: WriteHalf) -> T { + if Arc::ptr_eq(&self.0, &w.0) { + drop(w); + let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed"); + inner.into_inner() + } else { + #[cold] + fn panic_unrelated() -> ! { + panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.") + } + + panic_unrelated() + } + } +} + +impl AsyncRead for ReadHalf { + async fn read(&mut self, buf: B) -> BufResult { + self.0.lock().await.read(buf).await + } + + async fn read_vectored(&mut self, buf: V) -> BufResult { + self.0.lock().await.read_vectored(buf).await + } +} + +/// The writable half of a value returned from [`split`]. +#[derive(Debug)] +pub struct WriteHalf(Arc>); + +impl AsyncWrite for WriteHalf { + async fn write(&mut self, buf: B) -> BufResult { + self.0.lock().await.write(buf).await + } + + async fn write_vectored(&mut self, buf: B) -> BufResult { + self.0.lock().await.write_vectored(buf).await + } + + async fn flush(&mut self) -> IoResult<()> { + self.0.lock().await.flush().await + } + + async fn shutdown(&mut self) -> IoResult<()> { + self.0.lock().await.shutdown().await + } +} diff --git a/compio-io/tests/io.rs b/compio-io/tests/io.rs index 0c7c1e1c..82a2bd24 100644 --- a/compio-io/tests/io.rs +++ b/compio-io/tests/io.rs @@ -2,7 +2,7 @@ use std::io::Cursor; use compio_buf::{arrayvec::ArrayVec, BufResult, IoBuf, IoBufMut}; use compio_io::{ - AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt, + split, AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt, AsyncWriteAtExt, AsyncWriteExt, }; @@ -355,3 +355,19 @@ async fn read_to_end_at() { assert_eq!(len, 4); assert_eq!(buf, [4, 5, 1, 4]); } + +#[tokio::test] +async fn split_unsplit() { + let src = Cursor::new([1, 1, 4, 5, 1, 4]); + let (mut read, mut write) = split(src); + + let (len, buf) = read.read([0, 0, 0]).await.unwrap(); + assert_eq!(len, 3); + assert_eq!(buf, [1, 1, 4]); + + let (len, _) = write.write([2, 2, 2]).await.unwrap(); + assert_eq!(len, 3); + + let src = read.unsplit(write); + assert_eq!(src.into_inner(), [1, 1, 4, 2, 2, 2]); +} diff --git a/compio-net/src/lib.rs b/compio-net/src/lib.rs index d18d7647..72ca50fe 100644 --- a/compio-net/src/lib.rs +++ b/compio-net/src/lib.rs @@ -7,6 +7,7 @@ mod resolve; mod socket; +pub(crate) mod split; mod tcp; mod udp; mod unix; @@ -14,6 +15,7 @@ mod unix; pub use resolve::ToSocketAddrsAsync; pub(crate) use resolve::{each_addr, first_addr_buf}; pub(crate) use socket::*; +pub use split::*; pub use tcp::*; pub use udp::*; pub use unix::*; diff --git a/compio-net/src/split.rs b/compio-net/src/split.rs new file mode 100644 index 00000000..680d2edc --- /dev/null +++ b/compio-net/src/split.rs @@ -0,0 +1,135 @@ +use std::{error::Error, fmt, io, ops::Deref, sync::Arc}; + +use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_io::{AsyncRead, AsyncWrite}; + +pub(crate) fn split(stream: &T) -> (ReadHalf, WriteHalf) +where + for<'a> &'a T: AsyncRead + AsyncWrite, +{ + (ReadHalf(stream), WriteHalf(stream)) +} + +/// Borrowed read half. +#[derive(Debug)] +pub struct ReadHalf<'a, T>(&'a T); + +impl AsyncRead for ReadHalf<'_, T> +where + for<'a> &'a T: AsyncRead, +{ + async fn read(&mut self, buf: B) -> BufResult { + self.0.read(buf).await + } + + async fn read_vectored(&mut self, buf: V) -> BufResult { + self.0.read_vectored(buf).await + } +} + +/// Borrowed write half. +#[derive(Debug)] +pub struct WriteHalf<'a, T>(&'a T); + +impl AsyncWrite for WriteHalf<'_, T> +where + for<'a> &'a T: AsyncWrite, +{ + async fn write(&mut self, buf: B) -> BufResult { + self.0.write(buf).await + } + + async fn write_vectored(&mut self, buf: B) -> BufResult { + self.0.write_vectored(buf).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.0.flush().await + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.0.shutdown().await + } +} + +pub(crate) fn into_split(stream: T) -> (OwnedReadHalf, OwnedWriteHalf) +where + for<'a> &'a T: AsyncRead + AsyncWrite, +{ + let stream = Arc::new(stream); + (OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream)) +} + +/// Owned read half. +#[derive(Debug)] +pub struct OwnedReadHalf(Arc); + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to `into_split`. + pub fn reunite(self, w: OwnedWriteHalf) -> Result> { + if Arc::ptr_eq(&self.0, &w.0) { + drop(w); + Ok(Arc::try_unwrap(self.0) + .ok() + .expect("`Arc::try_unwrap` failed")) + } else { + Err(ReuniteError(self, w)) + } + } +} + +impl AsyncRead for OwnedReadHalf +where + for<'a> &'a T: AsyncRead, +{ + async fn read(&mut self, buf: B) -> BufResult { + self.0.deref().read(buf).await + } + + async fn read_vectored(&mut self, buf: V) -> BufResult { + self.0.deref().read_vectored(buf).await + } +} + +/// Owned write half. +#[derive(Debug)] +pub struct OwnedWriteHalf(Arc); + +impl AsyncWrite for OwnedWriteHalf +where + for<'a> &'a T: AsyncWrite, +{ + async fn write(&mut self, buf: B) -> BufResult { + self.0.deref().write(buf).await + } + + async fn write_vectored(&mut self, buf: B) -> BufResult { + self.0.deref().write_vectored(buf).await + } + + async fn flush(&mut self) -> io::Result<()> { + self.0.deref().flush().await + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.0.deref().shutdown().await + } +} + +/// Error indicating that two halves were not from the same socket, and thus +/// could not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index 6187de73..cd025a4f 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite}; use compio_runtime::{impl_attachable, impl_try_as_raw_fd}; use socket2::{Protocol, SockAddr, Type}; -use crate::{Socket, ToSocketAddrsAsync}; +use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf}; /// A TCP socket server, listening for connections. /// @@ -203,6 +203,25 @@ impl TcpStream { .local_addr() .map(|addr| addr.as_socket().expect("should be SocketAddr")) } + + /// Splits a [`TcpStream`] into a read half and a write half, which can be + /// used to read and write the stream concurrently. + /// + /// This method is more efficient than + /// [`into_split`](TcpStream::into_split), but the halves cannot + /// be moved into independently spawned tasks. + pub fn split(&self) -> (ReadHalf, WriteHalf) { + crate::split(self) + } + + /// Splits a [`TcpStream`] into a read half and a write half, which can be + /// used to read and write the stream concurrently. + /// + /// Unlike [`split`](TcpStream::split), the owned halves can be moved to + /// separate tasks, however this comes at the cost of a heap allocation. + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + crate::into_split(self) + } } impl AsyncRead for TcpStream { diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index 29d722be..85e23fc8 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite}; use compio_runtime::{impl_attachable, impl_try_as_raw_fd}; use socket2::{Domain, SockAddr, Type}; -use crate::Socket; +use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf}; /// A Unix socket server, listening for connections. /// @@ -159,6 +159,25 @@ impl UnixStream { pub fn local_addr(&self) -> io::Result { self.inner.local_addr() } + + /// Splits a [`UnixStream`] into a read half and a write half, which can be + /// used to read and write the stream concurrently. + /// + /// This method is more efficient than + /// [`into_split`](UnixStream::into_split), but the halves cannot + /// be moved into independently spawned tasks. + pub fn split(&self) -> (ReadHalf, WriteHalf) { + crate::split(self) + } + + /// Splits a [`UnixStream`] into a read half and a write half, which can be + /// used to read and write the stream concurrently. + /// + /// Unlike [`split`](UnixStream::split), the owned halves can be moved to + /// separate tasks, however this comes at the cost of a heap allocation. + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + crate::into_split(self) + } } impl AsyncRead for UnixStream { diff --git a/compio-net/tests/split.rs b/compio-net/tests/split.rs new file mode 100644 index 00000000..4d8b51a6 --- /dev/null +++ b/compio-net/tests/split.rs @@ -0,0 +1,100 @@ +use std::io::{Read, Write}; + +use compio_buf::BufResult; +use compio_io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use compio_net::{TcpStream, UnixListener, UnixStream}; + +#[compio_macros::test] +async fn tcp_split() { + const MSG: &[u8] = b"split"; + + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = compio_runtime::spawn_blocking(move || { + let (mut stream, _) = listener.accept().unwrap(); + stream.write_all(MSG).unwrap(); + + let mut read_buf = [0u8; 32]; + let read_len = stream.read(&mut read_buf).unwrap(); + assert_eq!(&read_buf[..read_len], MSG); + }); + + let stream = TcpStream::connect(&addr).await.unwrap(); + let (mut read_half, mut write_half) = stream.into_split(); + + let read_buf = [0u8; 32]; + let (read_res, buf) = read_half.read(read_buf).await.unwrap(); + assert_eq!(read_res, MSG.len()); + assert_eq!(&buf[..MSG.len()], MSG); + + write_half.write_all(MSG).await.unwrap(); + handle.await; +} + +#[compio_macros::test] +async fn tcp_unsplit() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = compio_runtime::spawn_blocking(move || { + drop(listener.accept().unwrap()); + drop(listener.accept().unwrap()); + }); + + let stream1 = TcpStream::connect(&addr).await.unwrap(); + let (read1, write1) = stream1.into_split(); + + let stream2 = TcpStream::connect(&addr).await.unwrap(); + let (_, write2) = stream2.into_split(); + + let read1 = match read1.reunite(write2) { + Ok(_) => panic!("Reunite should not succeed"), + Err(err) => err.0, + }; + + read1.reunite(write1).expect("Reunite should succeed"); + + handle.await; +} + +#[compio_macros::test] +async fn unix_split() { + let dir = tempfile::Builder::new() + .prefix("compio-uds-split-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("connect.sock"); + + let listener = UnixListener::bind(&sock_path).unwrap(); + + let client = UnixStream::connect(&sock_path).unwrap(); + let (server, _) = listener.accept().await.unwrap(); + + let (mut a_read, mut a_write) = server.into_split(); + let (mut b_read, mut b_write) = client.into_split(); + + let (a_response, b_response) = futures_util::future::try_join( + send_recv_all(&mut a_read, &mut a_write, b"A"), + send_recv_all(&mut b_read, &mut b_write, b"B"), + ) + .await + .unwrap(); + + assert_eq!(a_response, b"B"); + assert_eq!(b_response, b"A"); +} + +async fn send_recv_all( + read: &mut R, + write: &mut W, + input: &'static [u8], +) -> std::io::Result> { + write.write_all(input).await.0?; + write.shutdown().await?; + + let output = Vec::with_capacity(2); + let BufResult(res, buf) = read.read_exact(output).await; + assert_eq!(res.unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof); + Ok(buf) +}