From c2ff77b4469577c11e69b25344cec37f7300f3de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 5 Jan 2020 04:07:50 +0200 Subject: [PATCH] Refactor features and optional API and add support for tokio/gio async runtimes --- .travis.yml | 15 +- Cargo.toml | 74 ++++++-- README.md | 37 +++- examples/async-std-echo.rs | 31 +++ examples/autobahn-client.rs | 2 +- examples/client.rs | 2 +- examples/gio-echo.rs | 28 +++ examples/tokio-echo.rs | 30 +++ scripts/autobahn-client.sh | 4 +- scripts/autobahn-server.sh | 4 +- src/async_std.rs | 268 ++++++++++++++++++++++++++ src/async_tls.rs | 114 +++++++++++ src/connect.rs | 263 -------------------------- src/gio.rs | 137 ++++++++++++++ src/lib.rs | 100 +++++----- src/tokio.rs | 363 ++++++++++++++++++++++++++++++++++++ 16 files changed, 1129 insertions(+), 343 deletions(-) create mode 100644 examples/async-std-echo.rs create mode 100644 examples/gio-echo.rs create mode 100644 examples/tokio-echo.rs create mode 100644 src/async_std.rs create mode 100644 src/async_tls.rs delete mode 100644 src/connect.rs create mode 100644 src/gio.rs create mode 100644 src/tokio.rs diff --git a/.travis.yml b/.travis.yml index 392bf78..8ea85a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,9 +9,18 @@ before_script: - sudo apt-get install libssl-dev script: - - cargo check --release --no-default-features - - cargo test --release - - cargo test --release --no-default-features --features=native-tls,connect,async_std_runtime + - cargo check + - cargo check --features async-tls + - cargo check --features async-std-runtime,async-tls + - cargo check --features async-std-runtime,async-native-tls + - cargo check --features async-std-runtime,async-tls,async-native-tls + - cargo check --features tokio-runtime,async-tls + - cargo check --features tokio-runtime,tokio-tls + - cargo check --features tokio-runtime,async-tls,tokio-tls + - cargo check --features gio-runtime + - cargo check --features gio-runtime,async-tls + - cargo check --features async-std-runtime,async-tls,async-native-tls,tokio-runtime,tokio-tls,gio-runtime + - cargo test --features async-std-runtime after_success: - sudo apt-get install python-unittest2 diff --git a/Cargo.toml b/Cargo.toml index f540046..03cc47b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,24 +1,28 @@ [package] name = "async-tungstenite" -description = "async-std binding for Tungstenite, the Lightweight stream-based WebSocket implementation" +description = "Async binding for Tungstenite, the Lightweight stream-based WebSocket implementation" categories = ["web-programming::websocket", "network-programming", "asynchronous", "concurrency"] -keywords = ["websocket", "io", "web"] +keywords = ["websocket", "io", "web", "tokio", "async-std"] authors = ["Sebastian Dröge "] license = "MIT" homepage = "https://github.com/sdroege/async-tungstenite" repository = "https://github.com/sdroege/async-tungstenite" documentation = "https://docs.rs/async-tungstenite" -version = "0.2.1" +version = "0.3.0" edition = "2018" +readme = "README.md" [features] -default = ["connect", "tls", "async_std_runtime"] -connect = ["stream"] -async_std_runtime = ["connect", "async-std"] -tls-base = ["stream"] -tls = ["async-tls", "tls-base"] -native-tls = ["async-native-tls", "real-native-tls", "tls-base", "tungstenite/tls"] -stream = [] +default = [] +async-std-runtime = ["async-std"] +tokio-runtime = ["tokio"] +gio-runtime = ["gio", "glib"] +async-tls = ["real-async-tls"] +async-native-tls = ["async-std-runtime", "real-async-native-tls"] +tokio-tls = ["tokio-runtime", "real-tokio-tls", "real-native-tls", "tungstenite/tls"] + +[package.metadata.docs.rs] +features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-tls"] [dependencies] log = "0.4" @@ -33,20 +37,64 @@ default-features = false optional = true version = "1.0" -[dependencies.async-tls] +[dependencies.real-async-tls] optional = true version = "0.6.0" +package = "async-tls" -[dependencies.async-native-tls] +[dependencies.real-async-native-tls] optional = true -version = "0.1.0" +version = "0.3.0" +package = "async-native-tls" [dependencies.real-native-tls] optional = true version = "0.2" package = "native-tls" +[dependencies.tokio] +optional = true +version = "0.2" +features = ["tcp", "dns"] + +[dependencies.real-tokio-tls] +optional = true +version = "0.3" +package = "tokio-tls" + +[dependencies.gio] +optional = true +version = "0.8" + +[dependencies.glib] +optional = true +version = "0.9" + [dev-dependencies] url = "2.0.0" env_logger = "0.7" async-std = { version = "1.0", features = ["attributes"] } + +[[example]] +name = "autobahn-client" +required-features = ["async-std-runtime"] + +[[example]] +name = "client" +required-features = ["async-std-runtime"] + +[[example]] +name = "autobahn-server" +required-features = ["async-std-runtime"] + +[[example]] +name = "echo-server" +required-features = ["async-std-runtime"] + +[[example]] +name = "gio-echo" +required-features = ["gio-runtime"] + +[[example]] +name = "tokio-echo" +required-features = ["tokio-runtime"] diff --git a/README.md b/README.md index bd3229c..63c4f2a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # async-tungstenite -Asynchronous WebSockets for [async-std](https://async.rs) and `std` `Future`s. +Asynchronous WebSockets for [async-std](https://async.rs), +[tokio](https://tokio.rs), [gio](https://www.gtk-rs.org) and any `std` +`Future`s runtime. [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![Crates.io](https://img.shields.io/crates/v/async-tungstenite.svg?maxAge=2592000)](https://crates.io/crates/async-tungstenite) @@ -17,15 +19,36 @@ Add this in your `Cargo.toml`: async-tungstenite = "*" ``` -Take a look at the `examples/` directory for client and server examples. You may also want to get familiar with -[`async-std`](https://async.rs/) if you don't have any experience with it. +Take a look at the `examples/` directory for client and server examples. You +may also want to get familiar with [async-std](https://async.rs/) or +[tokio](https://tokio.rs) if you don't have any experience with it. ## What is async-tungstenite? -This crate is based on `tungstenite-rs` Rust WebSocket library and provides async-std bindings and wrappers for it, so you -can use it with non-blocking/asynchronous `TcpStream`s from and couple it together with other crates from the async-std stack. +This crate is based on [tungstenite](https://crates.io/crates/tungstenite) +Rust WebSocket library and provides async bindings and wrappers for it, so you +can use it with non-blocking/asynchronous `TcpStream`s from and couple it +together with other crates from the async stack. In addition, optional +integration with various other crates can be enabled via feature flags + + * `async-tls`: Enables the `async_tls` module, which provides integration + with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can + be used independent of any async runtime. + * `async-std-runtime`: Enables the `async_std` module, which provides + integration with the [async-std](https://async.rs) runtime. + * `async-native-tls`: Enables the additional functions in the `async_std` + module to implement TLS via + [async-native-tls](https://crates.io/crates/async-native-tls). + * `tokio-runtime`: Enables the `tokio` module, which provides integration + with the [tokio](https://tokio.rs) runtime. + * `tokio-tls`: Enables the additional functions in the `tokio` module to + implement TLS via [tokio-tls](https://crates.io/crates/tokio-tls). + * `gio-runtime`: Enables the `gio` module, which provides integration with + the [gio](https://www.gtk-rs.org) runtime. ## tokio-tungstenite -Originally this crate was created as a fork of [tokio-tungstenite](https://github.com/snapview/tokio-tungstenite) -and ported to [async-std](https://async.rs). +Originally this crate was created as a fork of +[tokio-tungstenite](https://github.com/snapview/tokio-tungstenite) and ported +to the traits of the [`futures`](https://crates.io/crates/futures) crate. +Integration into async-std, tokio and gio was added on top of that. diff --git a/examples/async-std-echo.rs b/examples/async-std-echo.rs new file mode 100644 index 0000000..7119900 --- /dev/null +++ b/examples/async-std-echo.rs @@ -0,0 +1,31 @@ +use async_tungstenite::{async_std::connect_async, tungstenite::Message}; +use futures::prelude::*; + +use async_std::task; + +async fn run() -> Result<(), Box> { + #[cfg(any(feature = "async-tls", feature = "async-native-tls"))] + let url = url::Url::parse("wss://echo.websocket.org").unwrap(); + #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] + let url = url::Url::parse("ws://echo.websocket.org").unwrap(); + + let (mut ws_stream, _) = connect_async(url).await?; + + let text = "Hello, World!"; + + println!("Sending: \"{}\"", text); + ws_stream.send(Message::text(text)).await?; + + let msg = ws_stream + .next() + .await + .ok_or_else(|| "didn't receive anything")??; + + println!("Received: {:?}", msg); + + Ok(()) +} + +fn main() -> Result<(), Box> { + task::block_on(run()) +} diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index dea802c..2f36585 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -1,4 +1,4 @@ -use async_tungstenite::{connect_async, tungstenite::Result}; +use async_tungstenite::{async_std::connect_async, tungstenite::Result}; use futures::{SinkExt, StreamExt}; use log::*; use url::Url; diff --git a/examples/client.rs b/examples/client.rs index c8dec0d..376baf0 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -19,7 +19,7 @@ use tungstenite::protocol::Message; use async_std::io; use async_std::prelude::*; use async_std::task; -use async_tungstenite::connect_async; +use async_tungstenite::async_std::connect_async; async fn run() { let _ = env_logger::try_init(); diff --git a/examples/gio-echo.rs b/examples/gio-echo.rs new file mode 100644 index 0000000..7c9d5e9 --- /dev/null +++ b/examples/gio-echo.rs @@ -0,0 +1,28 @@ +use async_tungstenite::{gio::connect_async, tungstenite::Message}; +use futures::prelude::*; + +async fn run() -> Result<(), Box> { + let url = url::Url::parse("wss://echo.websocket.org").unwrap(); + + let (mut ws_stream, _) = connect_async(url).await?; + + let text = "Hello, World!"; + + println!("Sending: \"{}\"", text); + ws_stream.send(Message::text(text)).await?; + + let msg = ws_stream + .next() + .await + .ok_or_else(|| "didn't receive anything")??; + + println!("Received: {:?}", msg); + + Ok(()) +} + +fn main() -> Result<(), Box> { + // Get the default main context and run our async function on it + let main_context = glib::MainContext::default(); + main_context.block_on(run()) +} diff --git a/examples/tokio-echo.rs b/examples/tokio-echo.rs new file mode 100644 index 0000000..9ab022c --- /dev/null +++ b/examples/tokio-echo.rs @@ -0,0 +1,30 @@ +use async_tungstenite::{tokio::connect_async, tungstenite::Message}; +use futures::prelude::*; + +async fn run() -> Result<(), Box> { + #[cfg(any(feature = "async-tls", feature = "tokio-tls"))] + let url = url::Url::parse("wss://echo.websocket.org").unwrap(); + #[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))] + let url = url::Url::parse("ws://echo.websocket.org").unwrap(); + + let (mut ws_stream, _) = connect_async(url).await?; + + let text = "Hello, World!"; + + println!("Sending: \"{}\"", text); + ws_stream.send(Message::text(text)).await?; + + let msg = ws_stream + .next() + .await + .ok_or_else(|| "didn't receive anything")??; + + println!("Received: {:?}", msg); + + Ok(()) +} + +fn main() -> Result<(), Box> { + let mut rt = tokio::runtime::Runtime::new()?; + rt.block_on(run()) +} diff --git a/scripts/autobahn-client.sh b/scripts/autobahn-client.sh index d78b52d..cd63693 100755 --- a/scripts/autobahn-client.sh +++ b/scripts/autobahn-client.sh @@ -23,10 +23,10 @@ function test_diff() { fi } -cargo build --release --example autobahn-client +cargo build --release --features async-std-runtime --example autobahn-client wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' & FUZZINGSERVER_PID=$! sleep 3 echo "Server PID: ${FUZZINGSERVER_PID}" -cargo run --release --example autobahn-client +cargo run --release --features async-std-runtime --example autobahn-client test_diff diff --git a/scripts/autobahn-server.sh b/scripts/autobahn-server.sh index b244d73..86be38a 100755 --- a/scripts/autobahn-server.sh +++ b/scripts/autobahn-server.sh @@ -23,8 +23,8 @@ function test_diff() { fi } -cargo build --release --example autobahn-server -cargo run --release --example autobahn-server & WSSERVER_PID=$! +cargo build --release --features async-std-runtime --example autobahn-server +cargo run --release --features async-std-runtime --example autobahn-server & WSSERVER_PID=$! echo "Server PID: ${WSSERVER_PID}" sleep 3 wstest -m fuzzingclient -s 'autobahn/fuzzingclient.json' diff --git a/src/async_std.rs b/src/async_std.rs new file mode 100644 index 0000000..54b1cbd --- /dev/null +++ b/src/async_std.rs @@ -0,0 +1,268 @@ +//! `async-std` integration. +use tungstenite::handshake::client::Response; +use tungstenite::protocol::WebSocketConfig; +use tungstenite::Error; + +use async_std::net::TcpStream; + +use super::{domain, Request, WebSocketStream}; + +#[cfg(feature = "async-native-tls")] +use futures::io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "async-native-tls")] +pub(crate) mod async_native_tls { + use async_native_tls::TlsConnector as AsyncTlsConnector; + use async_native_tls::TlsStream; + use real_async_native_tls as async_native_tls; + + use tungstenite::client::url_mode; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use futures::io::{AsyncRead, AsyncWrite}; + + use crate::stream::Stream as StreamSwitcher; + use crate::{ + client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, + }; + + /// A stream that might be protected with TLS. + pub type MaybeTlsStream = StreamSwitcher>; + + pub type AutoStream = MaybeTlsStream; + + pub type Connector = AsyncTlsConnector; + + async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + AsyncTlsConnector::new() + }; + connector + .connect(&domain, socket) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(stream)) + } + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into(); + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] +pub(crate) mod dummy_tls { + use futures::io::{AsyncRead, AsyncWrite}; + + use tungstenite::client::url_mode; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use crate::{ + client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, + }; + + pub type AutoStream = S; + type Connector = (); + + async fn wrap_stream( + socket: S, + _domain: String, + _connector: Option<()>, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(socket), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + } + } + + pub(crate) async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into(); + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] +use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; + +#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))] +use crate::async_tls::{client_async_tls_with_connector_and_config, AutoStream}; +#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))] +type Connector = real_async_tls::TlsConnector; + +#[cfg(feature = "async-native-tls")] +use self::async_native_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector}; + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub async fn client_async_tls( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await +} + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + +/// Connect to a given URL. +pub async fn connect_async( + request: R, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, +{ + connect_async_with_config(request, None).await +} + +/// Connect to a given URL with a given WebSocket configuration. +pub async fn connect_async_with_config( + request: R, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, socket, None, config).await +} + +#[cfg(any(feature = "async-tls", feature = "async-native-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector( + request: R, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, +{ + connect_async_with_tls_connector_and_config(request, connector, None).await +} + +#[cfg(any(feature = "async-tls", feature = "async-native-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector_and_config( + request: R, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, socket, connector, config).await +} diff --git a/src/async_tls.rs b/src/async_tls.rs new file mode 100644 index 0000000..8d35622 --- /dev/null +++ b/src/async_tls.rs @@ -0,0 +1,114 @@ +//! `async-tls` integration. +use tungstenite::client::url_mode; +use tungstenite::handshake::client::Response; +use tungstenite::protocol::WebSocketConfig; +use tungstenite::Error; + +use futures::io::{AsyncRead, AsyncWrite}; + +use super::{client_async_with_config, Request, WebSocketStream}; + +use async_tls::client::TlsStream; +use async_tls::TlsConnector as AsyncTlsConnector; +use real_async_tls as async_tls; + +use tungstenite::stream::Mode; + +use crate::domain; +use crate::stream::Stream as StreamSwitcher; + +type MaybeTlsStream = StreamSwitcher>; + +pub(crate) type AutoStream = MaybeTlsStream; + +async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, +) -> Result, Error> +where + S: 'static + AsyncRead + AsyncWrite + Unpin, +{ + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Tls => { + let stream = { + let connector = connector.unwrap_or_else(AsyncTlsConnector::new); + connector.connect(&domain, socket)?.await? + }; + Ok(StreamSwitcher::Tls(stream)) + } + } +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub async fn client_async_tls( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector and WebSocket configuration. +pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await +} diff --git a/src/connect.rs b/src/connect.rs deleted file mode 100644 index b83d4e7..0000000 --- a/src/connect.rs +++ /dev/null @@ -1,263 +0,0 @@ -//! Connection helper. -use tungstenite::client::url_mode; -use tungstenite::handshake::client::Response; -use tungstenite::protocol::WebSocketConfig; -use tungstenite::Error; - -use futures::io::{AsyncRead, AsyncWrite}; - -use super::{client_async_with_config, Request, WebSocketStream}; - -#[cfg(feature = "tls-base")] -pub(crate) mod encryption { - #[cfg(feature = "tls")] - use async_tls::client::TlsStream; - #[cfg(feature = "tls")] - use async_tls::TlsConnector as AsyncTlsConnector; - - #[cfg(feature = "native-tls")] - use async_native_tls::TlsConnector as AsyncTlsConnector; - #[cfg(feature = "native-tls")] - use async_native_tls::TlsStream; - - use tungstenite::stream::Mode; - use tungstenite::Error; - - use futures::io::{AsyncRead, AsyncWrite}; - - use crate::stream::Stream as StreamSwitcher; - - /// A stream that might be protected with TLS. - pub type MaybeTlsStream = StreamSwitcher>; - - pub type AutoStream = MaybeTlsStream; - #[cfg(feature = "tls")] - pub type Connector = async_tls::TlsConnector; - #[cfg(feature = "native-tls")] - pub type Connector = real_native_tls::TlsConnector; - - pub async fn wrap_stream( - socket: S, - domain: String, - connector: Option, - mode: Mode, - ) -> Result, Error> - where - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(socket)), - Mode::Tls => { - #[cfg(feature = "tls")] - let stream = { - let connector = connector.unwrap_or_else(AsyncTlsConnector::new); - connector.connect(&domain, socket)?.await? - }; - #[cfg(feature = "native-tls")] - let stream = { - let connector = if let Some(connector) = connector { - connector - } else { - let builder = real_native_tls::TlsConnector::builder(); - builder.build()? - }; - let connector = AsyncTlsConnector::from(connector); - connector.connect(&domain, socket).await? - }; - Ok(StreamSwitcher::Tls(stream)) - } - } - } -} -#[cfg(feature = "tls-base")] -pub use self::encryption::MaybeTlsStream; - -#[cfg(not(feature = "tls-base"))] -pub(crate) mod encryption { - use futures::io::{AsyncRead, AsyncWrite}; - - use tungstenite::stream::Mode; - use tungstenite::Error; - - pub type AutoStream = S; - pub type Connector = (); - - pub(crate) async fn wrap_stream( - socket: S, - _domain: String, - _connector: Option<()>, - mode: Mode, - ) -> Result, Error> - where - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - { - match mode { - Mode::Plain => Ok(socket), - Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), - } - } -} - -use self::encryption::AutoStream; - -/// Get a domain from an URL. -#[inline] -fn domain(request: &Request) -> Result { - match request.url.host_str() { - Some(d) => Ok(d.to_string()), - None => Err(Error::Url("no host name in the url".into())), - } -} - -/// Creates a WebSocket handshake from a request and a stream, -/// upgrading the stream to TLS if required. -pub async fn client_async_tls( - request: R, - stream: S, -) -> Result<(WebSocketStream>, Response), Error> -where - R: Into> + Unpin, - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - AutoStream: Unpin, -{ - client_async_tls_with_connector_and_config(request, stream, None, None).await -} - -/// Creates a WebSocket handshake from a request and a stream, -/// upgrading the stream to TLS if required and using the given -/// WebSocket configuration. -pub async fn client_async_tls_with_config( - request: R, - stream: S, - config: Option, -) -> Result<(WebSocketStream>, Response), Error> -where - R: Into> + Unpin, - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - AutoStream: Unpin, -{ - client_async_tls_with_connector_and_config(request, stream, None, config).await -} - -/// Creates a WebSocket handshake from a request and a stream, -/// upgrading the stream to TLS if required and using the given -/// connector. -pub async fn client_async_tls_with_connector( - request: R, - stream: S, - connector: Option, -) -> Result<(WebSocketStream>, Response), Error> -where - R: Into> + Unpin, - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - AutoStream: Unpin, -{ - client_async_tls_with_connector_and_config(request, stream, connector, None).await -} - -/// Creates a WebSocket handshake from a request and a stream, -/// upgrading the stream to TLS if required and using the given -/// connector and WebSocket configuration. -pub async fn client_async_tls_with_connector_and_config( - request: R, - stream: S, - connector: Option, - config: Option, -) -> Result<(WebSocketStream>, Response), Error> -where - R: Into> + Unpin, - S: 'static + AsyncRead + AsyncWrite + Send + Unpin, - AutoStream: Unpin, -{ - let request: Request = request.into(); - - let domain = domain(&request)?; - - // Make sure we check domain and mode first. URL must be valid. - let mode = url_mode(&request.url)?; - - let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?; - client_async_with_config(request, stream, config).await -} - -#[cfg(feature = "async_std_runtime")] -pub(crate) mod async_std_runtime { - use super::*; - use async_std::net::TcpStream; - - /// Connect to a given URL. - pub async fn connect_async( - request: R, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: Into> + Unpin, - { - connect_async_with_config(request, None).await - } - - /// Connect to a given URL with a given WebSocket configuration. - pub async fn connect_async_with_config( - request: R, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: Into> + Unpin, - { - let request: Request = request.into(); - - let domain = domain(&request)?; - let port = request - .url - .port_or_known_default() - .expect("Bug: port unknown"); - - let try_socket = TcpStream::connect((domain.as_str(), port)).await; - let socket = try_socket.map_err(Error::Io)?; - client_async_tls_with_config(request, socket, config).await - } - - #[cfg(any(feature = "tls", feature = "native-tls"))] - /// Connect to a given URL using the provided TLS connector. - pub async fn connect_async_with_tls_connector( - request: R, - connector: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: Into> + Unpin, - { - connect_async_with_tls_connector_and_config(request, connector, None).await - } - - #[cfg(any(feature = "tls", feature = "native-tls"))] - /// Connect to a given URL using the provided TLS connector. - pub async fn connect_async_with_tls_connector_and_config( - request: R, - connector: Option, - config: Option, - ) -> Result<(WebSocketStream>, Response), Error> - where - R: Into> + Unpin, - { - let request: Request = request.into(); - - let domain = domain(&request)?; - let port = request - .url - .port_or_known_default() - .expect("Bug: port unknown"); - - let try_socket = TcpStream::connect((domain.as_str(), port)).await; - let socket = try_socket.map_err(Error::Io)?; - client_async_tls_with_connector_and_config(request, socket, connector, config).await - } -} - -#[cfg(feature = "async_std_runtime")] -pub use async_std_runtime::{connect_async, connect_async_with_config}; -#[cfg(all( - feature = "async_std_runtime", - any(feature = "tls", feature = "native-tls") -))] -pub use async_std_runtime::{ - connect_async_with_tls_connector, connect_async_with_tls_connector_and_config, -}; diff --git a/src/gio.rs b/src/gio.rs new file mode 100644 index 0000000..6bf6a27 --- /dev/null +++ b/src/gio.rs @@ -0,0 +1,137 @@ +//! `gio` integration. +use tungstenite::Error; + +use std::io; + +use gio::prelude::*; + +use futures::io::{AsyncRead, AsyncWrite}; + +use tungstenite::client::url_mode; +use tungstenite::stream::Mode; + +use crate::{ + client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, +}; + +type MaybeTlsStream = IOStreamAsyncReadWrite; + +/// Connect to a given URL. +pub async fn connect_async( + request: R, +) -> Result<(WebSocketStream, Response), Error> +where + R: Into> + Unpin, +{ + connect_async_with_config(request, None).await +} + +/// Connect to a given URL with a given WebSocket configuration. +pub async fn connect_async_with_config( + request: R, + config: Option, +) -> Result<(WebSocketStream, Response), Error> +where + R: Into> + Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let client = gio::SocketClient::new(); + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + if let Mode::Tls = mode { + client.set_tls(true); + } else { + client.set_tls(false); + } + + let connectable = gio::NetworkAddress::new(domain.as_str(), port); + + let socket = client + .connect_async_future(&connectable) + .await + .map_err(|err| to_std_io_error(err))?; + let socket = IOStreamAsyncReadWrite::new(socket) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?; + + client_async_with_config(request, socket, config).await +} + +/// Adapter for `gio::IOStream` to provide `AsyncRead` and `AsyncWrite`. +#[derive(Debug)] +pub struct IOStreamAsyncReadWrite> { + io_stream: T, + read: gio::InputStreamAsyncRead, + write: gio::OutputStreamAsyncWrite, +} + +impl> IOStreamAsyncReadWrite { + /// Create a new `gio::IOStream` adapter + pub fn new(stream: T) -> Result, T> { + let write = stream + .get_output_stream() + .and_then(|s| s.dynamic_cast::().ok()) + .and_then(|s| s.into_async_write().ok()); + + let read = stream + .get_input_stream() + .and_then(|s| s.dynamic_cast::().ok()) + .and_then(|s| s.into_async_read().ok()); + + let (read, write) = match (read, write) { + (Some(read), Some(write)) => (read, write), + _ => return Err(stream), + }; + + Ok(IOStreamAsyncReadWrite { + io_stream: stream, + read, + write, + }) + } +} + +use std::pin::Pin; +use std::task::{Context, Poll}; + +impl + Unpin> AsyncRead for IOStreamAsyncReadWrite { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).read).poll_read(cx, buf) + } +} + +impl + Unpin> AsyncWrite for IOStreamAsyncReadWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).write).poll_write(cx, buf) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut Pin::get_mut(self).write).poll_close(cx) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut Pin::get_mut(self).write).poll_flush(cx) + } +} + +fn to_std_io_error(error: glib::Error) -> io::Error { + match error.kind::() { + Some(io_error_enum) => io::Error::new(io_error_enum.into(), error), + None => io::Error::new(io::ErrorKind::Other, error), + } +} diff --git a/src/lib.rs b/src/lib.rs index 1f095d4..665805f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,28 @@ -//! Async WebSocket usage. +//! Async WebSockets. //! -//! This library is an implementation of WebSocket handshakes and streams. It -//! is based on the crate which implements all required WebSocket protocol -//! logic. So this crate basically just brings async_std support / async_std integration -//! to it. +//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite) +//! Rust WebSocket library and provides async bindings and wrappers for it, so you +//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it +//! together with other crates from the async stack. In addition, optional +//! integration with various other crates can be enabled via feature flags +//! +//! * `async-tls`: Enables the `async_tls` module, which provides integration +//! with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can +//! be used independent of any async runtime. +//! * `async-std-runtime`: Enables the `async_std` module, which provides +//! integration with the [async-std](https://async.rs) runtime. +//! * `async-native-tls`: Enables the additional functions in the `async_std` +//! module to implement TLS via +//! [async-native-tls](https://crates.io/crates/async-native-tls). +//! * `tokio-runtime`: Enables the `tokio` module, which provides integration +//! with the [tokio](https://tokio.rs) runtime. +//! * `tokio-tls`: Enables the additional functions in the `tokio` module to +//! implement TLS via [tokio-tls](https://crates.io/crates/tokio-tls). +//! * `gio-runtime`: Enables the `gio` module, which provides integration with +//! the [gio](https://www.gtk-rs.org) runtime. //! //! Each WebSocket stream implements the required `Stream` and `Sink` traits, -//! so the socket is just a stream of messages coming in and going out. +//! making the socket a stream of WebSocket messages coming in and going out. #![deny( missing_docs, @@ -19,11 +35,14 @@ pub use tungstenite; mod compat; -#[cfg(feature = "connect")] -mod connect; mod handshake; -#[cfg(feature = "stream")] -pub mod stream; + +#[cfg(any( + feature = "async-tls", + feature = "async-native-tls", + feature = "tokio-tls", +))] +mod stream; use std::io::{Read, Write}; @@ -44,20 +63,15 @@ use tungstenite::{ server, }; -#[cfg(feature = "connect")] -pub use connect::{client_async_tls, client_async_tls_with_config}; -#[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))] -pub use connect::{client_async_tls_with_connector, client_async_tls_with_connector_and_config}; -#[cfg(feature = "async_std_runtime")] -pub use connect::{connect_async, connect_async_with_config}; -#[cfg(all( - feature = "async_std_runtime", - any(feature = "tls", feature = "native-tls") -))] -pub use connect::{connect_async_with_tls_connector, connect_async_with_tls_connector_and_config}; +#[cfg(feature = "async-std-runtime")] +pub mod async_std; +#[cfg(feature = "async-tls")] +pub mod async_tls; +#[cfg(feature = "gio-runtime")] +pub mod gio; +#[cfg(feature = "tokio-runtime")] +pub mod tokio; -#[cfg(all(feature = "connect", feature = "tls-base"))] -pub use connect::MaybeTlsStream; use std::error::Error; use tungstenite::protocol::CloseFrame; @@ -324,33 +338,17 @@ where } } -#[cfg(test)] -mod tests { - use crate::compat::AllowStd; - #[cfg(feature = "connect")] - use crate::connect::encryption::AutoStream; - use crate::WebSocketStream; - use futures::io::{AsyncReadExt, AsyncWriteExt}; - use std::io::{Read, Write}; - - fn is_read() {} - fn is_write() {} - fn is_async_read() {} - fn is_async_write() {} - fn is_unpin() {} - - #[test] - fn web_socket_stream_has_traits() { - is_read::>(); - is_write::>(); - - #[cfg(feature = "connect")] - is_async_read::>(); - #[cfg(feature = "connect")] - is_async_write::>(); - - is_unpin::>(); - #[cfg(feature = "connect")] - is_unpin::>>(); +#[cfg(any( + feature = "async-tls", + feature = "async-std-runtime", + feature = "tokio-runtime", + feature = "gio-runtime" +))] +/// Get a domain from an URL. +#[inline] +pub(crate) fn domain(request: &Request) -> Result { + match request.url.host_str() { + Some(d) => Ok(d.to_string()), + None => Err(tungstenite::Error::Url("no host name in the url".into())), } } diff --git a/src/tokio.rs b/src/tokio.rs new file mode 100644 index 0000000..a6aa783 --- /dev/null +++ b/src/tokio.rs @@ -0,0 +1,363 @@ +//! `tokio` integration. +use tungstenite::handshake::client::Response; +use tungstenite::protocol::WebSocketConfig; +use tungstenite::Error; + +use tokio::net::TcpStream; + +use super::{domain, Request, WebSocketStream}; + +use futures::io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "tokio-tls")] +pub(crate) mod tokio_tls { + use real_tokio_tls::TlsConnector as AsyncTlsConnector; + use real_tokio_tls::TlsStream; + + use tungstenite::client::url_mode; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use futures::io::{AsyncRead, AsyncWrite}; + + use crate::stream::Stream as StreamSwitcher; + use crate::{ + client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, + }; + + use super::TokioAdapter; + + /// A stream that might be protected with TLS. + pub type MaybeTlsStream = StreamSwitcher>>>; + + pub type AutoStream = MaybeTlsStream; + + pub type Connector = AsyncTlsConnector; + + async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + let connector = real_native_tls::TlsConnector::builder().build()?; + AsyncTlsConnector::from(connector) + }; + connector + .connect(&domain, TokioAdapter(socket)) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? + }; + Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + } + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into(); + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))] +pub(crate) mod dummy_tls { + use futures::io::{AsyncRead, AsyncWrite}; + + use tungstenite::client::url_mode; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use crate::{ + client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, + }; + + pub type AutoStream = S; + type Connector = (); + + async fn wrap_stream( + socket: S, + _domain: String, + _connector: Option<()>, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(socket), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + } + } + + pub(crate) async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into(); + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))] +use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream}; + +#[cfg(all(feature = "async-tls", not(feature = "tokio-tls")))] +use crate::async_tls::{client_async_tls_with_connector_and_config, AutoStream}; +#[cfg(all(feature = "async-tls", not(feature = "tokio-tls")))] +type Connector = real_async_tls::TlsConnector; + +#[cfg(feature = "tokio-tls")] +use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector}; + +#[cfg(feature = "tokio-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub async fn client_async_tls( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +#[cfg(feature = "tokio-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await +} + +#[cfg(feature = "tokio-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + +/// Connect to a given URL. +pub async fn connect_async( + request: R, +) -> Result< + ( + WebSocketStream>>, + Response, + ), + Error, +> +where + R: Into> + Unpin, +{ + connect_async_with_config(request, None).await +} + +/// Connect to a given URL with a given WebSocket configuration. +pub async fn connect_async_with_config( + request: R, + config: Option, +) -> Result< + ( + WebSocketStream>>, + Response, + ), + Error, +> +where + R: Into> + Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, TokioAdapter(socket), None, config).await +} + +#[cfg(any(feature = "async-tls", feature = "tokio-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector( + request: R, + connector: Option, +) -> Result< + ( + WebSocketStream>>, + Response, + ), + Error, +> +where + R: Into> + Unpin, +{ + connect_async_with_tls_connector_and_config(request, connector, None).await +} + +#[cfg(any(feature = "async-tls", feature = "tokio-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector_and_config( + request: R, + connector: Option, + config: Option, +) -> Result< + ( + WebSocketStream>>, + Response, + ), + Error, +> +where + R: Into> + Unpin, +{ + let request: Request = request.into(); + + let domain = domain(&request)?; + let port = request + .url + .port_or_known_default() + .expect("Bug: port unknown"); + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, TokioAdapter(socket), connector, config) + .await +} + +use pin_project::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide +/// the variants from the `futures` crate and the other way around. +#[pin_project] +#[derive(Debug, Clone)] +pub struct TokioAdapter(#[pin] pub T); + +impl AsyncRead for TokioAdapter { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().0.poll_read(cx, buf) + } +} + +impl AsyncWrite for TokioAdapter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().0.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_shutdown(cx) + } +} + +impl tokio::io::AsyncRead for TokioAdapter { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().0.poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for TokioAdapter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().0.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().0.poll_close(cx) + } +}