Upgrade to Rust 2018, format the code

pull/1/head
Artem Vorotnikov 5 years ago
parent 9cf7243860
commit 9bd5f01784
No known key found for this signature in database
GPG Key ID: E0148C3F2FBB7A20
  1. 1
      Cargo.toml
  2. 8
      README.md
  3. 42
      examples/autobahn-client.rs
  4. 55
      examples/autobahn-server.rs
  5. 64
      examples/client.rs
  6. 127
      examples/server.rs
  7. 111
      src/connect.rs
  8. 97
      src/lib.rs
  9. 6
      src/stream.rs
  10. 17
      tests/handshakes.rs

@ -9,6 +9,7 @@ homepage = "https://github.com/snapview/tokio-tungstenite"
documentation = "https://docs.rs/tokio-tungstenite/0.9.0" documentation = "https://docs.rs/tokio-tungstenite/0.9.0"
repository = "https://github.com/snapview/tokio-tungstenite" repository = "https://github.com/snapview/tokio-tungstenite"
version = "0.9.0" version = "0.9.0"
edition = "2018"
[features] [features]
default = ["connect", "tls"] default = ["connect", "tls"]

@ -10,19 +10,13 @@ Asynchronous WebSockets for Tokio stack.
## Usage ## Usage
First, you need to add this in your `Cargo.toml`: Add this in your `Cargo.toml`:
```toml ```toml
[dependencies] [dependencies]
tokio-tungstenite = "*" tokio-tungstenite = "*"
``` ```
Next, add this to your crate:
```rust
extern crate tokio_tungstenite;
```
Take a look at the `examples/` directory for client and server examples. You may also want to get familiar with Take a look at the `examples/` directory for client and server examples. You may also want to get familiar with
[tokio](https://tokio.rs/) if you don't have any experience with it. [tokio](https://tokio.rs/) if you don't have any experience with it.

@ -1,27 +1,15 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate futures;
extern crate tokio;
extern crate tokio_tungstenite;
extern crate url;
use url::Url;
use futures::{Future, Stream}; use futures::{Future, Stream};
use log::*;
use tokio_tungstenite::{ use tokio_tungstenite::{
connect_async, connect_async,
tungstenite::{ tungstenite::{connect, Error as WsError, Result},
connect,
Result,
Error as WsError,
},
}; };
use url::Url;
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect( let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
)?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
@ -29,7 +17,11 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())
@ -37,9 +29,11 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) { fn run_test(case: u32) {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = Url::parse( let case_url = Url::parse(&format!(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) "ws://localhost:9001/runCase?case={}&agent={}",
).unwrap(); case, AGENT
))
.unwrap();
let job = connect_async(case_url) let job = connect_async(case_url)
.map_err(|err| error!("Connect error: {}", err)) .map_err(|err| error!("Connect error: {}", err))
@ -49,11 +43,9 @@ fn run_test(case: u32) {
.filter(|msg| msg.is_text() || msg.is_binary()) .filter(|msg| msg.is_text() || msg.is_binary())
.forward(sink) .forward(sink)
.and_then(|(_stream, _sink)| Ok(())) .and_then(|(_stream, _sink)| Ok(()))
.map_err(|err| { .map_err(|err| match err {
match err { WsError::ConnectionClosed => (),
WsError::ConnectionClosed => (), err => info!("WS error {}", err),
err => info!("WS error {}", err),
}
}) })
}); });

@ -1,15 +1,7 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate futures;
extern crate tokio;
extern crate tokio_tungstenite;
use futures::{Future, Stream}; use futures::{Future, Stream};
use log::*;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_tungstenite::{ use tokio_tungstenite::{accept_async, tungstenite::Error as WsError};
accept_async,
tungstenite::Error as WsError,
};
fn main() { fn main() {
env_logger::init(); env_logger::init();
@ -20,30 +12,31 @@ fn main() {
let socket = TcpListener::bind(&addr).unwrap(); let socket = TcpListener::bind(&addr).unwrap();
info!("Listening on: {}", addr); info!("Listening on: {}", addr);
let srv = socket.incoming().map_err(Into::into).for_each(move |stream| { let srv = socket
.incoming()
let peer = stream.peer_addr().expect("connected streams should have a peer address"); .map_err(Into::into)
info!("Peer address: {}", peer); .for_each(move |stream| {
let peer = stream
accept_async(stream).and_then(move |ws_stream| { .peer_addr()
info!("New WebSocket connection: {}", peer); .expect("connected streams should have a peer address");
let (sink, stream) = ws_stream.split(); info!("Peer address: {}", peer);
let job = stream
.filter(|msg| msg.is_text() || msg.is_binary()) accept_async(stream).and_then(move |ws_stream| {
.forward(sink) info!("New WebSocket connection: {}", peer);
.and_then(|(_stream, _sink)| Ok(())) let (sink, stream) = ws_stream.split();
.map_err(|err| { let job = stream
match err { .filter(|msg| msg.is_text() || msg.is_binary())
.forward(sink)
.and_then(|(_stream, _sink)| Ok(()))
.map_err(|err| match err {
WsError::ConnectionClosed => (), WsError::ConnectionClosed => (),
err => info!("WS error: {}", err), err => info!("WS error: {}", err),
} });
});
tokio::spawn(job);
Ok(())
})
});
tokio::spawn(job);
Ok(())
})
});
runtime.block_on(srv).unwrap(); runtime.block_on(srv).unwrap();
} }

@ -10,12 +10,6 @@
//! //!
//! You can use this example together with the `server` example. //! You can use this example together with the `server` example.
extern crate futures;
extern crate tokio;
extern crate tokio_tungstenite;
extern crate tungstenite;
extern crate url;
use std::env; use std::env;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::thread; use std::thread;
@ -29,9 +23,9 @@ use tokio_tungstenite::stream::PeerAddr;
fn main() { fn main() {
// Specify the server address to which the client will be connecting. // Specify the server address to which the client will be connecting.
let connect_addr = env::args().nth(1).unwrap_or_else(|| { let connect_addr = env::args()
panic!("this program requires at least one argument") .nth(1)
}); .unwrap_or_else(|| panic!("this program requires at least one argument"));
let url = url::Url::parse(&connect_addr).unwrap(); let url = url::Url::parse(&connect_addr).unwrap();
@ -59,32 +53,37 @@ fn main() {
// finishes. If we don't have any more data to read or we won't receive any // finishes. If we don't have any more data to read or we won't receive any
// more work from the remote then we can exit. // more work from the remote then we can exit.
let mut stdout = io::stdout(); let mut stdout = io::stdout();
let client = connect_async(url).and_then(move |(ws_stream, _)| { let client = connect_async(url)
println!("WebSocket handshake has been successfully completed"); .and_then(move |(ws_stream, _)| {
println!("WebSocket handshake has been successfully completed");
let addr = ws_stream.peer_addr().expect("connected streams should have a peer address"); let addr = ws_stream
println!("Peer address: {}", addr); .peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr);
// `sink` is the stream of messages going out. // `sink` is the stream of messages going out.
// `stream` is the stream of incoming messages. // `stream` is the stream of incoming messages.
let (sink, stream) = ws_stream.split(); let (sink, stream) = ws_stream.split();
// We forward all messages, composed out of the data, entered to // We forward all messages, composed out of the data, entered to
// the stdin, to the `sink`. // the stdin, to the `sink`.
let send_stdin = stdin_rx.forward(sink); let send_stdin = stdin_rx.forward(sink);
let write_stdout = stream.for_each(move |message| { let write_stdout = stream.for_each(move |message| {
stdout.write_all(&message.into_data()).unwrap(); stdout.write_all(&message.into_data()).unwrap();
Ok(()) Ok(())
}); });
// Wait for either of futures to complete. // Wait for either of futures to complete.
send_stdin.map(|_| ()) send_stdin
.select(write_stdout.map(|_| ())) .map(|_| ())
.then(|_| Ok(())) .select(write_stdout.map(|_| ()))
}).map_err(|e| { .then(|_| Ok(()))
println!("Error during the websocket handshake occurred: {}", e); })
io::Error::new(io::ErrorKind::Other, e) .map_err(|e| {
}); println!("Error during the websocket handshake occurred: {}", e);
io::Error::new(io::ErrorKind::Other, e)
});
// And now that we've got our client, we execute it in the event loop! // And now that we've got our client, we execute it in the event loop!
tokio::runtime::run(client.map_err(|_e| ())); tokio::runtime::run(client.map_err(|_e| ()));
@ -97,8 +96,7 @@ fn read_stdin(mut tx: mpsc::Sender<Message>) {
loop { loop {
let mut buf = vec![0; 1024]; let mut buf = vec![0; 1024];
let n = match stdin.read(&mut buf) { let n = match stdin.read(&mut buf) {
Err(_) | Err(_) | Ok(0) => break,
Ok(0) => break,
Ok(n) => n, Ok(n) => n,
}; };
buf.truncate(n); buf.truncate(n);

@ -17,15 +17,10 @@
//! connected clients they'll all join the same room and see everyone else's //! connected clients they'll all join the same room and see everyone else's
//! messages. //! messages.
extern crate futures;
extern crate tokio;
extern crate tokio_tungstenite;
extern crate tungstenite;
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::sync::{Arc,Mutex}; use std::sync::{Arc, Mutex};
use futures::stream::Stream; use futures::stream::Stream;
use futures::Future; use futures::Future;
@ -47,8 +42,9 @@ fn main() {
let connections = Arc::new(Mutex::new(HashMap::new())); let connections = Arc::new(Mutex::new(HashMap::new()));
let srv = socket.incoming().for_each(move |stream| { let srv = socket.incoming().for_each(move |stream| {
let addr = stream
let addr = stream.peer_addr().expect("connected streams should have a peer address"); .peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr); println!("Peer address: {}", addr);
// We have to clone both of these values, because the `and_then` // We have to clone both of these values, because the `and_then`
@ -58,62 +54,67 @@ fn main() {
// `for_each` future). // `for_each` future).
let connections_inner = connections.clone(); let connections_inner = connections.clone();
accept_async(stream).and_then(move |ws_stream| { accept_async(stream)
println!("New WebSocket connection: {}", addr); .and_then(move |ws_stream| {
println!("New WebSocket connection: {}", addr);
// Create a channel for our stream, which other sockets will use to
// send us messages. Then register our address with the stream to send // Create a channel for our stream, which other sockets will use to
// data to us. // send us messages. Then register our address with the stream to send
let (tx, rx) = futures::sync::mpsc::unbounded(); // data to us.
connections_inner.lock().unwrap().insert(addr, tx); let (tx, rx) = futures::sync::mpsc::unbounded();
connections_inner.lock().unwrap().insert(addr, tx);
// Let's split the WebSocket stream, so we can work with the
// reading and writing halves separately. // Let's split the WebSocket stream, so we can work with the
let (sink, stream) = ws_stream.split(); // reading and writing halves separately.
let (sink, stream) = ws_stream.split();
// Whenever we receive a message from the client, we print it and
// send to other clients, excluding the sender. // Whenever we receive a message from the client, we print it and
let connections = connections_inner.clone(); // send to other clients, excluding the sender.
let ws_reader = stream.for_each(move |message: Message| { let connections = connections_inner.clone();
println!("Received a message from {}: {}", addr, message); let ws_reader = stream.for_each(move |message: Message| {
println!("Received a message from {}: {}", addr, message);
// For each open connection except the sender, send the
// string via the channel. // For each open connection except the sender, send the
let mut conns = connections.lock().unwrap(); // string via the channel.
let iter = conns.iter_mut() let mut conns = connections.lock().unwrap();
.filter(|&(&k, _)| k != addr) let iter = conns
.map(|(_, v)| v); .iter_mut()
for tx in iter { .filter(|&(&k, _)| k != addr)
tx.unbounded_send(message.clone()).unwrap(); .map(|(_, v)| v);
} for tx in iter {
Ok(()) tx.unbounded_send(message.clone()).unwrap();
}); }
Ok(())
// Whenever we receive a string on the Receiver, we write it to });
// `WriteHalf<WebSocketStream>`.
let ws_writer = rx.fold(sink, |mut sink, msg| { // Whenever we receive a string on the Receiver, we write it to
use futures::Sink; // `WriteHalf<WebSocketStream>`.
sink.start_send(msg).unwrap(); let ws_writer = rx.fold(sink, |mut sink, msg| {
Ok(sink) use futures::Sink;
}); sink.start_send(msg).unwrap();
Ok(sink)
// Now that we've got futures representing each half of the socket, we });
// use the `select` combinator to wait for either half to be done to
// tear down the other. Then we spawn off the result. // Now that we've got futures representing each half of the socket, we
let connection = ws_reader.map(|_| ()).map_err(|_| ()) // use the `select` combinator to wait for either half to be done to
.select(ws_writer.map(|_| ()).map_err(|_| ())); // tear down the other. Then we spawn off the result.
let connection = ws_reader
tokio::spawn(connection.then(move |_| { .map(|_| ())
connections_inner.lock().unwrap().remove(&addr); .map_err(|_| ())
println!("Connection {} closed.", addr); .select(ws_writer.map(|_| ()).map_err(|_| ()));
Ok(())
})); tokio::spawn(connection.then(move |_| {
connections_inner.lock().unwrap().remove(&addr);
println!("Connection {} closed.", addr);
Ok(())
}));
Ok(()) Ok(())
}).map_err(|e| { })
println!("Error during the websocket handshake occurred: {}", e); .map_err(|e| {
Error::new(ErrorKind::Other, e) println!("Error during the websocket handshake occurred: {}", e);
}) Error::new(ErrorKind::Other, e)
})
}); });
// Execute server. // Execute server.

@ -1,22 +1,19 @@
//! Connection helper. //! Connection helper.
extern crate tokio_dns;
extern crate tokio_tcp;
use std::net::SocketAddr;
use std::io::Result as IoResult; use std::io::Result as IoResult;
use std::net::SocketAddr;
use self::tokio_tcp::TcpStream; use tokio_tcp::TcpStream;
use futures::{future, Future}; use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::Error;
use tungstenite::client::url_mode; use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response; use tungstenite::handshake::client::Response;
use tungstenite::Error;
use stream::{NoDelay, PeerAddr}; use super::{client_async, Request, WebSocketStream};
use super::{WebSocketStream, Request, client_async}; use crate::stream::{NoDelay, PeerAddr};
impl NoDelay for TcpStream { impl NoDelay for TcpStream {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
@ -30,24 +27,21 @@ impl PeerAddr for TcpStream {
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
mod encryption { mod encryption {
extern crate native_tls; use native_tls::TlsConnector;
extern crate tokio_tls; use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream};
use self::native_tls::TlsConnector;
use self::tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream};
use std::io::{Read, Result as IoResult, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::io::{Read, Write, Result as IoResult};
use futures::{future, Future}; use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::Error;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error;
use stream::{NoDelay, PeerAddr, Stream as StreamSwitcher}; use crate::stream::{NoDelay, PeerAddr, Stream as StreamSwitcher};
/// A stream that might be protected with TLS. /// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>; pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
@ -66,50 +60,58 @@ mod encryption {
} }
} }
pub fn wrap_stream<S>(socket: S, domain: String, mode: Mode) pub fn wrap_stream<S>(
-> Box<dyn Future<Item=AutoStream<S>, Error=Error> + Send> socket: S,
domain: String,
mode: Mode,
) -> Box<dyn Future<Item = AutoStream<S>, Error = Error> + Send>
where where
S: 'static + AsyncRead + AsyncWrite + Send, S: 'static + AsyncRead + AsyncWrite + Send,
{ {
match mode { match mode {
Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))), Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))),
Mode::Tls => { Mode::Tls => Box::new(
Box::new(future::result(TlsConnector::new()) future::result(TlsConnector::new())
.map(TokioTlsConnector::from) .map(TokioTlsConnector::from)
.and_then(move |connector| connector.connect(&domain, socket)) .and_then(move |connector| connector.connect(&domain, socket))
.map(StreamSwitcher::Tls) .map(StreamSwitcher::Tls)
.map_err(Error::Tls)) .map_err(Error::Tls),
} ),
} }
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
pub use self::encryption::MaybeTlsStream; pub use self::encryption::MaybeTlsStream;
#[cfg(not(feature="tls"))] #[cfg(not(feature = "tls"))]
mod encryption { mod encryption {
use futures::{future, Future}; use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::Error;
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error;
pub type AutoStream<S> = S; pub type AutoStream<S> = S;
pub fn wrap_stream<S>(socket: S, _domain: String, mode: Mode) pub fn wrap_stream<S>(
-> Box<Future<Item=AutoStream<S>, Error=Error> + Send> socket: S,
_domain: String,
mode: Mode,
) -> Box<Future<Item = AutoStream<S>, Error = Error> + Send>
where where
S: 'static + AsyncRead + AsyncWrite + Send, S: 'static + AsyncRead + AsyncWrite + Send,
{ {
match mode { match mode {
Mode::Plain => Box::new(future::ok(socket)), Mode::Plain => Box::new(future::ok(socket)),
Mode::Tls => Box::new(future::err(Error::Url("TLS support not compiled in.".into()))), Mode::Tls => Box::new(future::err(Error::Url(
"TLS support not compiled in.".into(),
))),
} }
} }
} }
use self::encryption::{AutoStream, wrap_stream}; use self::encryption::{wrap_stream, AutoStream};
/// Get a domain from an URL. /// Get a domain from an URL.
#[inline] #[inline]
@ -122,8 +124,10 @@ fn domain(request: &Request) -> Result<String, Error> {
/// Creates a WebSocket handshake from a request and a stream, /// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required. /// upgrading the stream to TLS if required.
pub fn client_async_tls<R, S>(request: R, stream: S) pub fn client_async_tls<R, S>(
-> Box<dyn Future<Item=(WebSocketStream<AutoStream<S>>, Response), Error=Error> + Send> request: R,
stream: S,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<S>>, Response), Error = Error> + Send>
where where
R: Into<Request<'static>>, R: Into<Request<'static>>,
S: 'static + AsyncRead + AsyncWrite + NoDelay + Send, S: 'static + AsyncRead + AsyncWrite + NoDelay + Send,
@ -141,20 +145,23 @@ where
Err(e) => return Box::new(future::err(e)), Err(e) => return Box::new(future::err(e)),
}; };
Box::new(wrap_stream(stream, domain, mode) Box::new(
.and_then(|mut stream| { wrap_stream(stream, domain, mode)
NoDelay::set_nodelay(&mut stream, true) .and_then(|mut stream| {
.map(move |()| stream) NoDelay::set_nodelay(&mut stream, true)
.map_err(|e| e.into()) .map(move |()| stream)
}) .map_err(|e| e.into())
.and_then(move |stream| client_async(request, stream))) })
.and_then(move |stream| client_async(request, stream)),
)
} }
/// Connect to a given URL. /// Connect to a given URL.
pub fn connect_async<R>(request: R) pub fn connect_async<R>(
-> Box<dyn Future<Item=(WebSocketStream<AutoStream<TcpStream>>, Response), Error=Error> + Send> request: R,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<TcpStream>>, Response), Error = Error> + Send>
where where
R: Into<Request<'static>> R: Into<Request<'static>>,
{ {
let request: Request = request.into(); let request: Request = request.into();
@ -162,8 +169,14 @@ where
Ok(domain) => domain, Ok(domain) => domain,
Err(err) => return Box::new(future::err(err)), Err(err) => return Box::new(future::err(err)),
}; };
let port = request.url.port_or_known_default().expect("Bug: port unknown"); let port = request
.url
Box::new(tokio_dns::TcpStream::connect((domain.as_str(), port)).map_err(|e| e.into()) .port_or_known_default()
.and_then(move |socket| client_async_tls(request, socket))) .expect("Bug: port unknown");
Box::new(
tokio_dns::TcpStream::connect((domain.as_str(), port))
.map_err(|e| e.into())
.and_then(move |socket| client_async_tls(request, socket)),
)
} }

@ -13,48 +13,43 @@
unused_must_use, unused_must_use,
unused_mut, unused_mut,
unused_imports, unused_imports,
unused_import_braces)] unused_import_braces
)]
extern crate futures; pub use tungstenite;
extern crate tokio_io;
pub extern crate tungstenite; #[cfg(feature = "connect")]
#[cfg(feature="connect")]
mod connect; mod connect;
#[cfg(feature="stream")] #[cfg(feature = "stream")]
pub mod stream; pub mod stream;
use std::io::ErrorKind; use std::io::ErrorKind;
#[cfg(feature="stream")] #[cfg(feature = "stream")]
use std::{ use std::{io::Result as IoResult, net::SocketAddr};
net::SocketAddr,
io::Result as IoResult,
};
use futures::{Poll, Future, Async, AsyncSink, Stream, Sink, StartSend}; use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{ use tungstenite::{
error::Error as WsError, error::Error as WsError,
handshake::{ handshake::{
HandshakeRole, HandshakeError, client::{ClientHandshake, Request, Response},
client::{ClientHandshake, Response, Request}, server::{Callback, NoCallback, ServerHandshake},
server::{ServerHandshake, Callback, NoCallback}, HandshakeError, HandshakeRole,
}, },
protocol::{WebSocket, Message, Role, WebSocketConfig}, protocol::{Message, Role, WebSocket, WebSocketConfig},
server, server,
}; };
#[cfg(feature="connect")] #[cfg(feature = "connect")]
pub use connect::{connect_async, client_async_tls}; pub use connect::{client_async_tls, connect_async};
#[cfg(feature="stream")] #[cfg(feature = "stream")]
pub use stream::PeerAddr; pub use stream::PeerAddr;
#[cfg(all(feature="connect", feature="tls"))] #[cfg(all(feature = "connect", feature = "tls"))]
pub use connect::MaybeTlsStream; pub use connect::MaybeTlsStream;
/// Creates a WebSocket handshake from a request and a stream. /// Creates a WebSocket handshake from a request and a stream.
@ -69,13 +64,10 @@ pub use connect::MaybeTlsStream;
/// ///
/// This is typically used for clients who have already established, for /// This is typically used for clients who have already established, for
/// example, a TCP connection to the remote server. /// example, a TCP connection to the remote server.
pub fn client_async<'a, R, S>( pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync<S>
request: R,
stream: S,
) -> ConnectAsync<S>
where where
R: Into<Request<'a>>, R: Into<Request<'a>>,
S: AsyncRead + AsyncWrite S: AsyncRead + AsyncWrite,
{ {
client_async_with_config(request, stream, None) client_async_with_config(request, stream, None)
} }
@ -89,12 +81,12 @@ pub fn client_async_with_config<'a, R, S>(
) -> ConnectAsync<S> ) -> ConnectAsync<S>
where where
R: Into<Request<'a>>, R: Into<Request<'a>>,
S: AsyncRead + AsyncWrite S: AsyncRead + AsyncWrite,
{ {
ConnectAsync { ConnectAsync {
inner: MidHandshake { inner: MidHandshake {
inner: Some(ClientHandshake::start(stream, request.into(), config).handshake()) inner: Some(ClientHandshake::start(stream, request.into(), config).handshake()),
} },
} }
} }
@ -154,8 +146,8 @@ where
{ {
AcceptAsync { AcceptAsync {
inner: MidHandshake { inner: MidHandshake {
inner: Some(server::accept_hdr_with_config(stream, callback, config)) inner: Some(server::accept_hdr_with_config(stream, callback, config)),
} },
} }
} }
@ -175,11 +167,7 @@ pub struct WebSocketStream<S> {
impl<S> WebSocketStream<S> { impl<S> WebSocketStream<S> {
/// Convert a raw socket into a WebSocketStream without performing a /// Convert a raw socket into a WebSocketStream without performing a
/// handshake. /// handshake.
pub fn from_raw_socket( pub fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self {
stream: S,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
Self::new(WebSocket::from_raw_socket(stream, role, config)) Self::new(WebSocket::from_raw_socket(stream, role, config))
} }
@ -195,20 +183,21 @@ impl<S> WebSocketStream<S> {
} }
fn new(ws: WebSocket<S>) -> Self { fn new(ws: WebSocket<S>) -> Self {
WebSocketStream { WebSocketStream { inner: ws }
inner: ws,
}
} }
} }
#[cfg(feature="stream")] #[cfg(feature = "stream")]
impl<S: PeerAddr> PeerAddr for WebSocketStream<S> { impl<S: PeerAddr> PeerAddr for WebSocketStream<S> {
fn peer_addr(&self) -> IoResult<SocketAddr> { fn peer_addr(&self) -> IoResult<SocketAddr> {
self.inner.get_ref().peer_addr() self.inner.get_ref().peer_addr()
} }
} }
impl<T> Stream for WebSocketStream<T> where T: AsyncRead + AsyncWrite { impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite,
{
type Item = Message; type Item = Message;
type Error = WsError; type Error = WsError;
@ -219,12 +208,15 @@ impl<T> Stream for WebSocketStream<T> where T: AsyncRead + AsyncWrite {
.to_async() .to_async()
.or_else(|err| match err { .or_else(|err| match err {
WsError::ConnectionClosed => Ok(Async::Ready(None)), WsError::ConnectionClosed => Ok(Async::Ready(None)),
err => Err(err) err => Err(err),
}) })
} }
} }
impl<T> Sink for WebSocketStream<T> where T: AsyncRead + AsyncWrite { impl<T> Sink for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite,
{
type SinkItem = Message; type SinkItem = Message;
type SinkError = WsError; type SinkError = WsError;
@ -289,16 +281,14 @@ impl<H: HandshakeRole> Future for MidHandshake<H> {
match self.inner.take().expect("cannot poll MidHandshake twice") { match self.inner.take().expect("cannot poll MidHandshake twice") {
Ok(result) => Ok(Async::Ready(result)), Ok(result) => Ok(Async::Ready(result)),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Interrupted(s)) => { Err(HandshakeError::Interrupted(s)) => match s.handshake() {
match s.handshake() { Ok(result) => Ok(Async::Ready(result)),
Ok(result) => Ok(Async::Ready(result)), Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Interrupted(s)) => {
Err(HandshakeError::Interrupted(s)) => { self.inner = Some(Err(HandshakeError::Interrupted(s)));
self.inner = Some(Err(HandshakeError::Interrupted(s))); Ok(Async::NotReady)
Ok(Async::NotReady)
}
} }
} },
} }
} }
} }
@ -339,8 +329,7 @@ impl ToStartSend for Result<(), WsError> {
WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(AsyncSink::Ready), WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(AsyncSink::Ready),
WsError::SendQueueFull(msg) => Ok(AsyncSink::NotReady(msg)), WsError::SendQueueFull(msg) => Ok(AsyncSink::NotReady(msg)),
err => Err(err), err => Err(err),
} },
} }
} }
} }

@ -4,12 +4,10 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits. //! `Read + Write` traits.
extern crate bytes; use std::io::{Error as IoError, Read, Result as IoResult, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::io::{Read, Write, Result as IoResult, Error as IoError};
use self::bytes::{Buf, BufMut}; use bytes::{Buf, BufMut};
use futures::Poll; use futures::Poll;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};

@ -1,13 +1,8 @@
extern crate futures;
extern crate tokio_tcp;
extern crate tokio_tungstenite;
extern crate url;
use std::io; use std::io;
use futures::{Future, Stream}; use futures::{Future, Stream};
use tokio_tcp::{TcpStream, TcpListener}; use tokio_tcp::{TcpListener, TcpStream};
use tokio_tungstenite::{client_async, accept_async}; use tokio_tungstenite::{accept_async, client_async};
#[test] #[test]
fn handshakes() { fn handshakes() {
@ -24,9 +19,7 @@ fn handshakes() {
let handshakes = connections.and_then(|connection| { let handshakes = connections.and_then(|connection| {
accept_async(connection).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) accept_async(connection).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}); });
let server = handshakes.for_each(|_| { let server = handshakes.for_each(|_| Ok(()));
Ok(())
});
server.wait().unwrap(); server.wait().unwrap();
}); });
@ -38,8 +31,6 @@ fn handshakes() {
let url = url::Url::parse("ws://localhost:12345/").unwrap(); let url = url::Url::parse("ws://localhost:12345/").unwrap();
client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}); });
let client = handshake.and_then(|_| { let client = handshake.and_then(|_| Ok(()));
Ok(())
});
client.wait().unwrap(); client.wait().unwrap();
} }

Loading…
Cancel
Save