diff --git a/tokio/src/io/split.rs b/tokio/src/io/split.rs index 63f0960e4f3..2602929cdd1 100644 --- a/tokio/src/io/split.rs +++ b/tokio/src/io/split.rs @@ -6,13 +6,11 @@ use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use std::cell::UnsafeCell; use std::fmt; use std::io; use std::pin::Pin; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::{Acquire, Release}; use std::sync::Arc; +use std::sync::Mutex; use std::task::{Context, Poll}; cfg_io_util! { @@ -38,8 +36,7 @@ cfg_io_util! { let is_write_vectored = stream.is_write_vectored(); let inner = Arc::new(Inner { - locked: AtomicBool::new(false), - stream: UnsafeCell::new(stream), + stream: Mutex::new(stream), is_write_vectored, }); @@ -54,13 +51,19 @@ cfg_io_util! { } struct Inner { - locked: AtomicBool, - stream: UnsafeCell, + stream: Mutex, is_write_vectored: bool, } -struct Guard<'a, T> { - inner: &'a Inner, +impl Inner { + fn with_lock(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R { + let mut guard = self.stream.lock().unwrap(); + + // safety: we do not move the stream. + let stream = unsafe { Pin::new_unchecked(&mut *guard) }; + + f(stream) + } } impl ReadHalf { @@ -90,7 +93,7 @@ impl ReadHalf { .ok() .expect("`Arc::try_unwrap` failed"); - inner.stream.into_inner() + inner.stream.into_inner().unwrap() } else { panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") } @@ -111,8 +114,7 @@ impl AsyncRead for ReadHalf { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_read(cx, buf) + self.inner.with_lock(|stream| stream.poll_read(cx, buf)) } } @@ -122,18 +124,15 @@ impl AsyncWrite for WriteHalf { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write(cx, buf) + self.inner.with_lock(|stream| stream.poll_write(cx, buf)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_flush(cx) + self.inner.with_lock(|stream| stream.poll_flush(cx)) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_shutdown(cx) + self.inner.with_lock(|stream| stream.poll_shutdown(cx)) } fn poll_write_vectored( @@ -141,8 +140,8 @@ impl AsyncWrite for WriteHalf { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write_vectored(cx, bufs) + self.inner + .with_lock(|stream| stream.poll_write_vectored(cx, bufs)) } fn is_write_vectored(&self) -> bool { @@ -150,39 +149,6 @@ impl AsyncWrite for WriteHalf { } } -impl Inner { - fn poll_lock(&self, cx: &mut Context<'_>) -> Poll> { - if self - .locked - .compare_exchange(false, true, Acquire, Acquire) - .is_ok() - { - Poll::Ready(Guard { inner: self }) - } else { - // Spin... but investigate a better strategy - - std::thread::yield_now(); - cx.waker().wake_by_ref(); - - Poll::Pending - } - } -} - -impl Guard<'_, T> { - fn stream_pin(&mut self) -> Pin<&mut T> { - // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual - // exclusion. - unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } - } -} - -impl Drop for Guard<'_, T> { - fn drop(&mut self) { - self.inner.locked.store(false, Release); - } -} - unsafe impl Send for ReadHalf {} unsafe impl Send for WriteHalf {} unsafe impl Sync for ReadHalf {}