diff --git a/src/connect.rs b/src/connect.rs index 2a451fc..50cd33b 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,11 +1,12 @@ //! Connection helper. use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; +use tungstenite::protocol::WebSocketConfig; use tungstenite::Error; use futures::io::{AsyncRead, AsyncWrite}; -use super::{client_async, Request, WebSocketStream}; +use super::{client_async_with_config, Request, WebSocketStream}; #[cfg(feature = "tls-base")] pub(crate) mod encryption { @@ -119,7 +120,23 @@ where S: 'static + AsyncRead + AsyncWrite + Send + Unpin, AutoStream: Unpin, { - client_async_tls_with_connector(request, stream, None).await + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await } /// Creates a WebSocket handshake from a request and a stream, @@ -130,6 +147,23 @@ pub async fn client_async_tls_with_connector( stream: S, connector: Option, ) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin, @@ -143,7 +177,7 @@ where let mode = url_mode(&request.url)?; let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?; - client_async(request, stream).await + client_async_with_config(request, stream, config).await } #[cfg(feature = "async_std_runtime")] @@ -155,6 +189,17 @@ pub(crate) mod async_std_runtime { pub async fn connect_async( request: R, ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + { + connect_async_with_config(request, None).await + } + + /// Connect to a given URL with a given WebSocket configuration. + pub async fn connect_async_with_config( + request: R, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, { @@ -168,7 +213,7 @@ pub(crate) mod async_std_runtime { let try_socket = TcpStream::connect((domain.as_str(), port)).await; let socket = try_socket.map_err(Error::Io)?; - client_async_tls(request, socket).await + client_async_tls_with_config(request, socket, config).await } #[cfg(any(feature = "tls", feature = "native-tls"))] @@ -177,6 +222,19 @@ pub(crate) mod async_std_runtime { request: R, connector: Option, ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + { + connect_async_with_tls_connector_and_config(request, connector, None).await + } + + #[cfg(any(feature = "tls", feature = "native-tls"))] + /// Connect to a given URL using the provided TLS connector. + pub async fn connect_async_with_tls_connector_and_config( + request: R, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, { @@ -190,14 +248,16 @@ pub(crate) mod async_std_runtime { let try_socket = TcpStream::connect((domain.as_str(), port)).await; let socket = try_socket.map_err(Error::Io)?; - client_async_tls_with_connector(request, socket, connector).await + client_async_tls_with_connector_and_config(request, socket, connector, config).await } } #[cfg(feature = "async_std_runtime")] -pub use async_std_runtime::connect_async; +pub use async_std_runtime::{connect_async, connect_async_with_config}; #[cfg(all( feature = "async_std_runtime", any(feature = "tls", feature = "native-tls") ))] -pub use async_std_runtime::connect_async_with_tls_connector; +pub use async_std_runtime::{ + connect_async_with_tls_connector, connect_async_with_tls_connector_and_config, +}; diff --git a/src/lib.rs b/src/lib.rs index e2fc399..1f095d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,16 +45,16 @@ use tungstenite::{ }; #[cfg(feature = "connect")] -pub use connect::client_async_tls; +pub use connect::{client_async_tls, client_async_tls_with_config}; #[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))] -pub use connect::client_async_tls_with_connector; +pub use connect::{client_async_tls_with_connector, client_async_tls_with_connector_and_config}; #[cfg(feature = "async_std_runtime")] -pub use connect::connect_async; +pub use connect::{connect_async, connect_async_with_config}; #[cfg(all( feature = "async_std_runtime", any(feature = "tls", feature = "native-tls") ))] -pub use connect::connect_async_with_tls_connector; +pub use connect::{connect_async_with_tls_connector, connect_async_with_tls_connector_and_config}; #[cfg(all(feature = "connect", feature = "tls-base"))] pub use connect::MaybeTlsStream;