diff --git a/src/connect.rs b/src/connect.rs index 6296ed9..2a451fc 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -30,10 +30,15 @@ pub(crate) mod encryption { pub type MaybeTlsStream = StreamSwitcher>; pub type AutoStream = MaybeTlsStream; + #[cfg(feature = "tls")] + pub type Connector = async_tls::TlsConnector; + #[cfg(feature = "native-tls")] + pub type Connector = real_native_tls::TlsConnector; pub async fn wrap_stream( socket: S, domain: String, + connector: Option, mode: Mode, ) -> Result, Error> where @@ -44,13 +49,17 @@ pub(crate) mod encryption { Mode::Tls => { #[cfg(feature = "tls")] let stream = { - let connector = AsyncTlsConnector::new(); + let connector = connector.unwrap_or_else(|| AsyncTlsConnector::new()); connector.connect(&domain, socket)?.await? }; #[cfg(feature = "native-tls")] let stream = { - let builder = real_native_tls::TlsConnector::builder(); - let connector = builder.build()?; + let connector = if let Some(connector) = connector { + connector + } else { + let builder = real_native_tls::TlsConnector::builder(); + builder.build()? + }; let connector = AsyncTlsConnector::from(connector); connector.connect(&domain, socket).await? }; @@ -59,23 +68,23 @@ pub(crate) mod encryption { } } } - #[cfg(feature = "tls-base")] pub use self::encryption::MaybeTlsStream; #[cfg(not(feature = "tls-base"))] pub(crate) mod encryption { use futures::io::{AsyncRead, AsyncWrite}; - use futures::{future, Future}; use tungstenite::stream::Mode; use tungstenite::Error; pub type AutoStream = S; + pub type Connector = (); - pub async fn wrap_stream( + pub(crate) async fn wrap_stream( socket: S, _domain: String, + _connector: Option<()>, mode: Mode, ) -> Result, Error> where @@ -88,7 +97,7 @@ pub(crate) mod encryption { } } -use self::encryption::{wrap_stream, AutoStream}; +use self::encryption::AutoStream; /// Get a domain from an URL. #[inline] @@ -105,6 +114,22 @@ pub async fn client_async_tls( request: R, stream: S, ) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector(request, stream, None).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin, @@ -117,7 +142,7 @@ where // Make sure we check domain and mode first. URL must be valid. let mode = url_mode(&request.url)?; - let stream = wrap_stream(stream, domain, mode).await?; + let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?; client_async(request, stream).await } @@ -145,7 +170,34 @@ pub(crate) mod async_std_runtime { let socket = try_socket.map_err(Error::Io)?; client_async_tls(request, socket).await } + + #[cfg(any(feature = "tls", feature = "native-tls"))] + /// Connect to a given URL using the provided TLS connector. + pub async fn connect_async_with_tls_connector( + request: R, + connector: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + { + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector(request, socket, connector).await + } } #[cfg(feature = "async_std_runtime")] pub use async_std_runtime::connect_async; +#[cfg(all( + feature = "async_std_runtime", + any(feature = "tls", feature = "native-tls") +))] +pub use async_std_runtime::connect_async_with_tls_connector; diff --git a/src/lib.rs b/src/lib.rs index 11b8290..e2fc399 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,8 +46,15 @@ use tungstenite::{ #[cfg(feature = "connect")] pub use connect::client_async_tls; +#[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))] +pub use connect::client_async_tls_with_connector; #[cfg(feature = "async_std_runtime")] pub use connect::connect_async; +#[cfg(all( + feature = "async_std_runtime", + any(feature = "tls", feature = "native-tls") +))] +pub use connect::connect_async_with_tls_connector; #[cfg(all(feature = "connect", feature = "tls-base"))] pub use connect::MaybeTlsStream;