From b0ad00230b14af6abfee72a16abb8febc095a14a Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 2 Aug 2017 13:44:57 +0200 Subject: [PATCH] Adapts tokio-tungstenite to tungstenite/headers Basically changes the code of tokio-tungstenite to match the latest (current) status of tungstenite-rs. Fixes #13, fixes #9, fixes #6. --- Cargo.toml | 3 +- examples/client.rs | 2 +- examples/server.rs | 2 +- src/connect.rs | 22 +++++++------- src/lib.rs | 70 ++++++++++++++++++++++++++------------------- tests/handshakes.rs | 6 ++-- 6 files changed, 59 insertions(+), 46 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6bf64bd..3f3a99d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ tokio-io = "0.1.2" url = "1.4.0" [dependencies.tungstenite] -version = "0.2.4" +git = "https://github.com/snapview/tungstenite-rs" +branch = "headers" default-features = false [dependencies.bytes] diff --git a/examples/client.rs b/examples/client.rs index f02afd6..62dbe57 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -63,7 +63,7 @@ fn main() { // finishes. If we don't have any more data to read or we won't receive any // more work from the remote then we can exit. let mut stdout = io::stdout(); - let client = connect_async(url, handle.remote().clone()).and_then(|ws_stream| { + let client = connect_async(url, handle.remote().clone()).and_then(|(ws_stream, _)| { println!("WebSocket handshake has been successfully completed"); // `sink` is the stream of messages going out. diff --git a/examples/server.rs b/examples/server.rs index c757829..148d0a2 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -60,7 +60,7 @@ fn main() { let connections_inner = connections.clone(); let handle_inner = handle.clone(); - accept_async(stream).and_then(move |ws_stream| { + accept_async(stream, None).and_then(move |ws_stream| { println!("New WebSocket connection: {}", addr); // Create a channel for our stream, which other sockets will use to diff --git a/src/connect.rs b/src/connect.rs index 4d43694..24e442b 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -3,20 +3,20 @@ extern crate tokio_dns; extern crate tokio_core; -use self::tokio_dns::tcp_connect; -use self::tokio_core::reactor::Remote; - use std::io::Result as IoResult; -use futures::{Future, BoxFuture}; -use futures::future; +use self::tokio_core::net::TcpStream; +use self::tokio_core::reactor::Remote; +use self::tokio_dns::tcp_connect; -use super::{WebSocketStream, Request, client_async}; +use futures::future; +use futures::{Future, BoxFuture}; use tungstenite::Error; use tungstenite::client::url_mode; -use stream::NoDelay; +use tungstenite::handshake::client::Response; -use self::tokio_core::net::TcpStream; +use stream::NoDelay; +use super::{WebSocketStream, Request, client_async}; impl NoDelay for TcpStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { @@ -91,8 +91,10 @@ mod encryption { use self::encryption::{AutoStream, wrap_stream}; /// Connect to a given URL. -pub fn connect_async(request: R, handle: Remote) -> BoxFuture, Error> -where R: Into> +pub fn connect_async(request: R, handle: Remote) + -> BoxFuture<(WebSocketStream, Response), Error> +where + R: Into> { let request: Request = request.into(); diff --git a/src/lib.rs b/src/lib.rs index 203e1a7..59539ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,11 +7,6 @@ //! //! Each WebSocket stream implements the required `Stream` and `Sink` traits, //! so the socket is just a stream of messages coming in and going out. -//! -//! This crate primarily exports this ability through two extension traits, -//! `ClientHandshakeExt` and `ServerHandshakeExt`. These traits augment the -//! functionality provided by the `tungestenite` crate, on which this crate is -//! built. Configuration is done through `tungestenite` crate as well. #![deny( missing_docs, @@ -38,8 +33,8 @@ use tokio_io::{AsyncRead, AsyncWrite}; use url::Url; -use tungstenite::handshake::client::ClientHandshake; -use tungstenite::handshake::server::ServerHandshake; +use tungstenite::handshake::client::{ClientHandshake, Response}; +use tungstenite::handshake::server::{ServerHandshake, Callback}; use tungstenite::handshake::{HandshakeRole, HandshakeError}; use tungstenite::protocol::{WebSocket, Message}; use tungstenite::error::Error as WsError; @@ -90,11 +85,16 @@ impl<'a, U: Into> From for Request<'a> { /// depending on whether the handshake is successful. /// /// This is typically used for clients who have already established, for -/// example, a TCP connection to the remove server. +/// example, a TCP connection to the remote server. pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync - where R: Into>, S: AsyncRead + AsyncWrite { - let Request{url, headers} = request.into(); - let tungstenite_request = tungstenite::handshake::client::Request{url: url, extra_headers: Some(&headers)}; +where + R: Into>, + S: AsyncRead + AsyncWrite +{ + let Request{ url, headers } = request.into(); + let tungstenite_request = { + tungstenite::handshake::client::Request { url, extra_headers: Some(&headers) } + }; let handshake = ClientHandshake::start(stream, tungstenite_request).handshake(); ConnectAsync { @@ -115,10 +115,16 @@ pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync /// This is typically used after a socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform /// the server half of the accepting a client's websocket connection. -pub fn accept_async(stream: S) -> AcceptAsync { +/// +/// You can also pass an optional `callback` which will +/// be called when the websocket request is received from an incoming client. +pub fn accept_async(stream: S, callback: Option) -> AcceptAsync +where + S: AsyncRead + AsyncWrite, +{ AcceptAsync { inner: MidHandshake { - inner: Some(server::accept(stream)) + inner: Some(server::accept(stream, callback)) } } } @@ -161,49 +167,55 @@ impl Sink for WebSocketStream where T: AsyncRead + AsyncWrite { /// Future returned from client_async() which will resolve /// once the connection handshake has finished. -pub struct ConnectAsync { - inner: MidHandshake, +pub struct ConnectAsync { + inner: MidHandshake>, } impl Future for ConnectAsync { - type Item = WebSocketStream; + type Item = (WebSocketStream, Response); type Error = WsError; - fn poll(&mut self) -> Poll, WsError> { - self.inner.poll() + fn poll(&mut self) -> Poll { + match self.inner.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready((ws, resp)) => Ok(Async::Ready((WebSocketStream { inner: ws }, resp))), + } } } /// Future returned from accept_async() which will resolve /// once the connection handshake has finished. -pub struct AcceptAsync { - inner: MidHandshake, +pub struct AcceptAsync { + inner: MidHandshake>, } impl Future for AcceptAsync { type Item = WebSocketStream; type Error = WsError; - fn poll(&mut self) -> Poll, WsError> { - self.inner.poll() + fn poll(&mut self) -> Poll { + match self.inner.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(ws) => Ok(Async::Ready(WebSocketStream { inner: ws })), + } } } -struct MidHandshake { - inner: Option, HandshakeError>>, +struct MidHandshake { + inner: Option::FinalResult, HandshakeError>>, } -impl Future for MidHandshake { - type Item = WebSocketStream; +impl Future for MidHandshake { + type Item = ::FinalResult; type Error = WsError; - fn poll(&mut self) -> Poll, WsError> { + fn poll(&mut self) -> Poll { match self.inner.take().expect("cannot poll MidHandshake twice") { - Ok(stream) => Ok(WebSocketStream { inner: stream }.into()), + Ok(result) => Ok(Async::Ready(result)), Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Interrupted(s)) => { match s.handshake() { - Ok(stream) => Ok(WebSocketStream { inner: stream }.into()), + Ok(result) => Ok(Async::Ready(result)), Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Interrupted(s)) => { self.inner = Some(Err(HandshakeError::Interrupted(s))); diff --git a/tests/handshakes.rs b/tests/handshakes.rs index becebb3..fdb0b3c 100644 --- a/tests/handshakes.rs +++ b/tests/handshakes.rs @@ -25,8 +25,7 @@ fn handshakes() { let connections = listener.incoming(); tx.send(()).unwrap(); let handshakes = connections.and_then(|(connection, _)| { - accept_async(connection) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + accept_async(connection, None).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) }); let server = handshakes.for_each(|_| { Ok(()) @@ -42,8 +41,7 @@ fn handshakes() { let tcp = TcpStream::connect(&address, &handle); let handshake = tcp.and_then(|stream| { let url = url::Url::parse("ws://localhost:12345/").unwrap(); - client_async(url, stream) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) }); let client = handshake.and_then(|_| { Ok(())