From f784ba469fa3aa8d1c45beab378f82c44f1ebea9 Mon Sep 17 00:00:00 2001 From: Gus Caplan Date: Sat, 16 Dec 2023 11:19:26 -0800 Subject: [PATCH] io: add Join utility --- tokio/src/io/join.rs | 133 +++++++++++++++++++++++++++++++++++++++++ tokio/src/io/mod.rs | 2 + tokio/tests/io_join.rs | 83 +++++++++++++++++++++++++ 3 files changed, 218 insertions(+) create mode 100644 tokio/src/io/join.rs create mode 100644 tokio/tests/io_join.rs diff --git a/tokio/src/io/join.rs b/tokio/src/io/join.rs new file mode 100644 index 00000000000..0b20131ffd7 --- /dev/null +++ b/tokio/src/io/join.rs @@ -0,0 +1,133 @@ +//! Join two values implementing `AsyncRead` and `AsyncWrite` into a single one. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + pin_project_lite::pin_project! { + /// Joins two values implementing `AsyncRead` and `AsyncWrite` into a + /// single handle. + pub struct Join { + #[pin] + reader: R, + #[pin] + writer: W, + } + } + + impl Join + where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, + { + /// Join two values implementing `AsyncRead` and `AsyncWrite` into a + /// single handle. + pub fn new(reader: R, writer: W) -> Self { + Self { reader, writer } + } + + /// Splits this `Join` back into its `AsyncRead` and `AsyncWrite` + /// components. + pub fn split(self) -> (R, W) { + (self.reader, self.writer) + } + + /// Returns a reference to the inner reader. + pub fn reader(&self) -> &R { + &self.reader + } + + /// Returns a reference to the inner writer. + pub fn writer(&self) -> &W { + &self.writer + } + + /// Returns a mutable reference to the inner reader. + pub fn reader_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Returns a mutable reference to the inner writer. + pub fn writer_mut(&mut self) -> &mut W { + &mut self.writer + } + } + + impl AsyncRead for Join + where + R: AsyncRead + Unpin, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().reader.poll_read(cx, buf) + } + } + + impl AsyncWrite for Join + where + W: AsyncWrite + Unpin, + { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().writer.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().writer.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>] + ) -> Poll> { + self.project().writer.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } + } + + impl Clone for Join + where R: Clone, W: Clone + { + fn clone(&self) -> Self { + Join { + writer: self.writer.clone(), + reader: self.reader.clone(), + } + } + } + + impl fmt::Debug for Join + where R: fmt::Debug, W: fmt::Debug + { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Join") + .field("writer", &self.writer) + .field("reader", &self.reader) + .finish() + } + } +} diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 0fd6cc2c5cb..11bdf75f4cf 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -265,6 +265,8 @@ cfg_io_std! { cfg_io_util! { mod split; pub use split::{split, ReadHalf, WriteHalf}; + mod join; + pub use join::Join; pub(crate) mod seek; pub(crate) mod util; diff --git a/tokio/tests/io_join.rs b/tokio/tests/io_join.rs new file mode 100644 index 00000000000..29f2b7745ca --- /dev/null +++ b/tokio/tests/io_join.rs @@ -0,0 +1,83 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi does not support panic recovery + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Join, ReadBuf}; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +struct R; + +impl AsyncRead for R { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.put_slice(&[b'z']); + Poll::Ready(Ok(())) + } +} + +struct W; + +impl AsyncWrite for W { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Poll::Ready(Ok(2)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[test] +fn is_send_and_sync() { + fn assert_bound() {} + + assert_bound::>(); +} + +#[test] +fn method_delegation() { + let mut rw = Join::new(R, W); + let mut buf = [0; 1]; + + tokio_test::block_on(async move { + assert_eq!(1, rw.read(&mut buf).await.unwrap()); + assert_eq!(b'z', buf[0]); + + assert_eq!(1, rw.write(&[b'x']).await.unwrap()); + assert_eq!( + 2, + rw.write_vectored(&[io::IoSlice::new(&[b'x'])]) + .await + .unwrap() + ); + assert!(rw.is_write_vectored()); + + assert!(rw.flush().await.is_ok()); + assert!(rw.shutdown().await.is_ok()); + }); +}