From f34c4882172f04a5ba1310411413c6dd6fcb1c71 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 30 Jun 2017 15:15:24 +0200 Subject: [PATCH 1/5] Add basic support for examining headers (#6) --- examples/autobahn-client.rs | 6 +++--- examples/autobahn-server.rs | 2 +- examples/client.rs | 2 +- src/client.rs | 8 ++++++-- src/handshake/client.rs | 5 +++-- src/handshake/mod.rs | 9 +++++---- src/handshake/server.rs | 13 +++++++++---- src/server.rs | 3 ++- 8 files changed, 30 insertions(+), 18 deletions(-) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 700e2d7..8a08e4b 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message}; const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse("ws://localhost:9001/getCaseCount").unwrap() )?; let msg = socket.read_message()?; @@ -19,7 +19,7 @@ fn get_case_count() -> Result { } fn update_reports() -> Result<()> { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() )?; socket.close(None)?; @@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let mut socket = connect(case_url)?; + let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 458fa26..3492c7b 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -15,7 +15,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream).map_err(must_not_block)?; + let (mut socket, _) = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/client.rs b/examples/client.rs index 1c97aba..ca35805 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,7 +8,7 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init().unwrap(); - let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + let (mut socket, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) .expect("Can't connect"); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); diff --git a/src/client.rs b/src/client.rs index 3274ad0..33823ce 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,6 +6,8 @@ use std::io::{Read, Write}; use url::Url; +use handshake::headers::Headers; + #[cfg(feature="tls")] mod encryption { use std::net::TcpStream; @@ -75,7 +77,9 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) -> Result> { +pub fn connect<'t, Req: Into>>(request: Req) + -> Result<(WebSocket, Headers)> +{ let request: Request = request.into(); let mode = url_mode(&request.url)?; let addrs = request.url.to_socket_addrs()?; @@ -121,7 +125,7 @@ pub fn url_mode(url: &Url) -> Result { /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult, HandshakeError> + -> StdResult<(WebSocket, Headers), HandshakeError> where Stream: Read + Write, Req: Into>, { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index ff07d37..8fa9b82 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -96,7 +96,7 @@ impl ClientHandshake { impl HandshakeRole for ClientHandshake { type IncomingData = Response; - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result> { Ok(match finish { @@ -106,7 +106,8 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) + ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client), + result.headers) } }) } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 64d31f1..28aa0f6 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -16,6 +16,7 @@ use sha1::Sha1; use error::Error; use protocol::WebSocket; +use self::headers::Headers; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. @@ -37,7 +38,7 @@ impl MidHandshake { impl MidHandshake { /// Restarts the handshake process. - pub fn handshake(self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result<(WebSocket, Headers), HandshakeError> { let mut mach = self.machine; loop { mach = match mach.single_round()? { @@ -48,7 +49,7 @@ impl MidHandshake { RoundResult::StageFinished(s) => { match self.role.stage_finished(s)? { ProcessingResult::Continue(m) => m, - ProcessingResult::Done(ws) => return Ok(ws), + ProcessingResult::Done(ws, headers) => return Ok((ws, headers)), } } } @@ -102,7 +103,7 @@ pub trait HandshakeRole { #[doc(hidden)] type IncomingData: TryParse; #[doc(hidden)] - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result, Error>; } @@ -110,7 +111,7 @@ pub trait HandshakeRole { #[doc(hidden)] pub enum ProcessingResult { Continue(HandshakeMachine), - Done(WebSocket), + Done(WebSocket, Headers), } /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a3a5fab..a09ce65 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -60,7 +60,10 @@ impl<'h, 'b: 'h> FromHttparse> for Request { /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake; +pub struct ServerHandshake { + /// Incoming headers, received from the client. + headers: Option, +} impl ServerHandshake { /// Start server handshake. @@ -68,14 +71,14 @@ impl ServerHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake, + role: ServerHandshake { headers: None }, } } } impl HandshakeRole for ServerHandshake { type IncomingData = Request; - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result> { Ok(match finish { @@ -84,11 +87,13 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol("Junk after client request".into())) } let response = result.reply()?; + self.headers = Some(result.headers); ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } StageResult::DoneWriting(stream) => { debug!("Server handshake done."); - ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) + let headers = self.headers.take().expect("Bug: accepted client without headers"); + ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers) } }) } diff --git a/src/server.rs b/src/server.rs index 3217006..04083f5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; +use handshake::headers::Headers; use protocol::WebSocket; use std::io::{Read, Write}; @@ -14,7 +15,7 @@ use std::io::{Read, Write}; /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. pub fn accept(stream: Stream) - -> Result, HandshakeError> + -> Result<(WebSocket, Headers), HandshakeError> { ServerHandshake::start(stream).handshake() } From 44a15c9eabac16de1fcd3a7cee2ee8c39fe22a6f Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 25 Jul 2017 09:47:17 +0200 Subject: [PATCH 2/5] Implements #6 (client/server headers access) --- examples/autobahn-server.rs | 5 +- src/client.rs | 15 +++--- src/handshake/client.rs | 43 ++++++++-------- src/handshake/mod.rs | 49 ++++++++----------- src/handshake/server.rs | 97 +++++++++++++++++++++++++++---------- src/server.rs | 11 +++-- 6 files changed, 132 insertions(+), 88 deletions(-) diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3492c7b..79fb5e2 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream}; use std::thread::spawn; use tungstenite::{accept, HandshakeError, Error, Result, Message}; +use tungstenite::handshake::HandshakeRole; -fn must_not_block(err: HandshakeError) -> Error { +fn must_not_block(err: HandshakeError) -> Error { match err { HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), HandshakeError::Failure(f) => f, @@ -15,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let (mut socket, _) = accept(stream).map_err(must_not_block)?; + let mut socket = accept(stream, None).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/src/client.rs b/src/client.rs index 33823ce..2326ac7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,7 +6,7 @@ use std::io::{Read, Write}; use url::Url; -use handshake::headers::Headers; +use handshake::client::Response; #[cfg(feature="tls")] mod encryption { @@ -78,7 +78,7 @@ use error::{Error, Result}; /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect<'t, Req: Into>>(request: Req) - -> Result<(WebSocket, Headers)> + -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; @@ -124,10 +124,13 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult<(WebSocket, Headers), HandshakeError> -where Stream: Read + Write, - Req: Into>, +pub fn client<'t, Stream, Req>( + request: Req, + stream: Stream + ) -> StdResult<(WebSocket, Response), HandshakeError>> +where + Stream: Read + Write, + Req: Into>, { ClientHandshake::start(stream, request.into()).handshake() } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8fa9b82..cd5e6d0 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,18 +1,19 @@ //! Client handshake machine. +use std::io::{Read, Write}; +use std::marker::PhantomData; + use base64; -use rand; -use httparse; use httparse::Status; -use std::io::Write; +use httparse; +use rand; use url::Url; use error::{Error, Result}; use protocol::{WebSocket, Role}; - use super::headers::{Headers, FromHttparse, MAX_HEADERS}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::machine::{HandshakeMachine, StageResult, TryParse}; +use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; /// Client request. pub struct Request<'t> { @@ -52,13 +53,14 @@ impl From for Request<'static> { } /// Client handshake role. -pub struct ClientHandshake { +pub struct ClientHandshake { verify_data: VerifyData, + _marker: PhantomData, } -impl ClientHandshake { +impl ClientHandshake { /// Initiate a client handshake. - pub fn start(stream: Stream, request: Request) -> MidHandshake { + pub fn start(stream: S, request: Request) -> MidHandshake { let key = generate_key(); let machine = { @@ -83,9 +85,8 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); ClientHandshake { - verify_data: VerifyData { - accept_key: accept_key, - }, + verify_data: VerifyData { accept_key }, + _marker: PhantomData, } }; @@ -94,10 +95,12 @@ impl ClientHandshake { } } -impl HandshakeRole for ClientHandshake { +impl HandshakeRole for ClientHandshake { type IncomingData = Response; - fn stage_finished(&mut self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = (WebSocket, Response); + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneWriting(stream) => { @@ -106,8 +109,8 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client), - result.headers) + ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), + result)) } }) } @@ -167,8 +170,10 @@ impl VerifyData { /// Server response. pub struct Response { - code: u16, - headers: Headers, + /// HTTP response code of the response. + pub code: u16, + /// Received headers. + pub headers: Headers, } impl TryParse for Response { @@ -204,7 +209,6 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::{Response, generate_key}; use super::super::machine::TryParse; @@ -231,5 +235,4 @@ mod tests { assert_eq!(resp.code, 200); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); } - } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 28aa0f6..a6d9192 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -14,31 +14,17 @@ use base64; use sha1::Sha1; use error::Error; -use protocol::WebSocket; - -use self::headers::Headers; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. -pub struct MidHandshake { +pub struct MidHandshake { role: Role, - machine: HandshakeMachine, -} - -impl MidHandshake { - /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &Stream { - self.machine.get_ref() - } - /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut Stream { - self.machine.get_mut() - } + machine: HandshakeMachine, } -impl MidHandshake { +impl MidHandshake { /// Restarts the handshake process. - pub fn handshake(mut self) -> Result<(WebSocket, Headers), HandshakeError> { + pub fn handshake(mut self) -> Result> { let mut mach = self.machine; loop { mach = match mach.single_round()? { @@ -49,7 +35,7 @@ impl MidHandshake { RoundResult::StageFinished(s) => { match self.role.stage_finished(s)? { ProcessingResult::Continue(m) => m, - ProcessingResult::Done(ws, headers) => return Ok((ws, headers)), + ProcessingResult::Done(result) => return Ok(result), } } } @@ -58,14 +44,14 @@ impl MidHandshake { } /// A handshake result. -pub enum HandshakeError { +pub enum HandshakeError { /// Handshake was interrupted (would block). - Interrupted(MidHandshake), + Interrupted(MidHandshake), /// Handshake failed. Failure(Error), } -impl fmt::Debug for HandshakeError { +impl fmt::Debug for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"), @@ -74,7 +60,7 @@ impl fmt::Debug for HandshakeError { } } -impl fmt::Display for HandshakeError { +impl fmt::Display for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"), @@ -83,7 +69,7 @@ impl fmt::Display for HandshakeError { } } -impl ErrorTrait for HandshakeError { +impl ErrorTrait for HandshakeError { fn description(&self) -> &str { match *self { HandshakeError::Interrupted(_) => "Interrupted handshake", @@ -92,7 +78,7 @@ impl ErrorTrait for HandshakeError { } } -impl From for HandshakeError { +impl From for HandshakeError { fn from(err: Error) -> Self { HandshakeError::Failure(err) } @@ -103,15 +89,19 @@ pub trait HandshakeRole { #[doc(hidden)] type IncomingData: TryParse; #[doc(hidden)] - fn stage_finished(&mut self, finish: StageResult) - -> Result, Error>; + type InternalStream: Read + Write; + #[doc(hidden)] + type FinalResult; + #[doc(hidden)] + fn stage_finished(&mut self, finish: StageResult) + -> Result, Error>; } /// Stage processing result. #[doc(hidden)] -pub enum ProcessingResult { +pub enum ProcessingResult { Continue(HandshakeMachine), - Done(WebSocket, Headers), + Done(FinalResult), } /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. @@ -127,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result { #[cfg(test)] mod tests { - use super::convert_key; #[test] diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a09ce65..510126c 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,5 +1,10 @@ //! Server handshake machine. +use std::fmt::Write as FmtWrite; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::mem::replace; + use httparse; use httparse::Status; @@ -19,15 +24,23 @@ pub struct Request { impl Request { /// Reply to the response. - pub fn reply(&self) -> Result> { + pub fn reply(&self, extra_headers: Option>) -> Result> { let key = self.headers.find_first("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let reply = format!("\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n\ - \r\n", convert_key(key)?); + let mut reply = format!( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", + convert_key(key)? + ); + if let Some(eh) = extra_headers { + for (k, v) in eh { + write!(reply, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(reply, "\r\n").unwrap(); Ok(reply.into()) } } @@ -58,42 +71,68 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } +/// The callback type, the callback is called when the server receives an incoming WebSocket +/// handshake request from the client, specifying a callback allows you to analyze incoming headers +/// and add additional headers to the response that server sends to the client and/or reject the +/// connection based on the incoming headers. Due to usability problems which are caused by a +/// static dispatch when using callbacks in such places, the callback is boxed. +/// +/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current +/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets +/// stable. +pub type Callback = Box Result>>>; + /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake { - /// Incoming headers, received from the client. - headers: Option, +pub struct ServerHandshake { + /// Callback which is called whenever the server read the request from the client and is ready + /// to reply to it. The callback returns an optional headers which will be added to the reply + /// which the server sends to the user. + callback: Option, + /// Internal stream type. + _marker: PhantomData, } -impl ServerHandshake { - /// Start server handshake. - pub fn start(stream: Stream) -> MidHandshake { +impl ServerHandshake { + /// Start server handshake. `callback` specifies a custom callback which the user can pass to + /// the handshake, this callback will be called when the a websocket client connnects to the + /// server, you can specify the callback if you want to add additional header to the client + /// upon join based on the incoming headers. + pub fn start(stream: S, callback: Option) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { headers: None }, + role: ServerHandshake { callback, _marker: PhantomData }, } } } -impl HandshakeRole for ServerHandshake { +impl HandshakeRole for ServerHandshake { type IncomingData = Request; - fn stage_finished(&mut self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = WebSocket; + + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { - if ! tail.is_empty() { + if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())) } - let response = result.reply()?; - self.headers = Some(result.headers); + let extra_headers = { + if let Some(mut callback) = replace(&mut self.callback, None) { + callback(&result)? + } else { + None + } + }; + let response = result.reply(extra_headers)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } StageResult::DoneWriting(stream) => { debug!("Server handshake done."); - let headers = self.headers.take().expect("Bug: accepted client without headers"); - ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers) + ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) } }) } @@ -101,9 +140,9 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::Request; use super::super::machine::TryParse; + use super::super::client::Response; #[test] fn request_parsing() { @@ -124,7 +163,15 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply().unwrap(); - } + let _ = req.reply(None).unwrap(); + let extra_headers = Some(vec![(String::from("MyCustomHeader"), + String::from("MyCustomValue")), + (String::from("MyVersion"), + String::from("LOL"))]); + let reply = req.reply(extra_headers).unwrap(); + let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); + assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); + assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); + } } diff --git a/src/server.rs b/src/server.rs index 04083f5..b0d191d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; -use handshake::headers::Headers; +use handshake::server::Callback; use protocol::WebSocket; use std::io::{Read, Write}; @@ -13,9 +13,10 @@ use std::io::{Read, Write}; /// This function starts a server WebSocket handshake over the given stream. /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. -pub fn accept(stream: Stream) - -> Result<(WebSocket, Headers), HandshakeError> +/// those from `Mio` and others. You can also pass an optional `callback` which will +/// be called when the websocket request is received from an incoming client. +pub fn accept(stream: S, callback: Option) + -> Result, HandshakeError>> { - ServerHandshake::start(stream).handshake() + ServerHandshake::start(stream, callback).handshake() } From 5982d4094dd5d19283b1ca700654cc0e9cf6b6c2 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 28 Jul 2017 14:29:56 +0200 Subject: [PATCH 3/5] Update README and examples --- README.md | 10 ++++++++-- examples/client.rs | 9 ++++++++- src/handshake/headers.rs | 5 +++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0811262..cf917c4 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,21 @@ Lightweight stream-based WebSocket implementation for [Rust](http://www.rust-lan let server = TcpListener::bind("127.0.0.1:9001").unwrap(); for stream in server.incoming() { spawn (move || { - let mut websocket = accept(stream.unwrap()).unwrap(); + let mut websocket = accept(stream.unwrap(), None).unwrap(); loop { let msg = websocket.read_message().unwrap(); - websocket.write_message(msg).unwrap(); + + // We do not want to send back ping/pong messages. + if msg.is_binary() || msg.is_text() { + websocket.write_message(msg).unwrap(); + } } }); } ``` +Take a look at the examples section to see how to write a simple client/server. + [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) diff --git a/examples/client.rs b/examples/client.rs index ca35805..8e11038 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,9 +8,16 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init().unwrap(); - let (mut socket, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) .expect("Can't connect"); + println!("Connected to the server"); + println!("Response HTTP code: {}", response.code); + println!("Response contains the following headers:"); + for &(ref header, _ /*value*/) in response.headers.iter() { + println!("* {}", header); + } + socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index c9e6bee..b5a9f62 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -49,6 +49,11 @@ impl Headers { .unwrap_or(false) } + /// Allows to iterate over available headers. + pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { + self.data.iter() + } + } /// The iterator over headers. From e59169989ad17a4b846a316cec436fe435f02a3e Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 30 Jun 2017 15:22:59 +0200 Subject: [PATCH 4/5] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 55e2863..d7c7f47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.2.4" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.3.0" +version = "0.4.0" [features] default = ["tls"] From 41dfc3c506afc4d576c4f51f3b60c36ee98bb1d2 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 28 Jul 2017 15:24:16 +0200 Subject: [PATCH 5/5] Add server.rs example to the examples section --- examples/server.rs | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 examples/server.rs diff --git a/examples/server.rs b/examples/server.rs new file mode 100644 index 0000000..5b6bfee --- /dev/null +++ b/examples/server.rs @@ -0,0 +1,38 @@ +extern crate tungstenite; + +use std::thread::spawn; +use std::net::TcpListener; + +use tungstenite::accept; +use tungstenite::handshake::server::Request; + +fn main() { + let server = TcpListener::bind("127.0.0.1:3012").unwrap(); + for stream in server.incoming() { + spawn(move || { + let callback = |req: &Request| { + println!("Received a new ws handshake"); + println!("The request's path is: {}", req.path); + println!("The request's headers are:"); + for &(ref header, _ /* value */) in req.headers.iter() { + println!("* {}", header); + } + + // Let's add an additional header to our response to the client. + let extra_headers = vec![ + (String::from("MyCustomHeader"), String::from(":)")), + (String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")), + ]; + Ok(Some(extra_headers)) + }; + let mut websocket = accept(stream.unwrap(), Some(Box::new(callback))).unwrap(); + + loop { + let msg = websocket.read_message().unwrap(); + if msg.is_binary() || msg.is_text() { + websocket.write_message(msg).unwrap(); + } + } + }); + } +}