From af5ad735bc995468f25fa492735875a7473a8dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20DOUIN?= Date: Sun, 31 Mar 2024 09:34:21 +0200 Subject: [PATCH] Handle tag desynchronization (#284) --- src/client.rs | 32 ++++++++++++++++++++++++++++++-- src/error.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 39d2b34..2937004 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,6 +8,8 @@ use std::ops::{Deref, DerefMut}; use std::str; use std::sync::mpsc; +use crate::error::TagMismatch; + use super::authenticator::Authenticator; use super::error::{Bad, Bye, Error, No, ParseError, Result, ValidateError}; use super::extensions; @@ -171,6 +173,17 @@ pub struct Connection { pub greeting_read: bool, } +impl Connection { + /// Manually increment the current tag. + /// + /// This function can be manually executed by callers when the + /// previous tag was not reused, for example when a timeout did + /// not write anything on the stream. + pub fn skip_tag(&mut self) { + self.tag += 1; + } +} + /// A builder for the append command #[must_use] pub struct AppendCmd<'a, T: Read + Write> { @@ -1495,7 +1508,12 @@ impl Connection { fn run_command(&mut self, untagged_command: &str) -> Result<()> { let command = self.create_command(untagged_command); - self.write_line(command.into_bytes().as_slice()) + let result = self.write_line(command.into_bytes().as_slice()); + if result.is_err() { + // rollback tag increased in create_command() + self.tag -= 1; + } + result } fn run_command_and_read_response(&mut self, untagged_command: &str) -> Result> { @@ -1547,7 +1565,17 @@ impl Connection { .. }, )) => { - assert_eq!(tag.as_bytes(), match_tag.as_bytes()); + // check if tag matches + if tag.as_bytes() != match_tag.as_bytes() { + let expect = self.tag; + let actual = tag + .0 + .trim_start_matches(TAG_PREFIX) + .parse::() + .map_err(|_| tag.as_bytes().to_vec()); + break Err(Error::TagMismatch(TagMismatch { expect, actual })); + } + Some(match status { Status::Bad | Status::No | Status::Bye => Err(( status, diff --git a/src/error.rs b/src/error.rs index d11601a..d368d25 100644 --- a/src/error.rs +++ b/src/error.rs @@ -105,6 +105,10 @@ pub enum Error { /// In response to a STATUS command, the server sent OK without actually sending any STATUS /// responses first. MissingStatusResponse, + /// The server responded with a different command tag than the one we just sent. + /// + /// A new session must generally be established to recover from this. + TagMismatch(TagMismatch), /// StartTls is not available on the server StartTlsNotAvailable, /// Returns when Tls is not configured @@ -175,6 +179,7 @@ impl fmt::Display for Error { Error::Append => f.write_str("Could not append mail to mailbox"), Error::Unexpected(ref r) => write!(f, "Unexpected Response: {:?}", r), Error::MissingStatusResponse => write!(f, "Missing STATUS Response"), + Error::TagMismatch(ref data) => write!(f, "Mismatched Tag: {:?}", data), Error::StartTlsNotAvailable => write!(f, "StartTls is not available on the server"), Error::TlsNotConfigured => { write!(f, "TLS was requested, but no TLS features are enabled") @@ -203,6 +208,7 @@ impl StdError for Error { Error::Append => "Could not append mail to mailbox", Error::Unexpected(_) => "Unexpected Response", Error::MissingStatusResponse => "Missing STATUS Response", + Error::TagMismatch(ref e) => e.description(), Error::StartTlsNotAvailable => "StartTls is not available on the server", Error::TlsNotConfigured => "TLS was requested, but no TLS features are enabled", } @@ -218,6 +224,7 @@ impl StdError for Error { #[cfg(feature = "native-tls")] Error::TlsHandshake(ref e) => Some(e), Error::Parse(ParseError::DataNotUtf8(_, ref e)) => Some(e), + Error::TagMismatch(ref e) => Some(e), _ => None, } } @@ -296,6 +303,28 @@ impl StdError for ValidateError { } } +/// The server responded with a different command tag than last one we sent. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct TagMismatch { + /// Expected tag number + pub(crate) expect: u32, + /// Actual tag number, 0 if parse failed + pub(crate) actual: std::result::Result>, +} + +impl fmt::Display for TagMismatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Expected tag number is {}, actual {:?}", + self.expect, self.actual + ) + } +} + +impl StdError for TagMismatch {} + #[cfg(test)] mod tests { use super::*;