Update to tungstenite 0.10

Partially based on tungstenite commits
- 46dfd9ed3ee75b0261e9f5f71c8e70492407248b by Alexey Galakhov
- 31010fd636b3edc683199e3182ea34d799118d5b by Alexey Galakhov
pull/12/head 0.4.0
Sebastian Dröge 5 years ago committed by Sebastian Dröge
parent 016f21f5b4
commit 3a994e6e3b
  1. 4
      Cargo.toml
  2. 4
      examples/async-std-echo.rs
  3. 17
      examples/autobahn-client.rs
  4. 10
      examples/client.rs
  5. 2
      examples/gio-echo.rs
  6. 4
      examples/tokio-echo.rs
  7. 58
      src/async_std.rs
  8. 18
      src/async_tls.rs
  9. 20
      src/gio.rs
  10. 38
      src/lib.rs
  11. 59
      src/tokio.rs
  12. 2
      tests/communication.rs

@ -8,7 +8,7 @@ license = "MIT"
homepage = "https://github.com/sdroege/async-tungstenite" homepage = "https://github.com/sdroege/async-tungstenite"
repository = "https://github.com/sdroege/async-tungstenite" repository = "https://github.com/sdroege/async-tungstenite"
documentation = "https://docs.rs/async-tungstenite" documentation = "https://docs.rs/async-tungstenite"
version = "0.3.1" version = "0.4.0"
edition = "2018" edition = "2018"
readme = "README.md" readme = "README.md"
@ -30,7 +30,7 @@ futures = "0.3"
pin-project = "0.4" pin-project = "0.4"
[dependencies.tungstenite] [dependencies.tungstenite]
version = "0.9.2" version = "0.10.0"
default-features = false default-features = false
[dependencies.async-std] [dependencies.async-std]

@ -5,9 +5,9 @@ use async_std::task;
async fn run() -> Result<(), Box<dyn std::error::Error>> { async fn run() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(any(feature = "async-tls", feature = "async-native-tls"))] #[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
let url = url::Url::parse("wss://echo.websocket.org").unwrap(); let url = "wss://echo.websocket.org";
#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
let url = url::Url::parse("ws://echo.websocket.org").unwrap(); let url = "ws://echo.websocket.org";
let (mut ws_stream, _) = connect_async(url).await?; let (mut ws_stream, _) = connect_async(url).await?;

@ -1,15 +1,11 @@
use async_tungstenite::{async_std::connect_async, tungstenite::Error, tungstenite::Result}; use async_tungstenite::{async_std::connect_async, tungstenite::Error, tungstenite::Result};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use log::*; use log::*;
use url::Url;
const AGENT: &str = "Tungstenite"; const AGENT: &str = "Tungstenite";
async fn get_case_count() -> Result<u32> { async fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect_async( let (mut socket, _) = connect_async("ws://localhost:9001/getCaseCount").await?;
Url::parse("ws://localhost:9001/getCaseCount").expect("Can't connect to case count URL"),
)
.await?;
let msg = socket.next().await.expect("Can't fetch case count")?; let msg = socket.next().await.expect("Can't fetch case count")?;
socket.close(None).await?; socket.close(None).await?;
Ok(msg Ok(msg
@ -19,13 +15,10 @@ async fn get_case_count() -> Result<u32> {
} }
async fn update_reports() -> Result<()> { async fn update_reports() -> Result<()> {
let (mut socket, _) = connect_async( let (mut socket, _) = connect_async(&format!(
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}", "ws://localhost:9001/updateReports?agent={}",
AGENT AGENT
)) ))
.expect("Can't update reports"),
)
.await?; .await?;
socket.close(None).await?; socket.close(None).await?;
Ok(()) Ok(())
@ -33,11 +26,7 @@ async fn update_reports() -> Result<()> {
async fn run_test(case: u32) -> Result<()> { async fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = Url::parse(&format!( let case_url = &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT);
"ws://localhost:9001/runCase?case={}&agent={}",
case, AGENT
))
.expect("Bad testcase URL");
let (mut ws_stream, _) = connect_async(case_url).await?; let (mut ws_stream, _) = connect_async(case_url).await?;
while let Some(msg) = ws_stream.next().await { while let Some(msg) = ws_stream.next().await {

@ -25,23 +25,21 @@ async fn run() {
.nth(1) .nth(1)
.unwrap_or_else(|| panic!("this program requires at least one argument")); .unwrap_or_else(|| panic!("this program requires at least one argument"));
let url = url::Url::parse(&connect_addr).unwrap();
let (stdin_tx, stdin_rx) = futures::channel::mpsc::unbounded(); let (stdin_tx, stdin_rx) = futures::channel::mpsc::unbounded();
task::spawn(read_stdin(stdin_tx)); task::spawn(read_stdin(stdin_tx));
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect"); let (ws_stream, _) = connect_async(&connect_addr)
.await
.expect("Failed to connect");
println!("WebSocket handshake has been successfully completed"); println!("WebSocket handshake has been successfully completed");
let (write, read) = ws_stream.split(); let (write, read) = ws_stream.split();
let stdin_to_ws = stdin_rx.map(Ok).forward(write); let stdin_to_ws = stdin_rx.map(Ok).forward(write);
let ws_to_stdout = { let ws_to_stdout = {
read.for_each(|message| { read.for_each(|message| async {
async {
let data = message.unwrap().into_data(); let data = message.unwrap().into_data();
async_std::io::stdout().write_all(&data).await.unwrap(); async_std::io::stdout().write_all(&data).await.unwrap();
}
}) })
}; };

@ -2,7 +2,7 @@ use async_tungstenite::{gio::connect_async, tungstenite::Message};
use futures::prelude::*; use futures::prelude::*;
async fn run() -> Result<(), Box<dyn std::error::Error>> { async fn run() -> Result<(), Box<dyn std::error::Error>> {
let url = url::Url::parse("wss://echo.websocket.org").unwrap(); let url = "wss://echo.websocket.org";
let (mut ws_stream, _) = connect_async(url).await?; let (mut ws_stream, _) = connect_async(url).await?;

@ -3,9 +3,9 @@ use futures::prelude::*;
async fn run() -> Result<(), Box<dyn std::error::Error>> { async fn run() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(any(feature = "async-tls", feature = "tokio-tls"))] #[cfg(any(feature = "async-tls", feature = "tokio-tls"))]
let url = url::Url::parse("wss://echo.websocket.org").unwrap(); let url = "wss://echo.websocket.org";
#[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))] #[cfg(not(any(feature = "async-tls", feature = "tokio-tls")))]
let url = url::Url::parse("ws://echo.websocket.org").unwrap(); let url = "ws://echo.websocket.org";
let (mut ws_stream, _) = connect_async(url).await?; let (mut ws_stream, _) = connect_async(url).await?;

@ -1,11 +1,12 @@
//! `async-std` integration. //! `async-std` integration.
use tungstenite::handshake::client::Response; use tungstenite::client::IntoClientRequest;
use tungstenite::handshake::client::{Request, Response};
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error; use tungstenite::Error;
use async_std::net::TcpStream; use async_std::net::TcpStream;
use super::{domain, Request, WebSocketStream}; use super::{domain, port, WebSocketStream};
#[cfg(feature = "async-native-tls")] #[cfg(feature = "async-native-tls")]
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
@ -16,7 +17,8 @@ pub(crate) mod async_native_tls {
use async_native_tls::TlsStream; use async_native_tls::TlsStream;
use real_async_native_tls as async_native_tls; use real_async_native_tls as async_native_tls;
use tungstenite::client::url_mode; use tungstenite::client::uri_mode;
use tungstenite::handshake::client::Request;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
@ -24,7 +26,8 @@ pub(crate) mod async_native_tls {
use crate::stream::Stream as StreamSwitcher; use crate::stream::Stream as StreamSwitcher;
use crate::{ use crate::{
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream, client_async_with_config, domain, IntoClientRequest, Response, WebSocketConfig,
WebSocketStream,
}; };
/// A stream that might be protected with TLS. /// A stream that might be protected with TLS.
@ -72,16 +75,16 @@ pub(crate) mod async_native_tls {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?; let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await client_async_with_config(request, stream, config).await
@ -92,13 +95,12 @@ pub(crate) mod async_native_tls {
pub(crate) mod dummy_tls { pub(crate) mod dummy_tls {
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode; use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Request;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
use crate::{ use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
pub type AutoStream<S> = S; pub type AutoStream<S> = S;
type Connector = (); type Connector = ();
@ -125,16 +127,16 @@ pub(crate) mod dummy_tls {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?; let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await client_async_with_config(request, stream, config).await
@ -160,7 +162,7 @@ pub async fn client_async_tls<R, S>(
stream: S, stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -177,7 +179,7 @@ pub async fn client_async_tls_with_config<R, S>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -194,7 +196,7 @@ pub async fn client_async_tls_with_connector<R, S>(
connector: Option<Connector>, connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -206,7 +208,7 @@ pub async fn connect_async<R>(
request: R, request: R,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
connect_async_with_config(request, None).await connect_async_with_config(request, None).await
} }
@ -217,15 +219,12 @@ pub async fn connect_async_with_config<R>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
let port = request let port = port(&request)?;
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;
@ -239,7 +238,7 @@ pub async fn connect_async_with_tls_connector<R>(
connector: Option<Connector>, connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
connect_async_with_tls_connector_and_config(request, connector, None).await connect_async_with_tls_connector_and_config(request, connector, None).await
} }
@ -252,15 +251,12 @@ pub async fn connect_async_with_tls_connector_and_config<R>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
let port = request let port = port(&request)?;
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;

@ -1,12 +1,12 @@
//! `async-tls` integration. //! `async-tls` integration.
use tungstenite::client::url_mode; use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Response; use tungstenite::handshake::client::{Request, Response};
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error; use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use super::{client_async_with_config, Request, WebSocketStream}; use super::{client_async_with_config, WebSocketStream};
use async_tls::client::TlsStream; use async_tls::client::TlsStream;
use async_tls::TlsConnector as AsyncTlsConnector; use async_tls::TlsConnector as AsyncTlsConnector;
@ -49,7 +49,7 @@ pub async fn client_async_tls<R, S>(
stream: S, stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -65,7 +65,7 @@ pub async fn client_async_tls_with_config<R, S>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -81,7 +81,7 @@ pub async fn client_async_tls_with_connector<R, S>(
connector: Option<AsyncTlsConnector>, connector: Option<AsyncTlsConnector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -98,16 +98,16 @@ pub async fn client_async_tls_with_connector_and_config<R, S>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?; let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await client_async_with_config(request, stream, config).await

@ -7,12 +7,11 @@ use gio::prelude::*;
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode; use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Request;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use crate::{ use crate::{client_async_with_config, domain, port, Response, WebSocketConfig, WebSocketStream};
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
type MaybeTlsStream = IOStreamAsyncReadWrite<gio::SocketConnection>; type MaybeTlsStream = IOStreamAsyncReadWrite<gio::SocketConnection>;
@ -21,7 +20,7 @@ pub async fn connect_async<R>(
request: R, request: R,
) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error> ) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
connect_async_with_config(request, None).await connect_async_with_config(request, None).await
} }
@ -32,20 +31,17 @@ pub async fn connect_async_with_config<R>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error> ) -> Result<(WebSocketStream<MaybeTlsStream>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
let port = request let port = port(&request)?;
.url
.port_or_known_default()
.expect("Bug: port unknown");
let client = gio::SocketClient::new(); let client = gio::SocketClient::new();
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
if let Mode::Tls = mode { if let Mode::Tls = mode {
client.set_tls(true); client.set_tls(true);
} else { } else {

@ -54,9 +54,10 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tungstenite::{ use tungstenite::{
client::IntoClientRequest,
error::Error as WsError, error::Error as WsError,
handshake::{ handshake::{
client::{ClientHandshake, Request, Response}, client::{ClientHandshake, Response},
server::{Callback, NoCallback}, server::{Callback, NoCallback},
}, },
protocol::{Message, Role, WebSocket, WebSocketConfig}, protocol::{Message, Role, WebSocket, WebSocketConfig},
@ -92,7 +93,7 @@ pub async fn client_async<'a, R, S>(
stream: S, stream: S,
) -> Result<(WebSocketStream<S>, Response), WsError> ) -> Result<(WebSocketStream<S>, Response), WsError>
where where
R: Into<Request<'a>> + Unpin, R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
client_async_with_config(request, stream, None).await client_async_with_config(request, stream, None).await
@ -106,11 +107,12 @@ pub async fn client_async_with_config<'a, R, S>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<S>, Response), WsError> ) -> Result<(WebSocketStream<S>, Response), WsError>
where where
R: Into<Request<'a>> + Unpin, R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
let f = handshake::client_handshake(stream, move |allow_std| { let f = handshake::client_handshake(stream, move |allow_std| {
let cli_handshake = ClientHandshake::start(allow_std, request.into(), config); let request = request.into_client_request()?;
let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
cli_handshake.handshake() cli_handshake.handshake()
}); });
f.await.map_err(|e| { f.await.map_err(|e| {
@ -346,9 +348,33 @@ where
))] ))]
/// Get a domain from an URL. /// Get a domain from an URL.
#[inline] #[inline]
pub(crate) fn domain(request: &Request) -> Result<String, tungstenite::Error> { pub(crate) fn domain(
match request.url.host_str() { request: &tungstenite::handshake::client::Request,
) -> Result<String, tungstenite::Error> {
match request.uri().host() {
Some(d) => Ok(d.to_string()), Some(d) => Ok(d.to_string()),
None => Err(tungstenite::Error::Url("no host name in the url".into())), None => Err(tungstenite::Error::Url("no host name in the url".into())),
} }
} }
#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
feature = "tokio-runtime",
feature = "gio-runtime"
))]
/// Get the port from an URL.
#[inline]
pub(crate) fn port(
request: &tungstenite::handshake::client::Request,
) -> Result<u16, tungstenite::Error> {
request
.uri()
.port_u16()
.or_else(|| match request.uri().scheme_str() {
Some("wss") => Some(443),
Some("ws") => Some(80),
_ => None,
})
.ok_or(tungstenite::Error::Url("Url scheme not supported".into()))
}

@ -1,11 +1,12 @@
//! `tokio` integration. //! `tokio` integration.
use tungstenite::handshake::client::Response; use tungstenite::client::IntoClientRequest;
use tungstenite::handshake::client::{Request, Response};
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
use tungstenite::Error; use tungstenite::Error;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use super::{domain, Request, WebSocketStream}; use super::{domain, port, WebSocketStream};
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
@ -14,16 +15,15 @@ pub(crate) mod tokio_tls {
use real_tokio_tls::TlsConnector as AsyncTlsConnector; use real_tokio_tls::TlsConnector as AsyncTlsConnector;
use real_tokio_tls::TlsStream; use real_tokio_tls::TlsStream;
use tungstenite::client::url_mode; use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Request;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use crate::stream::Stream as StreamSwitcher; use crate::stream::Stream as StreamSwitcher;
use crate::{ use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
use super::TokioAdapter; use super::TokioAdapter;
@ -73,16 +73,16 @@ pub(crate) mod tokio_tls {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?; let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await client_async_with_config(request, stream, config).await
@ -93,13 +93,12 @@ pub(crate) mod tokio_tls {
pub(crate) mod dummy_tls { pub(crate) mod dummy_tls {
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::client::url_mode; use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Request;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
use crate::{ use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
client_async_with_config, domain, Request, Response, WebSocketConfig, WebSocketStream,
};
pub type AutoStream<S> = S; pub type AutoStream<S> = S;
type Connector = (); type Connector = ();
@ -126,16 +125,16 @@ pub(crate) mod dummy_tls {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?; let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await client_async_with_config(request, stream, config).await
@ -161,7 +160,7 @@ pub async fn client_async_tls<R, S>(
stream: S, stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -178,7 +177,7 @@ pub async fn client_async_tls_with_config<R, S>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -195,7 +194,7 @@ pub async fn client_async_tls_with_connector<R, S>(
connector: Option<Connector>, connector: Option<Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin,
AutoStream<S>: Unpin, AutoStream<S>: Unpin,
{ {
@ -213,7 +212,7 @@ pub async fn connect_async<R>(
Error, Error,
> >
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
connect_async_with_config(request, None).await connect_async_with_config(request, None).await
} }
@ -230,15 +229,12 @@ pub async fn connect_async_with_config<R>(
Error, Error,
> >
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
let port = request let port = port(&request)?;
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;
@ -258,7 +254,7 @@ pub async fn connect_async_with_tls_connector<R>(
Error, Error,
> >
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
connect_async_with_tls_connector_and_config(request, connector, None).await connect_async_with_tls_connector_and_config(request, connector, None).await
} }
@ -277,15 +273,12 @@ pub async fn connect_async_with_tls_connector_and_config<R>(
Error, Error,
> >
where where
R: Into<Request<'static>> + Unpin, R: IntoClientRequest + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let domain = domain(&request)?; let domain = domain(&request)?;
let port = request let port = port(&request)?;
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await; let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;

@ -48,7 +48,7 @@ async fn communication() {
let tcp = TcpStream::connect("0.0.0.0:12345") let tcp = TcpStream::connect("0.0.0.0:12345")
.await .await
.expect("Failed to connect"); .expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap(); let url = "ws://localhost:12345/";
let (mut stream, _) = client_async(url, tcp) let (mut stream, _) = client_async(url, tcp)
.await .await
.expect("Client failed to connect"); .expect("Client failed to connect");

Loading…
Cancel
Save