diff --git a/Cargo.toml b/Cargo.toml index 55e2863..d7c7f47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.2.4" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.3.0" +version = "0.4.0" [features] default = ["tls"] diff --git a/README.md b/README.md index 0811262..cf917c4 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,21 @@ Lightweight stream-based WebSocket implementation for [Rust](http://www.rust-lan let server = TcpListener::bind("127.0.0.1:9001").unwrap(); for stream in server.incoming() { spawn (move || { - let mut websocket = accept(stream.unwrap()).unwrap(); + let mut websocket = accept(stream.unwrap(), None).unwrap(); loop { let msg = websocket.read_message().unwrap(); - websocket.write_message(msg).unwrap(); + + // We do not want to send back ping/pong messages. + if msg.is_binary() || msg.is_text() { + websocket.write_message(msg).unwrap(); + } } }); } ``` +Take a look at the examples section to see how to write a simple client/server. + [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 700e2d7..8a08e4b 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message}; const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse("ws://localhost:9001/getCaseCount").unwrap() )?; let msg = socket.read_message()?; @@ -19,7 +19,7 @@ fn get_case_count() -> Result { } fn update_reports() -> Result<()> { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() )?; socket.close(None)?; @@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let mut socket = connect(case_url)?; + let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 458fa26..79fb5e2 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream}; use std::thread::spawn; use tungstenite::{accept, HandshakeError, Error, Result, Message}; +use tungstenite::handshake::HandshakeRole; -fn must_not_block(err: HandshakeError) -> Error { +fn must_not_block(err: HandshakeError) -> Error { match err { HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), HandshakeError::Failure(f) => f, @@ -15,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream).map_err(must_not_block)?; + let mut socket = accept(stream, None).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/client.rs b/examples/client.rs index 1c97aba..8e11038 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,9 +8,16 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init().unwrap(); - let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) .expect("Can't connect"); + println!("Connected to the server"); + println!("Response HTTP code: {}", response.code); + println!("Response contains the following headers:"); + for &(ref header, _ /*value*/) in response.headers.iter() { + println!("* {}", header); + } + socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); diff --git a/examples/server.rs b/examples/server.rs new file mode 100644 index 0000000..5b6bfee --- /dev/null +++ b/examples/server.rs @@ -0,0 +1,38 @@ +extern crate tungstenite; + +use std::thread::spawn; +use std::net::TcpListener; + +use tungstenite::accept; +use tungstenite::handshake::server::Request; + +fn main() { + let server = TcpListener::bind("127.0.0.1:3012").unwrap(); + for stream in server.incoming() { + spawn(move || { + let callback = |req: &Request| { + println!("Received a new ws handshake"); + println!("The request's path is: {}", req.path); + println!("The request's headers are:"); + for &(ref header, _ /* value */) in req.headers.iter() { + println!("* {}", header); + } + + // Let's add an additional header to our response to the client. + let extra_headers = vec![ + (String::from("MyCustomHeader"), String::from(":)")), + (String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")), + ]; + Ok(Some(extra_headers)) + }; + let mut websocket = accept(stream.unwrap(), Some(Box::new(callback))).unwrap(); + + loop { + let msg = websocket.read_message().unwrap(); + if msg.is_binary() || msg.is_text() { + websocket.write_message(msg).unwrap(); + } + } + }); + } +} diff --git a/src/client.rs b/src/client.rs index 3274ad0..2326ac7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,6 +6,8 @@ use std::io::{Read, Write}; use url::Url; +use handshake::client::Response; + #[cfg(feature="tls")] mod encryption { use std::net::TcpStream; @@ -75,7 +77,9 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) -> Result> { +pub fn connect<'t, Req: Into>>(request: Req) + -> Result<(WebSocket, Response)> +{ let request: Request = request.into(); let mode = url_mode(&request.url)?; let addrs = request.url.to_socket_addrs()?; @@ -120,10 +124,13 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult, HandshakeError> -where Stream: Read + Write, - Req: Into>, +pub fn client<'t, Stream, Req>( + request: Req, + stream: Stream + ) -> StdResult<(WebSocket, Response), HandshakeError>> +where + Stream: Read + Write, + Req: Into>, { ClientHandshake::start(stream, request.into()).handshake() } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index ff07d37..cd5e6d0 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,18 +1,19 @@ //! Client handshake machine. +use std::io::{Read, Write}; +use std::marker::PhantomData; + use base64; -use rand; -use httparse; use httparse::Status; -use std::io::Write; +use httparse; +use rand; use url::Url; use error::{Error, Result}; use protocol::{WebSocket, Role}; - use super::headers::{Headers, FromHttparse, MAX_HEADERS}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::machine::{HandshakeMachine, StageResult, TryParse}; +use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; /// Client request. pub struct Request<'t> { @@ -52,13 +53,14 @@ impl From for Request<'static> { } /// Client handshake role. -pub struct ClientHandshake { +pub struct ClientHandshake { verify_data: VerifyData, + _marker: PhantomData, } -impl ClientHandshake { +impl ClientHandshake { /// Initiate a client handshake. - pub fn start(stream: Stream, request: Request) -> MidHandshake { + pub fn start(stream: S, request: Request) -> MidHandshake { let key = generate_key(); let machine = { @@ -83,9 +85,8 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); ClientHandshake { - verify_data: VerifyData { - accept_key: accept_key, - }, + verify_data: VerifyData { accept_key }, + _marker: PhantomData, } }; @@ -94,10 +95,12 @@ impl ClientHandshake { } } -impl HandshakeRole for ClientHandshake { +impl HandshakeRole for ClientHandshake { type IncomingData = Response; - fn stage_finished(&self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = (WebSocket, Response); + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneWriting(stream) => { @@ -106,7 +109,8 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) + ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), + result)) } }) } @@ -166,8 +170,10 @@ impl VerifyData { /// Server response. pub struct Response { - code: u16, - headers: Headers, + /// HTTP response code of the response. + pub code: u16, + /// Received headers. + pub headers: Headers, } impl TryParse for Response { @@ -203,7 +209,6 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::{Response, generate_key}; use super::super::machine::TryParse; @@ -230,5 +235,4 @@ mod tests { assert_eq!(resp.code, 200); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); } - } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index c9e6bee..b5a9f62 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -49,6 +49,11 @@ impl Headers { .unwrap_or(false) } + /// Allows to iterate over available headers. + pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { + self.data.iter() + } + } /// The iterator over headers. diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 64d31f1..a6d9192 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -14,30 +14,17 @@ use base64; use sha1::Sha1; use error::Error; -use protocol::WebSocket; - use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. -pub struct MidHandshake { +pub struct MidHandshake { role: Role, - machine: HandshakeMachine, -} - -impl MidHandshake { - /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &Stream { - self.machine.get_ref() - } - /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut Stream { - self.machine.get_mut() - } + machine: HandshakeMachine, } -impl MidHandshake { +impl MidHandshake { /// Restarts the handshake process. - pub fn handshake(self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result> { let mut mach = self.machine; loop { mach = match mach.single_round()? { @@ -48,7 +35,7 @@ impl MidHandshake { RoundResult::StageFinished(s) => { match self.role.stage_finished(s)? { ProcessingResult::Continue(m) => m, - ProcessingResult::Done(ws) => return Ok(ws), + ProcessingResult::Done(result) => return Ok(result), } } } @@ -57,14 +44,14 @@ impl MidHandshake { } /// A handshake result. -pub enum HandshakeError { +pub enum HandshakeError { /// Handshake was interrupted (would block). - Interrupted(MidHandshake), + Interrupted(MidHandshake), /// Handshake failed. Failure(Error), } -impl fmt::Debug for HandshakeError { +impl fmt::Debug for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"), @@ -73,7 +60,7 @@ impl fmt::Debug for HandshakeError { } } -impl fmt::Display for HandshakeError { +impl fmt::Display for HandshakeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"), @@ -82,7 +69,7 @@ impl fmt::Display for HandshakeError { } } -impl ErrorTrait for HandshakeError { +impl ErrorTrait for HandshakeError { fn description(&self) -> &str { match *self { HandshakeError::Interrupted(_) => "Interrupted handshake", @@ -91,7 +78,7 @@ impl ErrorTrait for HandshakeError { } } -impl From for HandshakeError { +impl From for HandshakeError { fn from(err: Error) -> Self { HandshakeError::Failure(err) } @@ -102,15 +89,19 @@ pub trait HandshakeRole { #[doc(hidden)] type IncomingData: TryParse; #[doc(hidden)] - fn stage_finished(&self, finish: StageResult) - -> Result, Error>; + type InternalStream: Read + Write; + #[doc(hidden)] + type FinalResult; + #[doc(hidden)] + fn stage_finished(&mut self, finish: StageResult) + -> Result, Error>; } /// Stage processing result. #[doc(hidden)] -pub enum ProcessingResult { +pub enum ProcessingResult { Continue(HandshakeMachine), - Done(WebSocket), + Done(FinalResult), } /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. @@ -126,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result { #[cfg(test)] mod tests { - use super::convert_key; #[test] diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a3a5fab..510126c 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,5 +1,10 @@ //! Server handshake machine. +use std::fmt::Write as FmtWrite; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::mem::replace; + use httparse; use httparse::Status; @@ -19,15 +24,23 @@ pub struct Request { impl Request { /// Reply to the response. - pub fn reply(&self) -> Result> { + pub fn reply(&self, extra_headers: Option>) -> Result> { let key = self.headers.find_first("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let reply = format!("\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n\ - \r\n", convert_key(key)?); + let mut reply = format!( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", + convert_key(key)? + ); + if let Some(eh) = extra_headers { + for (k, v) in eh { + write!(reply, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(reply, "\r\n").unwrap(); Ok(reply.into()) } } @@ -58,32 +71,63 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } +/// The callback type, the callback is called when the server receives an incoming WebSocket +/// handshake request from the client, specifying a callback allows you to analyze incoming headers +/// and add additional headers to the response that server sends to the client and/or reject the +/// connection based on the incoming headers. Due to usability problems which are caused by a +/// static dispatch when using callbacks in such places, the callback is boxed. +/// +/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current +/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets +/// stable. +pub type Callback = Box Result>>>; + /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake; +pub struct ServerHandshake { + /// Callback which is called whenever the server read the request from the client and is ready + /// to reply to it. The callback returns an optional headers which will be added to the reply + /// which the server sends to the user. + callback: Option, + /// Internal stream type. + _marker: PhantomData, +} -impl ServerHandshake { - /// Start server handshake. - pub fn start(stream: Stream) -> MidHandshake { +impl ServerHandshake { + /// Start server handshake. `callback` specifies a custom callback which the user can pass to + /// the handshake, this callback will be called when the a websocket client connnects to the + /// server, you can specify the callback if you want to add additional header to the client + /// upon join based on the incoming headers. + pub fn start(stream: S, callback: Option) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake, + role: ServerHandshake { callback, _marker: PhantomData }, } } } -impl HandshakeRole for ServerHandshake { +impl HandshakeRole for ServerHandshake { type IncomingData = Request; - fn stage_finished(&self, finish: StageResult) - -> Result> + type InternalStream = S; + type FinalResult = WebSocket; + + fn stage_finished(&mut self, finish: StageResult) + -> Result> { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { - if ! tail.is_empty() { + if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())) } - let response = result.reply()?; + let extra_headers = { + if let Some(mut callback) = replace(&mut self.callback, None) { + callback(&result)? + } else { + None + } + }; + let response = result.reply(extra_headers)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } StageResult::DoneWriting(stream) => { @@ -96,9 +140,9 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::Request; use super::super::machine::TryParse; + use super::super::client::Response; #[test] fn request_parsing() { @@ -119,7 +163,15 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply().unwrap(); - } + let _ = req.reply(None).unwrap(); + let extra_headers = Some(vec![(String::from("MyCustomHeader"), + String::from("MyCustomValue")), + (String::from("MyVersion"), + String::from("LOL"))]); + let reply = req.reply(extra_headers).unwrap(); + let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); + assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); + assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); + } } diff --git a/src/server.rs b/src/server.rs index 3217006..b0d191d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; +use handshake::server::Callback; use protocol::WebSocket; use std::io::{Read, Write}; @@ -12,9 +13,10 @@ use std::io::{Read, Write}; /// This function starts a server WebSocket handshake over the given stream. /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. -pub fn accept(stream: Stream) - -> Result, HandshakeError> +/// those from `Mio` and others. You can also pass an optional `callback` which will +/// be called when the websocket request is received from an incoming client. +pub fn accept(stream: S, callback: Option) + -> Result, HandshakeError>> { - ServerHandshake::start(stream).handshake() + ServerHandshake::start(stream, callback).handshake() }