diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index 166c31a61c0..05be5ab2f71 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -55,6 +55,9 @@ use then::Then; mod try_next; use try_next::TryNext; +mod peekable; +use peekable::Peekable; + cfg_time! { pub(crate) mod timeout; pub(crate) mod timeout_repeating; @@ -1176,6 +1179,31 @@ pub trait StreamExt: Stream { assert!(max_size > 0, "`max_size` must be non-zero."); ChunksTimeout::new(self, max_size, duration) } + + /// Turns the stream into a peekable stream, whose first element can be peeked at without being + /// consumed. + /// ```rust + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// #[tokio::main] + /// # async fn _unused() {} + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// let iter = vec![1, 2, 3, 4].into_iter(); + /// let mut stream = stream::iter(iter).peekable(); + /// + /// assert_eq!(*stream.peek().await.unwrap(), 1); + /// assert_eq!(*stream.peek().await.unwrap(), 1); + /// assert_eq!(stream.next().await.unwrap(), 1); + /// assert_eq!(*stream.peek().await.unwrap(), 2); + /// } + /// ``` + fn peekable(self) -> Peekable + where + Self: Sized, + { + Peekable::new(self) + } } impl StreamExt for St where St: Stream {} diff --git a/tokio-stream/src/stream_ext/peekable.rs b/tokio-stream/src/stream_ext/peekable.rs new file mode 100644 index 00000000000..7545ee409f0 --- /dev/null +++ b/tokio-stream/src/stream_ext/peekable.rs @@ -0,0 +1,50 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::Stream; +use pin_project_lite::pin_project; + +use crate::stream_ext::Fuse; +use crate::StreamExt; + +pin_project! { + /// Stream returned by the [`chain`](super::StreamExt::peekable) method. + pub struct Peekable { + peek: Option, + #[pin] + stream: Fuse, + } +} + +impl Peekable { + pub(crate) fn new(stream: T) -> Self { + let stream = stream.fuse(); + Self { peek: None, stream } + } + + /// Peek at the next item in the stream. + pub async fn peek(&mut self) -> Option<&T::Item> + where + T: Unpin, + { + if let Some(ref it) = self.peek { + Some(it) + } else { + self.peek = self.next().await; + self.peek.as_ref() + } + } +} + +impl Stream for Peekable { + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if let Some(it) = this.peek.take() { + Poll::Ready(Some(it)) + } else { + this.stream.poll_next(cx) + } + } +}