Merge pull request #226 from dnaka91/connector
Add a connector to configure TLS configpull/222/head
commit
239f8e293f
@ -0,0 +1,219 @@ |
||||
//! Connection helper.
|
||||
use std::io::{Read, Write}; |
||||
|
||||
use crate::{ |
||||
client::{client_with_config, uri_mode, IntoClientRequest}, |
||||
error::UrlError, |
||||
handshake::client::Response, |
||||
protocol::WebSocketConfig, |
||||
stream::MaybeTlsStream, |
||||
ClientHandshake, Error, HandshakeError, Result, WebSocket, |
||||
}; |
||||
|
||||
/// A connector that can be used when establishing connections, allowing to control whether
|
||||
/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
|
||||
/// `Plain` variant.
|
||||
#[non_exhaustive] |
||||
#[allow(missing_debug_implementations)] |
||||
pub enum Connector { |
||||
/// Plain (non-TLS) connector.
|
||||
Plain, |
||||
/// `native-tls` TLS connector.
|
||||
#[cfg(feature = "native-tls")] |
||||
NativeTls(native_tls_crate::TlsConnector), |
||||
/// `rustls` TLS connector.
|
||||
#[cfg(feature = "__rustls-tls")] |
||||
Rustls(std::sync::Arc<rustls::ClientConfig>), |
||||
} |
||||
|
||||
mod encryption { |
||||
#[cfg(feature = "native-tls")] |
||||
pub mod native_tls { |
||||
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; |
||||
|
||||
use std::io::{Read, Write}; |
||||
|
||||
use crate::{ |
||||
error::TlsError, |
||||
stream::{MaybeTlsStream, Mode}, |
||||
Error, Result, |
||||
}; |
||||
|
||||
pub fn wrap_stream<S>( |
||||
socket: S, |
||||
domain: &str, |
||||
mode: Mode, |
||||
tls_connector: Option<TlsConnector>, |
||||
) -> Result<MaybeTlsStream<S>> |
||||
where |
||||
S: Read + Write, |
||||
{ |
||||
match mode { |
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
||||
Mode::Tls => { |
||||
let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok); |
||||
let connector = try_connector.map_err(TlsError::Native)?; |
||||
let connected = connector.connect(domain, socket); |
||||
match connected { |
||||
Err(e) => match e { |
||||
TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())), |
||||
TlsHandshakeError::WouldBlock(_) => { |
||||
panic!("Bug: TLS handshake not blocked") |
||||
} |
||||
}, |
||||
Ok(s) => Ok(MaybeTlsStream::NativeTls(s)), |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[cfg(feature = "__rustls-tls")] |
||||
pub mod rustls { |
||||
use rustls::{ClientConfig, ClientSession, StreamOwned}; |
||||
use webpki::DNSNameRef; |
||||
|
||||
use std::{ |
||||
io::{Read, Write}, |
||||
sync::Arc, |
||||
}; |
||||
|
||||
use crate::{ |
||||
error::TlsError, |
||||
stream::{MaybeTlsStream, Mode}, |
||||
Result, |
||||
}; |
||||
|
||||
pub fn wrap_stream<S>( |
||||
socket: S, |
||||
domain: &str, |
||||
mode: Mode, |
||||
tls_connector: Option<Arc<ClientConfig>>, |
||||
) -> Result<MaybeTlsStream<S>> |
||||
where |
||||
S: Read + Write, |
||||
{ |
||||
match mode { |
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
||||
Mode::Tls => { |
||||
let config = match tls_connector { |
||||
Some(config) => config, |
||||
None => { |
||||
#[allow(unused_mut)] |
||||
let mut config = ClientConfig::new(); |
||||
#[cfg(feature = "rustls-tls-native-roots")] |
||||
{ |
||||
config.root_store = rustls_native_certs::load_native_certs() |
||||
.map_err(|(_, err)| err)?; |
||||
} |
||||
#[cfg(feature = "rustls-tls-webpki-roots")] |
||||
{ |
||||
config |
||||
.root_store |
||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); |
||||
} |
||||
|
||||
Arc::new(config) |
||||
} |
||||
}; |
||||
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; |
||||
let client = ClientSession::new(&config, domain); |
||||
let stream = StreamOwned::new(client, socket); |
||||
|
||||
Ok(MaybeTlsStream::Rustls(stream)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub mod plain { |
||||
use std::io::{Read, Write}; |
||||
|
||||
use crate::{ |
||||
error::UrlError, |
||||
stream::{MaybeTlsStream, Mode}, |
||||
Error, Result, |
||||
}; |
||||
|
||||
pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>> |
||||
where |
||||
S: Read + Write, |
||||
{ |
||||
match mode { |
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
||||
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>; |
||||
|
||||
/// Creates a WebSocket handshake from a request and a stream,
|
||||
/// upgrading the stream to TLS if required.
|
||||
pub fn client_tls<R, S>( |
||||
request: R, |
||||
stream: S, |
||||
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> |
||||
where |
||||
R: IntoClientRequest, |
||||
S: Read + Write, |
||||
{ |
||||
client_tls_with_config(request, stream, None, None) |
||||
} |
||||
|
||||
/// The same as [`client_tls()`] but one can specify a websocket configuration,
|
||||
/// and an optional connector. If no connector is specified, a default one will
|
||||
/// be created.
|
||||
///
|
||||
/// Please refer to [`client_tls()`] for more details.
|
||||
pub fn client_tls_with_config<R, S>( |
||||
request: R, |
||||
stream: S, |
||||
config: Option<WebSocketConfig>, |
||||
connector: Option<Connector>, |
||||
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> |
||||
where |
||||
R: IntoClientRequest, |
||||
S: Read + Write, |
||||
{ |
||||
let request = request.into_client_request()?; |
||||
|
||||
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] |
||||
let domain = match request.uri().host() { |
||||
Some(d) => Ok(d.to_string()), |
||||
None => Err(Error::Url(UrlError::NoHostName)), |
||||
}?; |
||||
|
||||
let mode = uri_mode(&request.uri())?; |
||||
|
||||
let stream = match connector { |
||||
Some(conn) => match conn { |
||||
#[cfg(feature = "native-tls")] |
||||
Connector::NativeTls(conn) => { |
||||
self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn)) |
||||
} |
||||
#[cfg(feature = "__rustls-tls")] |
||||
Connector::Rustls(conn) => { |
||||
self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn)) |
||||
} |
||||
Connector::Plain => self::encryption::plain::wrap_stream(stream, mode), |
||||
}, |
||||
None => { |
||||
#[cfg(feature = "native-tls")] |
||||
{ |
||||
self::encryption::native_tls::wrap_stream(stream, &domain, mode, None) |
||||
} |
||||
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))] |
||||
{ |
||||
self::encryption::rustls::wrap_stream(stream, &domain, mode, None) |
||||
} |
||||
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] |
||||
{ |
||||
self::encryption::plain::wrap_stream(stream, mode) |
||||
} |
||||
} |
||||
}?; |
||||
|
||||
client_with_config(request, stream, config) |
||||
} |
Loading…
Reference in new issue