Tokio 0.2 Conversion

Update to use tokio 0.2 ecosystem to integrate with tungstenite.
pull/1/head
Danny Browning 5 years ago committed by Danny Browning
parent 59ca2c885e
commit 3821e0952a
  1. 2
      .travis.yml
  2. 26
      Cargo.toml
  3. 58
      examples/autobahn-client.rs
  4. 66
      examples/autobahn-server.rs
  5. 69
      examples/client.rs
  6. 165
      examples/server.rs
  7. 123
      src/compat.rs
  8. 131
      src/connect.rs
  9. 181
      src/handshake.rs
  10. 363
      src/lib.rs
  11. 125
      src/stream.rs
  12. 79
      tests/communication.rs
  13. 63
      tests/handshakes.rs

@ -1,4 +1,6 @@
language: rust
rust:
- nightly-2019-09-05
before_script:
- export PATH="$PATH:$HOME/.cargo/bin"

@ -8,21 +8,25 @@ license = "MIT"
homepage = "https://github.com/snapview/tokio-tungstenite"
documentation = "https://docs.rs/tokio-tungstenite/0.9.0"
repository = "https://github.com/snapview/tokio-tungstenite"
version = "0.9.0"
version = "0.10.0-alpha.1"
edition = "2018"
[features]
default = ["connect", "tls"]
connect = ["tokio-dns-unofficial", "tokio-tcp", "stream"]
connect = ["tokio-dns-unofficial", "tokio-net", "stream"]
tls = ["tokio-tls", "native-tls", "stream", "tungstenite/tls"]
stream = ["bytes"]
[dependencies]
futures = "0.1.23"
tokio-io = "0.1.7"
log = "0.4"
futures-preview = { version = "0.3.0-alpha.19", features = ["async-await"] }
pin-project = "0.4.0-alpha.9"
tokio-io = "0.2.0-alpha.6"
[dependencies.tungstenite]
version = "0.9.1"
#version = "0.9.1"
git = "https://github.com/snapview/tungstenite-rs.git"
branch = "master"
default-features = false
[dependencies.bytes]
@ -35,18 +39,18 @@ version = "0.2.0"
[dependencies.tokio-dns-unofficial]
optional = true
version = "0.3.1"
#version = "0.4.0"
git = "https://github.com/sbstp/tokio-dns.git"
[dependencies.tokio-tcp]
[dependencies.tokio-net]
optional = true
version = "0.1.0"
version = "0.2.0-alpha.6"
[dependencies.tokio-tls]
optional = true
version = "0.2.0"
version = "0.3.0-alpha.6"
[dev-dependencies]
tokio = "0.1.7"
tokio = "0.2.0-alpha.6"
url = "2.0.0"
env_logger = "0.6.1"
log = "0.4.6"

@ -1,33 +1,32 @@
use futures::{Future, Stream};
use futures::StreamExt;
use log::*;
use tokio_tungstenite::{
connect_async,
tungstenite::{connect, Error as WsError, Result},
};
use tokio_tungstenite::{connect_async, tungstenite::Result};
use url::Url;
const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read_message()?;
socket.close(None)?;
async fn get_case_count() -> Result<u32> {
let (mut socket, _) =
connect_async(Url::parse("ws://localhost:9001/getCaseCount").unwrap()).await?;
let msg = socket.next().await.unwrap()?;
socket.close(None).await?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
}
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
async fn update_reports() -> Result<()> {
let (mut socket, _) = connect_async(
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?;
socket.close(None)?;
)
.await?;
socket.close(None).await?;
Ok(())
}
fn run_test(case: u32) {
async fn run_test(case: u32) {
info!("Running test case {}", case);
let case_url = Url::parse(&format!(
"ws://localhost:9001/runCase?case={}&agent={}",
@ -35,31 +34,24 @@ fn run_test(case: u32) {
))
.unwrap();
let job = connect_async(case_url)
.map_err(|err| error!("Connect error: {}", err))
.and_then(|(ws_stream, _)| {
let (sink, stream) = ws_stream.split();
stream
.filter(|msg| msg.is_text() || msg.is_binary())
.forward(sink)
.and_then(|(_stream, _sink)| Ok(()))
.map_err(|err| match err {
WsError::ConnectionClosed => (),
err => info!("WS error {}", err),
})
});
tokio::run(job)
let (mut ws_stream, _) = connect_async(case_url).await.expect("Connect error");
while let Some(msg) = ws_stream.next().await {
let msg = msg.expect("Failed to get message");
if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await.expect("Write error");
}
}
}
fn main() {
#[tokio::main]
async fn main() {
env_logger::init();
let total = get_case_count().unwrap();
let total = get_case_count().await.unwrap();
for case in 1..(total + 1) {
run_test(case)
run_test(case).await
}
update_reports().unwrap();
update_reports().await.unwrap();
}

@ -1,42 +1,42 @@
use futures::{Future, Stream};
use futures::StreamExt;
use log::*;
use tokio::net::TcpListener;
use tokio_tungstenite::{accept_async, tungstenite::Error as WsError};
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::accept_async;
fn main() {
env_logger::init();
async fn accept_connection(peer: SocketAddr, stream: TcpStream) {
let mut ws_stream = accept_async(stream).await.expect("Failed to accept");
let mut runtime = tokio::runtime::Builder::new().build().unwrap();
info!("New WebSocket connection: {}", peer);
let addr = "127.0.0.1:9002".parse().unwrap();
let socket = TcpListener::bind(&addr).unwrap();
info!("Listening on: {}", addr);
while let Some(msg) = ws_stream.next().await {
let msg = msg.expect("Failed to get request");
if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await.expect("Failed to send response");
}
}
}
let srv = socket
.incoming()
.map_err(Into::into)
.for_each(move |stream| {
let peer = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", peer);
#[tokio::main]
async fn main() {
env_logger::init();
accept_async(stream).and_then(move |ws_stream| {
info!("New WebSocket connection: {}", peer);
let (sink, stream) = ws_stream.split();
let job = stream
.filter(|msg| msg.is_text() || msg.is_binary())
.forward(sink)
.and_then(|(_stream, _sink)| Ok(()))
.map_err(|err| match err {
WsError::ConnectionClosed => (),
err => info!("WS error: {}", err),
});
let addr = "127.0.0.1:9002"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("Not a socket address");
let socket = TcpListener::bind(&addr).await.unwrap();
let mut incoming = socket.incoming();
info!("Listening on: {}", addr);
tokio::spawn(job);
Ok(())
})
});
while let Some(stream) = incoming.next().await {
let stream = stream.expect("Failed to get stream");
let peer = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", peer);
runtime.block_on(srv).unwrap();
tokio::spawn(accept_connection(peer, stream));
}
}

@ -11,17 +11,19 @@
//! You can use this example together with the `server` example.
use std::env;
use std::io::{self, Read, Write};
use std::thread;
use std::io::{self, Write};
use futures::sync::mpsc;
use futures::{Future, Sink, Stream};
use futures::StreamExt;
use log::*;
use tungstenite::protocol::Message;
use tokio::io::AsyncReadExt;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::stream::PeerAddr;
fn main() {
#[tokio::main]
async fn main() {
let _ = env_logger::try_init();
// Specify the server address to which the client will be connecting.
let connect_addr = env::args()
.nth(1)
@ -33,9 +35,8 @@ fn main() {
// loop, so we farm out that work to a separate thread. This thread will
// read data from stdin and then send it to the event loop over a standard
// futures channel.
let (stdin_tx, stdin_rx) = mpsc::channel(0);
thread::spawn(|| read_stdin(stdin_tx));
let stdin_rx = stdin_rx.map_err(|_| panic!()); // errors not possible on rx
let (stdin_tx, mut stdin_rx) = futures::channel::mpsc::unbounded();
tokio::spawn(read_stdin(stdin_tx));
// After the TCP connection has been established, we set up our client to
// start forwarding data.
@ -53,53 +54,29 @@ fn main() {
// finishes. If we don't have any more data to read or we won't receive any
// more work from the remote then we can exit.
let mut stdout = io::stdout();
let client = connect_async(url)
.and_then(move |(ws_stream, _)| {
println!("WebSocket handshake has been successfully completed");
let addr = ws_stream
.peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr);
// `sink` is the stream of messages going out.
// `stream` is the stream of incoming messages.
let (sink, stream) = ws_stream.split();
let (mut ws_stream, _) = connect_async(url).await.expect("Failed to connect");
info!("WebSocket handshake has been successfully completed");
// We forward all messages, composed out of the data, entered to
// the stdin, to the `sink`.
let send_stdin = stdin_rx.forward(sink);
let write_stdout = stream.for_each(move |message| {
stdout.write_all(&message.into_data()).unwrap();
Ok(())
});
// Wait for either of futures to complete.
send_stdin
.map(|_| ())
.select(write_stdout.map(|_| ()))
.then(|_| Ok(()))
})
.map_err(|e| {
println!("Error during the websocket handshake occurred: {}", e);
io::Error::new(io::ErrorKind::Other, e)
});
// And now that we've got our client, we execute it in the event loop!
tokio::runtime::run(client.map_err(|_e| ()));
while let Some(msg) = stdin_rx.next().await {
ws_stream.send(msg).await.expect("Failed to send request");
if let Some(msg) = ws_stream.next().await {
let msg = msg.expect("Failed to get response");
stdout.write_all(&msg.into_data()).unwrap();
}
}
}
// Our helper method which will read data from stdin and send it along the
// sender provided.
fn read_stdin(mut tx: mpsc::Sender<Message>) {
let mut stdin = io::stdin();
async fn read_stdin(tx: futures::channel::mpsc::UnboundedSender<Message>) {
let mut stdin = tokio::io::stdin();
loop {
let mut buf = vec![0; 1024];
let n = match stdin.read(&mut buf) {
let n = match stdin.read(&mut buf).await {
Err(_) | Ok(0) => break,
Ok(n) => n,
};
buf.truncate(n);
tx = tx.send(Message::binary(buf)).wait().unwrap();
tx.unbounded_send(Message::binary(buf)).unwrap();
}
}

@ -17,106 +17,87 @@
//! connected clients they'll all join the same room and see everyone else's
//! messages.
use std::collections::HashMap;
use std::env;
use std::io::{Error, ErrorKind};
use std::sync::{Arc, Mutex};
use std::io::Error;
use futures::stream::Stream;
use futures::Future;
use tokio::net::TcpListener;
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures::StreamExt;
use log::*;
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::net::{TcpListener, TcpStream};
use tungstenite::protocol::Message;
use tokio_tungstenite::accept_async;
struct Connection {
addr: SocketAddr,
rx: UnboundedReceiver<Message>,
tx: UnboundedSender<Message>,
}
async fn handle_connection(connection: Connection) {
let mut connection = connection;
while let Some(msg) = connection.rx.next().await {
info!("Received a message from {}: {}", connection.addr, msg);
connection
.tx
.unbounded_send(msg)
.expect("Failed to forward message");
}
}
fn main() {
async fn accept_connection(stream: TcpStream) {
let addr = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", addr);
let mut ws_stream = tokio_tungstenite::accept_async(stream)
.await
.expect("Error during the websocket handshake occurred");
info!("New WebSocket connection: {}", addr);
// Create a channel for our stream, which other sockets will use to
// send us messages. Then register our address with the stream to send
// data to us.
let (msg_tx, msg_rx) = futures::channel::mpsc::unbounded();
let (response_tx, mut response_rx) = futures::channel::mpsc::unbounded();
let c = Connection {
addr: addr,
rx: msg_rx,
tx: response_tx,
};
tokio::spawn(handle_connection(c));
while let Some(message) = ws_stream.next().await {
let message = message.expect("Failed to get request");
msg_tx
.unbounded_send(message)
.expect("Failed to forward request");
if let Some(resp) = response_rx.next().await {
ws_stream.send(resp).await.expect("Failed to send response");
}
}
}
#[tokio::main]
async fn main() -> Result<(), Error> {
let _ = env_logger::try_init();
let addr = env::args().nth(1).unwrap_or("127.0.0.1:8080".to_string());
let addr = addr.parse().unwrap();
let addr = addr
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("Not a socket address");
// Create the event loop and TCP listener we'll accept connections on.
let socket = TcpListener::bind(&addr).unwrap();
println!("Listening on: {}", addr);
// Tokio Runtime uses a thread pool based executor by default, so we need
// to use Arc and Mutex to store the map of all connections we know about.
let connections = Arc::new(Mutex::new(HashMap::new()));
let srv = socket.incoming().for_each(move |stream| {
let addr = stream
.peer_addr()
.expect("connected streams should have a peer address");
println!("Peer address: {}", addr);
// We have to clone both of these values, because the `and_then`
// function below constructs a new future, `and_then` requires
// `FnOnce`, so we construct a move closure to move the
// environment inside the future (AndThen future may overlive our
// `for_each` future).
let connections_inner = connections.clone();
accept_async(stream)
.and_then(move |ws_stream| {
println!("New WebSocket connection: {}", addr);
// Create a channel for our stream, which other sockets will use to
// send us messages. Then register our address with the stream to send
// data to us.
let (tx, rx) = futures::sync::mpsc::unbounded();
connections_inner.lock().unwrap().insert(addr, tx);
// Let's split the WebSocket stream, so we can work with the
// reading and writing halves separately.
let (sink, stream) = ws_stream.split();
// Whenever we receive a message from the client, we print it and
// send to other clients, excluding the sender.
let connections = connections_inner.clone();
let ws_reader = stream.for_each(move |message: Message| {
println!("Received a message from {}: {}", addr, message);
// For each open connection except the sender, send the
// string via the channel.
let mut conns = connections.lock().unwrap();
let iter = conns
.iter_mut()
.filter(|&(&k, _)| k != addr)
.map(|(_, v)| v);
for tx in iter {
tx.unbounded_send(message.clone()).unwrap();
}
Ok(())
});
// Whenever we receive a string on the Receiver, we write it to
// `WriteHalf<WebSocketStream>`.
let ws_writer = rx.fold(sink, |mut sink, msg| {
use futures::Sink;
sink.start_send(msg).unwrap();
Ok(sink)
});
// Now that we've got futures representing each half of the socket, we
// use the `select` combinator to wait for either half to be done to
// tear down the other. Then we spawn off the result.
let connection = ws_reader
.map(|_| ())
.map_err(|_| ())
.select(ws_writer.map(|_| ()).map_err(|_| ()));
tokio::spawn(connection.then(move |_| {
connections_inner.lock().unwrap().remove(&addr);
println!("Connection {} closed.", addr);
Ok(())
}));
let try_socket = TcpListener::bind(&addr).await;
let socket = try_socket.expect("Failed to bind");
let mut incoming = socket.incoming();
info!("Listening on: {}", addr);
Ok(())
})
.map_err(|e| {
println!("Error during the websocket handshake occurred: {}", e);
Error::new(ErrorKind::Other, e)
})
});
while let Some(stream) = incoming.next().await {
let stream = stream.expect("Failed to accept stream");
tokio::spawn(accept_connection(stream));
}
// Execute server.
tokio::runtime::run(srv.map_err(|_e| ()));
Ok(())
}

@ -0,0 +1,123 @@
use log::*;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket};
pub(crate) trait HasContext {
fn set_context(&mut self, context: *mut ());
}
#[derive(Debug)]
pub struct AllowStd<S> {
pub(crate) inner: S,
pub(crate) context: *mut (),
}
impl<S> HasContext for AllowStd<S> {
fn set_context(&mut self, context: *mut ()) {
self.context = context;
}
}
pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>);
impl<S> Drop for Guard<'_, S> {
fn drop(&mut self) {
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = std::ptr::null_mut();
}
}
// *mut () context is neither Send nor Sync
unsafe impl<S: Send> Send for AllowStd<S> {}
unsafe impl<S: Sync> Sync for AllowStd<S> {}
impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe {
assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
pub(crate) fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub(crate) fn get_ref(&self) -> &S {
&self.inner
}
}
impl<S> Read for AllowStd<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Read.with_context read -> poll_read",
file!(),
line!()
);
stream.poll_read(ctx, buf)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for AllowStd<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Write.with_context write -> poll_write",
file!(),
line!()
);
stream.poll_write(ctx, buf)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(|ctx, stream| {
trace!(
"{}:{} Write.with_context flush -> poll_flush",
file!(),
line!()
);
stream.poll_flush(ctx)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}

@ -1,82 +1,50 @@
//! Connection helper.
use std::io::Result as IoResult;
use std::net::SocketAddr;
use tokio_tcp::TcpStream;
use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_net::tcp::TcpStream;
use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response;
use tungstenite::Error;
use super::{client_async, Request, WebSocketStream};
use crate::stream::{NoDelay, PeerAddr};
impl NoDelay for TcpStream {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
TcpStream::set_nodelay(self, nodelay)
}
}
impl PeerAddr for TcpStream {
fn peer_addr(&self) -> IoResult<SocketAddr> {
self.peer_addr()
}
}
#[cfg(feature = "tls")]
mod encryption {
pub(crate) mod encryption {
use native_tls::TlsConnector;
use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream};
use std::io::{Read, Result as IoResult, Write};
use std::net::SocketAddr;
use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::stream::Mode;
use tungstenite::Error;
use crate::stream::{NoDelay, PeerAddr, Stream as StreamSwitcher};
use crate::stream::Stream as StreamSwitcher;
/// A stream that might be protected with TLS.
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub type AutoStream<S> = MaybeTlsStream<S>;
impl<T: Read + Write + NoDelay> NoDelay for TlsStream<T> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().get_mut().set_nodelay(nodelay)
}
}
impl<S: Read + Write + PeerAddr> PeerAddr for TlsStream<S> {
fn peer_addr(&self) -> IoResult<SocketAddr> {
self.get_ref().get_ref().peer_addr()
}
}
pub fn wrap_stream<S>(
pub async fn wrap_stream<S>(
socket: S,
domain: String,
mode: Mode,
) -> Box<dyn Future<Item = AutoStream<S>, Error = Error> + Send>
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))),
Mode::Tls => Box::new(
future::result(TlsConnector::new())
.map(TokioTlsConnector::from)
.and_then(move |connector| connector.connect(&domain, socket))
.map(StreamSwitcher::Tls)
.map_err(Error::Tls),
),
Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
Mode::Tls => {
let try_connector = TlsConnector::new();
let connector = try_connector.map_err(Error::Tls)?;
let stream = TokioTlsConnector::from(connector);
let connected = stream.connect(&domain, socket).await;
match connected {
Err(e) => Err(Error::Tls(e)),
Ok(s) => Ok(StreamSwitcher::Tls(s)),
}
}
}
}
}
@ -85,7 +53,7 @@ mod encryption {
pub use self::encryption::MaybeTlsStream;
#[cfg(not(feature = "tls"))]
mod encryption {
pub(crate) mod encryption {
use futures::{future, Future};
use tokio_io::{AsyncRead, AsyncWrite};
@ -94,19 +62,17 @@ mod encryption {
pub type AutoStream<S> = S;
pub fn wrap_stream<S>(
pub async fn wrap_stream<S>(
socket: S,
_domain: String,
mode: Mode,
) -> Box<Future<Item = AutoStream<S>, Error = Error> + Send>
) -> Result<AutoStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Box::new(future::ok(socket)),
Mode::Tls => Box::new(future::err(Error::Url(
"TLS support not compiled in.".into(),
))),
Mode::Plain => Ok(socket),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
}
@ -124,59 +90,42 @@ fn domain(request: &Request) -> Result<String, Error> {
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub fn client_async_tls<R, S>(
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<S>>, Response), Error = Error> + Send>
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>>,
S: 'static + AsyncRead + AsyncWrite + NoDelay + Send,
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
let request: Request = request.into();
let domain = match domain(&request) {
Ok(domain) => domain,
Err(err) => return Box::new(future::err(err)),
};
let domain = domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = match url_mode(&request.url) {
Ok(m) => m,
Err(e) => return Box::new(future::err(e)),
};
Box::new(
wrap_stream(stream, domain, mode)
.and_then(|mut stream| {
NoDelay::set_nodelay(&mut stream, true)
.map(move |()| stream)
.map_err(|e| e.into())
})
.and_then(move |stream| client_async(request, stream)),
)
let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, mode).await?;
client_async(request, stream).await
}
/// Connect to a given URL.
pub fn connect_async<R>(
pub async fn connect_async<R>(
request: R,
) -> Box<dyn Future<Item = (WebSocketStream<AutoStream<TcpStream>>, Response), Error = Error> + Send>
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>>,
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = match domain(&request) {
Ok(domain) => domain,
Err(err) => return Box::new(future::err(err)),
};
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
Box::new(
tokio_dns::TcpStream::connect((domain.as_str(), port))
.map_err(|e| e.into())
.and_then(move |socket| client_async_tls(request, socket)),
)
let try_socket = tokio_dns::TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls(request, socket).await
}

@ -0,0 +1,181 @@
use crate::compat::{AllowStd, HasContext};
use crate::WebSocketStream;
use log::*;
use pin_project::pin_project;
use std::future::Future;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::handshake::client::Response;
use tungstenite::handshake::server::Callback;
use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake};
use tungstenite::{ClientHandshake, ServerHandshake, WebSocket};
pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
let ws = start.await;
WebSocketStream::new(ws)
}
struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
struct SkippedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
impl<F, S> Future for SkippedHandshakeFuture<F, S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = WebSocket<AllowStd<S>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self
.get_mut()
.0
.take()
.expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
};
Poll::Ready((inner.f)(stream))
}
}
#[pin_project]
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
enum StartedHandshake<Role: HandshakeRole> {
Done(Role::FinalResult),
Mid(WsHandshake<Role>),
}
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
struct StartedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
match start.await? {
StartedHandshake::Done(r) => Ok(r),
StartedHandshake::Mid(s) => {
let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
res
}
}
}
pub(crate) async fn client_handshake<F, S>(
stream: S,
f: F,
) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
where
F: FnOnce(
AllowStd<S>,
) -> Result<
<ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
Error<ClientHandshake<AllowStd<S>>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let result = handshake(stream, f).await?;
let (s, r) = result;
Ok((WebSocketStream::new(s), r))
}
pub(crate) async fn server_handshake<C, F, S>(
stream: S,
f: F,
) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
where
C: Callback + Unpin,
F: FnOnce(
AllowStd<S>,
) -> Result<
<ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
Error<ServerHandshake<AllowStd<S>, C>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
Ok(WebSocketStream::new(s))
}
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = Result<StartedHandshake<Role>, Error<Role>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
};
match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
}
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
{
type Output = Result<Role::FinalResult, Error<Role>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut s = this.0.take().expect("future polled after completion");
let machine = s.get_mut();
trace!("Setting context in handshake");
machine.get_mut().set_context(cx as *mut _ as *mut ());
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
*this.0 = Some(mid);
Poll::Pending
}
}
}
}

@ -18,26 +18,29 @@
pub use tungstenite;
mod compat;
#[cfg(feature = "connect")]
mod connect;
mod handshake;
#[cfg(feature = "stream")]
pub mod stream;
use std::io::ErrorKind;
#[cfg(feature = "stream")]
use std::{io::Result as IoResult, net::SocketAddr};
use std::io::{Read, Write};
use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
use compat::{cvt, AllowStd};
use futures::Stream;
use log::*;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{
error::Error as WsError,
handshake::{
client::{ClientHandshake, Request, Response},
server::{Callback, NoCallback, ServerHandshake},
HandshakeError, HandshakeRole,
server::{Callback, NoCallback},
},
protocol::{Message, Role, WebSocket, WebSocketConfig},
server,
@ -46,11 +49,10 @@ use tungstenite::{
#[cfg(feature = "connect")]
pub use connect::{client_async_tls, connect_async};
#[cfg(feature = "stream")]
pub use stream::PeerAddr;
#[cfg(all(feature = "connect", feature = "tls"))]
pub use connect::MaybeTlsStream;
use std::error::Error;
use tungstenite::protocol::CloseFrame;
/// Creates a WebSocket handshake from a request and a stream.
/// For convenience, the user may call this with a url string, a URL,
@ -64,30 +66,38 @@ pub use connect::MaybeTlsStream;
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to the remote server.
pub fn client_async<'a, R, S>(request: R, stream: S) -> ConnectAsync<S>
pub async fn client_async<'a, R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: Into<Request<'a>>,
S: AsyncRead + AsyncWrite,
R: Into<Request<'a>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
client_async_with_config(request, stream, None)
client_async_with_config(request, stream, None).await
}
/// The same as `client_async()` but the one can specify a websocket configuration.
/// Please refer to `client_async()` for more details.
pub fn client_async_with_config<'a, R, S>(
pub async fn client_async_with_config<'a, R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> ConnectAsync<S>
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: Into<Request<'a>>,
S: AsyncRead + AsyncWrite,
R: Into<Request<'a>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
ConnectAsync {
inner: MidHandshake {
inner: Some(ClientHandshake::start(stream, request.into(), config).handshake()),
},
}
let f = handshake::client_handshake(stream, move |allow_std| {
let cli_handshake = ClientHandshake::start(allow_std, request.into(), config);
cli_handshake.handshake()
});
f.await.map_err(|e| {
WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.description(),
))
})
}
/// Accepts a new WebSocket connection with the provided stream.
@ -101,23 +111,23 @@ where
/// This is typically used after a socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of the accepting a client's websocket connection.
pub fn accept_async<S>(stream: S) -> AcceptAsync<S, NoCallback>
pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async(stream, NoCallback)
accept_hdr_async(stream, NoCallback).await
}
/// The same as `accept_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_async()` for more details.
pub fn accept_async_with_config<S>(
pub async fn accept_async_with_config<S>(
stream: S,
config: Option<WebSocketConfig>,
) -> AcceptAsync<S, NoCallback>
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async_with_config(stream, NoCallback, config)
accept_hdr_async_with_config(stream, NoCallback, config).await
}
/// Accepts a new WebSocket connection with the provided stream.
@ -125,30 +135,34 @@ where
/// This function does the same as `accept_async()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_async<S, C>(stream: S, callback: C) -> AcceptAsync<S, C>
pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
C: Callback,
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
accept_hdr_async_with_config(stream, callback, None)
accept_hdr_async_with_config(stream, callback, None).await
}
/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_hdr_async()` for more details.
pub fn accept_hdr_async_with_config<S, C>(
pub async fn accept_hdr_async_with_config<S, C>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> AcceptAsync<S, C>
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
C: Callback,
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
AcceptAsync {
inner: MidHandshake {
inner: Some(server::accept_hdr_with_config(stream, callback, config)),
},
}
let f = handshake::server_handshake(stream, move |allow_std| {
server::accept_hdr_with_config(allow_std, callback, config)
});
f.await.map_err(|e| {
WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.description(),
))
})
}
/// A wrapper around an underlying raw stream which implements the WebSocket
@ -160,176 +174,187 @@ where
/// through the respective `Stream` and `Sink`. Check more information about
/// them in `futures-rs` crate documentation or have a look on the examples
/// and unit tests for this crate.
#[pin_project]
pub struct WebSocketStream<S> {
inner: WebSocket<S>,
#[pin]
inner: WebSocket<AllowStd<S>>,
}
impl<S> WebSocketStream<S> {
/// Convert a raw socket into a WebSocketStream without performing a
/// handshake.
pub fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::new(WebSocket::from_raw_socket(stream, role, config))
pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_raw_socket(allow_std, role, config)
})
.await
}
/// Convert a raw socket into a WebSocketStream without performing a
/// handshake.
pub fn from_partially_read(
pub async fn from_partially_read(
stream: S,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
Self::new(WebSocket::from_partially_read(stream, part, role, config))
) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_partially_read(allow_std, part, role, config)
})
.await
}
fn new(ws: WebSocket<S>) -> Self {
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
WebSocketStream { inner: ws }
}
}
#[cfg(feature = "stream")]
impl<S: PeerAddr> PeerAddr for WebSocketStream<S> {
fn peer_addr(&self) -> IoResult<SocketAddr> {
self.inner.get_ref().peer_addr()
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
where
S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
self.inner.get_mut().context = ctx as *mut _ as *mut ();
let mut g = compat::Guard(&mut self.inner);
f(&mut (g.0))
}
}
impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite,
{
type Item = Message;
type Error = WsError;
fn poll(&mut self) -> Poll<Option<Message>, WsError> {
self.inner
.read_message()
.map(Some)
.to_async()
.or_else(|err| match err {
WsError::ConnectionClosed => Ok(Async::Ready(None)),
err => Err(err),
})
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S
where
S: AsyncRead + AsyncWrite + Unpin,
{
&self.inner.get_ref().get_ref()
}
}
impl<T> Sink for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite,
{
type SinkItem = Message;
type SinkError = WsError;
fn start_send(&mut self, item: Message) -> StartSend<Message, WsError> {
self.inner.write_message(item).to_start_send()
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.inner.get_mut().get_mut()
}
fn poll_complete(&mut self) -> Poll<(), WsError> {
self.inner.write_pending().to_async()
/// Send a message to this websocket
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncWrite + AsyncRead + Unpin,
{
let f = SendFuture {
stream: self,
message: Some(msg),
};
f.await
}
fn close(&mut self) -> Poll<(), WsError> {
self.inner.close(None).to_async()
/// Close the underlying web socket
pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let f = CloseFuture {
stream: self,
message: Some(msg),
};
f.await
}
}
/// Future returned from client_async() which will resolve
/// once the connection handshake has finished.
pub struct ConnectAsync<S: AsyncRead + AsyncWrite> {
inner: MidHandshake<ClientHandshake<S>>,
}
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
type Item = (WebSocketStream<S>, Response);
type Error = WsError;
fn poll(&mut self) -> Poll<Self::Item, WsError> {
match self.inner.poll()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready((ws, resp)) => Ok(Async::Ready((WebSocketStream::new(ws), resp))),
impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
{
type Item = Result<Message, WsError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("{}:{} Stream.poll_next", file!(), line!());
match futures::ready!(self.with_context(cx, |s| {
trace!(
"{}:{} Stream.with_context poll_next -> read_message()",
file!(),
line!()
);
cvt(s.read_message())
})) {
Ok(v) => Poll::Ready(Some(Ok(v))),
Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}
/// Future returned from accept_async() which will resolve
/// once the connection handshake has finished.
pub struct AcceptAsync<S: AsyncRead + AsyncWrite, C: Callback> {
inner: MidHandshake<ServerHandshake<S, C>>,
#[pin_project]
struct SendFuture<'a, T> {
stream: &'a mut WebSocketStream<T>,
message: Option<Message>,
}
impl<S: AsyncRead + AsyncWrite, C: Callback> Future for AcceptAsync<S, C> {
type Item = WebSocketStream<S>;
type Error = WsError;
impl<'a, T> Future for SendFuture<'a, T>
where
T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
{
type Output = Result<(), WsError>;
fn poll(&mut self) -> Poll<Self::Item, WsError> {
match self.inner.poll()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(ws) => Ok(Async::Ready(WebSocketStream::new(ws))),
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(cx, |s| s.write_message(message)))
}
}
struct MidHandshake<H: HandshakeRole> {
inner: Option<Result<<H as HandshakeRole>::FinalResult, HandshakeError<H>>>,
#[pin_project]
struct CloseFuture<'a, T> {
stream: &'a mut WebSocketStream<T>,
message: Option<Option<CloseFrame<'a>>>,
}
impl<H: HandshakeRole> Future for MidHandshake<H> {
type Item = <H as HandshakeRole>::FinalResult;
type Error = WsError;
fn poll(&mut self) -> Poll<Self::Item, WsError> {
match self.inner.take().expect("cannot poll MidHandshake twice") {
Ok(result) => Ok(Async::Ready(result)),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Interrupted(s)) => match s.handshake() {
Ok(result) => Ok(Async::Ready(result)),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::Interrupted(s)) => {
self.inner = Some(Err(HandshakeError::Interrupted(s)));
Ok(Async::NotReady)
}
},
}
}
}
trait ToAsync {
type T;
type E;
fn to_async(self) -> Result<Async<Self::T>, Self::E>;
}
impl<'a, T> Future for CloseFuture<'a, T>
where
T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
{
type Output = Result<(), WsError>;
impl<T> ToAsync for Result<T, WsError> {
type T = T;
type E = WsError;
fn to_async(self) -> Result<Async<Self::T>, Self::E> {
match self {
Ok(x) => Ok(Async::Ready(x)),
Err(error) => match error {
WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(Async::NotReady),
err => Err(err),
},
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(cx, |s| s.close(message)))
}
}
trait ToStartSend {
type T;
type E;
fn to_start_send(self) -> StartSend<Self::T, Self::E>;
}
impl ToStartSend for Result<(), WsError> {
type T = Message;
type E = WsError;
fn to_start_send(self) -> StartSend<Self::T, Self::E> {
match self {
Ok(_) => Ok(AsyncSink::Ready),
Err(error) => match error {
WsError::Io(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(AsyncSink::Ready),
WsError::SendQueueFull(msg) => Ok(AsyncSink::NotReady(msg)),
err => Err(err),
},
}
#[cfg(test)]
mod tests {
use crate::compat::AllowStd;
use crate::connect::encryption::AutoStream;
use crate::WebSocketStream;
use std::io::{Read, Write};
use tokio_io::{AsyncReadExt, AsyncWriteExt};
fn is_read<T: Read>() {}
fn is_write<T: Write>() {}
fn is_async_read<T: AsyncReadExt>() {}
fn is_async_write<T: AsyncWriteExt>() {}
fn is_unpin<T: Unpin>() {}
#[test]
fn web_socket_stream_has_traits() {
is_read::<AllowStd<tokio::net::TcpStream>>();
is_write::<AllowStd<tokio::net::TcpStream>>();
is_async_read::<AutoStream<tokio::net::TcpStream>>();
is_async_write::<AutoStream<tokio::net::TcpStream>>();
is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
is_unpin::<WebSocketStream<AutoStream<tokio::net::TcpStream>>>();
is_unpin::<WebSocketStream<AutoStream<tokio_dns::TcpStream>>>();
}
}

@ -3,26 +3,11 @@
//! There is no dependency on actual TLS implementations. Everything like
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use std::pin::Pin;
use std::task::{Context, Poll};
use std::io::{Error as IoError, Read, Result as IoResult, Write};
use std::net::SocketAddr;
use bytes::{Buf, BufMut};
use futures::Poll;
use tokio_io::{AsyncRead, AsyncWrite};
/// Trait to switch TCP_NODELAY.
pub trait NoDelay {
/// Set the TCP_NODELAY option to the given value.
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>;
}
/// Trait to get the remote address from the underlying stream.
pub trait PeerAddr {
/// Returns the remote address that this stream is connected to.
fn peer_addr(&self) -> IoResult<SocketAddr>;
}
/// Stream, either plain TCP or TLS.
pub enum Stream<S, T> {
/// Unencrypted socket stream.
@ -31,74 +16,72 @@ pub enum Stream<S, T> {
Tls(T),
}
impl<S: Read, T: Read> Read for Stream<S, T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.read(buf),
Stream::Tls(ref mut s) => s.read(buf),
}
}
}
impl<S: Write, T: Write> Write for Stream<S, T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.write(buf),
Stream::Tls(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.flush(),
Stream::Tls(ref mut s) => s.flush(),
}
}
}
impl<S: NoDelay, T: NoDelay> NoDelay for Stream<S, T> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
impl<S: AsyncRead + Unpin, T: AsyncRead + Unpin> AsyncRead for Stream<S, T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
match *self {
Stream::Plain(ref mut s) => s.set_nodelay(nodelay),
Stream::Tls(ref mut s) => s.set_nodelay(nodelay),
Stream::Plain(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_read(cx, buf)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_read(cx, buf)
}
}
}
}
impl<S: PeerAddr, T: PeerAddr> PeerAddr for Stream<S, T> {
fn peer_addr(&self) -> IoResult<SocketAddr> {
impl<S: AsyncWrite + Unpin, T: AsyncWrite + Unpin> AsyncWrite for Stream<S, T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match *self {
Stream::Plain(ref s) => s.peer_addr(),
Stream::Tls(ref s) => s.peer_addr(),
Stream::Plain(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_write(cx, buf)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_write(cx, buf)
}
}
}
}
impl<S: AsyncRead, T: AsyncRead> AsyncRead for Stream<S, T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match *self {
Stream::Plain(ref s) => s.prepare_uninitialized_buffer(buf),
Stream::Tls(ref s) => s.prepare_uninitialized_buffer(buf),
Stream::Plain(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_flush(cx)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_flush(cx)
}
}
}
fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, IoError> {
match *self {
Stream::Plain(ref mut s) => s.read_buf(buf),
Stream::Tls(ref mut s) => s.read_buf(buf),
}
}
}
impl<S: AsyncWrite, T: AsyncWrite> AsyncWrite for Stream<S, T> {
fn shutdown(&mut self) -> Poll<(), IoError> {
match *self {
Stream::Plain(ref mut s) => s.shutdown(),
Stream::Tls(ref mut s) => s.shutdown(),
}
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, IoError> {
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match *self {
Stream::Plain(ref mut s) => s.write_buf(buf),
Stream::Tls(ref mut s) => s.write_buf(buf),
Stream::Plain(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_shutdown(cx)
}
Stream::Tls(ref mut s) => {
let pinned = unsafe { Pin::new_unchecked(s) };
pinned.poll_shutdown(cx)
}
}
}
}

@ -0,0 +1,79 @@
use futures::StreamExt;
use log::*;
use std::net::ToSocketAddrs;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::tcp::{TcpListener, TcpStream};
use tokio_tungstenite::{accept_async, client_async, WebSocketStream};
use tungstenite::Message;
async fn run_connection<S>(
connection: WebSocketStream<S>,
msg_tx: futures::channel::oneshot::Sender<Vec<Message>>,
) where
S: AsyncRead + AsyncWrite + Unpin,
{
info!("Running connection");
let mut connection = connection;
let mut messages = vec![];
while let Some(message) = connection.next().await {
info!("Message received");
let message = message.expect("Failed to get message");
messages.push(message);
}
msg_tx.send(messages).expect("Failed to send results");
}
#[tokio::test]
async fn communication() {
let _ = env_logger::try_init();
let (con_tx, con_rx) = futures::channel::oneshot::channel();
let (msg_tx, msg_rx) = futures::channel::oneshot::channel();
let f = async move {
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let listener = TcpListener::bind(&address).await.unwrap();
let mut connections = listener.incoming();
info!("Server ready");
con_tx.send(()).unwrap();
info!("Waiting on next connection");
let connection = connections.next().await.expect("No connections to accept");
let connection = connection.expect("Failed to accept connection");
let stream = accept_async(connection).await;
let stream = stream.expect("Failed to handshake with connection");
run_connection(stream, msg_tx).await;
};
tokio::spawn(f);
info!("Waiting for server to be ready");
con_rx.await.expect("Server not ready");
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let tcp = TcpStream::connect(&address)
.await
.expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap();
let (mut stream, _) = client_async(url, tcp)
.await
.expect("Client failed to connect");
for i in 1..10 {
info!("Sending message");
stream.send(Message::Text(format!("{}", i))).await.expect("Failed to send message");
}
stream.close(None).await.expect("Failed to close");
info!("Waiting for response messages");
let messages = msg_rx.await.expect("Failed to receive messages");
assert_eq!(messages.len(), 10);
}

@ -1,36 +1,41 @@
use std::io;
use futures::{Future, Stream};
use tokio_tcp::{TcpListener, TcpStream};
use futures::StreamExt;
use std::net::ToSocketAddrs;
use tokio::net::tcp::{TcpListener, TcpStream};
use tokio_tungstenite::{accept_async, client_async};
#[test]
fn handshakes() {
use std::sync::mpsc::channel;
use std::thread;
let (tx, rx) = channel();
#[tokio::test]
async fn handshakes() {
let (tx, rx) = futures::channel::oneshot::channel();
thread::spawn(move || {
let address = "0.0.0.0:12345".parse().unwrap();
let listener = TcpListener::bind(&address).unwrap();
let connections = listener.incoming();
let f = async move {
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let listener = TcpListener::bind(&address).await.unwrap();
let mut connections = listener.incoming();
tx.send(()).unwrap();
let handshakes = connections.and_then(|connection| {
accept_async(connection).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
});
let server = handshakes.for_each(|_| Ok(()));
while let Some(connection) = connections.next().await {
let connection = connection.expect("Failed to accept connection");
let stream = accept_async(connection).await;
stream.expect("Failed to handshake with connection");
}
};
server.wait().unwrap();
});
tokio::spawn(f);
rx.recv().unwrap();
let address = "0.0.0.0:12345".parse().unwrap();
let tcp = TcpStream::connect(&address);
let handshake = tcp.and_then(|stream| {
let url = url::Url::parse("ws://localhost:12345/").unwrap();
client_async(url, stream).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
});
let client = handshake.and_then(|_| Ok(()));
client.wait().unwrap();
rx.await.expect("Failed to wait for server to be ready");
let address = "0.0.0.0:12345"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let tcp = TcpStream::connect(&address)
.await
.expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap();
let _stream = client_async(url, tcp)
.await
.expect("Client failed to connect");
}

Loading…
Cancel
Save