diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 957342666..9c28dd036 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -27,7 +27,7 @@ codegen = ["async-trait"] compression = ["flate2"] default = ["transport", "codegen", "prost"] prost = ["prost1", "prost-derive"] -tls = ["transport", "tokio-rustls"] +tls = ["rustls-pemfile", "transport", "tokio-rustls"] tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-roots-common = ["tls"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] @@ -79,9 +79,10 @@ tower = {version = "0.4.7", features = ["balance", "buffer", "discover", "limit" tracing-futures = {version = "0.2", optional = true} # rustls -rustls-native-certs = {version = "0.5", optional = true} -tokio-rustls = {version = "0.22", optional = true} -webpki-roots = {version = "0.21.1", optional = true} +rustls-pemfile = { version = "0.2.1", optional = true } +rustls-native-certs = { version = "0.6.1", optional = true } +tokio-rustls = { version = "0.23.1", optional = true } +webpki-roots = { version = "0.22.1", optional = true } # compression flate2 = {version = "1.0", optional = true} diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index 00e640f76..a8c2d6096 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -14,7 +14,6 @@ pub struct ClientTlsConfig { domain: Option, cert: Option, identity: Option, - rustls_raw: Option, } #[cfg(feature = "tls")] @@ -36,7 +35,6 @@ impl ClientTlsConfig { domain: None, cert: None, identity: None, - rustls_raw: None, } } @@ -49,8 +47,6 @@ impl ClientTlsConfig { } /// Sets the CA Certificate against which to verify the server's TLS certificate. - /// - /// This has no effect if `rustls_client_config` is used to configure Rustls. pub fn ca_certificate(self, ca_certificate: Certificate) -> Self { ClientTlsConfig { cert: Some(ca_certificate), @@ -59,8 +55,6 @@ impl ClientTlsConfig { } /// Sets the client identity to present to the server. - /// - /// This has no effect if `rustls_client_config` is used to configure Rustls. pub fn identity(self, identity: Identity) -> Self { ClientTlsConfig { identity: Some(identity), @@ -68,26 +62,11 @@ impl ClientTlsConfig { } } - /// Use options specified by the given `ClientConfig` to configure TLS. - /// - /// This overrides all other TLS options set via other means. - pub fn rustls_client_config(self, config: tokio_rustls::rustls::ClientConfig) -> Self { - ClientTlsConfig { - rustls_raw: Some(config), - ..self - } - } - pub(crate) fn tls_connector(&self, uri: Uri) -> Result { let domain = match &self.domain { None => uri.host().ok_or_else(Error::new_invalid_uri)?.to_string(), Some(domain) => domain.clone(), }; - match &self.rustls_raw { - None => { - TlsConnector::new_with_rustls_cert(self.cert.clone(), self.identity.clone(), domain) - } - Some(c) => TlsConnector::new_with_rustls_raw(c.clone(), domain), - } + TlsConnector::new(self.cert.clone(), self.identity.clone(), domain) } } diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 40d60b232..53bd47c31 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -7,7 +7,7 @@ use crate::transport::Certificate; #[cfg(feature = "tls")] use std::sync::Arc; #[cfg(feature = "tls")] -use tokio_rustls::{rustls::Session, server::TlsStream}; +use tokio_rustls::server::TlsStream; /// Trait that connected IO resources implement and use to produce info about the connection. /// @@ -115,10 +115,10 @@ where let (inner, session) = self.get_ref(); let inner = inner.connect_info(); - let certs = if let Some(certs) = session.get_peer_certificates() { + let certs = if let Some(certs) = session.peer_certificates() { let certs = certs .into_iter() - .map(|c| Certificate::from_pem(c.0)) + .map(|c| Certificate::from_pem(c)) .collect(); Some(Arc::new(certs)) } else { diff --git a/tonic/src/transport/server/tls.rs b/tonic/src/transport/server/tls.rs index 999ec3035..b6c7ec974 100644 --- a/tonic/src/transport/server/tls.rs +++ b/tonic/src/transport/server/tls.rs @@ -11,7 +11,6 @@ use std::fmt; pub struct ServerTlsConfig { identity: Option, client_ca_root: Option, - rustls_raw: Option, } #[cfg(feature = "tls")] @@ -28,7 +27,6 @@ impl ServerTlsConfig { ServerTlsConfig { identity: None, client_ca_root: None, - rustls_raw: None, } } @@ -48,24 +46,7 @@ impl ServerTlsConfig { } } - /// Use options specified by the given `ServerConfig` to configure TLS. - /// - /// This overrides all other TLS options set via other means. - pub fn rustls_server_config( - &mut self, - config: tokio_rustls::rustls::ServerConfig, - ) -> &mut Self { - self.rustls_raw = Some(config); - self - } - pub(crate) fn tls_acceptor(&self) -> Result { - match &self.rustls_raw { - None => TlsAcceptor::new_with_rustls_identity( - self.identity.clone().unwrap(), - self.client_ca_root.clone(), - ), - Some(config) => TlsAcceptor::new_with_rustls_raw(config.clone()), - } + TlsAcceptor::new(self.identity.clone().unwrap(), self.client_ca_root.clone()) } } diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index c4d216b83..d0625ef0b 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,6 +3,8 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; +#[cfg(feature = "tls-roots-common")] +use std::convert::TryInto; use std::task::{Context, Poll}; use tower::make::MakeConnection; use tower_service::Service; @@ -39,22 +41,18 @@ impl Connector { #[cfg(feature = "tls-roots-common")] fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option { - use tokio_rustls::webpki::DNSNameRef; - if self.tls.is_some() { return self.tls.clone(); } - match (scheme, host) { - (Some("https"), Some(host)) => { - if DNSNameRef::try_from_ascii(host.as_bytes()).is_ok() { - TlsConnector::new_with_rustls_cert(None, None, host.to_owned()).ok() - } else { - None - } - } - _ => None, - } + let host = match (scheme, host) { + (Some("https"), Some(host)) => host, + _ => return None, + }; + + host.try_into() + .ok() + .and_then(|dns| TlsConnector::new(None, None, dns).ok()) } } diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 5bd960c48..38512985d 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -5,12 +5,13 @@ use crate::transport::{ }; #[cfg(feature = "tls-roots")] use rustls_native_certs; +#[cfg(feature = "tls")] +use std::convert::TryInto; use std::{fmt, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(feature = "tls")] use tokio_rustls::{ - rustls::{ClientConfig, NoClientAuth, ServerConfig, Session}, - webpki::DNSNameRef, + rustls::{ClientConfig, RootCertStore, ServerConfig, ServerName}, TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, }; @@ -31,58 +32,59 @@ enum TlsError { #[derive(Clone)] pub(crate) struct TlsConnector { config: Arc, - domain: Arc, + domain: Arc, } impl TlsConnector { #[cfg(feature = "tls")] - pub(crate) fn new_with_rustls_cert( + pub(crate) fn new( ca_cert: Option, identity: Option, domain: String, ) -> Result { - let mut config = ClientConfig::new(); - config.set_protocols(&[Vec::from(ALPN_H2)]); - - if let Some(identity) = identity { - let (client_cert, client_key) = rustls_keys::load_identity(identity)?; - config.set_single_client_cert(client_cert, client_key)?; - } + let builder = ClientConfig::builder().with_safe_defaults(); + let mut roots = RootCertStore::empty(); #[cfg(feature = "tls-roots")] { - config.root_store = match rustls_native_certs::load_native_certs() { - Ok(store) | Err((Some(store), _)) => store, - Err((None, error)) => return Err(error.into()), + match rustls_native_certs::load_native_certs() { + Ok(certs) => roots.add_parsable_certificates( + &certs.into_iter().map(|cert| cert.0).collect::>(), + ), + Err(error) => return Err(error.into()), }; } #[cfg(feature = "tls-webpki-roots")] { - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + use tokio_rustls::rustls::OwnedTrustAnchor; + + roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); } if let Some(cert) = ca_cert { - let mut buf = std::io::Cursor::new(&cert.pem[..]); - config.root_store.add_pem_file(&mut buf).unwrap(); + rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?; } - Ok(Self { - config: Arc::new(config), - domain: Arc::new(domain), - }) - } + let builder = builder.with_root_certificates(roots); + let mut config = match identity { + Some(identity) => { + let (client_cert, client_key) = rustls_keys::load_identity(identity)?; + builder.with_single_cert(client_cert, client_key)? + } + None => builder.with_no_client_auth(), + }; - #[cfg(feature = "tls")] - pub(crate) fn new_with_rustls_raw( - config: tokio_rustls::rustls::ClientConfig, - domain: String, - ) -> Result { + config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec()); Ok(Self { config: Arc::new(config), - domain: Arc::new(domain), + domain: Arc::new(domain.as_str().try_into()?), }) } @@ -91,15 +93,13 @@ impl TlsConnector { I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let tls_io = { - let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned(); - let io = RustlsConnector::from(self.config.clone()) - .connect(dns.as_ref(), io) + .connect(self.domain.as_ref().to_owned(), io) .await?; let (_, session) = io.get_ref(); - match session.get_alpn_protocol() { + match session.alpn_protocol() { Some(b) if b == b"h2" => (), _ => return Err(TlsError::H2NotNegotiated.into()), }; @@ -124,39 +124,26 @@ pub(crate) struct TlsAcceptor { impl TlsAcceptor { #[cfg(feature = "tls")] - pub(crate) fn new_with_rustls_identity( + pub(crate) fn new( identity: Identity, client_ca_root: Option, ) -> Result { - let (cert, key) = rustls_keys::load_identity(identity)?; + let builder = ServerConfig::builder().with_safe_defaults(); - let mut config = match client_ca_root { - None => ServerConfig::new(NoClientAuth::new()), + let builder = match client_ca_root { + None => builder.with_no_client_auth(), Some(cert) => { - let mut cert = std::io::Cursor::new(&cert.pem[..]); - - let mut client_root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); - if client_root_cert_store.add_pem_file(&mut cert).is_err() { - return Err(Box::new(TlsError::CertificateParseError)); - } - - let client_auth = - tokio_rustls::rustls::AllowAnyAuthenticatedClient::new(client_root_cert_store); - ServerConfig::new(client_auth) + use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; + let mut roots = RootCertStore::empty(); + rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?; + builder.with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots)) } }; - config.set_single_cert(cert, key)?; - config.set_protocols(&[Vec::from(ALPN_H2)]); - Ok(Self { - inner: Arc::new(config), - }) - } + let (cert, key) = rustls_keys::load_identity(identity)?; + let mut config = builder.with_single_cert(cert, key)?; - #[cfg(feature = "tls")] - pub(crate) fn new_with_rustls_raw( - config: tokio_rustls::rustls::ServerConfig, - ) -> Result { + config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec()); Ok(Self { inner: Arc::new(config), }) @@ -194,7 +181,9 @@ impl std::error::Error for TlsError {} #[cfg(feature = "tls")] mod rustls_keys { - use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey}; + use std::io::Cursor; + + use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore}; use crate::transport::service::tls::TlsError; use crate::transport::Identity; @@ -203,17 +192,17 @@ mod rustls_keys { mut cursor: std::io::Cursor<&[u8]>, ) -> Result { // First attempt to load the private key assuming it is PKCS8-encoded - if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) { - if !keys.is_empty() { - return Ok(keys.remove(0)); + if let Ok(mut keys) = rustls_pemfile::pkcs8_private_keys(&mut cursor) { + if let Some(key) = keys.pop() { + return Ok(PrivateKey(key)); } } // If it not, try loading the private key as an RSA key cursor.set_position(0); - if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) { - if !keys.is_empty() { - return Ok(keys.remove(0)); + if let Ok(mut keys) = rustls_pemfile::rsa_private_keys(&mut cursor) { + if let Some(key) = keys.pop() { + return Ok(PrivateKey(key)); } } @@ -226,8 +215,8 @@ mod rustls_keys { ) -> Result<(Vec, PrivateKey), crate::Error> { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); - match pemfile::certs(&mut cert) { - Ok(certs) => certs, + match rustls_pemfile::certs(&mut cert) { + Ok(certs) => certs.into_iter().map(Certificate).collect(), Err(_) => return Err(Box::new(TlsError::CertificateParseError)), } }; @@ -244,4 +233,15 @@ mod rustls_keys { Ok((cert, key)) } + + pub(crate) fn add_certs_from_pem( + mut certs: Cursor<&[u8]>, + roots: &mut RootCertStore, + ) -> Result<(), crate::Error> { + let (_, ignored) = roots.add_parsable_certificates(&rustls_pemfile::certs(&mut certs)?); + match ignored == 0 { + true => Ok(()), + false => Err(Box::new(TlsError::CertificateParseError)), + } + } }