From 6575aedcd8754788de0a9a94ee5549a2b145ba92 Mon Sep 17 00:00:00 2001
From: Josh Matthews <josh@joshmatthews.net>
Date: Tue, 30 Jun 2020 16:08:12 -0400
Subject: [PATCH] Add openssl support.

---
 .travis.yml            |   1 +
 Cargo.toml             |  10 +++
 examples/tokio-echo.rs |  12 ++-
 src/lib.rs             |   3 +
 src/tokio.rs           | 192 +++++++++++++++++++++++++++++++++++++++--
 5 files changed, 210 insertions(+), 8 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 34ea589..fc25514 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -20,6 +20,7 @@ script:
   - cargo check --features async-std-runtime,async-tls,async-native-tls
   - cargo check --features tokio-runtime,async-tls
   - cargo check --features tokio-runtime,tokio-native-tls
+  - cargo check --features tokio-runtime,tokio-openssl
   - cargo check --features tokio-runtime,async-tls,tokio-native-tls
   - cargo check --features gio-runtime
   - cargo check --features gio-runtime,async-tls
diff --git a/Cargo.toml b/Cargo.toml
index 599e42f..3016cce 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,7 @@ gio-runtime = ["gio", "glib"]
 async-tls = ["real-async-tls"]
 async-native-tls = ["async-std-runtime", "real-async-native-tls"]
 tokio-native-tls = ["tokio-runtime", "real-tokio-native-tls", "real-native-tls", "tungstenite/tls"]
+tokio-openssl = ["tokio-runtime", "real-tokio-openssl", "openssl"]
 
 [package.metadata.docs.rs]
 features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls"]
@@ -38,6 +39,15 @@ default-features = false
 optional = true
 version = "1.0"
 
+[dependencies.real-tokio-openssl]
+optional = true
+version = "0.4"
+package = "tokio-openssl"
+
+[dependencies.openssl]
+optional = true
+version = "0.10"
+
 [dependencies.real-async-tls]
 optional = true
 version = "0.7"
diff --git a/examples/tokio-echo.rs b/examples/tokio-echo.rs
index 39dca63..9a8a805 100644
--- a/examples/tokio-echo.rs
+++ b/examples/tokio-echo.rs
@@ -2,9 +2,17 @@ use async_tungstenite::{tokio::connect_async, tungstenite::Message};
 use futures::prelude::*;
 
 async fn run() -> Result<(), Box<dyn std::error::Error>> {
-    #[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))]
+    #[cfg(any(
+        feature = "async-tls",
+        feature = "tokio-native-tls",
+        feature = "tokio-openssl"
+    ))]
     let url = "wss://echo.websocket.org";
-    #[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))]
+    #[cfg(not(any(
+        feature = "async-tls",
+        feature = "tokio-native-tls",
+        feature = "tokio-openssl"
+    )))]
     let url = "ws://echo.websocket.org";
 
     let (mut ws_stream, _) = connect_async(url).await?;
diff --git a/src/lib.rs b/src/lib.rs
index 1ea658d..30d7bac 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -18,6 +18,8 @@
 //!    with the [tokio](https://tokio.rs) runtime.
 //!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
 //!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
+//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
+//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
 //!  * `gio-runtime`: Enables the `gio` module, which provides integration with
 //!    the [gio](https://www.gtk-rs.org) runtime.
 //!
@@ -41,6 +43,7 @@ mod handshake;
     feature = "async-tls",
     feature = "async-native-tls",
     feature = "tokio-native-tls",
+    feature = "tokio-openssl",
 ))]
 pub mod stream;
 
diff --git a/src/tokio.rs b/src/tokio.rs
index 2ec63fb..ed802b5 100644
--- a/src/tokio.rs
+++ b/src/tokio.rs
@@ -88,7 +88,103 @@ pub(crate) mod tokio_tls {
     }
 }
 
-#[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))]
+#[cfg(feature = "tokio-openssl")]
+pub(crate) mod tokio_tls {
+    use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod};
+    use real_tokio_openssl::connect;
+    use real_tokio_openssl::SslStream as TlsStream;
+
+    use tungstenite::client::{uri_mode, IntoClientRequest};
+    use tungstenite::handshake::client::Request;
+    use tungstenite::stream::Mode;
+    use tungstenite::Error;
+
+    use crate::stream::Stream as StreamSwitcher;
+    use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
+
+    use super::TokioAdapter;
+
+    /// A stream that might be protected with TLS.
+    pub type MaybeTlsStream<S> = StreamSwitcher<TokioAdapter<S>, TokioAdapter<TlsStream<S>>>;
+
+    pub type AutoStream<S> = MaybeTlsStream<S>;
+
+    pub type Connector = ConnectConfiguration;
+
+    async fn wrap_stream<S>(
+        socket: S,
+        domain: String,
+        connector: Option<Connector>,
+        mode: Mode,
+    ) -> Result<AutoStream<S>, Error>
+    where
+        S: 'static
+            + tokio::io::AsyncRead
+            + tokio::io::AsyncWrite
+            + Unpin
+            + std::fmt::Debug
+            + Send
+            + Sync,
+    {
+        match mode {
+            Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))),
+            Mode::Tls => {
+                let stream = {
+                    let connector = if let Some(connector) = connector {
+                        connector
+                    } else {
+                        SslConnector::builder(SslMethod::tls_client())
+                            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?
+                            .build()
+                            .configure()
+                            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?
+                    };
+                    connect(connector, &domain, socket)
+                        .await
+                        .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?
+                };
+                Ok(StreamSwitcher::Tls(TokioAdapter(stream)))
+            }
+        }
+    }
+
+    /// Creates a WebSocket handshake from a request and a stream,
+    /// upgrading the stream to TLS if required and using the given
+    /// connector and WebSocket configuration.
+    pub async fn client_async_tls_with_connector_and_config<R, S>(
+        request: R,
+        stream: S,
+        connector: Option<Connector>,
+        config: Option<WebSocketConfig>,
+    ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
+    where
+        R: IntoClientRequest + Unpin,
+        S: 'static
+            + tokio::io::AsyncRead
+            + tokio::io::AsyncWrite
+            + Unpin
+            + std::fmt::Debug
+            + Send
+            + Sync,
+        AutoStream<S>: Unpin,
+    {
+        let request: Request = request.into_client_request()?;
+
+        let domain = domain(&request)?;
+
+        // Make sure we check domain and mode first. URL must be valid.
+        let mode = uri_mode(request.uri())?;
+
+        let stream = wrap_stream(stream, domain, connector, mode).await?;
+        client_async_with_config(request, stream, config).await
+    }
+}
+
+#[cfg(not(any(
+    feature = "async-tls",
+    feature = "tokio-native-tls",
+    feature = "tokio-openssl"
+)))]
 pub(crate) mod dummy_tls {
     use tungstenite::client::{uri_mode, IntoClientRequest};
     use tungstenite::handshake::client::Request;
@@ -140,7 +236,11 @@ pub(crate) mod dummy_tls {
     }
 }
 
-#[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))]
+#[cfg(not(any(
+    feature = "async-tls",
+    feature = "tokio-native-tls",
+    feature = "tokio-openssl"
+)))]
 use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream};
 
 #[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))]
@@ -180,14 +280,19 @@ pub(crate) mod async_tls_adapter {
 
     pub type AutoStream<S> = MaybeTlsStream<TokioAdapter<S>>;
 }
-#[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))]
+#[cfg(all(feature = "async-tls", not(any(feature = "tokio-native-tls", feature = "tokio-openssl"))))]
 pub use self::async_tls_adapter::client_async_tls_with_connector_and_config;
-#[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))]
+#[cfg(all(feature = "async-tls", not(any(feature = "tokio-native-tls", feature = "tokio-openssl"))))]
 use self::async_tls_adapter::{AutoStream, Connector};
 
 #[cfg(feature = "tokio-native-tls")]
 use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector};
 
+#[cfg(feature = "tokio-openssl")]
+pub use self::tokio_tls::client_async_tls_with_connector_and_config;
+#[cfg(feature = "tokio-openssl")]
+use self::tokio_tls::{AutoStream, Connector};
+
 /// Creates a WebSocket handshake from a request and a stream.
 /// For convenience, the user may call this with a url string, a URL,
 /// or a `Request`. Calling with `Request` allows the user to add
@@ -337,6 +442,73 @@ where
     client_async_tls_with_connector_and_config(request, stream, connector, None).await
 }
 
+#[cfg(feature = "tokio-openssl")]
+/// Creates a WebSocket handshake from a request and a stream,
+/// upgrading the stream to TLS if required.
+pub async fn client_async_tls<R, S>(
+    request: R,
+    stream: S,
+) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
+where
+    R: IntoClientRequest + Unpin,
+    S: 'static
+        + tokio::io::AsyncRead
+        + tokio::io::AsyncWrite
+        + Unpin
+        + std::fmt::Debug
+        + Send
+        + Sync,
+    AutoStream<S>: Unpin,
+{
+    client_async_tls_with_connector_and_config(request, stream, None, None).await
+}
+
+#[cfg(feature = "tokio-openssl")]
+/// Creates a WebSocket handshake from a request and a stream,
+/// upgrading the stream to TLS if required and using the given
+/// WebSocket configuration.
+pub async fn client_async_tls_with_config<R, S>(
+    request: R,
+    stream: S,
+    config: Option<WebSocketConfig>,
+) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
+where
+    R: IntoClientRequest + Unpin,
+    S: 'static
+        + tokio::io::AsyncRead
+        + tokio::io::AsyncWrite
+        + Unpin
+        + std::fmt::Debug
+        + Send
+        + Sync,
+    AutoStream<S>: Unpin,
+{
+    client_async_tls_with_connector_and_config(request, stream, None, config).await
+}
+
+#[cfg(feature = "tokio-openssl")]
+/// Creates a WebSocket handshake from a request and a stream,
+/// upgrading the stream to TLS if required and using the given
+/// connector.
+pub async fn client_async_tls_with_connector<R, S>(
+    request: R,
+    stream: S,
+    connector: Option<Connector>,
+) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
+where
+    R: IntoClientRequest + Unpin,
+    S: 'static
+        + tokio::io::AsyncRead
+        + tokio::io::AsyncWrite
+        + Unpin
+        + std::fmt::Debug
+        + Send
+        + Sync,
+    AutoStream<S>: Unpin,
+{
+    client_async_tls_with_connector_and_config(request, stream, connector, None).await
+}
+
 /// Type alias for the stream type of the `connect_async()` functions.
 pub type ConnectStream = ClientStream<TcpStream>;
 
@@ -368,7 +540,11 @@ where
     client_async_tls_with_connector_and_config(request, socket, None, config).await
 }
 
-#[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))]
+#[cfg(any(
+    feature = "async-tls",
+    feature = "tokio-native-tls",
+    feature = "tokio-openssl"
+))]
 /// Connect to a given URL using the provided TLS connector.
 pub async fn connect_async_with_tls_connector<R>(
     request: R,
@@ -380,7 +556,11 @@ where
     connect_async_with_tls_connector_and_config(request, connector, None).await
 }
 
-#[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))]
+#[cfg(any(
+    feature = "async-tls",
+    feature = "tokio-native-tls",
+    feature = "tokio-openssl"
+))]
 /// Connect to a given URL using the provided TLS connector.
 pub async fn connect_async_with_tls_connector_and_config<R>(
     request: R,