//! Connection helper. use std::io::{Read, Write}; use crate::{ client::{client_with_config, uri_mode, IntoClientRequest}, error::UrlError, handshake::client::Response, protocol::WebSocketConfig, stream::MaybeTlsStream, ClientHandshake, Error, HandshakeError, Result, WebSocket, }; /// A connector that can be used when establishing connections, allowing to control whether /// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the /// `Plain` variant. #[non_exhaustive] #[allow(missing_debug_implementations)] pub enum Connector { /// Plain (non-TLS) connector. Plain, /// `native-tls` TLS connector. #[cfg(feature = "native-tls")] NativeTls(native_tls_crate::TlsConnector), /// `rustls` TLS connector. #[cfg(feature = "__rustls-tls")] Rustls(std::sync::Arc), } mod encryption { #[cfg(feature = "native-tls")] pub mod native_tls { use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; use std::io::{Read, Write}; use crate::{ error::TlsError, stream::{MaybeTlsStream, Mode}, Error, Result, }; pub fn wrap_stream( socket: S, domain: &str, mode: Mode, tls_connector: Option, ) -> Result> where S: Read + Write, { match mode { Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), Mode::Tls => { let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok); let connector = try_connector.map_err(TlsError::Native)?; let connected = connector.connect(domain, socket); match connected { Err(e) => match e { TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())), TlsHandshakeError::WouldBlock(_) => { panic!("Bug: TLS handshake not blocked") } }, Ok(s) => Ok(MaybeTlsStream::NativeTls(s)), } } } } } #[cfg(feature = "__rustls-tls")] pub mod rustls { use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; use std::{ convert::TryFrom, io::{Read, Write}, sync::Arc, }; use crate::{ error::TlsError, stream::{MaybeTlsStream, Mode}, Result, }; pub fn wrap_stream( socket: S, domain: &str, mode: Mode, tls_connector: Option>, ) -> Result> where S: Read + Write, { match mode { Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), Mode::Tls => { let config = match tls_connector { Some(config) => config, None => { #[allow(unused_mut)] let mut root_store = RootCertStore::empty(); #[cfg(feature = "rustls-tls-native-roots")] { for cert in rustls_native_certs::load_native_certs()? { root_store .add(&rustls::Certificate(cert.0)) .map_err(TlsError::Webpki)?; } } #[cfg(feature = "rustls-tls-webpki-roots")] { root_store.add_server_trust_anchors( webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( ta.subject, ta.spki, ta.name_constraints, ) }) ); } Arc::new( ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(), ) } }; let domain = ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?; let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?; let stream = StreamOwned::new(client, socket); Ok(MaybeTlsStream::Rustls(stream)) } } } } pub mod plain { use std::io::{Read, Write}; use crate::{ error::UrlError, stream::{MaybeTlsStream, Mode}, Error, Result, }; pub fn wrap_stream(socket: S, mode: Mode) -> Result> where S: Read + Write, { match mode { Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), } } } } type TlsHandshakeError = HandshakeError>>; /// Creates a WebSocket handshake from a request and a stream, /// upgrading the stream to TLS if required. pub fn client_tls( request: R, stream: S, ) -> Result<(WebSocket>, Response), TlsHandshakeError> where R: IntoClientRequest, S: Read + Write, { client_tls_with_config(request, stream, None, None) } /// The same as [`client_tls()`] but one can specify a websocket configuration, /// and an optional connector. If no connector is specified, a default one will /// be created. /// /// Please refer to [`client_tls()`] for more details. pub fn client_tls_with_config( request: R, stream: S, config: Option, connector: Option, ) -> Result<(WebSocket>, Response), TlsHandshakeError> where R: IntoClientRequest, S: Read + Write, { let request = request.into_client_request()?; #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] let domain = match request.uri().host() { Some(d) => Ok(d.to_string()), None => Err(Error::Url(UrlError::NoHostName)), }?; let mode = uri_mode(request.uri())?; let stream = match connector { Some(conn) => match conn { #[cfg(feature = "native-tls")] Connector::NativeTls(conn) => { self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn)) } #[cfg(feature = "__rustls-tls")] Connector::Rustls(conn) => { self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn)) } Connector::Plain => self::encryption::plain::wrap_stream(stream, mode), }, None => { #[cfg(feature = "native-tls")] { self::encryption::native_tls::wrap_stream(stream, &domain, mode, None) } #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))] { self::encryption::rustls::wrap_stream(stream, &domain, mode, None) } #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] { self::encryption::plain::wrap_stream(stream, mode) } } }?; client_with_config(request, stream, config) }