Skip to content

Commit

Permalink
Add RequestBuilder methods to accept invalid TLS (#40)
Browse files Browse the repository at this point in the history
* Add RequestBuilder methods to accept invalid TLS
  • Loading branch information
sbstp authored Jan 11, 2020
1 parent aeb2bd2 commit 461bb46
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 19 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ required-features = ["tls"]
name = "charset"
path = "examples/charset.rs"
required-features = ["charsets"]

[[test]]
name = "test_invalid_certs"
path = "tests/test_invalid_certs.rs"
required-features = ["tls"]
79 changes: 77 additions & 2 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use url::Url;
use crate::charsets::Charset;
use crate::error::{Error, ErrorKind, InvalidResponseKind, Result};
use crate::parsing::{parse_response, Response};
use crate::streams::BaseStream;
use crate::streams::{BaseStream, ConnectInfo};

const VERSION: &str = env!("CARGO_PKG_VERSION");

Expand Down Expand Up @@ -76,6 +76,10 @@ pub struct RequestBuilder<B = [u8; 0]> {
pub(crate) default_charset: Option<Charset>,
#[cfg(feature = "compress")]
allow_compression: bool,
#[cfg(feature = "tls")]
accept_invalid_certs: bool,
#[cfg(feature = "tls")]
accept_invalid_hostnames: bool,
}

impl RequestBuilder {
Expand Down Expand Up @@ -117,6 +121,10 @@ impl RequestBuilder {
default_charset: None,
#[cfg(feature = "compress")]
allow_compression: true,
#[cfg(feature = "tls")]
accept_invalid_certs: false,
#[cfg(feature = "tls")]
accept_invalid_hostnames: false,
})
}
}
Expand Down Expand Up @@ -252,6 +260,10 @@ impl<B> RequestBuilder<B> {
default_charset: self.default_charset,
#[cfg(feature = "compress")]
allow_compression: self.allow_compression,
#[cfg(feature = "tls")]
accept_invalid_certs: self.accept_invalid_certs,
#[cfg(feature = "tls")]
accept_invalid_hostnames: self.accept_invalid_hostnames,
}
}

Expand Down Expand Up @@ -356,6 +368,36 @@ impl<B> RequestBuilder<B> {
self.allow_compression = allow_compression;
self
}

/// Sets if this `Request` will accept invalid TLS certificates.
///
/// Accepting invalid certificates implies that invalid hostnames are accepted
/// as well.
///
/// The default value is `false`.
///
/// # Danger
/// Use this setting with care. This will accept **any** TLS certificate valid or not.
/// If you are using self signed certificates, it is much safer to add their root CA
/// to the list of trusted root CAs by your system.
#[cfg(feature = "tls")]
pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
self.accept_invalid_certs = accept_invalid_certs;
self
}

/// Sets if this `Request` will accept an invalid hostname in a TLS certificate.
///
/// The default value is `false`.
///
/// # Danger
/// Use this setting with care. This will accept TLS certificates that do not match
/// the hostname.
#[cfg(feature = "tls")]
pub fn danger_accept_invalid_hostnames(mut self, accept_invalid_hostnames: bool) -> Self {
self.accept_invalid_hostnames = accept_invalid_hostnames;
self
}
}

impl<B: AsRef<[u8]>> RequestBuilder<B> {
Expand All @@ -382,6 +424,10 @@ impl<B: AsRef<[u8]>> RequestBuilder<B> {
default_charset: self.default_charset,
#[cfg(feature = "compress")]
allow_compression: self.allow_compression,
#[cfg(feature = "tls")]
accept_invalid_certs: self.accept_invalid_certs,
#[cfg(feature = "tls")]
accept_invalid_hostnames: self.accept_invalid_hostnames,
};

header_insert(&mut prepped.headers, CONNECTION, "close")?;
Expand Down Expand Up @@ -417,6 +463,10 @@ pub struct PreparedRequest<B> {
pub(crate) default_charset: Option<Charset>,
#[cfg(feature = "compress")]
allow_compression: bool,
#[cfg(feature = "tls")]
accept_invalid_certs: bool,
#[cfg(feature = "tls")]
accept_invalid_hostnames: bool,
}

#[cfg(test)]
Expand All @@ -438,6 +488,10 @@ impl PreparedRequest<Vec<u8>> {
default_charset: None,
#[cfg(feature = "compress")]
allow_compression: true,
#[cfg(feature = "tls")]
accept_invalid_certs: false,
#[cfg(feature = "tls")]
accept_invalid_hostnames: false,
}
}
}
Expand Down Expand Up @@ -555,7 +609,16 @@ impl<B: AsRef<[u8]>> PreparedRequest<B> {
let mut redirections = 0;

loop {
let mut stream = BaseStream::connect(&url, self.connect_timeout, self.read_timeout)?;
let info = ConnectInfo {
url: &url,
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
#[cfg(feature = "tls")]
accept_invalid_certs: self.accept_invalid_certs,
#[cfg(feature = "tls")]
accept_invalid_hostnames: self.accept_invalid_hostnames,
};
let mut stream = BaseStream::connect(&info)?;
self.write_request(&mut stream, &url)?;
let resp = parse_response(stream, self)?;

Expand Down Expand Up @@ -637,3 +700,15 @@ fn test_header_append() {
assert!(val == "hello" || val == "world");
}
}

#[test]
#[cfg(feature = "tls")]
fn test_accept_invalid_certs_disabled_by_default() {
let builder = RequestBuilder::new(Method::GET, "https://localhost:7900");
assert_eq!(builder.accept_invalid_certs, false);
assert_eq!(builder.accept_invalid_hostnames, false);

let prepped = builder.prepare();
assert_eq!(prepped.accept_invalid_certs, false);
assert_eq!(prepped.accept_invalid_hostnames, false);
}
42 changes: 25 additions & 17 deletions src/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ use url::Url;
use crate::happy;
use crate::{ErrorKind, Result};

pub struct ConnectInfo<'u> {
pub url: &'u Url,
pub connect_timeout: Duration,
pub read_timeout: Duration,
#[cfg(feature = "tls")]
pub accept_invalid_certs: bool,
#[cfg(feature = "tls")]
pub accept_invalid_hostnames: bool,
}

#[derive(Debug)]
pub enum BaseStream {
Plain(TcpStream),
Expand All @@ -21,35 +31,33 @@ pub enum BaseStream {
}

impl BaseStream {
pub fn connect(url: &Url, connect_timeout: Duration, read_timeout: Duration) -> Result<BaseStream> {
let host = url.host_str().ok_or(ErrorKind::InvalidUrlHost)?;
let port = url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?;
pub fn connect(info: &ConnectInfo) -> Result<BaseStream> {
let host = info.url.host_str().ok_or(ErrorKind::InvalidUrlHost)?;
let port = info.url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?;

debug!("trying to connect to {}:{}", host, port);

match url.scheme() {
"http" => BaseStream::connect_tcp(host, port, connect_timeout, read_timeout).map(BaseStream::Plain),
match info.url.scheme() {
"http" => BaseStream::connect_tcp(host, port, info).map(BaseStream::Plain),
#[cfg(feature = "tls")]
"https" => BaseStream::connect_tls(host, port, connect_timeout, read_timeout).map(BaseStream::Tls),
"https" => BaseStream::connect_tls(host, port, info).map(BaseStream::Tls),
_ => Err(ErrorKind::InvalidBaseUrl.into()),
}
}

fn connect_tcp(host: &str, port: u16, connect_timeout: Duration, read_timeout: Duration) -> Result<TcpStream> {
let stream = happy::connect((host, port), connect_timeout)?;
stream.set_read_timeout(Some(read_timeout))?;
fn connect_tcp(host: &str, port: u16, info: &ConnectInfo) -> Result<TcpStream> {
let stream = happy::connect((host, port), info.connect_timeout)?;
stream.set_read_timeout(Some(info.read_timeout))?;
Ok(stream)
}

#[cfg(feature = "tls")]
fn connect_tls(
host: &str,
port: u16,
connect_timeout: Duration,
read_timeout: Duration,
) -> Result<TlsStream<TcpStream>> {
let connector = TlsConnector::new()?;
let stream = BaseStream::connect_tcp(host, port, connect_timeout, read_timeout)?;
fn connect_tls(host: &str, port: u16, info: &ConnectInfo) -> Result<TlsStream<TcpStream>> {
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(info.accept_invalid_certs)
.danger_accept_invalid_hostnames(info.accept_invalid_hostnames)
.build()?;
let stream = BaseStream::connect_tcp(host, port, info)?;
let tls_stream = match connector.connect(host, stream) {
Ok(stream) => stream,
Err(HandshakeError::Failure(err)) => return Err(err.into()),
Expand Down
55 changes: 55 additions & 0 deletions tests/test_invalid_certs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#[test]
fn test_error_when_self_signed() {
let res = attohttpc::get("https://self-signed.badssl.com/").send();
let err = res.err().unwrap();
match err.kind() {
attohttpc::ErrorKind::Tls(_) => (),
_ => panic!("wrong error returned!"),
}
}

#[test]
fn test_accept_invalid_certs_ok_when_self_signed() {
let res = attohttpc::get("https://self-signed.badssl.com/")
.danger_accept_invalid_certs(true)
.send();
assert!(res.is_ok());
}

#[test]
fn test_accept_invalid_certs_ok_when_wrong_host() {
let res = attohttpc::get("https://wrong-host.badssl.com/")
.danger_accept_invalid_certs(true)
.send();
assert!(res.is_ok());
}

#[test]
fn test_error_when_wrong_host() {
let res = attohttpc::get("https://wrong.host.badssl.com/").send();
let err = res.err().unwrap();
match err.kind() {
attohttpc::ErrorKind::Tls(_) => (),
_ => panic!("wrong error returned!"),
}
}

#[test]
fn test_accept_invalid_hostnames_error_when_expired() {
let res = attohttpc::get("https://expired.badssl.com/")
.danger_accept_invalid_hostnames(true)
.send();
let err = res.err().unwrap();
match err.kind() {
attohttpc::ErrorKind::Tls(_) => (),
_ => panic!("wrong error returned!"),
}
}

#[test]
fn test_accept_invalid_hostnames_ok_when_wrong_host() {
let res = attohttpc::get("https://wrong.host.badssl.com/")
.danger_accept_invalid_hostnames(true)
.send();
assert!(res.is_ok());
}

0 comments on commit 461bb46

Please sign in to comment.