diff --git a/src/tokio.rs b/src/tokio.rs index b0eabc5..2ec63fb 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -1,6 +1,7 @@ //! `tokio` integration. use tungstenite::client::IntoClientRequest; use tungstenite::handshake::client::{Request, Response}; +use tungstenite::handshake::server::{Callback, NoCallback}; use tungstenite::protocol::WebSocketConfig; use tungstenite::Error; @@ -20,15 +21,13 @@ pub(crate) mod tokio_tls { use tungstenite::stream::Mode; use tungstenite::Error; - use futures_io::{AsyncRead, AsyncWrite}; - use crate::stream::Stream as StreamSwitcher; use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; use super::TokioAdapter; /// A stream that might be protected with TLS. - pub type MaybeTlsStream = StreamSwitcher>>>; + pub type MaybeTlsStream = StreamSwitcher, TokioAdapter>>; pub type AutoStream = MaybeTlsStream; @@ -41,10 +40,10 @@ pub(crate) mod tokio_tls { mode: Mode, ) -> Result, Error> where - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), Mode::Tls => { let stream = { let connector = if let Some(connector) = connector { @@ -54,7 +53,7 @@ pub(crate) mod tokio_tls { AsyncTlsConnector::from(connector) }; connector - .connect(&domain, TokioAdapter(socket)) + .connect(&domain, socket) .await .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? }; @@ -74,7 +73,7 @@ pub(crate) mod tokio_tls { ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest + Unpin, - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, AutoStream: Unpin, { let request: Request = request.into_client_request()?; @@ -91,16 +90,16 @@ pub(crate) mod tokio_tls { #[cfg(not(any(feature = "async-tls", feature = "tokio-native-tls")))] pub(crate) mod dummy_tls { - use futures_io::{AsyncRead, AsyncWrite}; - use tungstenite::client::{uri_mode, IntoClientRequest}; use tungstenite::handshake::client::Request; use tungstenite::stream::Mode; use tungstenite::Error; + use super::TokioAdapter; + use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; - pub type AutoStream = S; + pub type AutoStream = TokioAdapter; type Connector = (); async fn wrap_stream( @@ -110,10 +109,10 @@ pub(crate) mod dummy_tls { mode: Mode, ) -> Result, Error> where - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { match mode { - Mode::Plain => Ok(socket), + Mode::Plain => Ok(TokioAdapter(socket)), Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), } } @@ -126,7 +125,7 @@ pub(crate) mod dummy_tls { ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest + Unpin, - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, AutoStream: Unpin, { let request: Request = request.into_client_request()?; @@ -145,13 +144,147 @@ pub(crate) mod dummy_tls { use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; #[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] -use crate::async_tls::{client_async_tls_with_connector_and_config, AutoStream}; +pub(crate) mod async_tls_adapter { + use super::{ + Error, IntoClientRequest, Response, TokioAdapter, WebSocketConfig, WebSocketStream, + }; + use crate::stream::Stream as StreamSwitcher; + use std::marker::Unpin; + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: IntoClientRequest + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + AutoStream: Unpin, + { + crate::async_tls::client_async_tls_with_connector_and_config( + request, + TokioAdapter(stream), + connector, + config, + ) + .await + } + + pub type Connector = real_async_tls::TlsConnector; + + pub type MaybeTlsStream = StreamSwitcher>; + + pub type AutoStream = MaybeTlsStream>; +} +#[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] +pub use self::async_tls_adapter::client_async_tls_with_connector_and_config; #[cfg(all(feature = "async-tls", not(feature = "tokio-native-tls")))] -type Connector = real_async_tls::TlsConnector; +use self::async_tls_adapter::{AutoStream, Connector}; #[cfg(feature = "tokio-native-tls")] use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector}; +/// Creates a WebSocket handshake from a request and a stream. +/// For convenience, the user may call this with a url string, a URL, +/// or a `Request`. Calling with `Request` allows the user to add +/// a WebSocket protocol or other custom headers. +/// +/// Internally, this custom creates a handshake representation and returns +/// a future representing the resolution of the WebSocket handshake. The +/// returned future will resolve to either `WebSocketStream` or `Error` +/// depending on whether the handshake is successful. +/// +/// This is typically used for clients who have already established, for +/// example, a TCP connection to the remote server. +pub async fn client_async<'a, R, S>( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + client_async_with_config(request, stream, None).await +} + +/// The same as `client_async()` but the one can specify a websocket configuration. +/// Please refer to `client_async()` for more details. +pub async fn client_async_with_config<'a, R, S>( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + crate::client_async_with_config(request, TokioAdapter(stream), config).await +} + +/// Accepts a new WebSocket connection with the provided stream. +/// +/// This function will internally call `server::accept` to create a +/// handshake representation and returns a future representing the +/// resolution of the WebSocket handshake. The returned future will resolve +/// to either `WebSocketStream` or `Error` depending if it's successful +/// or not. +/// +/// This is typically used after a socket has been accepted from a +/// `TcpListener`. That socket is then passed to this function to perform +/// the server half of the accepting a client's websocket connection. +pub async fn accept_async(stream: S) -> Result>, Error> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + accept_hdr_async(stream, NoCallback).await +} + +/// The same as `accept_async()` but the one can specify a websocket configuration. +/// Please refer to `accept_async()` for more details. +pub async fn accept_async_with_config( + stream: S, + config: Option, +) -> Result>, Error> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + accept_hdr_async_with_config(stream, NoCallback, config).await +} + +/// Accepts a new WebSocket connection with the provided stream. +/// +/// This function does the same as `accept_async()` but accepts an extra callback +/// for header processing. The callback receives headers of the incoming +/// requests and is able to add extra headers to the reply. +pub async fn accept_hdr_async( + stream: S, + callback: C, +) -> Result>, Error> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + C: Callback + Unpin, +{ + accept_hdr_async_with_config(stream, callback, None).await +} + +/// The same as `accept_hdr_async()` but the one can specify a websocket configuration. +/// Please refer to `accept_hdr_async()` for more details. +pub async fn accept_hdr_async_with_config( + stream: S, + callback: C, + config: Option, +) -> Result>, Error> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + C: Callback + Unpin, +{ + crate::accept_hdr_async_with_config(TokioAdapter(stream), callback, config).await +} + /// Type alias for the stream type of the `client_async()` functions. pub type ClientStream = AutoStream; @@ -164,7 +297,7 @@ pub async fn client_async_tls( ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest + Unpin, - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, AutoStream: Unpin, { client_async_tls_with_connector_and_config(request, stream, None, None).await @@ -181,7 +314,7 @@ pub async fn client_async_tls_with_config( ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest + Unpin, - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, AutoStream: Unpin, { client_async_tls_with_connector_and_config(request, stream, None, config).await @@ -198,14 +331,14 @@ pub async fn client_async_tls_with_connector( ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest + Unpin, - S: 'static + AsyncRead + AsyncWrite + Unpin, + S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, AutoStream: Unpin, { client_async_tls_with_connector_and_config(request, stream, connector, None).await } /// Type alias for the stream type of the `connect_async()` functions. -pub type ConnectStream = ClientStream>; +pub type ConnectStream = ClientStream; /// Connect to a given URL. pub async fn connect_async( @@ -232,7 +365,7 @@ where let try_socket = TcpStream::connect((domain.as_str(), port)).await; let socket = try_socket.map_err(Error::Io)?; - client_async_tls_with_connector_and_config(request, TokioAdapter(socket), None, config).await + client_async_tls_with_connector_and_config(request, socket, None, config).await } #[cfg(any(feature = "async-tls", feature = "tokio-native-tls"))] @@ -264,8 +397,7 @@ where let try_socket = TcpStream::connect((domain.as_str(), port)).await; let socket = try_socket.map_err(Error::Io)?; - client_async_tls_with_connector_and_config(request, TokioAdapter(socket), connector, config) - .await + client_async_tls_with_connector_and_config(request, socket, connector, config).await } use pin_project::pin_project;