diff --git a/tokio/src/io/join.rs b/tokio/src/io/join.rs new file mode 100644 index 00000000000..dbc7043b67e --- /dev/null +++ b/tokio/src/io/join.rs @@ -0,0 +1,117 @@ +//! Join two values implementing `AsyncRead` and `AsyncWrite` into a single one. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Join two values implementing `AsyncRead` and `AsyncWrite` into a +/// single handle. +pub fn join(reader: R, writer: W) -> Join +where + R: AsyncRead, + W: AsyncWrite, +{ + Join { reader, writer } +} + +pin_project_lite::pin_project! { + /// Joins two values implementing `AsyncRead` and `AsyncWrite` into a + /// single handle. + #[derive(Debug)] + pub struct Join { + #[pin] + reader: R, + #[pin] + writer: W, + } +} + +impl Join +where + R: AsyncRead, + W: AsyncWrite, +{ + /// Splits this `Join` back into its `AsyncRead` and `AsyncWrite` + /// components. + pub fn into_inner(self) -> (R, W) { + (self.reader, self.writer) + } + + /// Returns a reference to the inner reader. + pub fn reader(&self) -> &R { + &self.reader + } + + /// Returns a reference to the inner writer. + pub fn writer(&self) -> &W { + &self.writer + } + + /// Returns a mutable reference to the inner reader. + pub fn reader_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Returns a mutable reference to the inner writer. + pub fn writer_mut(&mut self) -> &mut W { + &mut self.writer + } + + /// Returns a pinned mutable reference to the inner reader. + pub fn reader_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().reader + } + + /// Returns a pinned mutable reference to the inner writer. + pub fn writer_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().writer + } +} + +impl AsyncRead for Join +where + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().reader.poll_read(cx, buf) + } +} + +impl AsyncWrite for Join +where + W: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().writer.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().writer.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 0fd6cc2c5cb..ff35a0e0f7e 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -265,6 +265,8 @@ cfg_io_std! { cfg_io_util! { mod split; pub use split::{split, ReadHalf, WriteHalf}; + mod join; + pub use join::{join, Join}; pub(crate) mod seek; pub(crate) mod util; diff --git a/tokio/tests/io_join.rs b/tokio/tests/io_join.rs new file mode 100644 index 00000000000..69b09393311 --- /dev/null +++ b/tokio/tests/io_join.rs @@ -0,0 +1,83 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{join, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Join, ReadBuf}; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +struct R; + +impl AsyncRead for R { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.put_slice(&[b'z']); + Poll::Ready(Ok(())) + } +} + +struct W; + +impl AsyncWrite for W { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Poll::Ready(Ok(2)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[test] +fn is_send_and_sync() { + fn assert_bound() {} + + assert_bound::>(); +} + +#[test] +fn method_delegation() { + let mut rw = join(R, W); + let mut buf = [0; 1]; + + tokio_test::block_on(async move { + assert_eq!(1, rw.read(&mut buf).await.unwrap()); + assert_eq!(b'z', buf[0]); + + assert_eq!(1, rw.write(&[b'x']).await.unwrap()); + assert_eq!( + 2, + rw.write_vectored(&[io::IoSlice::new(&[b'x'])]) + .await + .unwrap() + ); + assert!(rw.is_write_vectored()); + + assert!(rw.flush().await.is_ok()); + assert!(rw.shutdown().await.is_ok()); + }); +}