diff --git a/Cargo.toml b/Cargo.toml index 2ca2b5c..3d791fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,11 +48,11 @@ version = "0.20.0" [dependencies.rustls-native-certs] optional = true -version = "0.5.0" +version = "0.6.0" [dependencies.webpki] optional = true -version = "0.21" +version = "0.22" [dependencies.webpki-roots] optional = true diff --git a/src/error.rs b/src/error.rs index 510e3f4..e224da7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -255,9 +255,13 @@ pub enum TlsError { /// Rustls error. #[cfg(feature = "__rustls-tls")] #[error("rustls error: {0}")] - Rustls(#[from] rustls::TLSError), + Rustls(#[from] rustls::Error), + /// Webpki error. + #[cfg(feature = "__rustls-tls")] + #[error("webpki error: {0}")] + Webpki(#[from] webpki::Error), /// DNS name resolution error. #[cfg(feature = "__rustls-tls")] - #[error("Invalid DNS name: {0}")] - Dns(#[from] webpki::InvalidDNSNameError), + #[error("Invalid DNS name")] + InvalidDnsName, } diff --git a/src/stream.rs b/src/stream.rs index b7fe0e4..4775230 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,6 +4,8 @@ //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. +#[cfg(feature = "__rustls-tls")] +use std::ops::Deref; use std::{ fmt::{self, Debug}, io::{Read, Result as IoResult, Write}, @@ -45,7 +47,12 @@ impl NoDelay for TlsStream { } #[cfg(feature = "__rustls-tls")] -impl NoDelay for StreamOwned { +impl NoDelay for StreamOwned +where + S: Deref>, + SD: rustls::SideData, + T: Read + Write + NoDelay, +{ fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.sock.set_nodelay(nodelay) } @@ -61,7 +68,7 @@ pub enum MaybeTlsStream { NativeTls(native_tls_crate::TlsStream), #[cfg(feature = "__rustls-tls")] /// Encrypted socket stream using `rustls`. - Rustls(rustls::StreamOwned), + Rustls(rustls::StreamOwned), } impl Debug for MaybeTlsStream { @@ -73,13 +80,13 @@ impl Debug for MaybeTlsStream { #[cfg(feature = "__rustls-tls")] Self::Rustls(s) => { struct RustlsStreamDebug<'a, S: Read + Write>( - &'a rustls::StreamOwned, + &'a rustls::StreamOwned, ); impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StreamOwned") - .field("sess", &self.0.sess) + .field("conn", &self.0.conn) .field("sock", &self.0.sock) .finish() } diff --git a/src/tls.rs b/src/tls.rs index 4f07a54..ad54de3 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -70,10 +70,10 @@ mod encryption { #[cfg(feature = "__rustls-tls")] pub mod rustls { - use rustls::{ClientConfig, ClientSession, StreamOwned}; - use webpki::DNSNameRef; + use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; use std::{ + convert::TryFrom, io::{Read, Write}, sync::Arc, }; @@ -100,24 +100,40 @@ mod encryption { Some(config) => config, None => { #[allow(unused_mut)] - let mut config = ClientConfig::new(); + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "rustls-tls-native-roots")] { - config.root_store = rustls_native_certs::load_native_certs() - .map_err(|(_, err)| err)?; + for cert in rustls_native_certs::load_native_certs()? { + root_store + .add(&rustls::Certificate(cert.0)) + .map_err(TlsError::Webpki)?; + } } #[cfg(feature = "rustls-tls-webpki-roots")] { - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + ); } - Arc::new(config) + Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) } }; - let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; - let client = ClientSession::new(&config, domain); + let domain = + ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?; + let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?; let stream = StreamOwned::new(client, socket); Ok(MaybeTlsStream::Rustls(stream)) @@ -185,7 +201,7 @@ where None => Err(Error::Url(UrlError::NoHostName)), }?; - let mode = uri_mode(&request.uri())?; + let mode = uri_mode(request.uri())?; let stream = match connector { Some(conn) => match conn {