diff --git a/Cargo.toml b/Cargo.toml index 24f7a00..f78a5d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,17 +13,18 @@ edition = "2018" readme = "README.md" [features] -default = [] +default = ["tokio-rustls"] async-std-runtime = ["async-std"] tokio-runtime = ["tokio"] gio-runtime = ["gio", "glib"] async-tls = ["real-async-tls"] async-native-tls = ["async-std-runtime", "real-async-native-tls"] tokio-native-tls = ["tokio-runtime", "real-tokio-native-tls", "real-native-tls", "tungstenite/tls"] +tokio-rustls = ["tokio-runtime", "real-tokio-rustls", "tungstenite/tls"] tokio-openssl = ["tokio-runtime", "real-tokio-openssl", "openssl"] [package.metadata.docs.rs] -features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls"] +features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls", "tokio-rustls"] [dependencies] log = "0.4" @@ -73,6 +74,11 @@ optional = true version = "0.1" package = "tokio-native-tls" +[dependencies.real-tokio-rustls] +optional = true +version = "^0.14" +package = "tokio-rustls" + [dependencies.gio] optional = true version = "0.9" diff --git a/README.md b/README.md index 6cf42f4..d2e899b 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ integration with various other crates can be enabled via feature flags with the [tokio](https://tokio.rs) runtime. * `tokio-native-tls`: Enables the additional functions in the `tokio` module to implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls). + * `tokio-rustls`: Enables the additional functions in the `tokio` module to + implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls). * `gio-runtime`: Enables the `gio` module, which provides integration with the [gio](https://gtk-rs.org) runtime. diff --git a/src/lib.rs b/src/lib.rs index 30d7bac..1e251f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ //! with the [tokio](https://tokio.rs) runtime. //! * `tokio-native-tls`: Enables the additional functions in the `tokio` module to //! implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls). +//! * `tokio-rustls`: Enables the additional functions in the `tokio` module to +//! implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls). //! * `tokio-openssl`: Enables the additional functions in the `tokio` module to //! implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl). //! * `gio-runtime`: Enables the `gio` module, which provides integration with @@ -43,6 +45,7 @@ mod handshake; feature = "async-tls", feature = "async-native-tls", feature = "tokio-native-tls", + feature = "tokio-rustls", feature = "tokio-openssl", ))] pub mod stream; diff --git a/src/tokio.rs b/src/tokio.rs index ed802b5..11f0081 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -11,6 +11,85 @@ use super::{domain, port, WebSocketStream}; use futures_io::{AsyncRead, AsyncWrite}; +#[cfg(feature = "tokio-rustls")] +pub(crate) mod tokio_tls { + use real_tokio_rustls::{client::TlsStream, TlsConnector as AsyncTlsConnector}; + use real_tokio_rustls::rustls::ClientConfig; + use real_tokio_rustls::webpki::DNSNameRef; + + use tungstenite::client::{uri_mode, IntoClientRequest}; + use tungstenite::handshake::client::Request; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use crate::stream::Stream as StreamSwitcher; + use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + + use super::TokioAdapter; + + /// A stream that might be protected with TLS. + pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; + + pub type AutoStream = MaybeTlsStream; + + pub type Connector = AsyncTlsConnector; + + async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, + ) -> Result, Error> + where + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + let config = ClientConfig::new(); + AsyncTlsConnector::from(std::sync::Arc::new(config)) + }; + let domain = DNSNameRef::try_from_ascii_str(&domain).map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?; + connector + .connect(domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + #[cfg(feature = "tokio-native-tls")] pub(crate) mod tokio_tls { use real_tokio_native_tls::TlsConnector as AsyncTlsConnector; @@ -183,6 +262,7 @@ pub(crate) mod tokio_tls { #[cfg(not(any( feature = "async-tls", feature = "tokio-native-tls", + feature = "tokio-rustls", feature = "tokio-openssl" )))] pub(crate) mod dummy_tls { @@ -239,6 +319,7 @@ pub(crate) mod dummy_tls { #[cfg(not(any( feature = "async-tls", feature = "tokio-native-tls", + feature = "tokio-rustls", feature = "tokio-openssl" )))] use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; @@ -288,6 +369,11 @@ use self::async_tls_adapter::{AutoStream, Connector}; #[cfg(feature = "tokio-native-tls")] use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector}; +#[cfg(feature = "tokio-rustls")] +pub use self::tokio_tls::client_async_tls_with_connector_and_config; +#[cfg(feature = "tokio-rustls")] +use self::tokio_tls::{AutoStream, Connector}; + #[cfg(feature = "tokio-openssl")] pub use self::tokio_tls::client_async_tls_with_connector_and_config; #[cfg(feature = "tokio-openssl")] @@ -543,6 +629,7 @@ where #[cfg(any( feature = "async-tls", feature = "tokio-native-tls", + feature = "tokio-rustls", feature = "tokio-openssl" ))] /// Connect to a given URL using the provided TLS connector. @@ -559,6 +646,7 @@ where #[cfg(any( feature = "async-tls", feature = "tokio-native-tls", + feature = "tokio-rustls", feature = "tokio-openssl" ))] /// Connect to a given URL using the provided TLS connector.