From 3091d1156679eef073fafa35b0b1761f752c6e8d Mon Sep 17 00:00:00 2001 From: Alexey Galakhov Date: Tue, 5 Sep 2017 21:50:06 +0200 Subject: [PATCH] callback static dispatch Signed-off-by: Alexey Galakhov --- Cargo.toml | 2 +- examples/autobahn-server.rs | 2 +- examples/server.rs | 4 +-- src/handshake/server.rs | 53 ++++++++++++++++++++++++------------- src/lib.rs | 2 +- src/server.rs | 21 +++++++++++---- 6 files changed, 56 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a294e81..2c58e0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ authors = ["Alexey Galakhov"] license = "MIT/Apache-2.0" readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" -documentation = "https://docs.rs/tungstenite/0.4.0" +documentation = "https://docs.rs/tungstenite/0.5.0" repository = "https://github.com/snapview/tungstenite-rs" version = "0.5.0" diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 79fb5e2..697d880 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -16,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream, None).map_err(must_not_block)?; + let mut socket = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/server.rs b/examples/server.rs index 5b6bfee..86df2fe 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -3,7 +3,7 @@ extern crate tungstenite; use std::thread::spawn; use std::net::TcpListener; -use tungstenite::accept; +use tungstenite::accept_hdr; use tungstenite::handshake::server::Request; fn main() { @@ -25,7 +25,7 @@ fn main() { ]; Ok(Some(extra_headers)) }; - let mut websocket = accept(stream.unwrap(), Some(Box::new(callback))).unwrap(); + let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); loop { let msg = websocket.read_message().unwrap(); diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 510126c..a59eded 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -3,7 +3,6 @@ use std::fmt::Write as FmtWrite; use std::io::{Read, Write}; use std::marker::PhantomData; -use std::mem::replace; use httparse; use httparse::Status; @@ -71,43 +70,61 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } -/// The callback type, the callback is called when the server receives an incoming WebSocket -/// handshake request from the client, specifying a callback allows you to analyze incoming headers -/// and add additional headers to the response that server sends to the client and/or reject the -/// connection based on the incoming headers. Due to usability problems which are caused by a -/// static dispatch when using callbacks in such places, the callback is boxed. +/// The callback trait. /// -/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current -/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets -/// stable. -pub type Callback = Box Result>>>; +/// The callback is called when the server receives an incoming WebSocket +/// handshake request from the client. Specifying a callback allows you to analyze incoming headers +/// and add additional headers to the response that server sends to the client and/or reject the +/// connection based on the incoming headers. +pub trait Callback: Sized { + /// Called whenever the server read the request from the client and is ready to reply to it. + /// May return additional reply headers. + /// Returning an error resulting in rejecting the incoming connection. + fn on_request(self, request: &Request) -> Result>>; +} + +impl Callback for F where F: FnOnce(&Request) -> Result>> { + fn on_request(self, request: &Request) -> Result>> { + self(request) + } +} + +/// Stub for callback that does nothing. +#[derive(Clone, Copy)] +pub struct NoCallback; + +impl Callback for NoCallback { + fn on_request(self, _request: &Request) -> Result>> { + Ok(None) + } +} /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake { +pub struct ServerHandshake { /// Callback which is called whenever the server read the request from the client and is ready /// to reply to it. The callback returns an optional headers which will be added to the reply /// which the server sends to the user. - callback: Option, + callback: Option, /// Internal stream type. _marker: PhantomData, } -impl ServerHandshake { +impl ServerHandshake { /// Start server handshake. `callback` specifies a custom callback which the user can pass to /// the handshake, this callback will be called when the a websocket client connnects to the /// server, you can specify the callback if you want to add additional header to the client /// upon join based on the incoming headers. - pub fn start(stream: S, callback: Option) -> MidHandshake { + pub fn start(stream: S, callback: C) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { callback, _marker: PhantomData }, + role: ServerHandshake { callback: Some(callback), _marker: PhantomData }, } } } -impl HandshakeRole for ServerHandshake { +impl HandshakeRole for ServerHandshake { type IncomingData = Request; type InternalStream = S; type FinalResult = WebSocket; @@ -121,8 +138,8 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol("Junk after client request".into())) } let extra_headers = { - if let Some(mut callback) = replace(&mut self.callback, None) { - callback(&result)? + if let Some(callback) = self.callback.take() { + callback.on_request(&result)? } else { None } diff --git a/src/lib.rs b/src/lib.rs index 824934c..e3ca3d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,7 @@ pub mod util; mod input_buffer; pub use client::{connect, client}; -pub use server::accept; +pub use server::{accept, accept_hdr}; pub use error::{Error, Result}; pub use protocol::{WebSocket, Message}; pub use handshake::HandshakeError; diff --git a/src/server.rs b/src/server.rs index b0d191d..68e026f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,8 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; -use handshake::server::Callback; +use handshake::server::{Callback, NoCallback}; + use protocol::WebSocket; use std::io::{Read, Write}; @@ -13,10 +14,20 @@ use std::io::{Read, Write}; /// This function starts a server WebSocket handshake over the given stream. /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. You can also pass an optional `callback` which will -/// be called when the websocket request is received from an incoming client. -pub fn accept(stream: S, callback: Option) - -> Result, HandshakeError>> +/// those from `Mio` and others. +pub fn accept(stream: S) + -> Result, HandshakeError>> +{ + accept_hdr(stream, NoCallback) +} + +/// Accept the given Stream as a WebSocket. +/// +/// This function does the same as `accept()` but accepts an extra callback +/// for header processing. The callback receives headers of the incoming +/// requests and is able to add extra headers to the reply. +pub fn accept_hdr(stream: S, callback: C) + -> Result, HandshakeError>> { ServerHandshake::start(stream, callback).handshake() }