diff --git a/src/tokio.rs b/src/tokio.rs index 4b2a2a4..8410aeb 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -11,402 +11,63 @@ use super::{domain, port, WebSocketStream}; use futures_io::{AsyncRead, AsyncWrite}; -#[cfg(all(feature = "tokio-rustls", not(feature = "tokio-native-tls")))] -pub(crate) mod tokio_tls { - use real_tokio_rustls::rustls::ClientConfig; - use real_tokio_rustls::webpki::DNSNameRef; - use real_tokio_rustls::{client::TlsStream, TlsConnector as AsyncTlsConnector}; - - use tungstenite::client::{uri_mode, IntoClientRequest}; - use tungstenite::handshake::client::Request; - use tungstenite::stream::Mode; - use tungstenite::Error; - - use crate::stream::Stream as StreamSwitcher; - use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; - - use super::TokioAdapter; - - /// A stream that might be protected with TLS. - pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; - - pub type AutoStream = MaybeTlsStream; - - pub type Connector = AsyncTlsConnector; - - async fn wrap_stream( - socket: S, - domain: String, - connector: Option, - mode: Mode, - ) -> Result, Error> - where - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), - Mode::Tls => { - let stream = { - let connector = if let Some(connector) = connector { - connector - } else { - let mut config = ClientConfig::new(); - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - AsyncTlsConnector::from(std::sync::Arc::new(config)) - }; - let domain = DNSNameRef::try_from_ascii_str(&domain) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?; - connector - .connect(domain, socket) - .await - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? - }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) - } - } - } - - /// Creates a WebSocket handshake from a request and a stream, - /// upgrading the stream to TLS if required and using the given - /// connector and WebSocket configuration. - pub async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: IntoClientRequest + Unpin, - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - AutoStream: Unpin, - { - let request: Request = request.into_client_request()?; - - let domain = domain(&request)?; - - // Make sure we check domain and mode first. URL must be valid. - let mode = uri_mode(request.uri())?; - - let stream = wrap_stream(stream, domain, connector, mode).await?; - client_async_with_config(request, stream, config).await - } -} - #[cfg(feature = "tokio-native-tls")] -pub(crate) mod tokio_tls { - use real_tokio_native_tls::TlsConnector as AsyncTlsConnector; - use real_tokio_native_tls::TlsStream; - - use tungstenite::client::{uri_mode, IntoClientRequest}; - use tungstenite::handshake::client::Request; - use tungstenite::stream::Mode; - use tungstenite::Error; - - use crate::stream::Stream as StreamSwitcher; - use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; - - use super::TokioAdapter; - - /// A stream that might be protected with TLS. - pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; - - pub type AutoStream = MaybeTlsStream; - - pub type Connector = AsyncTlsConnector; - - async fn wrap_stream( - socket: S, - domain: String, - connector: Option, - mode: Mode, - ) -> Result, Error> - where - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), - Mode::Tls => { - let stream = { - let connector = if let Some(connector) = connector { - connector - } else { - let connector = real_native_tls::TlsConnector::builder().build()?; - AsyncTlsConnector::from(connector) - }; - connector - .connect(&domain, socket) - .await - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? - }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) - } - } - } +#[path = "tokio/native_tls.rs"] +mod tls; - /// Creates a WebSocket handshake from a request and a stream, - /// upgrading the stream to TLS if required and using the given - /// connector and WebSocket configuration. - pub async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: IntoClientRequest + Unpin, - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - AutoStream: Unpin, - { - let request: Request = request.into_client_request()?; - - let domain = domain(&request)?; - - // Make sure we check domain and mode first. URL must be valid. - let mode = uri_mode(request.uri())?; - - let stream = wrap_stream(stream, domain, connector, mode).await?; - client_async_with_config(request, stream, config).await - } -} +#[cfg(all(feature = "tokio-rustls", not(feature = "tokio-native-tls")))] +#[path = "tokio/rustls.rs"] +mod tls; #[cfg(all( feature = "tokio-openssl", not(any(feature = "tokio-native-tls", feature = "tokio-rustls")) ))] -pub(crate) mod tokio_tls { - use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod}; - use real_tokio_openssl::connect; - use real_tokio_openssl::SslStream as TlsStream; - - use tungstenite::client::{uri_mode, IntoClientRequest}; - use tungstenite::handshake::client::Request; - use tungstenite::stream::Mode; - use tungstenite::Error; - - use crate::stream::Stream as StreamSwitcher; - use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; - - use super::TokioAdapter; - - /// A stream that might be protected with TLS. - pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; - - pub type AutoStream = MaybeTlsStream; - - pub type Connector = ConnectConfiguration; - - async fn wrap_stream( - socket: S, - domain: String, - connector: Option, - mode: Mode, - ) -> Result, Error> - where - S: 'static - + tokio::io::AsyncRead - + tokio::io::AsyncWrite - + Unpin - + std::fmt::Debug - + Send - + Sync, - { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), - Mode::Tls => { - let stream = { - let connector = if let Some(connector) = connector { - connector - } else { - SslConnector::builder(SslMethod::tls_client()) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? - .build() - .configure() - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? - }; - connect(connector, &domain, socket) - .await - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? - }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) - } - } - } - - /// Creates a WebSocket handshake from a request and a stream, - /// upgrading the stream to TLS if required and using the given - /// connector and WebSocket configuration. - pub async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: IntoClientRequest + Unpin, - S: 'static - + tokio::io::AsyncRead - + tokio::io::AsyncWrite - + Unpin - + std::fmt::Debug - + Send - + Sync, - AutoStream: Unpin, - { - let request: Request = request.into_client_request()?; - - let domain = domain(&request)?; - - // Make sure we check domain and mode first. URL must be valid. - let mode = uri_mode(request.uri())?; - - let stream = wrap_stream(stream, domain, connector, mode).await?; - client_async_with_config(request, stream, config).await - } -} - -#[cfg(not(any( - feature = "async-tls", - feature = "tokio-native-tls", - feature = "tokio-rustls", - feature = "tokio-openssl" -)))] -pub(crate) mod dummy_tls { - use tungstenite::client::{uri_mode, IntoClientRequest}; - use tungstenite::handshake::client::Request; - use tungstenite::stream::Mode; - use tungstenite::Error; - - use super::TokioAdapter; - - use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; - - pub type AutoStream = TokioAdapter; - type Connector = (); - - async fn wrap_stream( - socket: S, - _domain: String, - _connector: Option<()>, - mode: Mode, - ) -> Result, Error> - where - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - { - match mode { - Mode::Plain => Ok(TokioAdapter(socket)), - Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), - } - } - - pub(crate) async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: IntoClientRequest + Unpin, - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - AutoStream: Unpin, - { - let request: Request = request.into_client_request()?; - - let domain = domain(&request)?; - - // Make sure we check domain and mode first. URL must be valid. - let mode = uri_mode(request.uri())?; - - let stream = wrap_stream(stream, domain, connector, mode).await?; - client_async_with_config(request, stream, config).await - } -} - -#[cfg(not(any( - feature = "async-tls", - feature = "tokio-native-tls", - feature = "tokio-rustls", - feature = "tokio-openssl" -)))] -use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; +#[path = "tokio/openssl.rs"] +mod tls; #[cfg(all( feature = "async-tls", not(any( - feature = "tokio-rustls", feature = "tokio-native-tls", - feature = "tokio-openssl" - )) -))] -pub(crate) mod async_tls_adapter { - use super::{ - Error, IntoClientRequest, Response, TokioAdapter, WebSocketConfig, WebSocketStream, - }; - use crate::stream::Stream as StreamSwitcher; - use std::marker::Unpin; - - /// Creates a WebSocket handshake from a request and a stream, - /// upgrading the stream to TLS if required and using the given - /// connector and WebSocket configuration. - pub async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: IntoClientRequest + Unpin, - S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - AutoStream: Unpin, - { - crate::async_tls::client_async_tls_with_connector_and_config( - request, - TokioAdapter(stream), - connector, - config, - ) - .await - } - - pub type Connector = real_async_tls::TlsConnector; - - pub type MaybeTlsStream = StreamSwitcher>; - - pub type AutoStream = MaybeTlsStream>; -} - -#[cfg(all( - feature = "async-tls", - not(any( feature = "tokio-rustls", - feature = "tokio-native-tls", feature = "tokio-openssl" )) ))] -pub use self::async_tls_adapter::client_async_tls_with_connector_and_config; -#[cfg(all( - feature = "async-tls", - not(any( - feature = "tokio-rustls", - feature = "tokio-native-tls", - feature = "tokio-openssl" - )) -))] -use self::async_tls_adapter::{AutoStream, Connector}; +#[path = "tokio/async_tls.rs"] +mod tls; -#[cfg(any( +#[cfg(not(any( + feature = "tokio-native-tls", feature = "tokio-rustls", + feature = "tokio-openssl", + feature = "async-tls" +)))] +#[path = "tokio/dummy_tls.rs"] +mod tls; + +#[cfg(any( feature = "tokio-native-tls", - feature = "tokio-openssl" + feature = "tokio-rustls", + feature = "tokio-openssl", + feature = "async-tls", ))] -pub use self::tokio_tls::client_async_tls_with_connector_and_config; +pub use self::tls::client_async_tls_with_connector_and_config; #[cfg(any( - feature = "tokio-rustls", feature = "tokio-native-tls", - feature = "tokio-openssl" + feature = "tokio-rustls", + feature = "tokio-openssl", + feature = "async-tls" ))] -use self::tokio_tls::{AutoStream, Connector}; +use self::tls::{AutoStream, Connector}; + +#[cfg(not(any( + feature = "tokio-native-tls", + feature = "tokio-rustls", + feature = "tokio-openssl", + feature = "async-tls" +)))] +use self::tls::{client_async_tls_with_connector_and_config, AutoStream}; /// Creates a WebSocket handshake from a request and a stream. /// For convenience, the user may call this with a url string, a URL, diff --git a/src/tokio/async_tls.rs b/src/tokio/async_tls.rs new file mode 100644 index 0000000..d233172 --- /dev/null +++ b/src/tokio/async_tls.rs @@ -0,0 +1,32 @@ +use super::{Error, IntoClientRequest, Response, TokioAdapter, WebSocketConfig, WebSocketStream}; +use crate::stream::Stream as StreamSwitcher; +use std::marker::Unpin; + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, +{ + crate::async_tls::client_async_tls_with_connector_and_config( + request, + TokioAdapter(stream), + connector, + config, + ) + .await +} + +pub type Connector = real_async_tls::TlsConnector; + +pub type MaybeTlsStream = StreamSwitcher>; + +pub type AutoStream = MaybeTlsStream>; diff --git a/src/tokio/dummy_tls.rs b/src/tokio/dummy_tls.rs new file mode 100644 index 0000000..090a769 --- /dev/null +++ b/src/tokio/dummy_tls.rs @@ -0,0 +1,49 @@ +use tungstenite::client::{uri_mode, IntoClientRequest}; +use tungstenite::handshake::client::Request; +use tungstenite::stream::Mode; +use tungstenite::Error; + +use super::TokioAdapter; + +use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + +pub type AutoStream = TokioAdapter; + +type Connector = (); + +async fn wrap_stream( + socket: S, + _domain: String, + _connector: Option<()>, + mode: Mode, +) -> Result, Error> +where + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + match mode { + Mode::Plain => Ok(TokioAdapter(socket)), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + } +} + +pub(crate) async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await +} diff --git a/src/tokio/native_tls.rs b/src/tokio/native_tls.rs new file mode 100644 index 0000000..8c45118 --- /dev/null +++ b/src/tokio/native_tls.rs @@ -0,0 +1,73 @@ +use real_tokio_native_tls::TlsConnector as AsyncTlsConnector; +use real_tokio_native_tls::TlsStream; + +use tungstenite::client::{uri_mode, IntoClientRequest}; +use tungstenite::handshake::client::Request; +use tungstenite::stream::Mode; +use tungstenite::Error; + +use crate::stream::Stream as StreamSwitcher; +use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + +use super::TokioAdapter; + +/// A stream that might be protected with TLS. +pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; + +pub type AutoStream = MaybeTlsStream; + +pub type Connector = AsyncTlsConnector; + +async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, +) -> Result, Error> +where + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + let connector = real_native_tls::TlsConnector::builder().build()?; + AsyncTlsConnector::from(connector) + }; + connector + .connect(&domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await +} diff --git a/src/tokio/openssl.rs b/src/tokio/openssl.rs new file mode 100644 index 0000000..3202ae2 --- /dev/null +++ b/src/tokio/openssl.rs @@ -0,0 +1,88 @@ +use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod}; +use real_tokio_openssl::connect; +use real_tokio_openssl::SslStream as TlsStream; + +use tungstenite::client::{uri_mode, IntoClientRequest}; +use tungstenite::handshake::client::Request; +use tungstenite::stream::Mode; +use tungstenite::Error; + +use crate::stream::Stream as StreamSwitcher; +use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + +use super::TokioAdapter; + +/// A stream that might be protected with TLS. +pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; + +pub type AutoStream = MaybeTlsStream; + +pub type Connector = ConnectConfiguration; + +async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, +) -> Result, Error> +where + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, +{ + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + SslConnector::builder(SslMethod::tls_client()) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + .build() + .configure() + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + connect(connector, &domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + + tokio::io::AsyncRead + + tokio::io::AsyncWrite + + Unpin + + std::fmt::Debug + + Send + + Sync, + AutoStream: Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await +} diff --git a/src/tokio/rustls.rs b/src/tokio/rustls.rs new file mode 100644 index 0000000..7bfbd28 --- /dev/null +++ b/src/tokio/rustls.rs @@ -0,0 +1,79 @@ +use real_tokio_rustls::rustls::ClientConfig; +use real_tokio_rustls::webpki::DNSNameRef; +use real_tokio_rustls::{client::TlsStream, TlsConnector as AsyncTlsConnector}; + +use tungstenite::client::{uri_mode, IntoClientRequest}; +use tungstenite::handshake::client::Request; +use tungstenite::stream::Mode; +use tungstenite::Error; + +use crate::stream::Stream as StreamSwitcher; +use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + +use super::TokioAdapter; + +/// A stream that might be protected with TLS. +pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; + +pub type AutoStream = MaybeTlsStream; + +pub type Connector = AsyncTlsConnector; + +async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, +) -> Result, Error> +where + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + let mut config = ClientConfig::new(); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + AsyncTlsConnector::from(std::sync::Arc::new(config)) + }; + let domain = DNSNameRef::try_from_ascii_str(&domain) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?; + connector + .connect(domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await +}