From 64c0c1c3abe5b1f9e55deb6441e14ea68ceb6605 Mon Sep 17 00:00:00 2001 From: Edward Rudd Date: Tue, 29 Aug 2023 22:48:58 -0400 Subject: [PATCH] fixup! Change the client builder so that it abstracts away connecting to TLS or non-TLS connections and what TLS provider is used. --- src/client.rs | 1 + src/client_builder.rs | 176 ++++++++++++++++------------------- src/error.rs | 11 +++ tests/builder_integration.rs | 27 ++++-- tests/imap_integration.rs | 6 +- 5 files changed, 117 insertions(+), 104 deletions(-) diff --git a/src/client.rs b/src/client.rs index 236de66..141b9b1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -340,6 +340,7 @@ impl Client { /// /// This consumes `self` since the Client is not much use without /// an underlying transport. + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub(crate) fn into_inner(self) -> Result { let res = self.conn.stream.into_inner()?; Ok(res) diff --git a/src/client_builder.rs b/src/client_builder.rs index f7435b1..afa9003 100644 --- a/src/client_builder.rs +++ b/src/client_builder.rs @@ -1,4 +1,4 @@ -use crate::{Client, Connection, Result}; +use crate::{Client, Connection, Error, Result}; use lazy_static::lazy_static; use std::io::{Read, Write}; @@ -8,7 +8,6 @@ use std::net::TcpStream; use native_tls::TlsConnector as NativeTlsConnector; use crate::extensions::idle::SetReadTimeout; -use imap_proto::Capability; #[cfg(feature = "rustls-tls")] use rustls_connector::{ rustls, @@ -57,10 +56,16 @@ lazy_static! { #[derive(Clone, Debug, PartialEq, Eq)] pub enum ConnectionMode { /// Automatically detect what connection mode should be used. - /// This will use TLS if the port is 993 and StartTLls if available. + /// This will use TLS if the port is 993, and StartTLls if the server says it's available. + /// If only Plaintext is available it will error out. + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + AutoTls, + /// Automatically detect what connection mode should be used. + /// This will use TLS if the port is 993, and StartTLls if the server says it's available. + /// Finally it will fallback to Plaintext Auto, /// A plain unencrypted Tcp connection - Tcp, + Plaintext, /// an encrypted TLS connection #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] Tls, @@ -80,7 +85,7 @@ pub enum TlsKind { /// Use the Rustls backend #[cfg(feature = "rustls-tls")] Rust, - /// Use whatever backend is available (rustls used it both are compiled in) + /// Use whatever backend is available (uses rustls if both are available) Any, } @@ -143,7 +148,7 @@ where #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] tls_kind: TlsKind, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - no_tls_verify: bool, + skip_tls_verify: bool, } impl ClientBuilder @@ -155,28 +160,17 @@ where ClientBuilder { domain, port, + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + mode: ConnectionMode::AutoTls, + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] mode: ConnectionMode::Auto, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] tls_kind: TlsKind::Any, #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - no_tls_verify: false, + skip_tls_verify: false, } } - /// Use [`STARTTLS`](https://tools.ietf.org/html/rfc2595) for this connection. - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - #[deprecated = "Use mode(ConnectionMode::StartTls) instead"] - pub fn starttls(&mut self) -> &mut Self { - self.mode(ConnectionMode::StartTls) - } - - /// Use TLS for this connection. - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - #[deprecated = "Use mode(ConnectionMode::Tls) instead"] - pub fn tls(&mut self) -> &mut Self { - self.mode(ConnectionMode::Tls) - } - /// Sets the Connection mode to use for this connection pub fn mode(&mut self, mode: ConnectionMode) -> &mut Self { self.mode = mode; @@ -204,56 +198,11 @@ where /// [`native_tls::TlsConnectorBuilder::danger_accept_invalid_hostnames`], /// [`rustls::ClientConfig::dangerous`] #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - pub fn danger_no_tls_verify(&mut self, no_tls_verify: bool) -> &mut Self { - self.no_tls_verify = no_tls_verify; + pub fn danger_skip_tls_verify(&mut self, skip_tls_verify: bool) -> &mut Self { + self.skip_tls_verify = skip_tls_verify; self } - /// Makes a [`Client`] connection using the native TLS backend. - /// - /// Forces TLS unless [`ClientBuilder::mode(ConnectionMode::StartTls)`] was called - /// - /// ```no_run - /// # use imap::ClientBuilder; - /// # {} #[cfg(feature = "rustls-tls")] - /// # fn main() -> Result<(), imap::Error> { - /// use imap::TlsKind; - /// let client = ClientBuilder::new("imap.example.com", 993).native_tls()?; - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "native-tls")] - #[deprecated = "Use tls_kind(TlsKind::Native).connect() instead"] - pub fn native_tls(&mut self) -> Result> { - if self.mode == ConnectionMode::Tcp { - self.mode(ConnectionMode::Tls); - } - - self.tls_kind(TlsKind::Native).connect() - } - - /// Makes a [`Client`] connection using the rustls TLS backend. - /// - /// Forces TLS unless [`ClientBuilder::starttls()`] was called - /// - /// ```no_run - /// # use imap::ClientBuilder; - /// # {} #[cfg(feature = "rustls-tls")] - /// # fn main() -> Result<(), imap::Error> { - /// let client = ClientBuilder::new("imap.example.com", 993).rustls()?; - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls-tls")] - #[deprecated = "Use tls_kind(TlsKind::Rust).connect() instead"] - pub fn rustls(&mut self) -> Result> { - if self.mode == ConnectionMode::Tcp { - self.mode(ConnectionMode::Tls); - } - - self.tls_kind(TlsKind::Rust).connect() - } - /// Make a [`Client`] using the configuration. /// /// ```no_run @@ -266,12 +215,12 @@ where /// # } /// ``` pub fn connect(&self) -> Result> { - self.connect_with(|_domain, tcp| { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - return self.build_tls_connection(tcp); - #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] - return Ok(tcp); - }) + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + return self.connect_with(|_domain, tcp| self.build_tls_connection(tcp)); + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + return self.connect_with(|_domain, _tcp| -> Result { + return Err(Error::TlsNotConfigured); + }); } /// Make a [`Client`] using a custom initialization. This function is intended @@ -286,9 +235,10 @@ where /// - domain: [`&str`] /// - tcp: [`TcpStream`] /// - /// and yield a `Result` where `C` is `Read + Write`. It should only perform - /// TLS initialization over the given `tcp` socket and return the encrypted stream - /// object, such as a [`native_tls::TlsStream`] or a [`rustls_connector::TlsStream`]. + /// and yield a `Result` where `C` is `Read + Write + Send + SetReadTimeout + 'static,`. + /// It should only perform TLS initialization over the given `tcp` socket and return the + /// encrypted stream object, such as a [`native_tls::TlsStream`] or a + /// [`rustls_connector::TlsStream`]. /// /// If the caller is using `STARTTLS` and previously called [`starttls`](Self::starttls) /// then the `tcp` socket given to the `handshake` function will be connected and will @@ -308,42 +258,56 @@ where /// # Ok(()) /// # } /// ``` + #[allow(unused_variables)] pub fn connect_with(&self, handshake: F) -> Result> where F: FnOnce(&str, TcpStream) -> Result, C: Read + Write + Send + SetReadTimeout + 'static, { + #[allow(unused_mut)] let mut greeting_read = false; let tcp = TcpStream::connect((self.domain.as_ref(), self.port))?; let stream: Connection = match self.mode { - ConnectionMode::Auto => { + ConnectionMode::AutoTls => { + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] if self.port == 993 { Box::new(handshake(self.domain.as_ref(), tcp)?) } else { - let mut client = Client::new(tcp); - client.read_greeting()?; + let (stream, upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?; greeting_read = true; - let capabilities = client.capabilities()?; - if capabilities.has(&Capability::Atom("STARTTLS".into())) { - client.run_command_and_check_ok("STARTTLS")?; - let tcp = client.into_inner()?; - Box::new(handshake(self.domain.as_ref(), tcp)?) - } else { - Box::new(client.into_inner()?) + if !upgraded { + Err(Error::StartTlsNotAvailable)? } + stream } + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + Err(Error::TlsNotConfigured)? } - ConnectionMode::Tcp => Box::new(tcp), + ConnectionMode::Auto => { + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + if self.port == 993 { + Box::new(handshake(self.domain.as_ref(), tcp)?) + } else { + let (stream, _upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?; + greeting_read = true; + + stream + } + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + Box::new(tcp) + } + ConnectionMode::Plaintext => Box::new(tcp), #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] ConnectionMode::StartTls => { - let mut client = Client::new(tcp); - client.read_greeting()?; + let (stream, upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?; greeting_read = true; - client.run_command_and_check_ok("STARTTLS")?; - let tcp = client.into_inner()?; - Box::new(handshake(self.domain.as_ref(), tcp)?) + + if !upgraded { + Err(Error::StartTlsNotAvailable)? + } + stream } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] ConnectionMode::Tls => Box::new(handshake(self.domain.as_ref(), tcp)?), @@ -359,6 +323,28 @@ where Ok(client) } + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + fn upgrade_tls( + &self, + mut client: Client, + handshake: F, + ) -> Result<(Connection, bool)> + where + F: FnOnce(&str, TcpStream) -> Result, + C: Read + Write + Send + SetReadTimeout + 'static, + { + client.read_greeting()?; + + let capabilities = client.capabilities()?; + if capabilities.has(&imap_proto::Capability::Atom("STARTTLS".into())) { + client.run_command_and_check_ok("STARTTLS")?; + let tcp = client.into_inner()?; + Ok((Box::new(handshake(self.domain.as_ref(), tcp)?), true)) + } else { + Ok((Box::new(client.into_inner()?), false)) + } + } + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] fn build_tls_connection(&self, tcp: TcpStream) -> Result { match self.tls_kind { @@ -386,7 +372,7 @@ where .with_safe_defaults() .with_root_certificates(CACERTS.clone()) .with_no_client_auth(); - if self.no_tls_verify { + if self.skip_tls_verify { let no_cert_verifier = NoCertVerification; config .dangerous() @@ -399,7 +385,7 @@ where #[cfg(feature = "native-tls")] fn build_tls_native(&self, tcp: TcpStream) -> Result { let mut builder = NativeTlsConnector::builder(); - if self.no_tls_verify { + if self.skip_tls_verify { builder.danger_accept_invalid_certs(true); builder.danger_accept_invalid_hostnames(true); } diff --git a/src/error.rs b/src/error.rs index e98880c..a881251 100644 --- a/src/error.rs +++ b/src/error.rs @@ -105,6 +105,11 @@ pub enum Error { /// In response to a STATUS command, the server sent OK without actually sending any STATUS /// responses first. MissingStatusResponse, + /// StartTls is not available on the server + StartTlsNotAvailable, + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + /// Returns when Tls is not configured + TlsNotConfigured, } impl From for Error { @@ -171,6 +176,9 @@ impl fmt::Display for Error { Error::Append => f.write_str("Could not append mail to mailbox"), Error::Unexpected(ref r) => write!(f, "Unexpected Response: {:?}", r), Error::MissingStatusResponse => write!(f, "Missing STATUS Response"), + Error::StartTlsNotAvailable => write!(f, "StartTls is not available on the server"), + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + Error::TlsNotConfigured => write!(f, "No Tls feature is available"), } } } @@ -195,6 +203,9 @@ impl StdError for Error { Error::Append => "Could not append mail to mailbox", Error::Unexpected(_) => "Unexpected Response", Error::MissingStatusResponse => "Missing STATUS Response", + Error::StartTlsNotAvailable => "StartTls is not available on the server", + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))] + Error::TlsNotConfigured => "No Tls feature is available", } } diff --git a/tests/builder_integration.rs b/tests/builder_integration.rs index d8e490c..b3c4232 100644 --- a/tests/builder_integration.rs +++ b/tests/builder_integration.rs @@ -36,7 +36,7 @@ fn starttls_force() { let user = "starttls@localhost"; let host = test_host(); let c = imap::ClientBuilder::new(&host, test_imap_port()) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .mode(ConnectionMode::StartTls) .connect() .unwrap(); @@ -51,7 +51,7 @@ fn tls_force() { let user = "tls@localhost"; let host = test_host(); let c = imap::ClientBuilder::new(&host, test_imaps_port()) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .mode(ConnectionMode::Tls) .connect() .unwrap(); @@ -66,7 +66,7 @@ fn tls_force_rustls() { let user = "tls@localhost"; let host = test_host(); let c = imap::ClientBuilder::new(&host, test_imaps_port()) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .tls_kind(imap::TlsKind::Rust) .mode(ConnectionMode::Tls) .connect() @@ -82,7 +82,7 @@ fn tls_force_native() { let user = "tls@localhost"; let host = test_host(); let c = imap::ClientBuilder::new(&host, test_imaps_port()) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .tls_kind(imap::TlsKind::Native) .mode(ConnectionMode::Tls) .connect() @@ -92,13 +92,28 @@ fn tls_force_native() { assert!(list_mailbox(&mut s).is_ok()); } +#[test] +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +fn auto_tls() { + let user = "auto@localhost"; + let host = test_host(); + let mut builder = imap::ClientBuilder::new(&host, test_imap_port()); + builder.danger_skip_tls_verify(true); + + let c = builder.connect().unwrap(); + let mut s = c.login(user, user).unwrap(); + s.debug = true; + assert!(list_mailbox(&mut s).is_ok()); +} + #[test] fn auto() { let user = "auto@localhost"; let host = test_host(); let mut builder = imap::ClientBuilder::new(&host, test_imap_port()); + builder.mode(ConnectionMode::Auto); #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - builder.danger_no_tls_verify(true); + builder.danger_skip_tls_verify(true); let c = builder.connect().unwrap(); let mut s = c.login(user, user).unwrap(); @@ -111,7 +126,7 @@ fn raw_force() { let user = "raw@localhost"; let host = test_host(); let c = imap::ClientBuilder::new(&host, test_imap_port()) - .mode(ConnectionMode::Tcp) + .mode(ConnectionMode::Plaintext) .connect() .unwrap(); let mut s = c.login(user, user).unwrap(); diff --git a/tests/imap_integration.rs b/tests/imap_integration.rs index 598d566..6a90568 100644 --- a/tests/imap_integration.rs +++ b/tests/imap_integration.rs @@ -68,7 +68,7 @@ fn session_with_options(user: &str, clean: bool) -> imap::Session { let host = test_host(); let mut s = imap::ClientBuilder::new(&host, test_imaps_port()) .mode(ConnectionMode::Tls) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .connect() .unwrap() .login(user, user) @@ -135,7 +135,7 @@ fn connect_insecure_then_secure() { // Not supported on greenmail because of https://github.com/greenmail-mail-test/greenmail/issues/135 imap::ClientBuilder::new(&host, test_imap_port()) .mode(ConnectionMode::StartTls) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .connect() .unwrap(); } @@ -145,7 +145,7 @@ fn connect_secure() { let host = test_host(); imap::ClientBuilder::new(&host, test_imaps_port()) .mode(ConnectionMode::Tls) - .danger_no_tls_verify(true) + .danger_skip_tls_verify(true) .connect() .unwrap(); }