Add client_async_tls and connect_async variants with WebSocketConfig parameter

pull/5/head
Sebastian Dröge 5 years ago
parent b4d5a9e84e
commit c647de44ef
  1. 74
      src/connect.rs
  2. 8
      src/lib.rs

@ -1,11 +1,12 @@
//! Connection helper. //! Connection helper.
use tungstenite::client::url_mode; use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response; use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error; use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use super::{client_async, Request, WebSocketStream}; use super::{client_async_with_config, Request, WebSocketStream};
#[cfg(feature = "tls-base")] #[cfg(feature = "tls-base")]
pub(crate) mod encryption { pub(crate) mod encryption {
@ -119,7 +120,23 @@ where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
client_async_tls_with_connector(request, stream, None).await client_async_tls_with_connector_and_config(request, stream, None, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// WebSocket configuration.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, config).await
} }
/// Creates a WebSocket handshake from a request and a stream, /// Creates a WebSocket handshake from a request and a stream,
@ -130,6 +147,23 @@ pub async fn client_async_tls_with_connector<R, S>(
stream: S, stream: S,
connector: Option<self::encryption::Connector>, connector: Option<self::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, connector, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector and WebSocket configuration.
pub async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<self::encryption::Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
@ -143,7 +177,7 @@ where
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?; let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?;
client_async(request, stream).await client_async_with_config(request, stream, config).await
} }
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "async_std_runtime")]
@ -155,6 +189,17 @@ pub(crate) mod async_std_runtime {
pub async fn connect_async<R>( pub async fn connect_async<R>(
request: R, request: R,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_config(request, None).await
}
/// Connect to a given URL with a given WebSocket configuration.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: Into<Request<'static>> + Unpin,
{ {
@ -168,7 +213,7 @@ pub(crate) mod async_std_runtime {
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;
client_async_tls(request, socket).await client_async_tls_with_config(request, socket, config).await
} }
#[cfg(any(feature = "tls", feature = "native-tls"))] #[cfg(any(feature = "tls", feature = "native-tls"))]
@ -177,6 +222,19 @@ pub(crate) mod async_std_runtime {
request: R, request: R,
connector: Option<super::encryption::Connector>, connector: Option<super::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_tls_connector_and_config(request, connector, None).await
}
#[cfg(any(feature = "tls", feature = "native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector_and_config<R>(
request: R,
connector: Option<super::encryption::Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: Into<Request<'static>> + Unpin,
{ {
@ -190,14 +248,16 @@ pub(crate) mod async_std_runtime {
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector(request, socket, connector).await client_async_tls_with_connector_and_config(request, socket, connector, config).await
} }
} }
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "async_std_runtime")]
pub use async_std_runtime::connect_async; pub use async_std_runtime::{connect_async, connect_async_with_config};
#[cfg(all( #[cfg(all(
feature = "async_std_runtime", feature = "async_std_runtime",
any(feature = "tls", feature = "native-tls") any(feature = "tls", feature = "native-tls")
))] ))]
pub use async_std_runtime::connect_async_with_tls_connector; pub use async_std_runtime::{
connect_async_with_tls_connector, connect_async_with_tls_connector_and_config,
};

@ -45,16 +45,16 @@ use tungstenite::{
}; };
#[cfg(feature = "connect")] #[cfg(feature = "connect")]
pub use connect::client_async_tls; pub use connect::{client_async_tls, client_async_tls_with_config};
#[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))] #[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))]
pub use connect::client_async_tls_with_connector; pub use connect::{client_async_tls_with_connector, client_async_tls_with_connector_and_config};
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "async_std_runtime")]
pub use connect::connect_async; pub use connect::{connect_async, connect_async_with_config};
#[cfg(all( #[cfg(all(
feature = "async_std_runtime", feature = "async_std_runtime",
any(feature = "tls", feature = "native-tls") any(feature = "tls", feature = "native-tls")
))] ))]
pub use connect::connect_async_with_tls_connector; pub use connect::{connect_async_with_tls_connector, connect_async_with_tls_connector_and_config};
#[cfg(all(feature = "connect", feature = "tls-base"))] #[cfg(all(feature = "connect", feature = "tls-base"))]
pub use connect::MaybeTlsStream; pub use connect::MaybeTlsStream;

Loading…
Cancel
Save