Skip to content

Commit

Permalink
fixup! Change the client builder so that it abstracts away connecting…
Browse files Browse the repository at this point in the history
… to TLS or non-TLS connections and what TLS provider is used.
  • Loading branch information
urkle committed Sep 6, 2023
1 parent d7275c8 commit 64c0c1c
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 104 deletions.
1 change: 1 addition & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ impl<T: Read + Write> Client<T> {
///
/// This consumes `self` since the Client is not much use without
/// an underlying transport.
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub(crate) fn into_inner(self) -> Result<T> {
let res = self.conn.stream.into_inner()?;
Ok(res)
Expand Down
176 changes: 81 additions & 95 deletions src/client_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Client, Connection, Result};
use crate::{Client, Connection, Error, Result};

use lazy_static::lazy_static;
use std::io::{Read, Write};
Expand All @@ -8,7 +8,6 @@ use std::net::TcpStream;
use native_tls::TlsConnector as NativeTlsConnector;

use crate::extensions::idle::SetReadTimeout;
use imap_proto::Capability;
#[cfg(feature = "rustls-tls")]
use rustls_connector::{
rustls,
Expand Down Expand Up @@ -57,10 +56,16 @@ lazy_static! {
#[derive(Clone, Debug, PartialEq, Eq)]

Check warning on line 56 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L56

Added line #L56 was not covered by tests
pub enum ConnectionMode {
/// Automatically detect what connection mode should be used.
/// This will use TLS if the port is 993 and StartTLls if available.
/// This will use TLS if the port is 993, and StartTLls if the server says it's available.
/// If only Plaintext is available it will error out.
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
AutoTls,
/// Automatically detect what connection mode should be used.
/// This will use TLS if the port is 993, and StartTLls if the server says it's available.
/// Finally it will fallback to Plaintext
Auto,
/// A plain unencrypted Tcp connection
Tcp,
Plaintext,
/// an encrypted TLS connection
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Tls,
Expand All @@ -80,7 +85,7 @@ pub enum TlsKind {
/// Use the Rustls backend
#[cfg(feature = "rustls-tls")]
Rust,
/// Use whatever backend is available (rustls used it both are compiled in)
/// Use whatever backend is available (uses rustls if both are available)
Any,
}

Expand Down Expand Up @@ -143,7 +148,7 @@ where
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
tls_kind: TlsKind,
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
no_tls_verify: bool,
skip_tls_verify: bool,
}

impl<D> ClientBuilder<D>
Expand All @@ -155,28 +160,17 @@ where
ClientBuilder {
domain,
port,
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
mode: ConnectionMode::AutoTls,
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
mode: ConnectionMode::Auto,
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
tls_kind: TlsKind::Any,
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
no_tls_verify: false,
skip_tls_verify: false,
}
}

/// Use [`STARTTLS`](https://tools.ietf.org/html/rfc2595) for this connection.
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[deprecated = "Use mode(ConnectionMode::StartTls) instead"]
pub fn starttls(&mut self) -> &mut Self {
self.mode(ConnectionMode::StartTls)
}

/// Use TLS for this connection.
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[deprecated = "Use mode(ConnectionMode::Tls) instead"]
pub fn tls(&mut self) -> &mut Self {
self.mode(ConnectionMode::Tls)
}

/// Sets the Connection mode to use for this connection
pub fn mode(&mut self, mode: ConnectionMode) -> &mut Self {
self.mode = mode;
Expand Down Expand Up @@ -204,56 +198,11 @@ where
/// [`native_tls::TlsConnectorBuilder::danger_accept_invalid_hostnames`],
/// [`rustls::ClientConfig::dangerous`]
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub fn danger_no_tls_verify(&mut self, no_tls_verify: bool) -> &mut Self {
self.no_tls_verify = no_tls_verify;
pub fn danger_skip_tls_verify(&mut self, skip_tls_verify: bool) -> &mut Self {
self.skip_tls_verify = skip_tls_verify;
self
}

/// Makes a [`Client`] connection using the native TLS backend.
///
/// Forces TLS unless [`ClientBuilder::mode(ConnectionMode::StartTls)`] was called
///
/// ```no_run
/// # use imap::ClientBuilder;
/// # {} #[cfg(feature = "rustls-tls")]
/// # fn main() -> Result<(), imap::Error> {
/// use imap::TlsKind;
/// let client = ClientBuilder::new("imap.example.com", 993).native_tls()?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "native-tls")]
#[deprecated = "Use tls_kind(TlsKind::Native).connect() instead"]
pub fn native_tls(&mut self) -> Result<Client<Connection>> {
if self.mode == ConnectionMode::Tcp {
self.mode(ConnectionMode::Tls);
}

self.tls_kind(TlsKind::Native).connect()
}

/// Makes a [`Client`] connection using the rustls TLS backend.
///
/// Forces TLS unless [`ClientBuilder::starttls()`] was called
///
/// ```no_run
/// # use imap::ClientBuilder;
/// # {} #[cfg(feature = "rustls-tls")]
/// # fn main() -> Result<(), imap::Error> {
/// let client = ClientBuilder::new("imap.example.com", 993).rustls()?;
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "rustls-tls")]
#[deprecated = "Use tls_kind(TlsKind::Rust).connect() instead"]
pub fn rustls(&mut self) -> Result<Client<Connection>> {
if self.mode == ConnectionMode::Tcp {
self.mode(ConnectionMode::Tls);
}

self.tls_kind(TlsKind::Rust).connect()
}

/// Make a [`Client`] using the configuration.
///
/// ```no_run
Expand All @@ -266,12 +215,12 @@ where
/// # }
/// ```
pub fn connect(&self) -> Result<Client<Connection>> {
self.connect_with(|_domain, tcp| {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
return self.build_tls_connection(tcp);
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
return Ok(tcp);
})
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
return self.connect_with(|_domain, tcp| self.build_tls_connection(tcp));
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
return self.connect_with(|_domain, _tcp| -> Result<Connection> {
return Err(Error::TlsNotConfigured);
});
}

/// Make a [`Client`] using a custom initialization. This function is intended
Expand All @@ -286,9 +235,10 @@ where
/// - domain: [`&str`]
/// - tcp: [`TcpStream`]
///
/// and yield a `Result<C>` where `C` is `Read + Write`. It should only perform
/// TLS initialization over the given `tcp` socket and return the encrypted stream
/// object, such as a [`native_tls::TlsStream`] or a [`rustls_connector::TlsStream`].
/// and yield a `Result<C>` where `C` is `Read + Write + Send + SetReadTimeout + 'static,`.
/// It should only perform TLS initialization over the given `tcp` socket and return the
/// encrypted stream object, such as a [`native_tls::TlsStream`] or a
/// [`rustls_connector::TlsStream`].
///
/// If the caller is using `STARTTLS` and previously called [`starttls`](Self::starttls)
/// then the `tcp` socket given to the `handshake` function will be connected and will
Expand All @@ -308,42 +258,56 @@ where
/// # Ok(())
/// # }
/// ```
#[allow(unused_variables)]
pub fn connect_with<F, C>(&self, handshake: F) -> Result<Client<Connection>>
where
F: FnOnce(&str, TcpStream) -> Result<C>,
C: Read + Write + Send + SetReadTimeout + 'static,
{
#[allow(unused_mut)]
let mut greeting_read = false;
let tcp = TcpStream::connect((self.domain.as_ref(), self.port))?;

let stream: Connection = match self.mode {
ConnectionMode::Auto => {
ConnectionMode::AutoTls => {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
if self.port == 993 {
Box::new(handshake(self.domain.as_ref(), tcp)?)

Check warning on line 275 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L275

Added line #L275 was not covered by tests
} else {
let mut client = Client::new(tcp);
client.read_greeting()?;
let (stream, upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?;
greeting_read = true;

let capabilities = client.capabilities()?;
if capabilities.has(&Capability::Atom("STARTTLS".into())) {
client.run_command_and_check_ok("STARTTLS")?;
let tcp = client.into_inner()?;
Box::new(handshake(self.domain.as_ref(), tcp)?)
} else {
Box::new(client.into_inner()?)
if !upgraded {
Err(Error::StartTlsNotAvailable)?

Check warning on line 281 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L281

Added line #L281 was not covered by tests
}
stream
}
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
Err(Error::TlsNotConfigured)?
}
ConnectionMode::Tcp => Box::new(tcp),
ConnectionMode::Auto => {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
if self.port == 993 {
Box::new(handshake(self.domain.as_ref(), tcp)?)

Check warning on line 291 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L291

Added line #L291 was not covered by tests
} else {
let (stream, _upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?;
greeting_read = true;

stream
}
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
Box::new(tcp)
}
ConnectionMode::Plaintext => Box::new(tcp),
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
ConnectionMode::StartTls => {
let mut client = Client::new(tcp);
client.read_greeting()?;
let (stream, upgraded) = self.upgrade_tls(Client::new(tcp), handshake)?;
greeting_read = true;
client.run_command_and_check_ok("STARTTLS")?;
let tcp = client.into_inner()?;
Box::new(handshake(self.domain.as_ref(), tcp)?)

if !upgraded {
Err(Error::StartTlsNotAvailable)?

Check warning on line 308 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L308

Added line #L308 was not covered by tests
}
stream
}
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
ConnectionMode::Tls => Box::new(handshake(self.domain.as_ref(), tcp)?),
Expand All @@ -359,6 +323,28 @@ where
Ok(client)
}

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
fn upgrade_tls<F, C>(
&self,
mut client: Client<TcpStream>,
handshake: F,
) -> Result<(Connection, bool)>
where
F: FnOnce(&str, TcpStream) -> Result<C>,
C: Read + Write + Send + SetReadTimeout + 'static,
{
client.read_greeting()?;

let capabilities = client.capabilities()?;
if capabilities.has(&imap_proto::Capability::Atom("STARTTLS".into())) {
client.run_command_and_check_ok("STARTTLS")?;
let tcp = client.into_inner()?;
Ok((Box::new(handshake(self.domain.as_ref(), tcp)?), true))
} else {
Ok((Box::new(client.into_inner()?), false))

Check warning on line 344 in src/client_builder.rs

View check run for this annotation

Codecov / codecov/patch

src/client_builder.rs#L344

Added line #L344 was not covered by tests
}
}

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
fn build_tls_connection(&self, tcp: TcpStream) -> Result<Connection> {
match self.tls_kind {
Expand Down Expand Up @@ -386,7 +372,7 @@ where
.with_safe_defaults()
.with_root_certificates(CACERTS.clone())
.with_no_client_auth();
if self.no_tls_verify {
if self.skip_tls_verify {
let no_cert_verifier = NoCertVerification;
config
.dangerous()
Expand All @@ -399,7 +385,7 @@ where
#[cfg(feature = "native-tls")]
fn build_tls_native(&self, tcp: TcpStream) -> Result<Connection> {
let mut builder = NativeTlsConnector::builder();
if self.no_tls_verify {
if self.skip_tls_verify {
builder.danger_accept_invalid_certs(true);
builder.danger_accept_invalid_hostnames(true);
}
Expand Down
11 changes: 11 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ pub enum Error {
/// In response to a STATUS command, the server sent OK without actually sending any STATUS
/// responses first.
MissingStatusResponse,
/// StartTls is not available on the server
StartTlsNotAvailable,
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
/// Returns when Tls is not configured
TlsNotConfigured,
}

impl From<IoError> for Error {
Expand Down Expand Up @@ -171,6 +176,9 @@ impl fmt::Display for Error {
Error::Append => f.write_str("Could not append mail to mailbox"),
Error::Unexpected(ref r) => write!(f, "Unexpected Response: {:?}", r),
Error::MissingStatusResponse => write!(f, "Missing STATUS Response"),
Error::StartTlsNotAvailable => write!(f, "StartTls is not available on the server"),

Check warning on line 179 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L179

Added line #L179 was not covered by tests
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
Error::TlsNotConfigured => write!(f, "No Tls feature is available"),
}
}
}
Expand All @@ -195,6 +203,9 @@ impl StdError for Error {
Error::Append => "Could not append mail to mailbox",
Error::Unexpected(_) => "Unexpected Response",
Error::MissingStatusResponse => "Missing STATUS Response",
Error::StartTlsNotAvailable => "StartTls is not available on the server",

Check warning on line 206 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L206

Added line #L206 was not covered by tests
#[cfg(all(not(feature = "native-tls"), not(feature = "rustls-tls")))]
Error::TlsNotConfigured => "No Tls feature is available",
}
}

Expand Down
Loading

0 comments on commit 64c0c1c

Please sign in to comment.