|
|
|
//! Connection helper.
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
use crate::{
|
|
|
|
client::{client_with_config, uri_mode, IntoClientRequest},
|
|
|
|
error::UrlError,
|
|
|
|
handshake::client::Response,
|
|
|
|
protocol::WebSocketConfig,
|
|
|
|
stream::MaybeTlsStream,
|
|
|
|
ClientHandshake, Error, HandshakeError, Result, WebSocket,
|
|
|
|
};
|
|
|
|
|
|
|
|
/// A connector that can be used when establishing connections, allowing to control whether
|
|
|
|
/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
|
|
|
|
/// `Plain` variant.
|
|
|
|
#[non_exhaustive]
|
|
|
|
#[allow(missing_debug_implementations)]
|
|
|
|
pub enum Connector {
|
|
|
|
/// Plain (non-TLS) connector.
|
|
|
|
Plain,
|
|
|
|
/// `native-tls` TLS connector.
|
|
|
|
#[cfg(feature = "native-tls")]
|
|
|
|
NativeTls(native_tls_crate::TlsConnector),
|
|
|
|
/// `rustls` TLS connector.
|
|
|
|
#[cfg(feature = "__rustls-tls")]
|
|
|
|
Rustls(std::sync::Arc<rustls::ClientConfig>),
|
|
|
|
}
|
|
|
|
|
|
|
|
mod encryption {
|
|
|
|
#[cfg(feature = "native-tls")]
|
|
|
|
pub mod native_tls {
|
|
|
|
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
|
|
|
|
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
use crate::{
|
|
|
|
error::TlsError,
|
|
|
|
stream::{MaybeTlsStream, Mode},
|
|
|
|
Error, Result,
|
|
|
|
};
|
|
|
|
|
|
|
|
pub fn wrap_stream<S>(
|
|
|
|
socket: S,
|
|
|
|
domain: &str,
|
|
|
|
mode: Mode,
|
|
|
|
tls_connector: Option<TlsConnector>,
|
|
|
|
) -> Result<MaybeTlsStream<S>>
|
|
|
|
where
|
|
|
|
S: Read + Write,
|
|
|
|
{
|
|
|
|
match mode {
|
|
|
|
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
|
|
|
Mode::Tls => {
|
|
|
|
let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
|
|
|
|
let connector = try_connector.map_err(TlsError::Native)?;
|
|
|
|
let connected = connector.connect(domain, socket);
|
|
|
|
match connected {
|
|
|
|
Err(e) => match e {
|
|
|
|
TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
|
|
|
|
TlsHandshakeError::WouldBlock(_) => {
|
|
|
|
panic!("Bug: TLS handshake not blocked")
|
|
|
|
}
|
|
|
|
},
|
|
|
|
Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(feature = "__rustls-tls")]
|
|
|
|
pub mod rustls {
|
|
|
|
use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};
|
|
|
|
|
|
|
|
use std::{
|
|
|
|
convert::TryFrom,
|
|
|
|
io::{Read, Write},
|
|
|
|
sync::Arc,
|
|
|
|
};
|
|
|
|
|
|
|
|
use crate::{
|
|
|
|
error::TlsError,
|
|
|
|
stream::{MaybeTlsStream, Mode},
|
|
|
|
Result,
|
|
|
|
};
|
|
|
|
|
|
|
|
pub fn wrap_stream<S>(
|
|
|
|
socket: S,
|
|
|
|
domain: &str,
|
|
|
|
mode: Mode,
|
|
|
|
tls_connector: Option<Arc<ClientConfig>>,
|
|
|
|
) -> Result<MaybeTlsStream<S>>
|
|
|
|
where
|
|
|
|
S: Read + Write,
|
|
|
|
{
|
|
|
|
match mode {
|
|
|
|
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
|
|
|
Mode::Tls => {
|
|
|
|
let config = match tls_connector {
|
|
|
|
Some(config) => config,
|
|
|
|
None => {
|
|
|
|
#[allow(unused_mut)]
|
|
|
|
let mut root_store = RootCertStore::empty();
|
|
|
|
|
|
|
|
#[cfg(feature = "rustls-tls-native-roots")]
|
|
|
|
{
|
|
|
|
for cert in rustls_native_certs::load_native_certs()? {
|
|
|
|
root_store
|
|
|
|
.add(&rustls::Certificate(cert.0))
|
|
|
|
.map_err(TlsError::Webpki)?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#[cfg(feature = "rustls-tls-webpki-roots")]
|
|
|
|
{
|
|
|
|
root_store.add_server_trust_anchors(
|
|
|
|
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
|
|
|
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
|
|
|
|
ta.subject,
|
|
|
|
ta.spki,
|
|
|
|
ta.name_constraints,
|
|
|
|
)
|
|
|
|
})
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
Arc::new(
|
|
|
|
ClientConfig::builder()
|
|
|
|
.with_safe_defaults()
|
|
|
|
.with_root_certificates(root_store)
|
|
|
|
.with_no_client_auth(),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
};
|
|
|
|
let domain =
|
|
|
|
ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?;
|
|
|
|
let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
|
|
|
|
let stream = StreamOwned::new(client, socket);
|
|
|
|
|
|
|
|
Ok(MaybeTlsStream::Rustls(stream))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub mod plain {
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
use crate::{
|
|
|
|
error::UrlError,
|
|
|
|
stream::{MaybeTlsStream, Mode},
|
|
|
|
Error, Result,
|
|
|
|
};
|
|
|
|
|
|
|
|
pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
|
|
|
|
where
|
|
|
|
S: Read + Write,
|
|
|
|
{
|
|
|
|
match mode {
|
|
|
|
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
|
|
|
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
|
|
|
|
|
|
|
|
/// Creates a WebSocket handshake from a request and a stream,
|
|
|
|
/// upgrading the stream to TLS if required.
|
|
|
|
pub fn client_tls<R, S>(
|
|
|
|
request: R,
|
|
|
|
stream: S,
|
|
|
|
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
|
|
|
|
where
|
|
|
|
R: IntoClientRequest,
|
|
|
|
S: Read + Write,
|
|
|
|
{
|
|
|
|
client_tls_with_config(request, stream, None, None)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// The same as [`client_tls()`] but one can specify a websocket configuration,
|
|
|
|
/// and an optional connector. If no connector is specified, a default one will
|
|
|
|
/// be created.
|
|
|
|
///
|
|
|
|
/// Please refer to [`client_tls()`] for more details.
|
|
|
|
pub fn client_tls_with_config<R, S>(
|
|
|
|
request: R,
|
|
|
|
stream: S,
|
|
|
|
config: Option<WebSocketConfig>,
|
|
|
|
connector: Option<Connector>,
|
|
|
|
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
|
|
|
|
where
|
|
|
|
R: IntoClientRequest,
|
|
|
|
S: Read + Write,
|
|
|
|
{
|
|
|
|
let request = request.into_client_request()?;
|
|
|
|
|
|
|
|
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
|
|
|
let domain = match request.uri().host() {
|
|
|
|
Some(d) => Ok(d.to_string()),
|
|
|
|
None => Err(Error::Url(UrlError::NoHostName)),
|
|
|
|
}?;
|
|
|
|
|
|
|
|
let mode = uri_mode(request.uri())?;
|
|
|
|
|
|
|
|
let stream = match connector {
|
|
|
|
Some(conn) => match conn {
|
|
|
|
#[cfg(feature = "native-tls")]
|
|
|
|
Connector::NativeTls(conn) => {
|
|
|
|
self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
|
|
|
|
}
|
|
|
|
#[cfg(feature = "__rustls-tls")]
|
|
|
|
Connector::Rustls(conn) => {
|
|
|
|
self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
|
|
|
|
}
|
|
|
|
Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
|
|
|
|
},
|
|
|
|
None => {
|
|
|
|
#[cfg(feature = "native-tls")]
|
|
|
|
{
|
|
|
|
self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
|
|
|
|
}
|
|
|
|
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
|
|
|
|
{
|
|
|
|
self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
|
|
|
|
}
|
|
|
|
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
|
|
|
|
{
|
|
|
|
self::encryption::plain::wrap_stream(stream, mode)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}?;
|
|
|
|
|
|
|
|
client_with_config(request, stream, config)
|
|
|
|
}
|