Tokio 0.2 Conversion

Update to use tokio 0.2 ecosystem to integrate with tungstenite.
pull/1/head
Danny Browning 5 years ago committed by Danny Browning
parent 59ca2c885e
commit 3821e0952a
  1. 2
      .travis.yml
  2. 26
      Cargo.toml
  3. 58
      examples/autobahn-client.rs
  4. 66
      examples/autobahn-server.rs
  5. 69
      examples/client.rs
  6. 165
      examples/server.rs
  7. 123
      src/compat.rs
  8. 131
      src/connect.rs
  9. 181
      src/handshake.rs
  10. 363
      src/lib.rs
  11. 125
      src/stream.rs
  12. 79
      tests/communication.rs
  13. 63
      tests/handshakes.rs

@ -1,4 +1,6 @@
language: rust language: rust
rust:
- nightly-2019-09-05
before_script: before_script:
- export PATH="$PATH:$HOME/.cargo/bin" - export PATH="$PATH:$HOME/.cargo/bin"

@ -8,21 +8,25 @@ license = "MIT"
homepage = "https://github.com/snapview/tokio-tungstenite" homepage = "https://github.com/snapview/tokio-tungstenite"
documentation = "https://docs.rs/tokio-tungstenite/0.9.0" documentation = "https://docs.rs/tokio-tungstenite/0.9.0"
repository = "https://github.com/snapview/tokio-tungstenite" repository = "https://github.com/snapview/tokio-tungstenite"
version = "0.9.0" version = "0.10.0-alpha.1"
edition = "2018" edition = "2018"
[features] [features]
default = ["connect", "tls"] default = ["connect", "tls"]
connect = ["tokio-dns-unofficial", "tokio-tcp", "stream"] connect = ["tokio-dns-unofficial", "tokio-net", "stream"]
tls = ["tokio-tls", "native-tls", "stream", "tungstenite/tls"] tls = ["tokio-tls", "native-tls", "stream", "tungstenite/tls"]
stream = ["bytes"] stream = ["bytes"]
[dependencies] [dependencies]
futures = "0.1.23" log = "0.4"
tokio-io = "0.1.7" futures-preview = { version = "0.3.0-alpha.19", features = ["async-await"] }
pin-project = "0.4.0-alpha.9"
tokio-io = "0.2.0-alpha.6"
[dependencies.tungstenite] [dependencies.tungstenite]
version = "0.9.1" #version = "0.9.1"
git = "https://github.com/snapview/tungstenite-rs.git"
branch = "master"
default-features = false default-features = false
[dependencies.bytes] [dependencies.bytes]
@ -35,18 +39,18 @@ version = "0.2.0"
[dependencies.tokio-dns-unofficial] [dependencies.tokio-dns-unofficial]
optional = true optional = true
version = "0.3.1" #version = "0.4.0"
git = "https://github.com/sbstp/tokio-dns.git"
[dependencies.tokio-tcp] [dependencies.tokio-net]
optional = true optional = true
version = "0.1.0" version = "0.2.0-alpha.6"
[dependencies.tokio-tls] [dependencies.tokio-tls]
optional = true optional = true
version = "0.2.0" version = "0.3.0-alpha.6"
[dev-dependencies] [dev-dependencies]
tokio = "0.1.7" tokio = "0.2.0-alpha.6"
url = "2.0.0" url = "2.0.0"
env_logger = "0.6.1" env_logger = "0.6.1"
log = "0.4.6"

@ -1,33 +1,32 @@
use futures::{Future, Stream}; use futures::StreamExt;
use log::*; use log::*;
use tokio_tungstenite::{ use tokio_tungstenite::{connect_async, tungstenite::Result};
connect_async,
tungstenite::{connect, Error as WsError, Result},
};
use url::Url; use url::Url;
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { async fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let (mut socket, _) =
let msg = socket.read_message()?; connect_async(Url::parse("ws://localhost:9001/getCaseCount").unwrap()).await?;
socket.close(None)?; let msg = socket.next().await.unwrap()?;
socket.close(None).await?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
} }
fn update_reports() -> Result<()> { async fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect_async(
Url::parse(&format!( Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}", "ws://localhost:9001/updateReports?agent={}",
AGENT AGENT
)) ))
.unwrap(), .unwrap(),
)?; )
socket.close(None)?; .await?;
socket.close(None).await?;
Ok(()) Ok(())
} }
fn run_test(case: u32) { async fn run_test(case: u32) {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = Url::parse(&format!( let case_url = Url::parse(&format!(
"ws://localhost:9001/runCase?case={}&agent={}", "ws://localhost:9001/runCase?case={}&agent={}",
@ -35,31 +34,24 @@ fn run_test(case: u32) {
)) ))
.unwrap(); .unwrap();
let job = connect_async(case_url) let (mut ws_stream, _) = connect_async(case_url).await.expect("Connect error");
.map_err(|err| error!("Connect error: {}", err)) while let Some(msg) = ws_stream.next().await {
.and_then(|(ws_stream, _)| { let msg = msg.expect("Failed to get message");
let (sink, stream) = ws_stream.split(); if msg.is_text() || msg.is_binary() {
stream ws_stream.send(msg).await.expect("Write error");
.filter(|msg| msg.is_text() || msg.is_binary()) }
.forward(sink) }
.and_then(|(_stream, _sink)| Ok(()))
.map_err(|err| match err {
WsError::ConnectionClosed => (),
err => info!("WS error {}", err),
})
});
tokio::run(job)
} }
fn main() { #[tokio::main]
async fn main() {
env_logger::init(); env_logger::init();
let total = get_case_count().unwrap(); let total = get_case_count().await.unwrap();
for case in 1..(total + 1) { for case in 1..(total + 1) {
run_test(case) run_test(case).await
} }
update_reports().unwrap(); update_reports().await.unwrap();
} }

@ -1,42 +1,42 @@
use futures::{Future, Stream}; use futures::StreamExt;
use log::*; use log::*;
use tokio::net::TcpListener; use std::net::{SocketAddr, ToSocketAddrs};
use tokio_tungstenite::{accept_async, tungstenite::Error as WsError}; use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::accept_async;
fn main() { async fn accept_connection(peer: SocketAddr, stream: TcpStream) {
env_logger::init(); let mut ws_stream = accept_async(stream).await.expect("Failed to accept");
let mut runtime = tokio::runtime::Builder::new().build().unwrap(); info!("New WebSocket connection: {}", peer);
let addr = "127.0.0.1:9002".parse().unwrap(); while let Some(msg) = ws_stream.next().await {
let socket = TcpListener::bind(&addr).unwrap(); let msg = msg.expect("Failed to get request");
info!("Listening on: {}", addr); if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await.expect("Failed to send response");
}
}
}
let srv = socket #[tokio::main]
.incoming() async fn main() {
.map_err(Into::into) env_logger::init();
.for_each(move |stream| {
let peer = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", peer);
accept_async(stream).and_then(move |ws_stream| { let addr = "127.0.0.1:9002"
info!("New WebSocket connection: {}", peer); .to_socket_addrs()
let (sink, stream) = ws_stream.split(); .expect("Not a valid address")
let job = stream .next()
.filter(|msg| msg.is_text() || msg.is_binary()) .expect("Not a socket address");
.forward(sink) let socket = TcpListener::bind(&addr).await.unwrap();
.and_then(|(_stream, _sink)| Ok(())) let mut incoming = socket.incoming();
.map_err(|err| match err { info!("Listening on: {}", addr);
WsError::ConnectionClosed => (),
err => info!("WS error: {}", err),
});
tokio::spawn(job); while let Some(stream) = incoming.next().await {
Ok(()) let stream = stream.expect("Failed to get stream");
}) let peer = stream
}); .peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", peer);
runtime.block_on(srv).unwrap(); tokio::spawn(accept_connection(peer, stream));
}
} }

@ -11,17 +11,19 @@
//! You can use this example together with the `server` example. //! You can use this example together with the `server` example.
use std::env; use std::env;
use std::io::{self, Read, Write}; use std::io::{self, Write};
use std::thread;
use futures::sync::mpsc; use futures::StreamExt;
use futures::{Future, Sink, Stream}; use log::*;
use tungstenite::protocol::Message; use tungstenite::protocol::Message;
use tokio::io::AsyncReadExt;
use tokio_tungstenite::connect_async; use tokio_tungstenite::connect_async;
use tokio_tungstenite::stream::PeerAddr;
fn main() { #[tokio::main]
async fn main() {
let _ = env_logger::try_init();
// Specify the server address to which the client will be connecting. // Specify the server address to which the client will be connecting.
let connect_addr = env::args() let connect_addr = env::args()
.nth(1) .nth(1)
@ -33,9 +35,8 @@ fn main() {
// loop, so we farm out that work to a separate thread. This thread will // loop, so we farm out that work to a separate thread. This thread will
// read data from stdin and then send it to the event loop over a standard // read data from stdin and then send it to the event loop over a standard
// futures channel. // futures channel.
let (stdin_tx, stdin_rx) = mpsc::channel(0); let (stdin_tx, mut stdin_rx) = futures::channel::mpsc::unbounded();
thread::spawn(|| read_stdin(stdin_tx)); tokio::spawn(read_stdin(stdin_tx));
let stdin_rx = stdin_rx.map_err(|_| panic!()); // errors not possible on rx
// After the TCP connection has been established, we set up our client to // After the TCP connection has been established, we set up our client to
// start forwarding data. // start forwarding data.
@ -53,53 +54,29 @@ fn main() {
// finishes. If we don't have any more data to read or we won't receive any // finishes. If we don't have any more data to read or we won't receive any
// more work from the remote then we can exit. // more work from the remote then we can exit.
let mut stdout = io::stdout(); let mut stdout = io::stdout();
let client = connect_async(url) let (mut ws_stream, _) = connect_async(url).await.expect("Failed to connect");
.and_then(move |(ws_stream, _)| { info!("WebSocket handshake has been successfully completed");
println!("WebSocket handshake has been successfully completed");
let addr = ws_stream
.peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr);
// `sink` is the stream of messages going out.
// `stream` is the stream of incoming messages.
let (sink, stream) = ws_stream.split();
// We forward all messages, composed out of the data, entered to while let Some(msg) = stdin_rx.next().await {
// the stdin, to the `sink`. ws_stream.send(msg).await.expect("Failed to send request");
let send_stdin = stdin_rx.forward(sink); if let Some(msg) = ws_stream.next().await {
let write_stdout = stream.for_each(move |message| { let msg = msg.expect("Failed to get response");
stdout.write_all(&message.into_data()).unwrap(); stdout.write_all(&msg.into_data()).unwrap();
Ok(()) }
}); }
// Wait for either of futures to complete.
send_stdin
.map(|_| ())
.select(write_stdout.map(|_| ()))
.then(|_| Ok(()))
})
.map_err(|e| {
println!("Error during the websocket handshake occurred: {}", e);
io::Error::new(io::ErrorKind::Other, e)
});
// And now that we've got our client, we execute it in the event loop!
tokio::runtime::run(client.map_err(|_e| ()));
} }
// Our helper method which will read data from stdin and send it along the // Our helper method which will read data from stdin and send it along the
// sender provided. // sender provided.
fn read_stdin(mut tx: mpsc::Sender<Message>) { async fn read_stdin(tx: futures::channel::mpsc::UnboundedSender<Message>) {
let mut stdin = io::stdin(); let mut stdin = tokio::io::stdin();
loop { loop {
let mut buf = vec![0; 1024]; let mut buf = vec![0; 1024];
let n = match stdin.read(&mut buf) { let n = match stdin.read(&mut buf).await {
Err(_) | Ok(0) => break, Err(_) | Ok(0) => break,
Ok(n) => n, Ok(n) => n,
}; };
buf.truncate(n); buf.truncate(n);
tx = tx.send(Message::binary(buf)).wait().unwrap(); tx.unbounded_send(Message::binary(buf)).unwrap();
} }
} }

@ -17,106 +17,87 @@
//! connected clients they'll all join the same room and see everyone else's //! connected clients they'll all join the same room and see everyone else's
//! messages. //! messages.
use std::collections::HashMap;
use std::env; use std::env;
use std::io::{Error, ErrorKind}; use std::io::Error;
use std::sync::{Arc, Mutex};
use futures::stream::Stream; use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures::Future; use futures::StreamExt;
use tokio::net::TcpListener; use log::*;
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::net::{TcpListener, TcpStream};
use tungstenite::protocol::Message; use tungstenite::protocol::Message;
use tokio_tungstenite::accept_async; struct Connection {
addr: SocketAddr,
rx: UnboundedReceiver<Message>,
tx: UnboundedSender<Message>,
}
async fn handle_connection(connection: Connection) {
let mut connection = connection;
while let Some(msg) = connection.rx.next().await {
info!("Received a message from {}: {}", connection.addr, msg);
connection
.tx
.unbounded_send(msg)
.expect("Failed to forward message");
}
}
fn main() { async fn accept_connection(stream: TcpStream) {
let addr = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", addr);
let mut ws_stream = tokio_tungstenite::accept_async(stream)
.await
.expect("Error during the websocket handshake occurred");
info!("New WebSocket connection: {}", addr);
// Create a channel for our stream, which other sockets will use to
// send us messages. Then register our address with the stream to send
// data to us.
let (msg_tx, msg_rx) = futures::channel::mpsc::unbounded();
let (response_tx, mut response_rx) = futures::channel::mpsc::unbounded();
let c = Connection {
addr: addr,
rx: msg_rx,
tx: response_tx,
};
tokio::spawn(handle_connection(c));
while let Some(message) = ws_stream.next().await {
let message = message.expect("Failed to get request");
msg_tx
.unbounded_send(message)
.expect("Failed to forward request");
if let Some(resp) = response_rx.next().await {
ws_stream.send(resp).await.expect("Failed to send response");
}
}
}
#[tokio::main]
async fn main() -> Result<(), Error> {
let _ = env_logger::try_init();
let addr = env::args().nth(1).unwrap_or("127.0.0.1:8080".to_string()); let addr = env::args().nth(1).unwrap_or("127.0.0.1:8080".to_string());
let addr = addr.parse().unwrap(); let addr = addr
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("Not a socket address");
// Create the event loop and TCP listener we'll accept connections on. // Create the event loop and TCP listener we'll accept connections on.
let socket = TcpListener::bind(&addr).unwrap(); let try_socket = TcpListener::bind(&addr).await;
println!("Listening on: {}", addr); let socket = try_socket.expect("Failed to bind");
let mut incoming = socket.incoming();
// Tokio Runtime uses a thread pool based executor by default, so we need info!("Listening on: {}", addr);
// to use Arc and Mutex to store the map of all connections we know about.
let connections = Arc::new(Mutex::new(HashMap::new()));
let srv = socket.incoming().for_each(move |stream| {
let addr = stream
.peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr);
// We have to clone both of these values, because the `and_then`
// function below constructs a new future, `and_then` requires
// `FnOnce`, so we construct a move closure to move the
// environment inside the future (AndThen future may overlive our
// `for_each` future).
let connections_inner = connections.clone();
accept_async(stream)
.and_then(move |ws_stream| {
println!("New WebSocket connection: {}", addr);
// Create a channel for our stream, which other sockets will use to
// send us messages. Then register our address with the stream to send
// data to us.
let (tx, rx) = futures::sync::mpsc::unbounded();
connections_inner.lock().unwrap().insert(addr, tx);
// Let's split the WebSocket stream, so we can work with the
// reading and writing halves separately.
let (sink, stream) = ws_stream.split();
// Whenever we receive a message from the client, we print it and
// send to other clients, excluding the sender.
let connections = connections_inner.clone();
let ws_reader = stream.for_each(move |message: Message| {
println!("Received a message from {}: {}", addr, message);
// For each open connection except the sender, send the
// string via the channel.
let mut conns = connections.lock().unwrap();
let iter = conns
.iter_mut()
.filter(|&(&k, _)| k != addr)
.map(|(_, v)| v);
for tx in iter {
tx.unbounded_send(message.clone()).unwrap();
}
Ok(())
});
// Whenever we receive a string on the Receiver, we write it to
// `WriteHalf<WebSocketStream>`.
let ws_writer = rx.fold(sink, |mut sink, msg| {
use futures::Sink;
sink.start_send(msg).unwrap();
Ok(sink)
});
// Now that we've got futures representing each half of the socket, we
// use the `select` combinator to wait for either half to be done to
// tear down the other. Then we spawn off the result.
let connection = ws_reader
.map(|_| ())
.map_err(|_| ())
.select(ws_writer.map(|_| ()).map_err(|_| ()));
tokio::spawn(connection.then(move |_| {
connections_inner.lock().unwrap().remove(&addr);
println!("Connection {} closed.", addr);
Ok(())
}));
Ok(()) while let Some(stream) = incoming.next().await {
}) let stream = stream.expect("Failed to accept stream");
.map_err(|e| { tokio::spawn(accept_connection(stream));
println!("Error during the websocket handshake occurred: {}", e); }
Error::new(ErrorKind::Other, e)
})
});
// Execute server. Ok(())
tokio::runtime::run(srv.map_err(|_e| ()));
} }

@ -0,0 +1,123 @@
use log::*;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket};
pub(crate) trait HasContext {
fn set_context(&mut self, context: *mut ());
}
#[derive(Debug)]
pub struct AllowStd<S> {
pub(crate) inner: S,
pub(crate) context: *mut (),
}
impl<S> HasContext for AllowStd<S> {
fn set_context(&mut self, context: *mut ()) {
self.context = context;
}
}
pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>);
impl<S> Drop for Guard<'_, S> {
fn drop(&mut self) {
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = std::ptr::null_mut();
}
}
// *mut () context is neither Send nor Sync
unsafe impl<S: Send> Send for AllowStd<S> {}
unsafe impl<S: Sync> Sync for AllowStd<S> {}
impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe {
assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
pub(crate) fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub(crate) fn get_ref(&self) -> &S {
&self.inner
}
}
impl<S> Read for AllowStd<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Read.with_context read -> poll_read",
file!(),
line!()
);
stream.poll_read(ctx, buf)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for AllowStd<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Write.with_context write -> poll_write",
file!(),
line!()
);
stream.poll_write(ctx, buf)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Write.with_context flush -> poll_flush",
file!(),
line!()
);
stream.poll_flush(ctx)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}

@ -1,82 +1,50 @@
//! Connection helper. //! Connection helper.
use std::io::Result as IoResult;
use std::net::SocketAddr;
use tokio_tcp::TcpStream;
use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tokio_net::tcp::TcpStream;
use tungstenite::client::url_mode; use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response; use tungstenite::handshake::client::Response;
use tungstenite::Error; use tungstenite::Error;
use super::{client_async, Request, WebSocketStream}; use super::{client_async, Request, WebSocketStream};
use crate::stream::{NoDelay, PeerAddr};
impl NoDelay for TcpStream {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
TcpStream::set_nodelay(self, nodelay)
}
}
impl PeerAddr for TcpStream {
fn peer_addr(&self) -> IoResult<SocketAddr> {
self.peer_addr()
}
}
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
mod encryption { pub(crate) mod encryption {
use native_tls::TlsConnector; use native_tls::TlsConnector;
use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream}; use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream};
use std::io::{Read, Result as IoResult, Write};
use std::net::SocketAddr;
use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
use crate::stream::{NoDelay, PeerAddr, Stream as StreamSwitcher}; use crate::stream::Stream as StreamSwitcher;
/// A stream that might be protected with TLS. /// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>; pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub type AutoStream<S> = MaybeTlsStream<S>; pub type AutoStream<S> = MaybeTlsStream<S>;
impl<T: Read + Write + NoDelay> NoDelay for TlsStream<T> { pub async fn wrap_stream<S>(
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().get_mut().set_nodelay(nodelay)
}
}
impl<S: Read + Write + PeerAddr> PeerAddr for TlsStream<S> {
fn peer_addr(&self) -> IoResult<SocketAddr> {
self.get_ref().get_ref().peer_addr()
}
}
pub fn wrap_stream<S>(
socket: S, socket: S,
domain: String, domain: String,
mode: Mode, mode: Mode,
) -> Box<dyn Future<Item = AutoStream<S>, Error = Error> + Send> ) -> Result<AutoStream<S>, Error>
where where
S: 'static + AsyncRead + AsyncWrite + Send, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{ {
match mode { match mode {
Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))), Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => Box::new( Mode::Tls => {
future::result(TlsConnector::new()) let try_connector = TlsConnector::new();
.map(TokioTlsConnector::from) let connector = try_connector.map_err(Error::Tls)?;
.and_then(move |connector| connector.connect(&domain, socket)) let stream = TokioTlsConnector::from(connector);
.map(StreamSwitcher::Tls) let connected = stream.connect(&domain, socket).await;
.map_err(Error::Tls), match connected {
), Err(e) => Err(Error::Tls(e)),
Ok(s) => Ok(StreamSwitcher::Tls(s)),
}
}
} }
} }
} }
@ -85,7 +53,7 @@ mod encryption {
pub use self::encryption::MaybeTlsStream; pub use self::encryption::MaybeTlsStream;
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
mod encryption { pub(crate) mod encryption {
use futures::{future, Future}; use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
@ -94,19 +62,17 @@ mod encryption {
pub type AutoStream<S> = S; pub type AutoStream<S> = S;
pub fn wrap_stream<S>( pub async fn wrap_stream<S>(
socket: S, socket: S,
_domain: String, _domain: String,
mode: Mode, mode: Mode,
) -> Box<Future<Item = AutoStream<S>, Error = Error> + Send> ) -> Result<AutoStream<S>, Error>
where where
S: 'static + AsyncRead + AsyncWrite + Send, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{ {
match mode { match mode {
Mode::Plain => Box::new(future::ok(socket)), Mode::Plain => Ok(socket),
Mode::Tls => Box::new(future::err(Error::Url( Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
"TLS support not compiled in.".into(),
))),
} }
} }
} }
@ -124,59 +90,42 @@ fn domain(request: &Request) -> Result<String, Error> {
/// Creates a WebSocket handshake from a request and a stream, /// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required. /// upgrading the stream to TLS if required.
pub fn client_async_tls<R, S>( pub async fn client_async_tls<R, S>(
request: R, request: R,
stream: S, stream: S,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<S>>, Response), Error = Error> + Send> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>>, R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + NoDelay + Send, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into();
let domain = match domain(&request) { let domain = domain(&request)?;
Ok(domain) => domain,
Err(err) => return Box::new(future::err(err)),
};
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = match url_mode(&request.url) { let mode = url_mode(&request.url)?;
Ok(m) => m,
Err(e) => return Box::new(future::err(e)), let stream = wrap_stream(stream, domain, mode).await?;
}; client_async(request, stream).await
Box::new(
wrap_stream(stream, domain, mode)
.and_then(|mut stream| {
NoDelay::set_nodelay(&mut stream, true)
.map(move |()| stream)
.map_err(|e| e.into())
})
.and_then(move |stream| client_async(request, stream)),
)
} }
/// Connect to a given URL. /// Connect to a given URL.
pub fn connect_async<R>( pub async fn connect_async<R>(
request: R, request: R,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<TcpStream>>, Response), Error = Error> + Send> ) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where where
R: Into<Request<'static>>, R: Into<Request<'static>> + Unpin,
{ {
let request: Request = request.into(); let request: Request = request.into();
let domain = match domain(&request) { let domain = domain(&request)?;
Ok(domain) => domain,
Err(err) => return Box::new(future::err(err)),
};
let port = request let port = request
.url .url
.port_or_known_default() .port_or_known_default()
.expect("Bug: port unknown"); .expect("Bug: port unknown");
Box::new( let try_socket = tokio_dns::TcpStream::connect((domain.as_str(), port)).await;
tokio_dns::TcpStream::connect((domain.as_str(), port)) let socket = try_socket.map_err(Error::Io)?;
.map_err(|e| e.into()) client_async_tls(request, socket).await
.and_then(move |socket| client_async_tls(request, socket)),
)
} }

@ -0,0 +1,181 @@
use crate::compat::{AllowStd, HasContext};
use crate::WebSocketStream;
use log::*;
use pin_project::pin_project;
use std::future::Future;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::handshake::client::Response;
use tungstenite::handshake::server::Callback;
use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake};
use tungstenite::{ClientHandshake, ServerHandshake, WebSocket};
pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
let ws = start.await;
WebSocketStream::new(ws)
}
struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
struct SkippedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
impl<F, S> Future for SkippedHandshakeFuture<F, S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = WebSocket<AllowStd<S>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self
.get_mut()
.0
.take()
.expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
};
Poll::Ready((inner.f)(stream))
}
}
#[pin_project]
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
enum StartedHandshake<Role: HandshakeRole> {
Done(Role::FinalResult),
Mid(WsHandshake<Role>),
}
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
struct StartedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
match start.await? {
StartedHandshake::Done(r) => Ok(r),
StartedHandshake::Mid(s) => {
let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
res
}
}
}
pub(crate) async fn client_handshake<F, S>(
stream: S,
f: F,
) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
where
F: FnOnce(
AllowStd<S>,
) -> Result<
<ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
Error<ClientHandshake<AllowStd<S>>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let result = handshake(stream, f).await?;
let (s, r) = result;
Ok((WebSocketStream::new(s), r))
}
pub(crate) async fn server_handshake<C, F, S>(
stream: S,
f: F,
) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
where
C: Callback + Unpin,
F: FnOnce(
AllowStd<S>,
) -> Result<
<ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
Error<ServerHandshake<AllowStd<S>, C>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
Ok(WebSocketStream::new(s))
}
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = Result<StartedHandshake<Role>, Error<Role>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
};
match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
}
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
{
type Output = Result<Role::FinalResult, Error<Role>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut s = this.0.take().expect("future polled after completion");
let machine = s.get_mut();
trace!("Setting context in handshake");
machine.get_mut().set_context(cx as *mut _ as *mut ());
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
*this.0 = Some(mid);
Poll::Pending
}
}
}
}

@ -18,26 +18,29 @@
pub use tungstenite; pub use tungstenite;
mod compat;
#[cfg(feature = "connect")] #[cfg(feature = "connect")]
mod connect; mod connect;
mod handshake;
#[cfg(feature = "stream")] #[cfg(feature = "stream")]
pub mod stream; pub mod stream;
use std::io::ErrorKind; use std::io::{Read, Write};
#[cfg(feature = "stream")]
use std::{io::Result as IoResult, net::SocketAddr};
use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; use compat::{cvt, AllowStd};
use futures::Stream;
use log::*;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{ use tungstenite::{
error::Error as WsError, error::Error as WsError,
handshake::{ handshake::{
client::{ClientHandshake, Request, Response}, client::{ClientHandshake, Request, Response},
server::{Callback, NoCallback, ServerHandshake}, server::{Callback, NoCallback},
HandshakeError, HandshakeRole,
}, },
protocol::{Message, Role, WebSocket, WebSocketConfig}, protocol::{Message, Role, WebSocket, WebSocketConfig},
server, server,
@ -46,11 +49,10 @@ use tungstenite::{
#[cfg(feature = "connect")] #[cfg(feature = "connect")]
pub use connect::{client_async_tls, connect_async}; pub use connect::{client_async_tls, connect_async};
#[cfg(feature = "stream")]
pub use stream::PeerAddr;
#[cfg(all(feature = "connect", feature = "tls"))] #[cfg(all(feature = "connect", feature = "tls"))]
pub use connect::MaybeTlsStream; pub use connect::MaybeTlsStream;
use std::error::Error;
use tungstenite::protocol::CloseFrame;
/// Creates a WebSocket handshake from a request and a stream. /// Creates a WebSocket handshake from a request and a stream.
/// For convenience, the user may call this with a url string, a URL, /// For convenience, the user may call this with a url string, a URL,
@ -64,30 +66,38 @@ pub use connect::MaybeTlsStream;
/// ///
/// This is typically used for clients who have already established, for /// This is typically used for clients who have already established, for
/// example, a TCP connection to the remote server. /// example, a TCP connection to the remote server.
pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync<S> pub async fn client_async<'a, R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<S>, Response), WsError>
where where
R: Into<Request<'a>>, R: Into<Request<'a>> + Unpin,
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
{ {
client_async_with_config(request, stream, None) client_async_with_config(request, stream, None).await
} }
/// The same as `client_async()` but the one can specify a websocket configuration. /// The same as `client_async()` but the one can specify a websocket configuration.
/// Please refer to `client_async()` for more details. /// Please refer to `client_async()` for more details.
pub fn client_async_with_config<'a, R, S>( pub async fn client_async_with_config<'a, R, S>(
request: R, request: R,
stream: S, stream: S,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> ConnectAsync<S> ) -> Result<(WebSocketStream<S>, Response), WsError>
where where
R: Into<Request<'a>>, R: Into<Request<'a>> + Unpin,
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
{ {
ConnectAsync { let f = handshake::client_handshake(stream, move |allow_std| {
inner: MidHandshake { let cli_handshake = ClientHandshake::start(allow_std, request.into(), config);
inner: Some(ClientHandshake::start(stream, request.into(), config).handshake()), cli_handshake.handshake()
}, });
} f.await.map_err(|e| {
WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.description(),
))
})
} }
/// Accepts a new WebSocket connection with the provided stream. /// Accepts a new WebSocket connection with the provided stream.
@ -101,23 +111,23 @@ where
/// This is typically used after a socket has been accepted from a /// This is typically used after a socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform /// `TcpListener`. That socket is then passed to this function to perform
/// the server half of the accepting a client's websocket connection. /// the server half of the accepting a client's websocket connection.
pub fn accept_async<S>(stream: S) -> AcceptAsync<S, NoCallback> pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
where where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
{ {
accept_hdr_async(stream, NoCallback) accept_hdr_async(stream, NoCallback).await
} }
/// The same as `accept_async()` but the one can specify a websocket configuration. /// The same as `accept_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_async()` for more details. /// Please refer to `accept_async()` for more details.
pub fn accept_async_with_config<S>( pub async fn accept_async_with_config<S>(
stream: S, stream: S,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> AcceptAsync<S, NoCallback> ) -> Result<WebSocketStream<S>, WsError>
where where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
{ {
accept_hdr_async_with_config(stream, NoCallback, config) accept_hdr_async_with_config(stream, NoCallback, config).await
} }
/// Accepts a new WebSocket connection with the provided stream. /// Accepts a new WebSocket connection with the provided stream.
@ -125,30 +135,34 @@ where
/// This function does the same as `accept_async()` but accepts an extra callback /// This function does the same as `accept_async()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming /// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply. /// requests and is able to add extra headers to the reply.
pub fn accept_hdr_async<S, C>(stream: S, callback: C) -> AcceptAsync<S, C> pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
where where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
C: Callback, C: Callback + Unpin,
{ {
accept_hdr_async_with_config(stream, callback, None) accept_hdr_async_with_config(stream, callback, None).await
} }
/// The same as `accept_hdr_async()` but the one can specify a websocket configuration. /// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_hdr_async()` for more details. /// Please refer to `accept_hdr_async()` for more details.
pub fn accept_hdr_async_with_config<S, C>( pub async fn accept_hdr_async_with_config<S, C>(
stream: S, stream: S,
callback: C, callback: C,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> AcceptAsync<S, C> ) -> Result<WebSocketStream<S>, WsError>
where where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite + Unpin,
C: Callback, C: Callback + Unpin,
{ {
AcceptAsync { let f = handshake::server_handshake(stream, move |allow_std| {
inner: MidHandshake { server::accept_hdr_with_config(allow_std, callback, config)
inner: Some(server::accept_hdr_with_config(stream, callback, config)), });
}, f.await.map_err(|e| {
} WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.description(),
))
})
} }
/// A wrapper around an underlying raw stream which implements the WebSocket /// A wrapper around an underlying raw stream which implements the WebSocket
@ -160,176 +174,187 @@ where
/// through the respective `Stream` and `Sink`. Check more information about /// through the respective `Stream` and `Sink`. Check more information about
/// them in `futures-rs` crate documentation or have a look on the examples /// them in `futures-rs` crate documentation or have a look on the examples
/// and unit tests for this crate. /// and unit tests for this crate.
#[pin_project]
pub struct WebSocketStream<S> { pub struct WebSocketStream<S> {
inner: WebSocket<S>, #[pin]
inner: WebSocket<AllowStd<S>>,
} }
impl<S> WebSocketStream<S> { impl<S> WebSocketStream<S> {
/// Convert a raw socket into a WebSocketStream without performing a /// Convert a raw socket into a WebSocketStream without performing a
/// handshake. /// handshake.
pub fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self { pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
Self::new(WebSocket::from_raw_socket(stream, role, config)) where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_raw_socket(allow_std, role, config)
})
.await
} }
/// Convert a raw socket into a WebSocketStream without performing a /// Convert a raw socket into a WebSocketStream without performing a
/// handshake. /// handshake.
pub fn from_partially_read( pub async fn from_partially_read(
stream: S, stream: S,
part: Vec<u8>, part: Vec<u8>,
role: Role, role: Role,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Self { ) -> Self
Self::new(WebSocket::from_partially_read(stream, part, role, config)) where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_partially_read(allow_std, part, role, config)
})
.await
} }
fn new(ws: WebSocket<S>) -> Self { pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
WebSocketStream { inner: ws } WebSocketStream { inner: ws }
} }
}
#[cfg(feature = "stream")] fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
impl<S: PeerAddr> PeerAddr for WebSocketStream<S> { where
fn peer_addr(&self) -> IoResult<SocketAddr> { S: Unpin,
self.inner.get_ref().peer_addr() F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
self.inner.get_mut().context = ctx as *mut _ as *mut ();
let mut g = compat::Guard(&mut self.inner);
f(&mut (g.0))
} }
}
impl<T> Stream for WebSocketStream<T> /// Returns a shared reference to the inner stream.
where pub fn get_ref(&self) -> &S
T: AsyncRead + AsyncWrite, where
{ S: AsyncRead + AsyncWrite + Unpin,
type Item = Message; {
type Error = WsError; &self.inner.get_ref().get_ref()
fn poll(&mut self) -> Poll<Option<Message>, WsError> {
self.inner
.read_message()
.map(Some)
.to_async()
.or_else(|err| match err {
WsError::ConnectionClosed => Ok(Async::Ready(None)),
err => Err(err),
})
} }
}
impl<T> Sink for WebSocketStream<T> /// Returns a mutable reference to the inner stream.
where pub fn get_mut(&mut self) -> &mut S
T: AsyncRead + AsyncWrite, where
{ S: AsyncRead + AsyncWrite + Unpin,
type SinkItem = Message; {
type SinkError = WsError; self.inner.get_mut().get_mut()
fn start_send(&mut self, item: Message) -> StartSend<Message, WsError> {
self.inner.write_message(item).to_start_send()
} }
fn poll_complete(&mut self) -> Poll<(), WsError> { /// Send a message to this websocket
self.inner.write_pending().to_async() pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncWrite + AsyncRead + Unpin,
{
let f = SendFuture {
stream: self,
message: Some(msg),
};
f.await
} }
fn close(&mut self) -> Poll<(), WsError> { /// Close the underlying web socket
self.inner.close(None).to_async() pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let f = CloseFuture {
stream: self,
message: Some(msg),
};
f.await
} }
} }
/// Future returned from client_async() which will resolve impl<T> Stream for WebSocketStream<T>
/// once the connection handshake has finished. where
pub struct ConnectAsync<S: AsyncRead + AsyncWrite> { T: AsyncRead + AsyncWrite + Unpin,
inner: MidHandshake<ClientHandshake<S>>, AllowStd<T>: Read + Write,
} {
type Item = Result<Message, WsError>;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
type Item = (WebSocketStream<S>, Response); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
type Error = WsError; trace!("{}:{} Stream.poll_next", file!(), line!());
match futures::ready!(self.with_context(cx, |s| {
fn poll(&mut self) -> Poll<Self::Item, WsError> { trace!(
match self.inner.poll()? { "{}:{} Stream.with_context poll_next -> read_message()",
Async::NotReady => Ok(Async::NotReady), file!(),
Async::Ready((ws, resp)) => Ok(Async::Ready((WebSocketStream::new(ws), resp))), line!()
);
cvt(s.read_message())
})) {
Ok(v) => Poll::Ready(Some(Ok(v))),
Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
} }
} }
} }
/// Future returned from accept_async() which will resolve #[pin_project]
/// once the connection handshake has finished. struct SendFuture<'a, T> {
pub struct AcceptAsync<S: AsyncRead + AsyncWrite, C: Callback> { stream: &'a mut WebSocketStream<T>,
inner: MidHandshake<ServerHandshake<S, C>>, message: Option<Message>,
} }
impl<S: AsyncRead + AsyncWrite, C: Callback> Future for AcceptAsync<S, C> { impl<'a, T> Future for SendFuture<'a, T>
type Item = WebSocketStream<S>; where
type Error = WsError; T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
{
type Output = Result<(), WsError>;
fn poll(&mut self) -> Poll<Self::Item, WsError> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.inner.poll()? { let this = self.project();
Async::NotReady => Ok(Async::NotReady), let message = this.message.take().expect("Cannot poll twice");
Async::Ready(ws) => Ok(Async::Ready(WebSocketStream::new(ws))), Poll::Ready(this.stream.with_context(cx, |s| s.write_message(message)))
}
} }
} }
struct MidHandshake<H: HandshakeRole> { #[pin_project]
inner: Option<Result<<H as HandshakeRole>::FinalResult, HandshakeError<H>>>, struct CloseFuture<'a, T> {
stream: &'a mut WebSocketStream<T>,
message: Option<Option<CloseFrame<'a>>>,
} }
impl<H: HandshakeRole> Future for MidHandshake<H> { impl<'a, T> Future for CloseFuture<'a, T>
type Item = <H as HandshakeRole>::FinalResult; where
type Error = WsError; T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
fn poll(&mut self) -> Poll<Self::Item, WsError> { {
match self.inner.take().expect("cannot poll MidHandshake twice") { type Output = Result<(), WsError>;
Ok(result) => Ok(Async::Ready(result)),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Interrupted(s)) => match s.handshake() {
Ok(result) => Ok(Async::Ready(result)),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Interrupted(s)) => {
self.inner = Some(Err(HandshakeError::Interrupted(s)));
Ok(Async::NotReady)
}
},
}
}
}
trait ToAsync {
type T;
type E;
fn to_async(self) -> Result<Async<Self::T>, Self::E>;
}
impl<T> ToAsync for Result<T, WsError> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
type T = T; let this = self.project();
type E = WsError; let message = this.message.take().expect("Cannot poll twice");
fn to_async(self) -> Result<Async<Self::T>, Self::E> { Poll::Ready(this.stream.with_context(cx, |s| s.close(message)))
match self {
Ok(x) => Ok(Async::Ready(x)),
Err(error) => match error {
WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(Async::NotReady),
err => Err(err),
},
}
} }
} }
trait ToStartSend { #[cfg(test)]
type T; mod tests {
type E; use crate::compat::AllowStd;
fn to_start_send(self) -> StartSend<Self::T, Self::E>; use crate::connect::encryption::AutoStream;
} use crate::WebSocketStream;
use std::io::{Read, Write};
impl ToStartSend for Result<(), WsError> { use tokio_io::{AsyncReadExt, AsyncWriteExt};
type T = Message;
type E = WsError; fn is_read<T: Read>() {}
fn to_start_send(self) -> StartSend<Self::T, Self::E> { fn is_write<T: Write>() {}
match self { fn is_async_read<T: AsyncReadExt>() {}
Ok(_) => Ok(AsyncSink::Ready), fn is_async_write<T: AsyncWriteExt>() {}
Err(error) => match error { fn is_unpin<T: Unpin>() {}
WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(AsyncSink::Ready),
WsError::SendQueueFull(msg) => Ok(AsyncSink::NotReady(msg)), #[test]
err => Err(err), fn web_socket_stream_has_traits() {
}, is_read::<AllowStd<tokio::net::TcpStream>>();
} is_write::<AllowStd<tokio::net::TcpStream>>();
is_async_read::<AutoStream<tokio::net::TcpStream>>();
is_async_write::<AutoStream<tokio::net::TcpStream>>();
is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
is_unpin::<WebSocketStream<AutoStream<tokio::net::TcpStream>>>();
is_unpin::<WebSocketStream<AutoStream<tokio_dns::TcpStream>>>();
} }
} }

@ -3,26 +3,11 @@
//! There is no dependency on actual TLS implementations. Everything like //! There is no dependency on actual TLS implementations. Everything like
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits. //! `Read + Write` traits.
use std::pin::Pin;
use std::task::{Context, Poll};
use std::io::{Error as IoError, Read, Result as IoResult, Write};
use std::net::SocketAddr;
use bytes::{Buf, BufMut};
use futures::Poll;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
/// Trait to switch TCP_NODELAY.
pub trait NoDelay {
/// Set the TCP_NODELAY option to the given value.
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>;
}
/// Trait to get the remote address from the underlying stream.
pub trait PeerAddr {
/// Returns the remote address that this stream is connected to.
fn peer_addr(&self) -> IoResult<SocketAddr>;
}
/// Stream, either plain TCP or TLS. /// Stream, either plain TCP or TLS.
pub enum Stream<S, T> { pub enum Stream<S, T> {
/// Unencrypted socket stream. /// Unencrypted socket stream.
@ -31,74 +16,72 @@ pub enum Stream<S, T> {
Tls(T), Tls(T),
} }
impl<S: Read, T: Read> Read for Stream<S, T> { impl<S: AsyncRead + Unpin, T: AsyncRead + Unpin> AsyncRead for Stream<S, T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { fn poll_read(
match *self { mut self: Pin<&mut Self>,
Stream::Plain(ref mut s) => s.read(buf), cx: &mut Context<'_>,
Stream::Tls(ref mut s) => s.read(buf), buf: &mut [u8],
} ) -> Poll<std::io::Result<usize>> {
}
}
impl<S: Write, T: Write> Write for Stream<S, T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.write(buf),
Stream::Tls(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.flush(),
Stream::Tls(ref mut s) => s.flush(),
}
}
}
impl<S: NoDelay, T: NoDelay> NoDelay for Stream<S, T> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
match *self { match *self {
Stream::Plain(ref mut s) => s.set_nodelay(nodelay), Stream::Plain(ref mut s) => {
Stream::Tls(ref mut s) => s.set_nodelay(nodelay), let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_read(cx, buf)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_read(cx, buf)
}
} }
} }
} }
impl<S: PeerAddr, T: PeerAddr> PeerAddr for Stream<S, T> { impl<S: AsyncWrite + Unpin, T: AsyncWrite + Unpin> AsyncWrite for Stream<S, T> {
fn peer_addr(&self) -> IoResult<SocketAddr> { fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match *self { match *self {
Stream::Plain(ref s) => s.peer_addr(), Stream::Plain(ref mut s) => {
Stream::Tls(ref s) => s.peer_addr(), let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_write(cx, buf)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_write(cx, buf)
}
} }
} }
}
impl<S: AsyncRead, T: AsyncRead> AsyncRead for Stream<S, T> { fn poll_flush(
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match *self { match *self {
Stream::Plain(ref s) => s.prepare_uninitialized_buffer(buf), Stream::Plain(ref mut s) => {
Stream::Tls(ref s) => s.prepare_uninitialized_buffer(buf), let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_flush(cx)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_flush(cx)
}
} }
} }
fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, IoError> {
match *self {
Stream::Plain(ref mut s) => s.read_buf(buf),
Stream::Tls(ref mut s) => s.read_buf(buf),
}
}
}
impl<S: AsyncWrite, T: AsyncWrite> AsyncWrite for Stream<S, T> { fn poll_shutdown(
fn shutdown(&mut self) -> Poll<(), IoError> { mut self: Pin<&mut Self>,
match *self { cx: &mut Context<'_>,
Stream::Plain(ref mut s) => s.shutdown(), ) -> Poll<Result<(), std::io::Error>> {
Stream::Tls(ref mut s) => s.shutdown(),
}
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, IoError> {
match *self { match *self {
Stream::Plain(ref mut s) => s.write_buf(buf), Stream::Plain(ref mut s) => {
Stream::Tls(ref mut s) => s.write_buf(buf), let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_shutdown(cx)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_shutdown(cx)
}
} }
} }
} }

@ -0,0 +1,79 @@
use futures::StreamExt;
use log::*;
use std::net::ToSocketAddrs;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::tcp::{TcpListener, TcpStream};
use tokio_tungstenite::{accept_async, client_async, WebSocketStream};
use tungstenite::Message;
async fn run_connection<S>(
connection: WebSocketStream<S>,
msg_tx: futures::channel::oneshot::Sender<Vec<Message>>,
) where
S: AsyncRead + AsyncWrite + Unpin,
{
info!("Running connection");
let mut connection = connection;
let mut messages = vec![];
while let Some(message) = connection.next().await {
info!("Message received");
let message = message.expect("Failed to get message");
messages.push(message);
}
msg_tx.send(messages).expect("Failed to send results");
}
#[tokio::test]
async fn communication() {
let _ = env_logger::try_init();
let (con_tx, con_rx) = futures::channel::oneshot::channel();
let (msg_tx, msg_rx) = futures::channel::oneshot::channel();
let f = async move {
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let listener = TcpListener::bind(&address).await.unwrap();
let mut connections = listener.incoming();
info!("Server ready");
con_tx.send(()).unwrap();
info!("Waiting on next connection");
let connection = connections.next().await.expect("No connections to accept");
let connection = connection.expect("Failed to accept connection");
let stream = accept_async(connection).await;
let stream = stream.expect("Failed to handshake with connection");
run_connection(stream, msg_tx).await;
};
tokio::spawn(f);
info!("Waiting for server to be ready");
con_rx.await.expect("Server not ready");
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let tcp = TcpStream::connect(&address)
.await
.expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap();
let (mut stream, _) = client_async(url, tcp)
.await
.expect("Client failed to connect");
for i in 1..10 {
info!("Sending message");
stream.send(Message::Text(format!("{}", i))).await.expect("Failed to send message");
}
stream.close(None).await.expect("Failed to close");
info!("Waiting for response messages");
let messages = msg_rx.await.expect("Failed to receive messages");
assert_eq!(messages.len(), 10);
}

@ -1,36 +1,41 @@
use std::io; use futures::StreamExt;
use std::net::ToSocketAddrs;
use futures::{Future, Stream}; use tokio::net::tcp::{TcpListener, TcpStream};
use tokio_tcp::{TcpListener, TcpStream};
use tokio_tungstenite::{accept_async, client_async}; use tokio_tungstenite::{accept_async, client_async};
#[test] #[tokio::test]
fn handshakes() { async fn handshakes() {
use std::sync::mpsc::channel; let (tx, rx) = futures::channel::oneshot::channel();
use std::thread;
let (tx, rx) = channel();
thread::spawn(move || { let f = async move {
let address = "0.0.0.0:12345".parse().unwrap(); let address = "0.0.0.0:12345"
let listener = TcpListener::bind(&address).unwrap(); .to_socket_addrs()
let connections = listener.incoming(); .expect("Not a valid address")
.next()
.expect("No address resolved");
let listener = TcpListener::bind(&address).await.unwrap();
let mut connections = listener.incoming();
tx.send(()).unwrap(); tx.send(()).unwrap();
let handshakes = connections.and_then(|connection| { while let Some(connection) = connections.next().await {
accept_async(connection).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) let connection = connection.expect("Failed to accept connection");
}); let stream = accept_async(connection).await;
let server = handshakes.for_each(|_| Ok(())); stream.expect("Failed to handshake with connection");
}
};
server.wait().unwrap(); tokio::spawn(f);
});
rx.recv().unwrap(); rx.await.expect("Failed to wait for server to be ready");
let address = "0.0.0.0:12345".parse().unwrap(); let address = "0.0.0.0:12345"
let tcp = TcpStream::connect(&address); .to_socket_addrs()
let handshake = tcp.and_then(|stream| { .expect("Not a valid address")
let url = url::Url::parse("ws://localhost:12345/").unwrap(); .next()
client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .expect("No address resolved");
}); let tcp = TcpStream::connect(&address)
let client = handshake.and_then(|_| Ok(())); .await
client.wait().unwrap(); .expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap();
let _stream = client_async(url, tcp)
.await
.expect("Client failed to connect");
} }

Loading…
Cancel
Save