diff --git a/.travis.yml b/.travis.yml index 34ea589..fc25514 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,6 +20,7 @@ script: - cargo check --features async-std-runtime,async-tls,async-native-tls - cargo check --features tokio-runtime,async-tls - cargo check --features tokio-runtime,tokio-native-tls + - cargo check --features tokio-runtime,tokio-openssl - cargo check --features tokio-runtime,async-tls,tokio-native-tls - cargo check --features gio-runtime - cargo check --features gio-runtime,async-tls diff --git a/Cargo.toml b/Cargo.toml index 599e42f..3016cce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ gio-runtime = ["gio", "glib"] async-tls = ["real-async-tls"] async-native-tls = ["async-std-runtime", "real-async-native-tls"] tokio-native-tls = ["tokio-runtime", "real-tokio-native-tls", "real-native-tls", "tungstenite/tls"] +tokio-openssl = ["tokio-runtime", "real-tokio-openssl", "openssl"] [package.metadata.docs.rs] features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls"] @@ -38,6 +39,15 @@ default-features = false optional = true version = "1.0" +[dependencies.real-tokio-openssl] +optional = true +version = "0.4" +package = "tokio-openssl" + +[dependencies.openssl] +optional = true +version = "0.10" + [dependencies.real-async-tls] optional = true version = "0.7" diff --git a/examples/tokio-echo.rs b/examples/tokio-echo.rs index 39dca63..9a8a805 100644 --- a/examples/tokio-echo.rs +++ b/examples/tokio-echo.rs @@ -2,9 +2,17 @@ use async_tungstenite::{tokio::connect_async, tungstenite::Message}; use futures::prelude::*; async fn run() -> Result<(), Box> { - #[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))] + #[cfg(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" + ))] let url = "wss://echo.websocket.org"; - #[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))] + #[cfg(not(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" + )))] let url = "ws://echo.websocket.org"; let (mut ws_stream, _) = connect_async(url).await?; diff --git a/src/lib.rs b/src/lib.rs index 1ea658d..30d7bac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ //! with the [tokio](https://tokio.rs) runtime. //! * `tokio-native-tls`: Enables the additional functions in the `tokio` module to //! implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls). +//! * `tokio-openssl`: Enables the additional functions in the `tokio` module to +//! implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl). //! * `gio-runtime`: Enables the `gio` module, which provides integration with //! the [gio](https://www.gtk-rs.org) runtime. //! @@ -41,6 +43,7 @@ mod handshake; feature = "async-tls", feature = "async-native-tls", feature = "tokio-native-tls", + feature = "tokio-openssl", ))] pub mod stream; diff --git a/src/tokio.rs b/src/tokio.rs index 2ec63fb..ed802b5 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -88,7 +88,103 @@ pub(crate) mod tokio_tls { } } -#[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))] +#[cfg(feature = "tokio-openssl")] +pub(crate) mod tokio_tls { + use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod}; + use real_tokio_openssl::connect; + use real_tokio_openssl::SslStream as TlsStream; + + use tungstenite::client::{uri_mode, IntoClientRequest}; + use tungstenite::handshake::client::Request; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use crate::stream::Stream as StreamSwitcher; + use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + + use super::TokioAdapter; + + /// A stream that might be protected with TLS. + pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; + + pub type AutoStream = MaybeTlsStream; + + pub type Connector = ConnectConfiguration; + + async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, + ) -> Result, Error> + where + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + SslConnector::builder(SslMethod::tls_client()) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + .build() + .configure() + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + connect(connector, &domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: IntoClientRequest + Unpin, + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + AutoStream: Unpin, + { + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" +)))] pub(crate) mod dummy_tls { use tungstenite::client::{uri_mode, IntoClientRequest}; use tungstenite::handshake::client::Request; @@ -140,7 +236,11 @@ pub(crate) mod dummy_tls { } } -#[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))] +#[cfg(not(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" +)))] use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; #[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] @@ -180,14 +280,19 @@ pub(crate) mod async_tls_adapter { pub type AutoStream = MaybeTlsStream>; } -#[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] +#[cfg(all(feature = "async-tls", not(any(feature = "tokio-native-tls", feature = "tokio-openssl"))))] pub use self::async_tls_adapter::client_async_tls_with_connector_and_config; -#[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] +#[cfg(all(feature = "async-tls", not(any(feature = "tokio-native-tls", feature = "tokio-openssl"))))] use self::async_tls_adapter::{AutoStream, Connector}; #[cfg(feature = "tokio-native-tls")] use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector}; +#[cfg(feature = "tokio-openssl")] +pub use self::tokio_tls::client_async_tls_with_connector_and_config; +#[cfg(feature = "tokio-openssl")] +use self::tokio_tls::{AutoStream, Connector}; + /// Creates a WebSocket handshake from a request and a stream. /// For convenience, the user may call this with a url string, a URL, /// or a `Request`. Calling with `Request` allows the user to add @@ -337,6 +442,73 @@ where client_async_tls_with_connector_and_config(request, stream, connector, None).await } +#[cfg(feature = "tokio-openssl")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub async fn client_async_tls( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +#[cfg(feature = "tokio-openssl")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await +} + +#[cfg(feature = "tokio-openssl")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + /// Type alias for the stream type of the `connect_async()` functions. pub type ConnectStream = ClientStream; @@ -368,7 +540,11 @@ where client_async_tls_with_connector_and_config(request, socket, None, config).await } -#[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))] +#[cfg(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" +))] /// Connect to a given URL using the provided TLS connector. pub async fn connect_async_with_tls_connector( request: R, @@ -380,7 +556,11 @@ where connect_async_with_tls_connector_and_config(request, connector, None).await } -#[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))] +#[cfg(any( + feature = "async-tls", + feature = "tokio-native-tls", + feature = "tokio-openssl" +))] /// Connect to a given URL using the provided TLS connector. pub async fn connect_async_with_tls_connector_and_config( request: R,