Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(io,net): add split #206

Merged
merged 5 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compio-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository = { workspace = true }

[dependencies]
compio-buf = { workspace = true, features = ["arrayvec"] }
futures-util = { workspace = true }
paste = { workspace = true }

[dev-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions compio-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ mod buffer;
#[cfg(feature = "compat")]
pub mod compat;
mod read;
mod split;
pub mod util;
mod write;

pub(crate) type IoResult<T> = std::io::Result<T>;

pub use read::*;
pub use split::*;
pub use util::{copy, null, repeat};
pub use write::*;
75 changes: 75 additions & 0 deletions compio-io/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::sync::Arc;

use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use futures_util::lock::Mutex;

use crate::{AsyncRead, AsyncWrite, IoResult};

/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
/// [`AsyncRead`] and [`AsyncWrite`] handles.
pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
let stream = Arc::new(Mutex::new(stream));
(ReadHalf(stream.clone()), WriteHalf(stream))
}

/// The readable half of a value returned from [`split`].
#[derive(Debug)]
pub struct ReadHalf<T>(Arc<Mutex<T>>);

impl<T: Unpin> ReadHalf<T> {
/// Reunites with a previously split [`WriteHalf`].
///
/// # Panics
///
/// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from
/// the same [`split`] operation this method will panic.
/// This can be checked ahead of time by comparing the stored pointer
/// of the two halves.
#[track_caller]
pub fn unsplit(self, w: WriteHalf<T>) -> T {
if Arc::ptr_eq(&self.0, &w.0) {
drop(w);
let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
inner.into_inner()
} else {
#[cold]
fn panic_unrelated() -> ! {
panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
}

panic_unrelated()
}
}
}

impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.lock().await.read_vectored(buf).await
}
}

/// The writable half of a value returned from [`split`].
#[derive(Debug)]
pub struct WriteHalf<T>(Arc<Mutex<T>>);

impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write_vectored(buf).await
}

async fn flush(&mut self) -> IoResult<()> {
self.0.lock().await.flush().await
}

async fn shutdown(&mut self) -> IoResult<()> {
self.0.lock().await.shutdown().await
}
}
18 changes: 17 additions & 1 deletion compio-io/tests/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::Cursor;

use compio_buf::{arrayvec::ArrayVec, BufResult, IoBuf, IoBufMut};
use compio_io::{
AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
split, AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
AsyncWriteAtExt, AsyncWriteExt,
};

Expand Down Expand Up @@ -355,3 +355,19 @@ async fn read_to_end_at() {
assert_eq!(len, 4);
assert_eq!(buf, [4, 5, 1, 4]);
}

#[tokio::test]
async fn split_unsplit() {
let src = Cursor::new([1, 1, 4, 5, 1, 4]);
let (mut read, mut write) = split(src);

let (len, buf) = read.read([0, 0, 0]).await.unwrap();
assert_eq!(len, 3);
assert_eq!(buf, [1, 1, 4]);

let (len, _) = write.write([2, 2, 2]).await.unwrap();
assert_eq!(len, 3);

let src = read.unsplit(write);
assert_eq!(src.into_inner(), [1, 1, 4, 2, 2, 2]);
}
2 changes: 2 additions & 0 deletions compio-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

mod resolve;
mod socket;
pub(crate) mod split;
mod tcp;
mod udp;
mod unix;

pub use resolve::ToSocketAddrsAsync;
pub(crate) use resolve::{each_addr, first_addr_buf};
pub(crate) use socket::*;
pub use split::*;
pub use tcp::*;
pub use udp::*;
pub use unix::*;
135 changes: 135 additions & 0 deletions compio-net/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::{error::Error, fmt, io, ops::Deref, sync::Arc};

use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_io::{AsyncRead, AsyncWrite};

pub(crate) fn split<T>(stream: &T) -> (ReadHalf<T>, WriteHalf<T>)
where
for<'a> &'a T: AsyncRead + AsyncWrite,
{
(ReadHalf(stream), WriteHalf(stream))
}

/// Borrowed read half.
#[derive(Debug)]
pub struct ReadHalf<'a, T>(&'a T);

impl<T> AsyncRead for ReadHalf<'_, T>
where
for<'a> &'a T: AsyncRead,
{
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.read_vectored(buf).await
}
}

/// Borrowed write half.
#[derive(Debug)]
pub struct WriteHalf<'a, T>(&'a T);

impl<T> AsyncWrite for WriteHalf<'_, T>
where
for<'a> &'a T: AsyncWrite,
{
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.write_vectored(buf).await
}

async fn flush(&mut self) -> io::Result<()> {
self.0.flush().await
}

async fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown().await
}
}

pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>)
where
for<'a> &'a T: AsyncRead + AsyncWrite,
{
let stream = Arc::new(stream);
(OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream))
}

/// Owned read half.
#[derive(Debug)]
pub struct OwnedReadHalf<T>(Arc<T>);

impl<T: Unpin> OwnedReadHalf<T> {
/// Attempts to put the two halves of a `TcpStream` back together and
/// recover the original socket. Succeeds only if the two halves
/// originated from the same call to `into_split`.
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
if Arc::ptr_eq(&self.0, &w.0) {
drop(w);
Ok(Arc::try_unwrap(self.0)
.ok()
.expect("`Arc::try_unwrap` failed"))
} else {
Err(ReuniteError(self, w))
}
}
}

impl<T> AsyncRead for OwnedReadHalf<T>
where
for<'a> &'a T: AsyncRead,
{
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.deref().read_vectored(buf).await
}
}

/// Owned write half.
#[derive(Debug)]
pub struct OwnedWriteHalf<T>(Arc<T>);

impl<T> AsyncWrite for OwnedWriteHalf<T>
where
for<'a> &'a T: AsyncWrite,
{
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().write_vectored(buf).await
}

async fn flush(&mut self) -> io::Result<()> {
self.0.deref().flush().await
}

async fn shutdown(&mut self) -> io::Result<()> {
self.0.deref().shutdown().await
}
}

/// Error indicating that two halves were not from the same socket, and thus
/// could not be reunited.
#[derive(Debug)]
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);

impl<T> fmt::Display for ReuniteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}

impl<T: fmt::Debug> Error for ReuniteError<T> {}
21 changes: 20 additions & 1 deletion compio-net/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
use socket2::{Protocol, SockAddr, Type};

use crate::{Socket, ToSocketAddrsAsync};
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf};

/// A TCP socket server, listening for connections.
///
Expand Down Expand Up @@ -203,6 +203,25 @@ impl TcpStream {
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}

/// Splits a [`TcpStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// This method is more efficient than
/// [`into_split`](TcpStream::into_split), but the halves cannot
/// be moved into independently spawned tasks.
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
crate::split(self)
}

/// Splits a [`TcpStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](TcpStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
}

impl AsyncRead for TcpStream {
Expand Down
21 changes: 20 additions & 1 deletion compio-net/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
use socket2::{Domain, SockAddr, Type};

use crate::Socket;
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};

/// A Unix socket server, listening for connections.
///
Expand Down Expand Up @@ -159,6 +159,25 @@ impl UnixStream {
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}

/// Splits a [`UnixStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// This method is more efficient than
/// [`into_split`](UnixStream::into_split), but the halves cannot
/// be moved into independently spawned tasks.
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
crate::split(self)
}

/// Splits a [`UnixStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](UnixStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
}

impl AsyncRead for UnixStream {
Expand Down
Loading