diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3492c7b..79fb5e2 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream}; use std::thread::spawn; use tungstenite::{accept, HandshakeError, Error, Result, Message}; +use tungstenite::handshake::HandshakeRole; -fn must_not_block(err: HandshakeError) -> Error { +fn must_not_block(err: HandshakeError) -> Error { match err { HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), HandshakeError::Failure(f) => f, @@ -15,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let (mut socket, _) = accept(stream).map_err(must_not_block)?; + let mut socket = accept(stream, None).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/src/client.rs b/src/client.rs index 33823ce..2326ac7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,7 +6,7 @@ use std::io::{Read, Write}; use url::Url; -use handshake::headers::Headers; +use handshake::client::Response; #[cfg(feature="tls")] mod encryption { @@ -78,7 +78,7 @@ use error::{Error, Result}; /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect<'t, Req: Into>>(request: Req) - -> Result<(WebSocket, Headers)> + -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; @@ -124,10 +124,13 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult<(WebSocket, Headers), HandshakeError> -where Stream: Read + Write, - Req: Into>, +pub fn client<'t, Stream, Req>( + request: Req, + stream: Stream + ) -> StdResult<(WebSocket, Response), HandshakeError>> +where + Stream: Read + Write, + Req: Into>, { ClientHandshake::start(stream, request.into()).handshake() } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8fa9b82..cd5e6d0 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,18 +1,19 @@ //! Client handshake machine. +use std::io::{Read, Write}; +use std::marker::PhantomData; + use base64; -use rand; -use httparse; use httparse::Status; -use std::io::Write; +use httparse; +use rand; use url::Url; use error::{Error, Result}; use protocol::{WebSocket, Role}; - use super::headers::{Headers, FromHttparse, MAX_HEADERS}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::machine::{HandshakeMachine, StageResult, TryParse}; +use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; /// Client request. pub struct Request<'t> { @@ -52,13 +53,14 @@ impl From for Request<'static> { } /// Client handshake role. -pub struct ClientHandshake { +pub struct ClientHandshake { verify_data: VerifyData, + _marker: PhantomData, } -impl ClientHandshake { +impl ClientHandshake { /// Initiate a client handshake. - pub fn start(stream: Stream, request: Request) -> MidHandshake { + pub fn start(stream: S, request: Request) -> MidHandshake { let key = generate_key(); let machine = { @@ -83,9 +85,8 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); ClientHandshake { - verify_data: VerifyData { - accept_key: accept_key, - }, + verify_data: VerifyData { accept_key }, + _marker: PhantomData, } }; @@ -94,10 +95,12 @@ impl ClientHandshake { } } -impl HandshakeRole for ClientHandshake { +impl HandshakeRole for ClientHandshake { type IncomingData = Response; - fn stage_finished(&mut self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = (WebSocket, Response); + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneWriting(stream) => { @@ -106,8 +109,8 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client), - result.headers) + ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), + result)) } }) } @@ -167,8 +170,10 @@ impl VerifyData { /// Server response. pub struct Response { - code: u16, - headers: Headers, + /// HTTP response code of the response. + pub code: u16, + /// Received headers. + pub headers: Headers, } impl TryParse for Response { @@ -204,7 +209,6 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::{Response, generate_key}; use super::super::machine::TryParse; @@ -231,5 +235,4 @@ mod tests { assert_eq!(resp.code, 200); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); } - } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 28aa0f6..a6d9192 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -14,31 +14,17 @@ use base64; use sha1::Sha1; use error::Error; -use protocol::WebSocket; - -use self::headers::Headers; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. -pub struct MidHandshake { +pub struct MidHandshake { role: Role, - machine: HandshakeMachine, -} - -impl MidHandshake { - /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &Stream { - self.machine.get_ref() - } - /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut Stream { - self.machine.get_mut() - } + machine: HandshakeMachine, } -impl MidHandshake { +impl MidHandshake { /// Restarts the handshake process. - pub fn handshake(mut self) -> Result<(WebSocket, Headers), HandshakeError> { + pub fn handshake(mut self) -> Result> { let mut mach = self.machine; loop { mach = match mach.single_round()? { @@ -49,7 +35,7 @@ impl MidHandshake { RoundResult::StageFinished(s) => { match self.role.stage_finished(s)? { ProcessingResult::Continue(m) => m, - ProcessingResult::Done(ws, headers) => return Ok((ws, headers)), + ProcessingResult::Done(result) => return Ok(result), } } } @@ -58,14 +44,14 @@ impl MidHandshake { } /// A handshake result. -pub enum HandshakeError { +pub enum HandshakeError { /// Handshake was interrupted (would block). - Interrupted(MidHandshake), + Interrupted(MidHandshake), /// Handshake failed. Failure(Error), } -impl fmt::Debug for HandshakeError { +impl fmt::Debug for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"), @@ -74,7 +60,7 @@ impl fmt::Debug for HandshakeError { } } -impl fmt::Display for HandshakeError { +impl fmt::Display for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"), @@ -83,7 +69,7 @@ impl fmt::Display for HandshakeError { } } -impl ErrorTrait for HandshakeError { +impl ErrorTrait for HandshakeError { fn description(&self) -> &str { match *self { HandshakeError::Interrupted(_) => "Interrupted handshake", @@ -92,7 +78,7 @@ impl ErrorTrait for HandshakeError { } } -impl From for HandshakeError { +impl From for HandshakeError { fn from(err: Error) -> Self { HandshakeError::Failure(err) } @@ -103,15 +89,19 @@ pub trait HandshakeRole { #[doc(hidden)] type IncomingData: TryParse; #[doc(hidden)] - fn stage_finished(&mut self, finish: StageResult) - -> Result, Error>; + type InternalStream: Read + Write; + #[doc(hidden)] + type FinalResult; + #[doc(hidden)] + fn stage_finished(&mut self, finish: StageResult) + -> Result, Error>; } /// Stage processing result. #[doc(hidden)] -pub enum ProcessingResult { +pub enum ProcessingResult { Continue(HandshakeMachine), - Done(WebSocket, Headers), + Done(FinalResult), } /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. @@ -127,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result { #[cfg(test)] mod tests { - use super::convert_key; #[test] diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a09ce65..510126c 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,5 +1,10 @@ //! Server handshake machine. +use std::fmt::Write as FmtWrite; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::mem::replace; + use httparse; use httparse::Status; @@ -19,15 +24,23 @@ pub struct Request { impl Request { /// Reply to the response. - pub fn reply(&self) -> Result> { + pub fn reply(&self, extra_headers: Option>) -> Result> { let key = self.headers.find_first("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let reply = format!("\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n\ - \r\n", convert_key(key)?); + let mut reply = format!( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", + convert_key(key)? + ); + if let Some(eh) = extra_headers { + for (k, v) in eh { + write!(reply, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(reply, "\r\n").unwrap(); Ok(reply.into()) } } @@ -58,42 +71,68 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } +/// The callback type, the callback is called when the server receives an incoming WebSocket +/// handshake request from the client, specifying a callback allows you to analyze incoming headers +/// and add additional headers to the response that server sends to the client and/or reject the +/// connection based on the incoming headers. Due to usability problems which are caused by a +/// static dispatch when using callbacks in such places, the callback is boxed. +/// +/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current +/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets +/// stable. +pub type Callback = Box Result>>>; + /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake { - /// Incoming headers, received from the client. - headers: Option, +pub struct ServerHandshake { + /// Callback which is called whenever the server read the request from the client and is ready + /// to reply to it. The callback returns an optional headers which will be added to the reply + /// which the server sends to the user. + callback: Option, + /// Internal stream type. + _marker: PhantomData, } -impl ServerHandshake { - /// Start server handshake. - pub fn start(stream: Stream) -> MidHandshake { +impl ServerHandshake { + /// Start server handshake. `callback` specifies a custom callback which the user can pass to + /// the handshake, this callback will be called when the a websocket client connnects to the + /// server, you can specify the callback if you want to add additional header to the client + /// upon join based on the incoming headers. + pub fn start(stream: S, callback: Option) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { headers: None }, + role: ServerHandshake { callback, _marker: PhantomData }, } } } -impl HandshakeRole for ServerHandshake { +impl HandshakeRole for ServerHandshake { type IncomingData = Request; - fn stage_finished(&mut self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = WebSocket; + + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { - if ! tail.is_empty() { + if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())) } - let response = result.reply()?; - self.headers = Some(result.headers); + let extra_headers = { + if let Some(mut callback) = replace(&mut self.callback, None) { + callback(&result)? + } else { + None + } + }; + let response = result.reply(extra_headers)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } StageResult::DoneWriting(stream) => { debug!("Server handshake done."); - let headers = self.headers.take().expect("Bug: accepted client without headers"); - ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers) + ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) } }) } @@ -101,9 +140,9 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::Request; use super::super::machine::TryParse; + use super::super::client::Response; #[test] fn request_parsing() { @@ -124,7 +163,15 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply().unwrap(); - } + let _ = req.reply(None).unwrap(); + let extra_headers = Some(vec![(String::from("MyCustomHeader"), + String::from("MyCustomValue")), + (String::from("MyVersion"), + String::from("LOL"))]); + let reply = req.reply(extra_headers).unwrap(); + let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); + assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); + assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); + } } diff --git a/src/server.rs b/src/server.rs index 04083f5..b0d191d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; -use handshake::headers::Headers; +use handshake::server::Callback; use protocol::WebSocket; use std::io::{Read, Write}; @@ -13,9 +13,10 @@ use std::io::{Read, Write}; /// This function starts a server WebSocket handshake over the given stream. /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. -pub fn accept(stream: Stream) - -> Result<(WebSocket, Headers), HandshakeError> +/// those from `Mio` and others. You can also pass an optional `callback` which will +/// be called when the websocket request is received from an incoming client. +pub fn accept(stream: S, callback: Option) + -> Result, HandshakeError>> { - ServerHandshake::start(stream).handshake() + ServerHandshake::start(stream, callback).handshake() }