From fb8bbca9edb9e0c5fca33340fae04bde39a5f61e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Thu, 12 Dec 2019 22:52:01 +0200 Subject: [PATCH] Add optional support for TLS via async-native-tls instead of async-tls (rustls) Can be enabled with the "native-tls" feature instead of just "tls". --- Cargo.toml | 15 +++++++++++++-- src/connect.rs | 32 +++++++++++++++++++++++--------- src/lib.rs | 2 +- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b6b5c97..577065a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,10 @@ edition = "2018" [features] default = ["connect", "tls", "async_std_runtime"] connect = ["stream"] -async_std_runtime = ["connect", "tls", "async-std"] -tls = ["async-tls", "stream"] +async_std_runtime = ["connect", "async-std"] +tls-base = ["stream"] +tls = ["async-tls", "tls-base"] +native-tls = ["async-native-tls", "real-native-tls", "tls-base", "tungstenite/tls"] stream = [] [dependencies] @@ -35,6 +37,15 @@ version = "1.0" optional = true version = "0.6.0" +[dependencies.async-native-tls] +optional = true +version = "0.1.0" + +[dependencies.real-native-tls] +optional = true +version = "0.2" +package = "native-tls" + [dev-dependencies] url = "2.0.0" env_logger = "0.7" diff --git a/src/connect.rs b/src/connect.rs index 56187d9..6296ed9 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -7,11 +7,18 @@ use futures::io::{AsyncRead, AsyncWrite}; use super::{client_async, Request, WebSocketStream}; -#[cfg(feature = "tls")] +#[cfg(feature = "tls-base")] pub(crate) mod encryption { + #[cfg(feature = "tls")] use async_tls::client::TlsStream; + #[cfg(feature = "tls")] use async_tls::TlsConnector as AsyncTlsConnector; + #[cfg(feature = "native-tls")] + use async_native_tls::TlsConnector as AsyncTlsConnector; + #[cfg(feature = "native-tls")] + use async_native_tls::TlsStream; + use tungstenite::stream::Mode; use tungstenite::Error; @@ -35,21 +42,28 @@ pub(crate) mod encryption { match mode { Mode::Plain => Ok(StreamSwitcher::Plain(socket)), Mode::Tls => { - let stream = AsyncTlsConnector::new(); - let connected = stream.connect(&domain, socket)?.await; - match connected { - Err(e) => Err(Error::Io(e)), - Ok(s) => Ok(StreamSwitcher::Tls(s)), - } + #[cfg(feature = "tls")] + let stream = { + let connector = AsyncTlsConnector::new(); + connector.connect(&domain, socket)?.await? + }; + #[cfg(feature = "native-tls")] + let stream = { + let builder = real_native_tls::TlsConnector::builder(); + let connector = builder.build()?; + let connector = AsyncTlsConnector::from(connector); + connector.connect(&domain, socket).await? + }; + Ok(StreamSwitcher::Tls(stream)) } } } } -#[cfg(feature = "tls")] +#[cfg(feature = "tls-base")] pub use self::encryption::MaybeTlsStream; -#[cfg(not(feature = "tls"))] +#[cfg(not(feature = "tls-base"))] pub(crate) mod encryption { use futures::io::{AsyncRead, AsyncWrite}; use futures::{future, Future}; diff --git a/src/lib.rs b/src/lib.rs index 751ca7b..11b8290 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ pub use connect::client_async_tls; #[cfg(feature = "async_std_runtime")] pub use connect::connect_async; -#[cfg(all(feature = "connect", feature = "tls"))] +#[cfg(all(feature = "connect", feature = "tls-base"))] pub use connect::MaybeTlsStream; use std::error::Error; use tungstenite::protocol::CloseFrame;