From 3821e0952a3e42462e69af1ff5870d087c9ae2aa Mon Sep 17 00:00:00 2001 From: Danny Browning Date: Mon, 9 Sep 2019 17:39:14 -0600 Subject: [PATCH] Tokio 0.2 Conversion Update to use tokio 0.2 ecosystem to integrate with tungstenite. --- .travis.yml | 2 + Cargo.toml | 26 +-- examples/autobahn-client.rs | 58 +++--- examples/autobahn-server.rs | 66 +++---- examples/client.rs | 69 +++---- examples/server.rs | 165 ++++++++-------- src/compat.rs | 123 ++++++++++++ src/connect.rs | 131 ++++--------- src/handshake.rs | 181 ++++++++++++++++++ src/lib.rs | 363 +++++++++++++++++++----------------- src/stream.rs | 125 ++++++------- tests/communication.rs | 79 ++++++++ tests/handshakes.rs | 63 ++++--- 13 files changed, 876 insertions(+), 575 deletions(-) create mode 100644 src/compat.rs create mode 100644 src/handshake.rs create mode 100644 tests/communication.rs diff --git a/.travis.yml b/.travis.yml index 40f0edc..f441fe5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,6 @@ language: rust +rust: + - nightly-2019-09-05 before_script: - export PATH="$PATH:$HOME/.cargo/bin" diff --git a/Cargo.toml b/Cargo.toml index b787b41..a34d1d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,21 +8,25 @@ license = "MIT" homepage = "https://github.com/snapview/tokio-tungstenite" documentation = "https://docs.rs/tokio-tungstenite/0.9.0" repository = "https://github.com/snapview/tokio-tungstenite" -version = "0.9.0" +version = "0.10.0-alpha.1" edition = "2018" [features] default = ["connect", "tls"] -connect = ["tokio-dns-unofficial", "tokio-tcp", "stream"] +connect = ["tokio-dns-unofficial", "tokio-net", "stream"] tls = ["tokio-tls", "native-tls", "stream", "tungstenite/tls"] stream = ["bytes"] [dependencies] -futures = "0.1.23" -tokio-io = "0.1.7" +log = "0.4" +futures-preview = { version = "0.3.0-alpha.19", features = ["async-await"] } +pin-project = "0.4.0-alpha.9" +tokio-io = "0.2.0-alpha.6" [dependencies.tungstenite] -version = "0.9.1" +#version = "0.9.1" +git = "https://github.com/snapview/tungstenite-rs.git" +branch = "master" default-features = false [dependencies.bytes] @@ -35,18 +39,18 @@ version = "0.2.0" [dependencies.tokio-dns-unofficial] optional = true -version = "0.3.1" +#version = "0.4.0" +git = "https://github.com/sbstp/tokio-dns.git" -[dependencies.tokio-tcp] +[dependencies.tokio-net] optional = true -version = "0.1.0" +version = "0.2.0-alpha.6" [dependencies.tokio-tls] optional = true -version = "0.2.0" +version = "0.3.0-alpha.6" [dev-dependencies] -tokio = "0.1.7" +tokio = "0.2.0-alpha.6" url = "2.0.0" env_logger = "0.6.1" -log = "0.4.6" diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 172cd5a..a4b2684 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -1,33 +1,32 @@ -use futures::{Future, Stream}; +use futures::StreamExt; use log::*; -use tokio_tungstenite::{ - connect_async, - tungstenite::{connect, Error as WsError, Result}, -}; +use tokio_tungstenite::{connect_async, tungstenite::Result}; use url::Url; const AGENT: &'static str = "Tungstenite"; -fn get_case_count() -> Result { - let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; - let msg = socket.read_message()?; - socket.close(None)?; +async fn get_case_count() -> Result { + let (mut socket, _) = + connect_async(Url::parse("ws://localhost:9001/getCaseCount").unwrap()).await?; + let msg = socket.next().await.unwrap()?; + socket.close(None).await?; Ok(msg.into_text()?.parse::().unwrap()) } -fn update_reports() -> Result<()> { - let (mut socket, _) = connect( +async fn update_reports() -> Result<()> { + let (mut socket, _) = connect_async( Url::parse(&format!( "ws://localhost:9001/updateReports?agent={}", AGENT )) .unwrap(), - )?; - socket.close(None)?; + ) + .await?; + socket.close(None).await?; Ok(()) } -fn run_test(case: u32) { +async fn run_test(case: u32) { info!("Running test case {}", case); let case_url = Url::parse(&format!( "ws://localhost:9001/runCase?case={}&agent={}", @@ -35,31 +34,24 @@ fn run_test(case: u32) { )) .unwrap(); - let job = connect_async(case_url) - .map_err(|err| error!("Connect error: {}", err)) - .and_then(|(ws_stream, _)| { - let (sink, stream) = ws_stream.split(); - stream - .filter(|msg| msg.is_text() || msg.is_binary()) - .forward(sink) - .and_then(|(_stream, _sink)| Ok(())) - .map_err(|err| match err { - WsError::ConnectionClosed => (), - err => info!("WS error {}", err), - }) - }); - - tokio::run(job) + let (mut ws_stream, _) = connect_async(case_url).await.expect("Connect error"); + while let Some(msg) = ws_stream.next().await { + let msg = msg.expect("Failed to get message"); + if msg.is_text() || msg.is_binary() { + ws_stream.send(msg).await.expect("Write error"); + } + } } -fn main() { +#[tokio::main] +async fn main() { env_logger::init(); - let total = get_case_count().unwrap(); + let total = get_case_count().await.unwrap(); for case in 1..(total + 1) { - run_test(case) + run_test(case).await } - update_reports().unwrap(); + update_reports().await.unwrap(); } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 23505d2..f3b40ab 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,42 +1,42 @@ -use futures::{Future, Stream}; +use futures::StreamExt; use log::*; -use tokio::net::TcpListener; -use tokio_tungstenite::{accept_async, tungstenite::Error as WsError}; +use std::net::{SocketAddr, ToSocketAddrs}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::accept_async; -fn main() { - env_logger::init(); +async fn accept_connection(peer: SocketAddr, stream: TcpStream) { + let mut ws_stream = accept_async(stream).await.expect("Failed to accept"); - let mut runtime = tokio::runtime::Builder::new().build().unwrap(); + info!("New WebSocket connection: {}", peer); - let addr = "127.0.0.1:9002".parse().unwrap(); - let socket = TcpListener::bind(&addr).unwrap(); - info!("Listening on: {}", addr); + while let Some(msg) = ws_stream.next().await { + let msg = msg.expect("Failed to get request"); + if msg.is_text() || msg.is_binary() { + ws_stream.send(msg).await.expect("Failed to send response"); + } + } +} - let srv = socket - .incoming() - .map_err(Into::into) - .for_each(move |stream| { - let peer = stream - .peer_addr() - .expect("connected streams should have a peer address"); - info!("Peer address: {}", peer); +#[tokio::main] +async fn main() { + env_logger::init(); - accept_async(stream).and_then(move |ws_stream| { - info!("New WebSocket connection: {}", peer); - let (sink, stream) = ws_stream.split(); - let job = stream - .filter(|msg| msg.is_text() || msg.is_binary()) - .forward(sink) - .and_then(|(_stream, _sink)| Ok(())) - .map_err(|err| match err { - WsError::ConnectionClosed => (), - err => info!("WS error: {}", err), - }); + let addr = "127.0.0.1:9002" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("Not a socket address"); + let socket = TcpListener::bind(&addr).await.unwrap(); + let mut incoming = socket.incoming(); + info!("Listening on: {}", addr); - tokio::spawn(job); - Ok(()) - }) - }); + while let Some(stream) = incoming.next().await { + let stream = stream.expect("Failed to get stream"); + let peer = stream + .peer_addr() + .expect("connected streams should have a peer address"); + info!("Peer address: {}", peer); - runtime.block_on(srv).unwrap(); + tokio::spawn(accept_connection(peer, stream)); + } } diff --git a/examples/client.rs b/examples/client.rs index 00eaa35..f98e4c5 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -11,17 +11,19 @@ //! You can use this example together with the `server` example. use std::env; -use std::io::{self, Read, Write}; -use std::thread; +use std::io::{self, Write}; -use futures::sync::mpsc; -use futures::{Future, Sink, Stream}; +use futures::StreamExt; +use log::*; use tungstenite::protocol::Message; +use tokio::io::AsyncReadExt; use tokio_tungstenite::connect_async; -use tokio_tungstenite::stream::PeerAddr; -fn main() { +#[tokio::main] +async fn main() { + let _ = env_logger::try_init(); + // Specify the server address to which the client will be connecting. let connect_addr = env::args() .nth(1) @@ -33,9 +35,8 @@ fn main() { // loop, so we farm out that work to a separate thread. This thread will // read data from stdin and then send it to the event loop over a standard // futures channel. - let (stdin_tx, stdin_rx) = mpsc::channel(0); - thread::spawn(|| read_stdin(stdin_tx)); - let stdin_rx = stdin_rx.map_err(|_| panic!()); // errors not possible on rx + let (stdin_tx, mut stdin_rx) = futures::channel::mpsc::unbounded(); + tokio::spawn(read_stdin(stdin_tx)); // After the TCP connection has been established, we set up our client to // start forwarding data. @@ -53,53 +54,29 @@ fn main() { // finishes. If we don't have any more data to read or we won't receive any // more work from the remote then we can exit. let mut stdout = io::stdout(); - let client = connect_async(url) - .and_then(move |(ws_stream, _)| { - println!("WebSocket handshake has been successfully completed"); - - let addr = ws_stream - .peer_addr() - .expect("connected streams should have a peer address"); - println!("Peer address: {}", addr); - - // `sink` is the stream of messages going out. - // `stream` is the stream of incoming messages. - let (sink, stream) = ws_stream.split(); + let (mut ws_stream, _) = connect_async(url).await.expect("Failed to connect"); + info!("WebSocket handshake has been successfully completed"); - // We forward all messages, composed out of the data, entered to - // the stdin, to the `sink`. - let send_stdin = stdin_rx.forward(sink); - let write_stdout = stream.for_each(move |message| { - stdout.write_all(&message.into_data()).unwrap(); - Ok(()) - }); - - // Wait for either of futures to complete. - send_stdin - .map(|_| ()) - .select(write_stdout.map(|_| ())) - .then(|_| Ok(())) - }) - .map_err(|e| { - println!("Error during the websocket handshake occurred: {}", e); - io::Error::new(io::ErrorKind::Other, e) - }); - - // And now that we've got our client, we execute it in the event loop! - tokio::runtime::run(client.map_err(|_e| ())); + while let Some(msg) = stdin_rx.next().await { + ws_stream.send(msg).await.expect("Failed to send request"); + if let Some(msg) = ws_stream.next().await { + let msg = msg.expect("Failed to get response"); + stdout.write_all(&msg.into_data()).unwrap(); + } + } } // Our helper method which will read data from stdin and send it along the // sender provided. -fn read_stdin(mut tx: mpsc::Sender) { - let mut stdin = io::stdin(); +async fn read_stdin(tx: futures::channel::mpsc::UnboundedSender) { + let mut stdin = tokio::io::stdin(); loop { let mut buf = vec![0; 1024]; - let n = match stdin.read(&mut buf) { + let n = match stdin.read(&mut buf).await { Err(_) | Ok(0) => break, Ok(n) => n, }; buf.truncate(n); - tx = tx.send(Message::binary(buf)).wait().unwrap(); + tx.unbounded_send(Message::binary(buf)).unwrap(); } } diff --git a/examples/server.rs b/examples/server.rs index 82fe890..5ea221b 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -17,106 +17,87 @@ //! connected clients they'll all join the same room and see everyone else's //! messages. -use std::collections::HashMap; use std::env; -use std::io::{Error, ErrorKind}; -use std::sync::{Arc, Mutex}; +use std::io::Error; -use futures::stream::Stream; -use futures::Future; -use tokio::net::TcpListener; +use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; +use futures::StreamExt; +use log::*; +use std::net::{SocketAddr, ToSocketAddrs}; +use tokio::net::{TcpListener, TcpStream}; use tungstenite::protocol::Message; -use tokio_tungstenite::accept_async; +struct Connection { + addr: SocketAddr, + rx: UnboundedReceiver, + tx: UnboundedSender, +} + +async fn handle_connection(connection: Connection) { + let mut connection = connection; + while let Some(msg) = connection.rx.next().await { + info!("Received a message from {}: {}", connection.addr, msg); + connection + .tx + .unbounded_send(msg) + .expect("Failed to forward message"); + } +} -fn main() { +async fn accept_connection(stream: TcpStream) { + let addr = stream + .peer_addr() + .expect("connected streams should have a peer address"); + info!("Peer address: {}", addr); + + let mut ws_stream = tokio_tungstenite::accept_async(stream) + .await + .expect("Error during the websocket handshake occurred"); + + info!("New WebSocket connection: {}", addr); + + // Create a channel for our stream, which other sockets will use to + // send us messages. Then register our address with the stream to send + // data to us. + let (msg_tx, msg_rx) = futures::channel::mpsc::unbounded(); + let (response_tx, mut response_rx) = futures::channel::mpsc::unbounded(); + let c = Connection { + addr: addr, + rx: msg_rx, + tx: response_tx, + }; + tokio::spawn(handle_connection(c)); + + while let Some(message) = ws_stream.next().await { + let message = message.expect("Failed to get request"); + msg_tx + .unbounded_send(message) + .expect("Failed to forward request"); + if let Some(resp) = response_rx.next().await { + ws_stream.send(resp).await.expect("Failed to send response"); + } + } +} +#[tokio::main] +async fn main() -> Result<(), Error> { + let _ = env_logger::try_init(); let addr = env::args().nth(1).unwrap_or("127.0.0.1:8080".to_string()); - let addr = addr.parse().unwrap(); + let addr = addr + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("Not a socket address"); // Create the event loop and TCP listener we'll accept connections on. - let socket = TcpListener::bind(&addr).unwrap(); - println!("Listening on: {}", addr); - - // Tokio Runtime uses a thread pool based executor by default, so we need - // to use Arc and Mutex to store the map of all connections we know about. - let connections = Arc::new(Mutex::new(HashMap::new())); - - let srv = socket.incoming().for_each(move |stream| { - let addr = stream - .peer_addr() - .expect("connected streams should have a peer address"); - println!("Peer address: {}", addr); - - // We have to clone both of these values, because the `and_then` - // function below constructs a new future, `and_then` requires - // `FnOnce`, so we construct a move closure to move the - // environment inside the future (AndThen future may overlive our - // `for_each` future). - let connections_inner = connections.clone(); - - accept_async(stream) - .and_then(move |ws_stream| { - println!("New WebSocket connection: {}", addr); - - // Create a channel for our stream, which other sockets will use to - // send us messages. Then register our address with the stream to send - // data to us. - let (tx, rx) = futures::sync::mpsc::unbounded(); - connections_inner.lock().unwrap().insert(addr, tx); - - // Let's split the WebSocket stream, so we can work with the - // reading and writing halves separately. - let (sink, stream) = ws_stream.split(); - - // Whenever we receive a message from the client, we print it and - // send to other clients, excluding the sender. - let connections = connections_inner.clone(); - let ws_reader = stream.for_each(move |message: Message| { - println!("Received a message from {}: {}", addr, message); - - // For each open connection except the sender, send the - // string via the channel. - let mut conns = connections.lock().unwrap(); - let iter = conns - .iter_mut() - .filter(|&(&k, _)| k != addr) - .map(|(_, v)| v); - for tx in iter { - tx.unbounded_send(message.clone()).unwrap(); - } - Ok(()) - }); - - // Whenever we receive a string on the Receiver, we write it to - // `WriteHalf`. - let ws_writer = rx.fold(sink, |mut sink, msg| { - use futures::Sink; - sink.start_send(msg).unwrap(); - Ok(sink) - }); - - // Now that we've got futures representing each half of the socket, we - // use the `select` combinator to wait for either half to be done to - // tear down the other. Then we spawn off the result. - let connection = ws_reader - .map(|_| ()) - .map_err(|_| ()) - .select(ws_writer.map(|_| ()).map_err(|_| ())); - - tokio::spawn(connection.then(move |_| { - connections_inner.lock().unwrap().remove(&addr); - println!("Connection {} closed.", addr); - Ok(()) - })); + let try_socket = TcpListener::bind(&addr).await; + let socket = try_socket.expect("Failed to bind"); + let mut incoming = socket.incoming(); + info!("Listening on: {}", addr); - Ok(()) - }) - .map_err(|e| { - println!("Error during the websocket handshake occurred: {}", e); - Error::new(ErrorKind::Other, e) - }) - }); + while let Some(stream) = incoming.next().await { + let stream = stream.expect("Failed to accept stream"); + tokio::spawn(accept_connection(stream)); + } - // Execute server. - tokio::runtime::run(srv.map_err(|_e| ())); + Ok(()) } diff --git a/src/compat.rs b/src/compat.rs new file mode 100644 index 0000000..218a338 --- /dev/null +++ b/src/compat.rs @@ -0,0 +1,123 @@ +use log::*; +use std::io::{Read, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio_io::{AsyncRead, AsyncWrite}; +use tungstenite::{Error as WsError, WebSocket}; + +pub(crate) trait HasContext { + fn set_context(&mut self, context: *mut ()); +} +#[derive(Debug)] +pub struct AllowStd { + pub(crate) inner: S, + pub(crate) context: *mut (), +} + +impl HasContext for AllowStd { + fn set_context(&mut self, context: *mut ()) { + self.context = context; + } +} + +pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket>); + +impl Drop for Guard<'_, S> { + fn drop(&mut self) { + trace!("{}:{} Guard.drop", file!(), line!()); + (self.0).get_mut().context = std::ptr::null_mut(); + } +} + +// *mut () context is neither Send nor Sync +unsafe impl Send for AllowStd {} +unsafe impl Sync for AllowStd {} + +impl AllowStd +where + S: Unpin, +{ + fn with_context(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, + { + trace!("{}:{} AllowStd.with_context", file!(), line!()); + unsafe { + assert!(!self.context.is_null()); + let waker = &mut *(self.context as *mut _); + f(waker, Pin::new(&mut self.inner)) + } + } + + pub(crate) fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + pub(crate) fn get_ref(&self) -> &S { + &self.inner + } +} + +impl Read for AllowStd +where + S: AsyncRead + Unpin, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + trace!("{}:{} Read.read", file!(), line!()); + match self.with_context(|ctx, stream| { + trace!( + "{}:{} Read.with_context read -> poll_read", + file!(), + line!() + ); + stream.poll_read(ctx, buf) + }) { + Poll::Ready(r) => r, + Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + } + } +} + +impl Write for AllowStd +where + S: AsyncWrite + Unpin, +{ + fn write(&mut self, buf: &[u8]) -> std::io::Result { + trace!("{}:{} Write.write", file!(), line!()); + match self.with_context(|ctx, stream| { + trace!( + "{}:{} Write.with_context write -> poll_write", + file!(), + line!() + ); + stream.poll_write(ctx, buf) + }) { + Poll::Ready(r) => r, + Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + trace!("{}:{} Write.flush", file!(), line!()); + match self.with_context(|ctx, stream| { + trace!( + "{}:{} Write.with_context flush -> poll_flush", + file!(), + line!() + ); + stream.poll_flush(ctx) + }) { + Poll::Ready(r) => r, + Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + } + } +} + +pub(crate) fn cvt(r: Result) -> Poll> { + match r { + Ok(v) => Poll::Ready(Ok(v)), + Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } +} diff --git a/src/connect.rs b/src/connect.rs index 00e4fb3..458dc85 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,82 +1,50 @@ //! Connection helper. - -use std::io::Result as IoResult; -use std::net::SocketAddr; - -use tokio_tcp::TcpStream; - -use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_net::tcp::TcpStream; use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; use tungstenite::Error; use super::{client_async, Request, WebSocketStream}; -use crate::stream::{NoDelay, PeerAddr}; - -impl NoDelay for TcpStream { - fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { - TcpStream::set_nodelay(self, nodelay) - } -} - -impl PeerAddr for TcpStream { - fn peer_addr(&self) -> IoResult { - self.peer_addr() - } -} #[cfg(feature = "tls")] -mod encryption { +pub(crate) mod encryption { use native_tls::TlsConnector; use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream}; - use std::io::{Read, Result as IoResult, Write}; - use std::net::SocketAddr; - - use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::stream::Mode; use tungstenite::Error; - use crate::stream::{NoDelay, PeerAddr, Stream as StreamSwitcher}; + use crate::stream::Stream as StreamSwitcher; /// A stream that might be protected with TLS. pub type MaybeTlsStream = StreamSwitcher>; pub type AutoStream = MaybeTlsStream; - impl NoDelay for TlsStream { - fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - } - - impl PeerAddr for TlsStream { - fn peer_addr(&self) -> IoResult { - self.get_ref().get_ref().peer_addr() - } - } - - pub fn wrap_stream( + pub async fn wrap_stream( socket: S, domain: String, mode: Mode, - ) -> Box, Error = Error> + Send> + ) -> Result, Error> where - S: 'static + AsyncRead + AsyncWrite + Send, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, { match mode { - Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))), - Mode::Tls => Box::new( - future::result(TlsConnector::new()) - .map(TokioTlsConnector::from) - .and_then(move |connector| connector.connect(&domain, socket)) - .map(StreamSwitcher::Tls) - .map_err(Error::Tls), - ), + Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Tls => { + let try_connector = TlsConnector::new(); + let connector = try_connector.map_err(Error::Tls)?; + let stream = TokioTlsConnector::from(connector); + let connected = stream.connect(&domain, socket).await; + match connected { + Err(e) => Err(Error::Tls(e)), + Ok(s) => Ok(StreamSwitcher::Tls(s)), + } + } } } } @@ -85,7 +53,7 @@ mod encryption { pub use self::encryption::MaybeTlsStream; #[cfg(not(feature = "tls"))] -mod encryption { +pub(crate) mod encryption { use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -94,19 +62,17 @@ mod encryption { pub type AutoStream = S; - pub fn wrap_stream( + pub async fn wrap_stream( socket: S, _domain: String, mode: Mode, - ) -> Box, Error = Error> + Send> + ) -> Result, Error> where - S: 'static + AsyncRead + AsyncWrite + Send, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, { match mode { - Mode::Plain => Box::new(future::ok(socket)), - Mode::Tls => Box::new(future::err(Error::Url( - "TLS support not compiled in.".into(), - ))), + Mode::Plain => Ok(socket), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), } } } @@ -124,59 +90,42 @@ fn domain(request: &Request) -> Result { /// Creates a WebSocket handshake from a request and a stream, /// upgrading the stream to TLS if required. -pub fn client_async_tls( +pub async fn client_async_tls( request: R, stream: S, -) -> Box>, Response), Error = Error> + Send> +) -> Result<(WebSocketStream>, Response), Error> where - R: Into>, - S: 'static + AsyncRead + AsyncWrite + NoDelay + Send, + R: Into> + Unpin, + S: 'static + AsyncRead + AsyncWrite + Send + Unpin, + AutoStream: Unpin, { let request: Request = request.into(); - let domain = match domain(&request) { - Ok(domain) => domain, - Err(err) => return Box::new(future::err(err)), - }; + let domain = domain(&request)?; // Make sure we check domain and mode first. URL must be valid. - let mode = match url_mode(&request.url) { - Ok(m) => m, - Err(e) => return Box::new(future::err(e)), - }; - - Box::new( - wrap_stream(stream, domain, mode) - .and_then(|mut stream| { - NoDelay::set_nodelay(&mut stream, true) - .map(move |()| stream) - .map_err(|e| e.into()) - }) - .and_then(move |stream| client_async(request, stream)), - ) + let mode = url_mode(&request.url)?; + + let stream = wrap_stream(stream, domain, mode).await?; + client_async(request, stream).await } /// Connect to a given URL. -pub fn connect_async( +pub async fn connect_async( request: R, -) -> Box>, Response), Error = Error> + Send> +) -> Result<(WebSocketStream>, Response), Error> where - R: Into>, + R: Into> + Unpin, { let request: Request = request.into(); - let domain = match domain(&request) { - Ok(domain) => domain, - Err(err) => return Box::new(future::err(err)), - }; + let domain = domain(&request)?; let port = request .url .port_or_known_default() .expect("Bug: port unknown"); - Box::new( - tokio_dns::TcpStream::connect((domain.as_str(), port)) - .map_err(|e| e.into()) - .and_then(move |socket| client_async_tls(request, socket)), - ) + let try_socket = tokio_dns::TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls(request, socket).await } diff --git a/src/handshake.rs b/src/handshake.rs new file mode 100644 index 0000000..9c4738c --- /dev/null +++ b/src/handshake.rs @@ -0,0 +1,181 @@ +use crate::compat::{AllowStd, HasContext}; +use crate::WebSocketStream; +use log::*; +use pin_project::pin_project; +use std::future::Future; +use std::io::{Read, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tungstenite::handshake::client::Response; +use tungstenite::handshake::server::Callback; +use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake}; +use tungstenite::{ClientHandshake, ServerHandshake, WebSocket}; + +pub(crate) async fn without_handshake(stream: S, f: F) -> WebSocketStream +where + F: FnOnce(AllowStd) -> WebSocket> + Unpin, + S: AsyncRead + AsyncWrite + Unpin, +{ + let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream })); + + let ws = start.await; + + WebSocketStream::new(ws) +} + +struct SkippedHandshakeFuture(Option>); +struct SkippedHandshakeFutureInner { + f: F, + stream: S, +} + +impl Future for SkippedHandshakeFuture +where + F: FnOnce(AllowStd) -> WebSocket> + Unpin, + S: Unpin, + AllowStd: Read + Write, +{ + type Output = WebSocket>; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let inner = self + .get_mut() + .0 + .take() + .expect("future polled after completion"); + trace!("Setting context when skipping handshake"); + let stream = AllowStd { + inner: inner.stream, + context: ctx as *mut _ as *mut (), + }; + + Poll::Ready((inner.f)(stream)) + } +} + +#[pin_project] +struct MidHandshake(Option>); + +enum StartedHandshake { + Done(Role::FinalResult), + Mid(WsHandshake), +} + +struct StartedHandshakeFuture(Option>); +struct StartedHandshakeFutureInner { + f: F, + stream: S, +} + +async fn handshake(stream: S, f: F) -> Result> +where + Role: HandshakeRole + Unpin, + Role::InternalStream: HasContext, + F: FnOnce(AllowStd) -> Result> + Unpin, + S: AsyncRead + AsyncWrite + Unpin, +{ + let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); + + match start.await? { + StartedHandshake::Done(r) => Ok(r), + StartedHandshake::Mid(s) => { + let res: Result> = MidHandshake::(Some(s)).await; + res + } + } +} + +pub(crate) async fn client_handshake( + stream: S, + f: F, +) -> Result<(WebSocketStream, Response), Error>>> +where + F: FnOnce( + AllowStd, + ) -> Result< + > as HandshakeRole>::FinalResult, + Error>>, + > + Unpin, + S: AsyncRead + AsyncWrite + Unpin, +{ + let result = handshake(stream, f).await?; + let (s, r) = result; + Ok((WebSocketStream::new(s), r)) +} + +pub(crate) async fn server_handshake( + stream: S, + f: F, +) -> Result, Error, C>>> +where + C: Callback + Unpin, + F: FnOnce( + AllowStd, + ) -> Result< + , C> as HandshakeRole>::FinalResult, + Error, C>>, + > + Unpin, + S: AsyncRead + AsyncWrite + Unpin, +{ + let s: WebSocket> = handshake(stream, f).await?; + Ok(WebSocketStream::new(s)) +} + +impl Future for StartedHandshakeFuture +where + Role: HandshakeRole, + Role::InternalStream: HasContext, + F: FnOnce(AllowStd) -> Result> + Unpin, + S: Unpin, + AllowStd: Read + Write, +{ + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let inner = self.0.take().expect("future polled after completion"); + trace!("Setting ctx when starting handshake"); + let stream = AllowStd { + inner: inner.stream, + context: ctx as *mut _ as *mut (), + }; + + match (inner.f)(stream) { + Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), + Err(Error::Interrupted(mut mid)) => { + let machine = mid.get_mut(); + machine.get_mut().set_context(std::ptr::null_mut()); + Poll::Ready(Ok(StartedHandshake::Mid(mid))) + } + Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), + } + } +} + +impl Future for MidHandshake +where + Role: HandshakeRole + Unpin, + Role::InternalStream: HasContext, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut s = this.0.take().expect("future polled after completion"); + + let machine = s.get_mut(); + trace!("Setting context in handshake"); + machine.get_mut().set_context(cx as *mut _ as *mut ()); + + match s.handshake() { + Ok(stream) => Poll::Ready(Ok(stream)), + Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), + Err(Error::Interrupted(mut mid)) => { + let machine = mid.get_mut(); + machine.get_mut().set_context(std::ptr::null_mut()); + *this.0 = Some(mid); + Poll::Pending + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 469920b..bfade42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,26 +18,29 @@ pub use tungstenite; +mod compat; #[cfg(feature = "connect")] mod connect; - +mod handshake; #[cfg(feature = "stream")] pub mod stream; -use std::io::ErrorKind; - -#[cfg(feature = "stream")] -use std::{io::Result as IoResult, net::SocketAddr}; +use std::io::{Read, Write}; -use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use compat::{cvt, AllowStd}; +use futures::Stream; +use log::*; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::{ error::Error as WsError, handshake::{ client::{ClientHandshake, Request, Response}, - server::{Callback, NoCallback, ServerHandshake}, - HandshakeError, HandshakeRole, + server::{Callback, NoCallback}, }, protocol::{Message, Role, WebSocket, WebSocketConfig}, server, @@ -46,11 +49,10 @@ use tungstenite::{ #[cfg(feature = "connect")] pub use connect::{client_async_tls, connect_async}; -#[cfg(feature = "stream")] -pub use stream::PeerAddr; - #[cfg(all(feature = "connect", feature = "tls"))] pub use connect::MaybeTlsStream; +use std::error::Error; +use tungstenite::protocol::CloseFrame; /// Creates a WebSocket handshake from a request and a stream. /// For convenience, the user may call this with a url string, a URL, @@ -64,30 +66,38 @@ pub use connect::MaybeTlsStream; /// /// This is typically used for clients who have already established, for /// example, a TCP connection to the remote server. -pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync +pub async fn client_async<'a, R, S>( + request: R, + stream: S, +) -> Result<(WebSocketStream, Response), WsError> where - R: Into>, - S: AsyncRead + AsyncWrite, + R: Into> + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { - client_async_with_config(request, stream, None) + client_async_with_config(request, stream, None).await } /// The same as `client_async()` but the one can specify a websocket configuration. /// Please refer to `client_async()` for more details. -pub fn client_async_with_config<'a, R, S>( +pub async fn client_async_with_config<'a, R, S>( request: R, stream: S, config: Option, -) -> ConnectAsync +) -> Result<(WebSocketStream, Response), WsError> where - R: Into>, - S: AsyncRead + AsyncWrite, + R: Into> + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { - ConnectAsync { - inner: MidHandshake { - inner: Some(ClientHandshake::start(stream, request.into(), config).handshake()), - }, - } + let f = handshake::client_handshake(stream, move |allow_std| { + let cli_handshake = ClientHandshake::start(allow_std, request.into(), config); + cli_handshake.handshake() + }); + f.await.map_err(|e| { + WsError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + e.description(), + )) + }) } /// Accepts a new WebSocket connection with the provided stream. @@ -101,23 +111,23 @@ where /// This is typically used after a socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform /// the server half of the accepting a client's websocket connection. -pub fn accept_async(stream: S) -> AcceptAsync +pub async fn accept_async(stream: S) -> Result, WsError> where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { - accept_hdr_async(stream, NoCallback) + accept_hdr_async(stream, NoCallback).await } /// The same as `accept_async()` but the one can specify a websocket configuration. /// Please refer to `accept_async()` for more details. -pub fn accept_async_with_config( +pub async fn accept_async_with_config( stream: S, config: Option, -) -> AcceptAsync +) -> Result, WsError> where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin, { - accept_hdr_async_with_config(stream, NoCallback, config) + accept_hdr_async_with_config(stream, NoCallback, config).await } /// Accepts a new WebSocket connection with the provided stream. @@ -125,30 +135,34 @@ where /// This function does the same as `accept_async()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr_async(stream: S, callback: C) -> AcceptAsync +pub async fn accept_hdr_async(stream: S, callback: C) -> Result, WsError> where - S: AsyncRead + AsyncWrite, - C: Callback, + S: AsyncRead + AsyncWrite + Unpin, + C: Callback + Unpin, { - accept_hdr_async_with_config(stream, callback, None) + accept_hdr_async_with_config(stream, callback, None).await } /// The same as `accept_hdr_async()` but the one can specify a websocket configuration. /// Please refer to `accept_hdr_async()` for more details. -pub fn accept_hdr_async_with_config( +pub async fn accept_hdr_async_with_config( stream: S, callback: C, config: Option, -) -> AcceptAsync +) -> Result, WsError> where - S: AsyncRead + AsyncWrite, - C: Callback, + S: AsyncRead + AsyncWrite + Unpin, + C: Callback + Unpin, { - AcceptAsync { - inner: MidHandshake { - inner: Some(server::accept_hdr_with_config(stream, callback, config)), - }, - } + let f = handshake::server_handshake(stream, move |allow_std| { + server::accept_hdr_with_config(allow_std, callback, config) + }); + f.await.map_err(|e| { + WsError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + e.description(), + )) + }) } /// A wrapper around an underlying raw stream which implements the WebSocket @@ -160,176 +174,187 @@ where /// through the respective `Stream` and `Sink`. Check more information about /// them in `futures-rs` crate documentation or have a look on the examples /// and unit tests for this crate. +#[pin_project] pub struct WebSocketStream { - inner: WebSocket, + #[pin] + inner: WebSocket>, } impl WebSocketStream { /// Convert a raw socket into a WebSocketStream without performing a /// handshake. - pub fn from_raw_socket(stream: S, role: Role, config: Option) -> Self { - Self::new(WebSocket::from_raw_socket(stream, role, config)) + pub async fn from_raw_socket(stream: S, role: Role, config: Option) -> Self + where + S: AsyncRead + AsyncWrite + Unpin, + { + handshake::without_handshake(stream, move |allow_std| { + WebSocket::from_raw_socket(allow_std, role, config) + }) + .await } /// Convert a raw socket into a WebSocketStream without performing a /// handshake. - pub fn from_partially_read( + pub async fn from_partially_read( stream: S, part: Vec, role: Role, config: Option, - ) -> Self { - Self::new(WebSocket::from_partially_read(stream, part, role, config)) + ) -> Self + where + S: AsyncRead + AsyncWrite + Unpin, + { + handshake::without_handshake(stream, move |allow_std| { + WebSocket::from_partially_read(allow_std, part, role, config) + }) + .await } - fn new(ws: WebSocket) -> Self { + pub(crate) fn new(ws: WebSocket>) -> Self { WebSocketStream { inner: ws } } -} -#[cfg(feature = "stream")] -impl PeerAddr for WebSocketStream { - fn peer_addr(&self) -> IoResult { - self.inner.get_ref().peer_addr() + fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R + where + S: Unpin, + F: FnOnce(&mut WebSocket>) -> R, + AllowStd: Read + Write, + { + trace!("{}:{} WebSocketStream.with_context", file!(), line!()); + self.inner.get_mut().context = ctx as *mut _ as *mut (); + let mut g = compat::Guard(&mut self.inner); + f(&mut (g.0)) } -} -impl Stream for WebSocketStream -where - T: AsyncRead + AsyncWrite, -{ - type Item = Message; - type Error = WsError; - - fn poll(&mut self) -> Poll, WsError> { - self.inner - .read_message() - .map(Some) - .to_async() - .or_else(|err| match err { - WsError::ConnectionClosed => Ok(Async::Ready(None)), - err => Err(err), - }) + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S + where + S: AsyncRead + AsyncWrite + Unpin, + { + &self.inner.get_ref().get_ref() } -} -impl Sink for WebSocketStream -where - T: AsyncRead + AsyncWrite, -{ - type SinkItem = Message; - type SinkError = WsError; - - fn start_send(&mut self, item: Message) -> StartSend { - self.inner.write_message(item).to_start_send() + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S + where + S: AsyncRead + AsyncWrite + Unpin, + { + self.inner.get_mut().get_mut() } - fn poll_complete(&mut self) -> Poll<(), WsError> { - self.inner.write_pending().to_async() + /// Send a message to this websocket + pub async fn send(&mut self, msg: Message) -> Result<(), WsError> + where + S: AsyncWrite + AsyncRead + Unpin, + { + let f = SendFuture { + stream: self, + message: Some(msg), + }; + f.await } - fn close(&mut self) -> Poll<(), WsError> { - self.inner.close(None).to_async() + /// Close the underlying web socket + pub async fn close(&mut self, msg: Option>) -> Result<(), WsError> + where + S: AsyncRead + AsyncWrite + Unpin, + { + let f = CloseFuture { + stream: self, + message: Some(msg), + }; + f.await } } -/// Future returned from client_async() which will resolve -/// once the connection handshake has finished. -pub struct ConnectAsync { - inner: MidHandshake>, -} - -impl Future for ConnectAsync { - type Item = (WebSocketStream, Response); - type Error = WsError; - - fn poll(&mut self) -> Poll { - match self.inner.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready((ws, resp)) => Ok(Async::Ready((WebSocketStream::new(ws), resp))), +impl Stream for WebSocketStream +where + T: AsyncRead + AsyncWrite + Unpin, + AllowStd: Read + Write, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("{}:{} Stream.poll_next", file!(), line!()); + match futures::ready!(self.with_context(cx, |s| { + trace!( + "{}:{} Stream.with_context poll_next -> read_message()", + file!(), + line!() + ); + cvt(s.read_message()) + })) { + Ok(v) => Poll::Ready(Some(Ok(v))), + Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), } } } -/// Future returned from accept_async() which will resolve -/// once the connection handshake has finished. -pub struct AcceptAsync { - inner: MidHandshake>, +#[pin_project] +struct SendFuture<'a, T> { + stream: &'a mut WebSocketStream, + message: Option, } -impl Future for AcceptAsync { - type Item = WebSocketStream; - type Error = WsError; +impl<'a, T> Future for SendFuture<'a, T> +where + T: AsyncRead + AsyncWrite + Unpin, + AllowStd: Read + Write, +{ + type Output = Result<(), WsError>; - fn poll(&mut self) -> Poll { - match self.inner.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(ws) => Ok(Async::Ready(WebSocketStream::new(ws))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let message = this.message.take().expect("Cannot poll twice"); + Poll::Ready(this.stream.with_context(cx, |s| s.write_message(message))) } } -struct MidHandshake { - inner: Option::FinalResult, HandshakeError>>, +#[pin_project] +struct CloseFuture<'a, T> { + stream: &'a mut WebSocketStream, + message: Option>>, } -impl Future for MidHandshake { - type Item = ::FinalResult; - type Error = WsError; - - fn poll(&mut self) -> Poll { - match self.inner.take().expect("cannot poll MidHandshake twice") { - Ok(result) => Ok(Async::Ready(result)), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::Interrupted(s)) => match s.handshake() { - Ok(result) => Ok(Async::Ready(result)), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::Interrupted(s)) => { - self.inner = Some(Err(HandshakeError::Interrupted(s))); - Ok(Async::NotReady) - } - }, - } - } -} - -trait ToAsync { - type T; - type E; - fn to_async(self) -> Result, Self::E>; -} +impl<'a, T> Future for CloseFuture<'a, T> +where + T: AsyncRead + AsyncWrite + Unpin, + AllowStd: Read + Write, +{ + type Output = Result<(), WsError>; -impl ToAsync for Result { - type T = T; - type E = WsError; - fn to_async(self) -> Result, Self::E> { - match self { - Ok(x) => Ok(Async::Ready(x)), - Err(error) => match error { - WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(Async::NotReady), - err => Err(err), - }, - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let message = this.message.take().expect("Cannot poll twice"); + Poll::Ready(this.stream.with_context(cx, |s| s.close(message))) } } -trait ToStartSend { - type T; - type E; - fn to_start_send(self) -> StartSend; -} - -impl ToStartSend for Result<(), WsError> { - type T = Message; - type E = WsError; - fn to_start_send(self) -> StartSend { - match self { - Ok(_) => Ok(AsyncSink::Ready), - Err(error) => match error { - WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(AsyncSink::Ready), - WsError::SendQueueFull(msg) => Ok(AsyncSink::NotReady(msg)), - err => Err(err), - }, - } +#[cfg(test)] +mod tests { + use crate::compat::AllowStd; + use crate::connect::encryption::AutoStream; + use crate::WebSocketStream; + use std::io::{Read, Write}; + use tokio_io::{AsyncReadExt, AsyncWriteExt}; + + fn is_read() {} + fn is_write() {} + fn is_async_read() {} + fn is_async_write() {} + fn is_unpin() {} + + #[test] + fn web_socket_stream_has_traits() { + is_read::>(); + is_write::>(); + + is_async_read::>(); + is_async_write::>(); + + is_unpin::>(); + is_unpin::>>(); + is_unpin::>>(); } } diff --git a/src/stream.rs b/src/stream.rs index d961abe..d8c5b13 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -3,26 +3,11 @@ //! There is no dependency on actual TLS implementations. Everything like //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. +use std::pin::Pin; +use std::task::{Context, Poll}; -use std::io::{Error as IoError, Read, Result as IoResult, Write}; -use std::net::SocketAddr; - -use bytes::{Buf, BufMut}; -use futures::Poll; use tokio_io::{AsyncRead, AsyncWrite}; -/// Trait to switch TCP_NODELAY. -pub trait NoDelay { - /// Set the TCP_NODELAY option to the given value. - fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>; -} - -/// Trait to get the remote address from the underlying stream. -pub trait PeerAddr { - /// Returns the remote address that this stream is connected to. - fn peer_addr(&self) -> IoResult; -} - /// Stream, either plain TCP or TLS. pub enum Stream { /// Unencrypted socket stream. @@ -31,74 +16,72 @@ pub enum Stream { Tls(T), } -impl Read for Stream { - fn read(&mut self, buf: &mut [u8]) -> IoResult { - match *self { - Stream::Plain(ref mut s) => s.read(buf), - Stream::Tls(ref mut s) => s.read(buf), - } - } -} - -impl Write for Stream { - fn write(&mut self, buf: &[u8]) -> IoResult { - match *self { - Stream::Plain(ref mut s) => s.write(buf), - Stream::Tls(ref mut s) => s.write(buf), - } - } - fn flush(&mut self) -> IoResult<()> { - match *self { - Stream::Plain(ref mut s) => s.flush(), - Stream::Tls(ref mut s) => s.flush(), - } - } -} - -impl NoDelay for Stream { - fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { +impl AsyncRead for Stream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { match *self { - Stream::Plain(ref mut s) => s.set_nodelay(nodelay), - Stream::Tls(ref mut s) => s.set_nodelay(nodelay), + Stream::Plain(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_read(cx, buf) + } + Stream::Tls(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_read(cx, buf) + } } } } -impl PeerAddr for Stream { - fn peer_addr(&self) -> IoResult { +impl AsyncWrite for Stream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { match *self { - Stream::Plain(ref s) => s.peer_addr(), - Stream::Tls(ref s) => s.peer_addr(), + Stream::Plain(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_write(cx, buf) + } + Stream::Tls(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_write(cx, buf) + } } } -} -impl AsyncRead for Stream { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { match *self { - Stream::Plain(ref s) => s.prepare_uninitialized_buffer(buf), - Stream::Tls(ref s) => s.prepare_uninitialized_buffer(buf), + Stream::Plain(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_flush(cx) + } + Stream::Tls(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_flush(cx) + } } } - fn read_buf(&mut self, buf: &mut B) -> Poll { - match *self { - Stream::Plain(ref mut s) => s.read_buf(buf), - Stream::Tls(ref mut s) => s.read_buf(buf), - } - } -} -impl AsyncWrite for Stream { - fn shutdown(&mut self) -> Poll<(), IoError> { - match *self { - Stream::Plain(ref mut s) => s.shutdown(), - Stream::Tls(ref mut s) => s.shutdown(), - } - } - fn write_buf(&mut self, buf: &mut B) -> Poll { + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { match *self { - Stream::Plain(ref mut s) => s.write_buf(buf), - Stream::Tls(ref mut s) => s.write_buf(buf), + Stream::Plain(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_shutdown(cx) + } + Stream::Tls(ref mut s) => { + let pinned = unsafe { Pin::new_unchecked(s) }; + pinned.poll_shutdown(cx) + } } } } diff --git a/tests/communication.rs b/tests/communication.rs new file mode 100644 index 0000000..e9c3e5b --- /dev/null +++ b/tests/communication.rs @@ -0,0 +1,79 @@ +use futures::StreamExt; +use log::*; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::tcp::{TcpListener, TcpStream}; +use tokio_tungstenite::{accept_async, client_async, WebSocketStream}; +use tungstenite::Message; + +async fn run_connection( + connection: WebSocketStream, + msg_tx: futures::channel::oneshot::Sender>, +) where + S: AsyncRead + AsyncWrite + Unpin, +{ + info!("Running connection"); + let mut connection = connection; + let mut messages = vec![]; + while let Some(message) = connection.next().await { + info!("Message received"); + let message = message.expect("Failed to get message"); + messages.push(message); + } + msg_tx.send(messages).expect("Failed to send results"); +} + +#[tokio::test] +async fn communication() { + let _ = env_logger::try_init(); + + let (con_tx, con_rx) = futures::channel::oneshot::channel(); + let (msg_tx, msg_rx) = futures::channel::oneshot::channel(); + + let f = async move { + let address = "0.0.0.0:12345" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let listener = TcpListener::bind(&address).await.unwrap(); + let mut connections = listener.incoming(); + info!("Server ready"); + con_tx.send(()).unwrap(); + info!("Waiting on next connection"); + let connection = connections.next().await.expect("No connections to accept"); + let connection = connection.expect("Failed to accept connection"); + let stream = accept_async(connection).await; + let stream = stream.expect("Failed to handshake with connection"); + run_connection(stream, msg_tx).await; + }; + + tokio::spawn(f); + + info!("Waiting for server to be ready"); + + con_rx.await.expect("Server not ready"); + let address = "0.0.0.0:12345" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let tcp = TcpStream::connect(&address) + .await + .expect("Failed to connect"); + let url = url::Url::parse("ws://localhost:12345/").unwrap(); + let (mut stream, _) = client_async(url, tcp) + .await + .expect("Client failed to connect"); + + for i in 1..10 { + info!("Sending message"); + stream.send(Message::Text(format!("{}", i))).await.expect("Failed to send message"); + } + + stream.close(None).await.expect("Failed to close"); + + info!("Waiting for response messages"); + let messages = msg_rx.await.expect("Failed to receive messages"); + assert_eq!(messages.len(), 10); +} diff --git a/tests/handshakes.rs b/tests/handshakes.rs index 530e7f6..cadec84 100644 --- a/tests/handshakes.rs +++ b/tests/handshakes.rs @@ -1,36 +1,41 @@ -use std::io; - -use futures::{Future, Stream}; -use tokio_tcp::{TcpListener, TcpStream}; +use futures::StreamExt; +use std::net::ToSocketAddrs; +use tokio::net::tcp::{TcpListener, TcpStream}; use tokio_tungstenite::{accept_async, client_async}; -#[test] -fn handshakes() { - use std::sync::mpsc::channel; - use std::thread; - - let (tx, rx) = channel(); +#[tokio::test] +async fn handshakes() { + let (tx, rx) = futures::channel::oneshot::channel(); - thread::spawn(move || { - let address = "0.0.0.0:12345".parse().unwrap(); - let listener = TcpListener::bind(&address).unwrap(); - let connections = listener.incoming(); + let f = async move { + let address = "0.0.0.0:12345" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let listener = TcpListener::bind(&address).await.unwrap(); + let mut connections = listener.incoming(); tx.send(()).unwrap(); - let handshakes = connections.and_then(|connection| { - accept_async(connection).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }); - let server = handshakes.for_each(|_| Ok(())); + while let Some(connection) = connections.next().await { + let connection = connection.expect("Failed to accept connection"); + let stream = accept_async(connection).await; + stream.expect("Failed to handshake with connection"); + } + }; - server.wait().unwrap(); - }); + tokio::spawn(f); - rx.recv().unwrap(); - let address = "0.0.0.0:12345".parse().unwrap(); - let tcp = TcpStream::connect(&address); - let handshake = tcp.and_then(|stream| { - let url = url::Url::parse("ws://localhost:12345/").unwrap(); - client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }); - let client = handshake.and_then(|_| Ok(())); - client.wait().unwrap(); + rx.await.expect("Failed to wait for server to be ready"); + let address = "0.0.0.0:12345" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let tcp = TcpStream::connect(&address) + .await + .expect("Failed to connect"); + let url = url::Url::parse("ws://localhost:12345/").unwrap(); + let _stream = client_async(url, tcp) + .await + .expect("Client failed to connect"); }