Refactor features and optional API and add support for tokio/gio async runtimes

pull/5/head 0.3.0
Sebastian Dröge 5 years ago
parent 5613f9e47d
commit c2ff77b446
  1. 15
      .travis.yml
  2. 74
      Cargo.toml
  3. 37
      README.md
  4. 31
      examples/async-std-echo.rs
  5. 2
      examples/autobahn-client.rs
  6. 2
      examples/client.rs
  7. 28
      examples/gio-echo.rs
  8. 30
      examples/tokio-echo.rs
  9. 4
      scripts/autobahn-client.sh
  10. 4
      scripts/autobahn-server.sh
  11. 268
      src/async_std.rs
  12. 114
      src/async_tls.rs
  13. 263
      src/connect.rs
  14. 137
      src/gio.rs
  15. 100
      src/lib.rs
  16. 363
      src/tokio.rs

@ -9,9 +9,18 @@ before_script:
- sudo apt-get install libssl-dev - sudo apt-get install libssl-dev
script: script:
- cargo check --release --no-default-features - cargo check
- cargo test --release - cargo check --features async-tls
- cargo test --release --no-default-features --features=native-tls,connect,async_std_runtime - cargo check --features async-std-runtime,async-tls
- cargo check --features async-std-runtime,async-native-tls
- cargo check --features async-std-runtime,async-tls,async-native-tls
- cargo check --features tokio-runtime,async-tls
- cargo check --features tokio-runtime,tokio-tls
- cargo check --features tokio-runtime,async-tls,tokio-tls
- cargo check --features gio-runtime
- cargo check --features gio-runtime,async-tls
- cargo check --features async-std-runtime,async-tls,async-native-tls,tokio-runtime,tokio-tls,gio-runtime
- cargo test --features async-std-runtime
after_success: after_success:
- sudo apt-get install python-unittest2 - sudo apt-get install python-unittest2

@ -1,24 +1,28 @@
[package] [package]
name = "async-tungstenite" name = "async-tungstenite"
description = "async-std binding for Tungstenite, the Lightweight stream-based WebSocket implementation" description = "Async binding for Tungstenite, the Lightweight stream-based WebSocket implementation"
categories = ["web-programming::websocket", "network-programming", "asynchronous", "concurrency"] categories = ["web-programming::websocket", "network-programming", "asynchronous", "concurrency"]
keywords = ["websocket", "io", "web"] keywords = ["websocket", "io", "web", "tokio", "async-std"]
authors = ["Sebastian Dröge <sebastian@centricular.com>"] authors = ["Sebastian Dröge <sebastian@centricular.com>"]
license = "MIT" license = "MIT"
homepage = "https://github.com/sdroege/async-tungstenite" homepage = "https://github.com/sdroege/async-tungstenite"
repository = "https://github.com/sdroege/async-tungstenite" repository = "https://github.com/sdroege/async-tungstenite"
documentation = "https://docs.rs/async-tungstenite" documentation = "https://docs.rs/async-tungstenite"
version = "0.2.1" version = "0.3.0"
edition = "2018" edition = "2018"
readme = "README.md"
[features] [features]
default = ["connect", "tls", "async_std_runtime"] default = []
connect = ["stream"] async-std-runtime = ["async-std"]
async_std_runtime = ["connect", "async-std"] tokio-runtime = ["tokio"]
tls-base = ["stream"] gio-runtime = ["gio", "glib"]
tls = ["async-tls", "tls-base"] async-tls = ["real-async-tls"]
native-tls = ["async-native-tls", "real-native-tls", "tls-base", "tungstenite/tls"] async-native-tls = ["async-std-runtime", "real-async-native-tls"]
stream = [] tokio-tls = ["tokio-runtime", "real-tokio-tls", "real-native-tls", "tungstenite/tls"]
[package.metadata.docs.rs]
features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-tls"]
[dependencies] [dependencies]
log = "0.4" log = "0.4"
@ -33,20 +37,64 @@ default-features = false
optional = true optional = true
version = "1.0" version = "1.0"
[dependencies.async-tls] [dependencies.real-async-tls]
optional = true optional = true
version = "0.6.0" version = "0.6.0"
package = "async-tls"
[dependencies.async-native-tls] [dependencies.real-async-native-tls]
optional = true optional = true
version = "0.1.0" version = "0.3.0"
package = "async-native-tls"
[dependencies.real-native-tls] [dependencies.real-native-tls]
optional = true optional = true
version = "0.2" version = "0.2"
package = "native-tls" package = "native-tls"
[dependencies.tokio]
optional = true
version = "0.2"
features = ["tcp", "dns"]
[dependencies.real-tokio-tls]
optional = true
version = "0.3"
package = "tokio-tls"
[dependencies.gio]
optional = true
version = "0.8"
[dependencies.glib]
optional = true
version = "0.9"
[dev-dependencies] [dev-dependencies]
url = "2.0.0" url = "2.0.0"
env_logger = "0.7" env_logger = "0.7"
async-std = { version = "1.0", features = ["attributes"] } async-std = { version = "1.0", features = ["attributes"] }
[[example]]
name = "autobahn-client"
required-features = ["async-std-runtime"]
[[example]]
name = "client"
required-features = ["async-std-runtime"]
[[example]]
name = "autobahn-server"
required-features = ["async-std-runtime"]
[[example]]
name = "echo-server"
required-features = ["async-std-runtime"]
[[example]]
name = "gio-echo"
required-features = ["gio-runtime"]
[[example]]
name = "tokio-echo"
required-features = ["tokio-runtime"]

@ -1,6 +1,8 @@
# async-tungstenite # async-tungstenite
Asynchronous WebSockets for [async-std](https://async.rs) and `std` `Future`s. Asynchronous WebSockets for [async-std](https://async.rs),
[tokio](https://tokio.rs), [gio](https://www.gtk-rs.org) and any `std`
`Future`s runtime.
[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE)
[![Crates.io](https://img.shields.io/crates/v/async-tungstenite.svg?maxAge=2592000)](https://crates.io/crates/async-tungstenite) [![Crates.io](https://img.shields.io/crates/v/async-tungstenite.svg?maxAge=2592000)](https://crates.io/crates/async-tungstenite)
@ -17,15 +19,36 @@ Add this in your `Cargo.toml`:
async-tungstenite = "*" async-tungstenite = "*"
``` ```
Take a look at the `examples/` directory for client and server examples. You may also want to get familiar with Take a look at the `examples/` directory for client and server examples. You
[`async-std`](https://async.rs/) if you don't have any experience with it. may also want to get familiar with [async-std](https://async.rs/) or
[tokio](https://tokio.rs) if you don't have any experience with it.
## What is async-tungstenite? ## What is async-tungstenite?
This crate is based on `tungstenite-rs` Rust WebSocket library and provides async-std bindings and wrappers for it, so you This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
can use it with non-blocking/asynchronous `TcpStream`s from and couple it together with other crates from the async-std stack. Rust WebSocket library and provides async bindings and wrappers for it, so you
can use it with non-blocking/asynchronous `TcpStream`s from and couple it
together with other crates from the async stack. In addition, optional
integration with various other crates can be enabled via feature flags
* `async-tls`: Enables the `async_tls` module, which provides integration
with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
be used independent of any async runtime.
* `async-std-runtime`: Enables the `async_std` module, which provides
integration with the [async-std](https://async.rs) runtime.
* `async-native-tls`: Enables the additional functions in the `async_std`
module to implement TLS via
[async-native-tls](https://crates.io/crates/async-native-tls).
* `tokio-runtime`: Enables the `tokio` module, which provides integration
with the [tokio](https://tokio.rs) runtime.
* `tokio-tls`: Enables the additional functions in the `tokio` module to
implement TLS via [tokio-tls](https://crates.io/crates/tokio-tls).
* `gio-runtime`: Enables the `gio` module, which provides integration with
the [gio](https://www.gtk-rs.org) runtime.
## tokio-tungstenite ## tokio-tungstenite
Originally this crate was created as a fork of [tokio-tungstenite](https://github.com/snapview/tokio-tungstenite) Originally this crate was created as a fork of
and ported to [async-std](https://async.rs). [tokio-tungstenite](https://github.com/snapview/tokio-tungstenite) and ported
to the traits of the [`futures`](https://crates.io/crates/futures) crate.
Integration into async-std, tokio and gio was added on top of that.

@ -0,0 +1,31 @@
use async_tungstenite::{async_std::connect_async, tungstenite::Message};
use futures::prelude::*;
use async_std::task;
async fn run() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
let url = url::Url::parse("wss://echo.websocket.org").unwrap();
#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
let url = url::Url::parse("ws://echo.websocket.org").unwrap();
let (mut ws_stream, _) = connect_async(url).await?;
let text = "Hello, World!";
println!("Sending: \"{}\"", text);
ws_stream.send(Message::text(text)).await?;
let msg = ws_stream
.next()
.await
.ok_or_else(|| "didn't receive anything")??;
println!("Received: {:?}", msg);
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
task::block_on(run())
}

@ -1,4 +1,4 @@
use async_tungstenite::{connect_async, tungstenite::Result}; use async_tungstenite::{async_std::connect_async, tungstenite::Result};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use log::*; use log::*;
use url::Url; use url::Url;

@ -19,7 +19,7 @@ use tungstenite::protocol::Message;
use async_std::io; use async_std::io;
use async_std::prelude::*; use async_std::prelude::*;
use async_std::task; use async_std::task;
use async_tungstenite::connect_async; use async_tungstenite::async_std::connect_async;
async fn run() { async fn run() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();

@ -0,0 +1,28 @@
use async_tungstenite::{gio::connect_async, tungstenite::Message};
use futures::prelude::*;
async fn run() -> Result<(), Box<dyn std::error::Error>> {
let url = url::Url::parse("wss://echo.websocket.org").unwrap();
let (mut ws_stream, _) = connect_async(url).await?;
let text = "Hello, World!";
println!("Sending: \"{}\"", text);
ws_stream.send(Message::text(text)).await?;
let msg = ws_stream
.next()
.await
.ok_or_else(|| "didn't receive anything")??;
println!("Received: {:?}", msg);
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Get the default main context and run our async function on it
let main_context = glib::MainContext::default();
main_context.block_on(run())
}

@ -0,0 +1,30 @@
use async_tungstenite::{tokio::connect_async, tungstenite::Message};
use futures::prelude::*;
async fn run() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(any(feature = "async-tls", feature = "tokio-tls"))]
let url = url::Url::parse("wss://echo.websocket.org").unwrap();
#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))]
let url = url::Url::parse("ws://echo.websocket.org").unwrap();
let (mut ws_stream, _) = connect_async(url).await?;
let text = "Hello, World!";
println!("Sending: \"{}\"", text);
ws_stream.send(Message::text(text)).await?;
let msg = ws_stream
.next()
.await
.ok_or_else(|| "didn't receive anything")??;
println!("Received: {:?}", msg);
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut rt = tokio::runtime::Runtime::new()?;
rt.block_on(run())
}

@ -23,10 +23,10 @@ function test_diff() {
fi fi
} }
cargo build --release --example autobahn-client cargo build --release --features async-std-runtime --example autobahn-client
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' & FUZZINGSERVER_PID=$! wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' & FUZZINGSERVER_PID=$!
sleep 3 sleep 3
echo "Server PID: ${FUZZINGSERVER_PID}" echo "Server PID: ${FUZZINGSERVER_PID}"
cargo run --release --example autobahn-client cargo run --release --features async-std-runtime --example autobahn-client
test_diff test_diff

@ -23,8 +23,8 @@ function test_diff() {
fi fi
} }
cargo build --release --example autobahn-server cargo build --release --features async-std-runtime --example autobahn-server
cargo run --release --example autobahn-server & WSSERVER_PID=$! cargo run --release --features async-std-runtime --example autobahn-server & WSSERVER_PID=$!
echo "Server PID: ${WSSERVER_PID}" echo "Server PID: ${WSSERVER_PID}"
sleep 3 sleep 3
wstest -m fuzzingclient -s 'autobahn/fuzzingclient.json' wstest -m fuzzingclient -s 'autobahn/fuzzingclient.json'

@ -0,0 +1,268 @@
//! `async-std` integration.
use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error;
use async_std::net::TcpStream;
use super::{domain, Request, WebSocketStream};
#[cfg(feature = "async-native-tls")]
use futures::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "async-native-tls")]
pub(crate) mod async_native_tls {
use async_native_tls::TlsConnector as AsyncTlsConnector;
use async_native_tls::TlsStream;
use real_async_native_tls as async_native_tls;
use tungstenite::client::url_mode;
use tungstenite::stream::Mode;
use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite};
use crate::stream::Stream as StreamSwitcher;
use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
/// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub type AutoStream<S> = MaybeTlsStream<S>;
pub type Connector = AsyncTlsConnector;
async fn wrap_stream<S>(
socket: S,
domain: String,
connector: Option<Connector>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Unpin,
{
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => {
let stream = {
let connector = if let Some(connector) = connector {
connector
} else {
AsyncTlsConnector::new()
};
connector
.connect(&domain, socket)
.await
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?
};
Ok(StreamSwitcher::Tls(stream))
}
}
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector and WebSocket configuration.
pub async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<AsyncTlsConnector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
}
#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
pub(crate) mod dummy_tls {
use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode;
use tungstenite::stream::Mode;
use tungstenite::Error;
use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
pub type AutoStream<S> = S;
type Connector = ();
async fn wrap_stream<S>(
socket: S,
_domain: String,
_connector: Option<()>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Unpin,
{
match mode {
Mode::Plain => Ok(socket),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
pub(crate) async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
}
#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream};
#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
use crate::async_tls::{client_async_tls_with_connector_and_config, AutoStream};
#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
type Connector = real_async_tls::TlsConnector;
#[cfg(feature = "async-native-tls")]
use self::async_native_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector};
#[cfg(feature = "async-native-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, None).await
}
#[cfg(feature = "async-native-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// WebSocket configuration.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, config).await
}
#[cfg(feature = "async-native-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector.
pub async fn client_async_tls_with_connector<R, S>(
request: R,
stream: S,
connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, connector, None).await
}
/// Connect to a given URL.
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_config(request, None).await
}
/// Connect to a given URL with a given WebSocket configuration.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector_and_config(request, socket, None, config).await
}
#[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector<R>(
request: R,
connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_tls_connector_and_config(request, connector, None).await
}
#[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector_and_config<R>(
request: R,
connector: Option<Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector_and_config(request, socket, connector, config).await
}

@ -0,0 +1,114 @@
//! `async-tls` integration.
use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite};
use super::{client_async_with_config, Request, WebSocketStream};
use async_tls::client::TlsStream;
use async_tls::TlsConnector as AsyncTlsConnector;
use real_async_tls as async_tls;
use tungstenite::stream::Mode;
use crate::domain;
use crate::stream::Stream as StreamSwitcher;
type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub(crate) type AutoStream<S> = MaybeTlsStream<S>;
async fn wrap_stream<S>(
socket: S,
domain: String,
connector: Option<AsyncTlsConnector>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Unpin,
{
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => {
let stream = {
let connector = connector.unwrap_or_else(AsyncTlsConnector::new);
connector.connect(&domain, socket)?.await?
};
Ok(StreamSwitcher::Tls(stream))
}
}
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// WebSocket configuration.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, config).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector.
pub async fn client_async_tls_with_connector<R, S>(
request: R,
stream: S,
connector: Option<AsyncTlsConnector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, connector, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector and WebSocket configuration.
pub async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<AsyncTlsConnector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}

@ -1,263 +0,0 @@
//! Connection helper.
use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite};
use super::{client_async_with_config, Request, WebSocketStream};
#[cfg(feature = "tls-base")]
pub(crate) mod encryption {
#[cfg(feature = "tls")]
use async_tls::client::TlsStream;
#[cfg(feature = "tls")]
use async_tls::TlsConnector as AsyncTlsConnector;
#[cfg(feature = "native-tls")]
use async_native_tls::TlsConnector as AsyncTlsConnector;
#[cfg(feature = "native-tls")]
use async_native_tls::TlsStream;
use tungstenite::stream::Mode;
use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite};
use crate::stream::Stream as StreamSwitcher;
/// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub type AutoStream<S> = MaybeTlsStream<S>;
#[cfg(feature = "tls")]
pub type Connector = async_tls::TlsConnector;
#[cfg(feature = "native-tls")]
pub type Connector = real_native_tls::TlsConnector;
pub async fn wrap_stream<S>(
socket: S,
domain: String,
connector: Option<Connector>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => {
#[cfg(feature = "tls")]
let stream = {
let connector = connector.unwrap_or_else(AsyncTlsConnector::new);
connector.connect(&domain, socket)?.await?
};
#[cfg(feature = "native-tls")]
let stream = {
let connector = if let Some(connector) = connector {
connector
} else {
let builder = real_native_tls::TlsConnector::builder();
builder.build()?
};
let connector = AsyncTlsConnector::from(connector);
connector.connect(&domain, socket).await?
};
Ok(StreamSwitcher::Tls(stream))
}
}
}
}
#[cfg(feature = "tls-base")]
pub use self::encryption::MaybeTlsStream;
#[cfg(not(feature = "tls-base"))]
pub(crate) mod encryption {
use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::stream::Mode;
use tungstenite::Error;
pub type AutoStream<S> = S;
pub type Connector = ();
pub(crate) async fn wrap_stream<S>(
socket: S,
_domain: String,
_connector: Option<()>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Ok(socket),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
}
use self::encryption::AutoStream;
/// Get a domain from an URL.
#[inline]
fn domain(request: &Request) -> Result<String, Error> {
match request.url.host_str() {
Some(d) => Ok(d.to_string()),
None => Err(Error::Url("no host name in the url".into())),
}
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// WebSocket configuration.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, config).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector.
pub async fn client_async_tls_with_connector<R, S>(
request: R,
stream: S,
connector: Option<self::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, connector, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector and WebSocket configuration.
pub async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<self::encryption::Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
#[cfg(feature = "async_std_runtime")]
pub(crate) mod async_std_runtime {
use super::*;
use async_std::net::TcpStream;
/// Connect to a given URL.
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_config(request, None).await
}
/// Connect to a given URL with a given WebSocket configuration.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_config(request, socket, config).await
}
#[cfg(any(feature = "tls", feature = "native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector<R>(
request: R,
connector: Option<super::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_tls_connector_and_config(request, connector, None).await
}
#[cfg(any(feature = "tls", feature = "native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector_and_config<R>(
request: R,
connector: Option<super::encryption::Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector_and_config(request, socket, connector, config).await
}
}
#[cfg(feature = "async_std_runtime")]
pub use async_std_runtime::{connect_async, connect_async_with_config};
#[cfg(all(
feature = "async_std_runtime",
any(feature = "tls", feature = "native-tls")
))]
pub use async_std_runtime::{
connect_async_with_tls_connector, connect_async_with_tls_connector_and_config,
};

@ -0,0 +1,137 @@
//! `gio` integration.
use tungstenite::Error;
use std::io;
use gio::prelude::*;
use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode;
use tungstenite::stream::Mode;
use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
type MaybeTlsStream = IOStreamAsyncReadWrite<gio::SocketConnection>;
/// Connect to a given URL.
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_config(request, None).await
}
/// Connect to a given URL with a given WebSocket configuration.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let client = gio::SocketClient::new();
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
if let Mode::Tls = mode {
client.set_tls(true);
} else {
client.set_tls(false);
}
let connectable = gio::NetworkAddress::new(domain.as_str(), port);
let socket = client
.connect_async_future(&connectable)
.await
.map_err(|err| to_std_io_error(err))?;
let socket = IOStreamAsyncReadWrite::new(socket)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
client_async_with_config(request, socket, config).await
}
/// Adapter for `gio::IOStream` to provide `AsyncRead` and `AsyncWrite`.
#[derive(Debug)]
pub struct IOStreamAsyncReadWrite<T: IsA<gio::IOStream>> {
io_stream: T,
read: gio::InputStreamAsyncRead<gio::PollableInputStream>,
write: gio::OutputStreamAsyncWrite<gio::PollableOutputStream>,
}
impl<T: IsA<gio::IOStream>> IOStreamAsyncReadWrite<T> {
/// Create a new `gio::IOStream` adapter
pub fn new(stream: T) -> Result<IOStreamAsyncReadWrite<T>, T> {
let write = stream
.get_output_stream()
.and_then(|s| s.dynamic_cast::<gio::PollableOutputStream>().ok())
.and_then(|s| s.into_async_write().ok());
let read = stream
.get_input_stream()
.and_then(|s| s.dynamic_cast::<gio::PollableInputStream>().ok())
.and_then(|s| s.into_async_read().ok());
let (read, write) = match (read, write) {
(Some(read), Some(write)) => (read, write),
_ => return Err(stream),
};
Ok(IOStreamAsyncReadWrite {
io_stream: stream,
read,
write,
})
}
}
use std::pin::Pin;
use std::task::{Context, Poll};
impl<T: IsA<gio::IOStream> + Unpin> AsyncRead for IOStreamAsyncReadWrite<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut Pin::get_mut(self).read).poll_read(cx, buf)
}
}
impl<T: IsA<gio::IOStream> + Unpin> AsyncWrite for IOStreamAsyncReadWrite<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_write(cx, buf)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_close(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_flush(cx)
}
}
fn to_std_io_error(error: glib::Error) -> io::Error {
match error.kind::<gio::IOErrorEnum>() {
Some(io_error_enum) => io::Error::new(io_error_enum.into(), error),
None => io::Error::new(io::ErrorKind::Other, error),
}
}

@ -1,12 +1,28 @@
//! Async WebSocket usage. //! Async WebSockets.
//! //!
//! This library is an implementation of WebSocket handshakes and streams. It //! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
//! is based on the crate which implements all required WebSocket protocol //! Rust WebSocket library and provides async bindings and wrappers for it, so you
//! logic. So this crate basically just brings async_std support / async_std integration //! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
//! to it. //! together with other crates from the async stack. In addition, optional
//! integration with various other crates can be enabled via feature flags
//!
//! * `async-tls`: Enables the `async_tls` module, which provides integration
//! with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
//! be used independent of any async runtime.
//! * `async-std-runtime`: Enables the `async_std` module, which provides
//! integration with the [async-std](https://async.rs) runtime.
//! * `async-native-tls`: Enables the additional functions in the `async_std`
//! module to implement TLS via
//! [async-native-tls](https://crates.io/crates/async-native-tls).
//! * `tokio-runtime`: Enables the `tokio` module, which provides integration
//! with the [tokio](https://tokio.rs) runtime.
//! * `tokio-tls`: Enables the additional functions in the `tokio` module to
//! implement TLS via [tokio-tls](https://crates.io/crates/tokio-tls).
//! * `gio-runtime`: Enables the `gio` module, which provides integration with
//! the [gio](https://www.gtk-rs.org) runtime.
//! //!
//! Each WebSocket stream implements the required `Stream` and `Sink` traits, //! Each WebSocket stream implements the required `Stream` and `Sink` traits,
//! so the socket is just a stream of messages coming in and going out. //! making the socket a stream of WebSocket messages coming in and going out.
#![deny( #![deny(
missing_docs, missing_docs,
@ -19,11 +35,14 @@
pub use tungstenite; pub use tungstenite;
mod compat; mod compat;
#[cfg(feature = "connect")]
mod connect;
mod handshake; mod handshake;
#[cfg(feature = "stream")]
pub mod stream; #[cfg(any(
feature = "async-tls",
feature = "async-native-tls",
feature = "tokio-tls",
))]
mod stream;
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -44,20 +63,15 @@ use tungstenite::{
server, server,
}; };
#[cfg(feature = "connect")] #[cfg(feature = "async-std-runtime")]
pub use connect::{client_async_tls, client_async_tls_with_config}; pub mod async_std;
#[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))] #[cfg(feature = "async-tls")]
pub use connect::{client_async_tls_with_connector, client_async_tls_with_connector_and_config}; pub mod async_tls;
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "gio-runtime")]
pub use connect::{connect_async, connect_async_with_config}; pub mod gio;
#[cfg(all( #[cfg(feature = "tokio-runtime")]
feature = "async_std_runtime", pub mod tokio;
any(feature = "tls", feature = "native-tls")
))]
pub use connect::{connect_async_with_tls_connector, connect_async_with_tls_connector_and_config};
#[cfg(all(feature = "connect", feature = "tls-base"))]
pub use connect::MaybeTlsStream;
use std::error::Error; use std::error::Error;
use tungstenite::protocol::CloseFrame; use tungstenite::protocol::CloseFrame;
@ -324,33 +338,17 @@ where
} }
} }
#[cfg(test)] #[cfg(any(
mod tests { feature = "async-tls",
use crate::compat::AllowStd; feature = "async-std-runtime",
#[cfg(feature = "connect")] feature = "tokio-runtime",
use crate::connect::encryption::AutoStream; feature = "gio-runtime"
use crate::WebSocketStream; ))]
use futures::io::{AsyncReadExt, AsyncWriteExt}; /// Get a domain from an URL.
use std::io::{Read, Write}; #[inline]
pub(crate) fn domain(request: &Request) -> Result<String, tungstenite::Error> {
fn is_read<T: Read>() {} match request.url.host_str() {
fn is_write<T: Write>() {} Some(d) => Ok(d.to_string()),
fn is_async_read<T: AsyncReadExt>() {} None => Err(tungstenite::Error::Url("no host name in the url".into())),
fn is_async_write<T: AsyncWriteExt>() {}
fn is_unpin<T: Unpin>() {}
#[test]
fn web_socket_stream_has_traits() {
is_read::<AllowStd<async_std::net::TcpStream>>();
is_write::<AllowStd<async_std::net::TcpStream>>();
#[cfg(feature = "connect")]
is_async_read::<AutoStream<async_std::net::TcpStream>>();
#[cfg(feature = "connect")]
is_async_write::<AutoStream<async_std::net::TcpStream>>();
is_unpin::<WebSocketStream<async_std::net::TcpStream>>();
#[cfg(feature = "connect")]
is_unpin::<WebSocketStream<AutoStream<async_std::net::TcpStream>>>();
} }
} }

@ -0,0 +1,363 @@
//! `tokio` integration.
use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error;
use tokio::net::TcpStream;
use super::{domain, Request, WebSocketStream};
use futures::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tokio-tls")]
pub(crate) mod tokio_tls {
use real_tokio_tls::TlsConnector as AsyncTlsConnector;
use real_tokio_tls::TlsStream;
use tungstenite::client::url_mode;
use tungstenite::stream::Mode;
use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite};
use crate::stream::Stream as StreamSwitcher;
use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
use super::TokioAdapter;
/// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TokioAdapter<TlsStream<TokioAdapter<S>>>>;
pub type AutoStream<S> = MaybeTlsStream<S>;
pub type Connector = AsyncTlsConnector;
async fn wrap_stream<S>(
socket: S,
domain: String,
connector: Option<Connector>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Unpin,
{
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => {
let stream = {
let connector = if let Some(connector) = connector {
connector
} else {
let connector = real_native_tls::TlsConnector::builder().build()?;
AsyncTlsConnector::from(connector)
};
connector
.connect(&domain, TokioAdapter(socket))
.await
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?
};
Ok(StreamSwitcher::Tls(TokioAdapter(stream)))
}
}
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector and WebSocket configuration.
pub async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<AsyncTlsConnector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
}
#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))]
pub(crate) mod dummy_tls {
use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode;
use tungstenite::stream::Mode;
use tungstenite::Error;
use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
pub type AutoStream<S> = S;
type Connector = ();
async fn wrap_stream<S>(
socket: S,
_domain: String,
_connector: Option<()>,
mode: Mode,
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Unpin,
{
match mode {
Mode::Plain => Ok(socket),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
pub(crate) async fn client_async_tls_with_connector_and_config<R, S>(
request: R,
stream: S,
connector: Option<Connector>,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
}
#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))]
use self::dummy_tls::{client_async_tls_with_connector_and_config, AutoStream};
#[cfg(all(feature = "async-tls", not(feature = "tokio-tls")))]
use crate::async_tls::{client_async_tls_with_connector_and_config, AutoStream};
#[cfg(all(feature = "async-tls", not(feature = "tokio-tls")))]
type Connector = real_async_tls::TlsConnector;
#[cfg(feature = "tokio-tls")]
use self::tokio_tls::{client_async_tls_with_connector_and_config, AutoStream, Connector};
#[cfg(feature = "tokio-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, None).await
}
#[cfg(feature = "tokio-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// WebSocket configuration.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, None, config).await
}
#[cfg(feature = "tokio-tls")]
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector.
pub async fn client_async_tls_with_connector<R, S>(
request: R,
stream: S,
connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector_and_config(request, stream, connector, None).await
}
/// Connect to a given URL.
pub async fn connect_async<R>(
request: R,
) -> Result<
(
WebSocketStream<AutoStream<TokioAdapter<TcpStream>>>,
Response,
),
Error,
>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_config(request, None).await
}
/// Connect to a given URL with a given WebSocket configuration.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<
(
WebSocketStream<AutoStream<TokioAdapter<TcpStream>>>,
Response,
),
Error,
>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector_and_config(request, TokioAdapter(socket), None, config).await
}
#[cfg(any(feature = "async-tls", feature = "tokio-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector<R>(
request: R,
connector: Option<Connector>,
) -> Result<
(
WebSocketStream<AutoStream<TokioAdapter<TcpStream>>>,
Response,
),
Error,
>
where
R: Into<Request<'static>> + Unpin,
{
connect_async_with_tls_connector_and_config(request, connector, None).await
}
#[cfg(any(feature = "async-tls", feature = "tokio-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector_and_config<R>(
request: R,
connector: Option<Connector>,
config: Option<WebSocketConfig>,
) -> Result<
(
WebSocketStream<AutoStream<TokioAdapter<TcpStream>>>,
Response,
),
Error,
>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector_and_config(request, TokioAdapter(socket), connector, config)
.await
}
use pin_project::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
/// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide
/// the variants from the `futures` crate and the other way around.
#[pin_project]
#[derive(Debug, Clone)]
pub struct TokioAdapter<T>(#[pin] pub T);
impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().0.poll_read(cx, buf)
}
}
impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.project().0.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.project().0.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.project().0.poll_shutdown(cx)
}
}
impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.project().0.poll_read(cx, buf)
}
}
impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.project().0.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.project().0.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().0.poll_close(cx)
}
}
Loading…
Cancel
Save