diff --git a/Cargo.toml b/Cargo.toml index c9fe75d..e5f9387 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,9 @@ sha1 = "0.4.0" url = "1.5.1" utf-8 = "0.7.1" +[patch.crates-io] +input_buffer = { git = "https://github.com/unv-annihilator/input_buffer", rev = "f940362f34afd61a34d126d211a9ad2bf2ec903a" } + [dependencies.native-tls] optional = true version = "0.1.5" diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 8a08e4b..b0d9a24 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -1,18 +1,17 @@ -#[macro_use] extern crate log; extern crate env_logger; +#[macro_use] +extern crate log; extern crate tungstenite; extern crate url; use url::Url; -use tungstenite::{connect, Error, Result, Message}; +use tungstenite::{connect, Error, Message, Result}; const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { - let (mut socket, _) = connect( - Url::parse("ws://localhost:9001/getCaseCount").unwrap() - )?; + let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let msg = socket.read_message()?; socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) @@ -20,7 +19,10 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() + Url::parse(&format!( + "ws://localhost:9001/updateReports?agent={}", + AGENT + )).unwrap(), )?; socket.close(None)?; Ok(()) @@ -28,18 +30,17 @@ fn update_reports() -> Result<()> { fn run_test(case: u32) -> Result<()> { info!("Running test case {}", case); - let case_url = Url::parse( - &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) - ).unwrap(); + let case_url = Url::parse(&format!( + "ws://localhost:9001/runCase?case={}&agent={}", + case, AGENT + )).unwrap(); let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { - msg @ Message::Text(_) | - msg @ Message::Binary(_) => { + msg @ Message::Text(_) | msg @ Message::Binary(_) => { socket.write_message(msg)?; } - Message::Ping(_) | - Message::Pong(_) => {} + Message::Ping(_) | Message::Pong(_) => {} } } } @@ -52,12 +53,13 @@ fn main() { for case in 1..(total + 1) { if let Err(e) = run_test(case) { match e { - Error::Protocol(_) => { } - err => { warn!("test: {}", err); } + Error::Protocol(_) => {} + err => { + warn!("test: {}", err); + } } } } update_reports().unwrap(); } - diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 697d880..a3a7616 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,11 +1,12 @@ -#[macro_use] extern crate log; extern crate env_logger; +#[macro_use] +extern crate log; extern crate tungstenite; use std::net::{TcpListener, TcpStream}; use std::thread::spawn; -use tungstenite::{accept, HandshakeError, Error, Result, Message}; +use tungstenite::{accept, Error, HandshakeError, Message, Result}; use tungstenite::handshake::HandshakeRole; fn must_not_block(err: HandshakeError) -> Error { @@ -19,12 +20,10 @@ fn handle_client(stream: TcpStream) -> Result<()> { let mut socket = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { - msg @ Message::Text(_) | - msg @ Message::Binary(_) => { + msg @ Message::Text(_) | msg @ Message::Binary(_) => { socket.write_message(msg)?; } - Message::Ping(_) | - Message::Pong(_) => {} + Message::Ping(_) | Message::Pong(_) => {} } } } @@ -35,14 +34,12 @@ fn main() { let server = TcpListener::bind("127.0.0.1:9001").unwrap(); for stream in server.incoming() { - spawn(move || { - match stream { - Ok(stream) => match handle_client(stream) { - Ok(_) => (), - Err(e) => warn!("Error in client: {}", e), - }, - Err(e) => warn!("Error accepting stream: {}", e), - } + spawn(move || match stream { + Ok(stream) => match handle_client(stream) { + Ok(_) => (), + Err(e) => warn!("Error in client: {}", e), + }, + Err(e) => warn!("Error accepting stream: {}", e), }); } } diff --git a/examples/client.rs b/examples/client.rs index 8e11038..3116ebc 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,15 +1,15 @@ +extern crate env_logger; extern crate tungstenite; extern crate url; -extern crate env_logger; use url::Url; -use tungstenite::{Message, connect}; +use tungstenite::{connect, Message}; fn main() { env_logger::init().unwrap(); - let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) - .expect("Can't connect"); + let (mut socket, response) = + connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect"); println!("Connected to the server"); println!("Response HTTP code: {}", response.code); @@ -18,11 +18,12 @@ fn main() { println!("* {}", header); } - socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + socket + .write_message(Message::Text("Hello WebSocket".into())) + .unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); } // socket.close(None); - } diff --git a/examples/server.rs b/examples/server.rs index 86df2fe..f7b4494 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -21,7 +21,10 @@ fn main() { // Let's add an additional header to our response to the client. let extra_headers = vec![ (String::from("MyCustomHeader"), String::from(":)")), - (String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")), + ( + String::from("SOME_TUNGSTENITE_HEADER"), + String::from("header_value"), + ), ]; Ok(Some(extra_headers)) }; diff --git a/src/client.rs b/src/client.rs index 2326ac7..da1eab7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,6 @@ //! Methods to connect to an WebSocket as a client. -use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; use std::io::{Read, Write}; @@ -8,10 +8,10 @@ use url::Url; use handshake::client::Response; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] mod encryption { use std::net::TcpStream; - use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError}; + use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector}; pub use native_tls::TlsStream; pub use stream::Stream as StreamSwitcher; @@ -26,10 +26,13 @@ mod encryption { Mode::Plain => Ok(StreamSwitcher::Plain(stream)), Mode::Tls => { let connector = TlsConnector::builder()?.build()?; - connector.connect(domain, stream) + connector + .connect(domain, stream) .map_err(|e| match e { TlsHandshakeError::Failure(f) => f.into(), - TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), + TlsHandshakeError::Interrupted(_) => { + panic!("Bug: TLS handshake not blocked") + } }) .map(StreamSwitcher::Tls) } @@ -37,7 +40,7 @@ mod encryption { } } -#[cfg(not(feature="tls"))] +#[cfg(not(feature = "tls"))] mod encryption { use std::net::TcpStream; @@ -61,10 +64,9 @@ use self::encryption::wrap_stream; use protocol::WebSocket; use handshake::HandshakeError; use handshake::client::{ClientHandshake, Request}; -use stream::{NoDelay, Mode}; +use stream::{Mode, NoDelay}; use error::{Error, Result}; - /// Connect to the given WebSocket in blocking mode. /// /// The URL may be either ws:// or wss://. @@ -77,30 +79,31 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) - -> Result<(WebSocket, Response)> -{ +pub fn connect<'t, Req: Into>>( + request: Req, +) -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; let addrs = request.url.to_socket_addrs()?; let mut stream = connect_to_some(addrs, &request.url, mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client(request, stream) - .map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + client(request, stream).map_err(|e| match e { + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) } fn connect_to_some(addrs: A, url: &Url, mode: Mode) -> Result - where A: Iterator +where + A: Iterator, { - let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = url.host_str() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", url, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { - return Ok(stream) + return Ok(stream); } } } @@ -115,7 +118,7 @@ pub fn url_mode(url: &Url) -> Result { match url.scheme() { "ws" => Ok(Mode::Plain), "wss" => Ok(Mode::Tls), - _ => Err(Error::Url("URL scheme not supported".into())) + _ => Err(Error::Url("URL scheme not supported".into())), } } @@ -126,8 +129,8 @@ pub fn url_mode(url: &Url) -> Result { /// Any stream supporting `Read + Write` will do. pub fn client<'t, Stream, Req>( request: Req, - stream: Stream - ) -> StdResult<(WebSocket, Response), HandshakeError>> + stream: Stream, +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: Into>, diff --git a/src/error.rs b/src/error.rs index 4b027fd..a1d9097 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,7 +13,7 @@ use httparse; use protocol::frame::CloseFrame; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] pub mod tls { //! TLS error wrapper module, feature-gated. pub use native_tls::Error; @@ -29,7 +29,7 @@ pub enum Error { ConnectionClosed(Option>), /// Input-output error Io(io::Error), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] /// TLS error Tls(tls::Error), /// Buffer capacity exhausted @@ -55,7 +55,7 @@ impl fmt::Display for Error { } } Error::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] Error::Tls(ref err) => write!(f, "TLS error: {}", err), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), @@ -71,7 +71,7 @@ impl ErrorTrait for Error { match *self { Error::ConnectionClosed(_) => "A close handshake is performed", Error::Io(ref err) => err.description(), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] Error::Tls(ref err) => err.description(), Error::Capacity(ref msg) => msg.borrow(), Error::Protocol(ref msg) => msg.borrow(), @@ -100,7 +100,7 @@ impl From for Error { } } -#[cfg(feature="tls")] +#[cfg(feature = "tls")] impl From for Error { fn from(err: tls::Error) -> Self { Error::Tls(err) diff --git a/src/handshake/client.rs b/src/handshake/client.rs index d6a5264..2800e42 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -11,12 +11,13 @@ use rand; use url::Url; use error::{Error, Result}; -use protocol::{WebSocket, Role}; -use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use protocol::{Role, WebSocket}; +use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; /// Client request. +#[derive(Debug)] pub struct Request<'t> { /// `ws://` or `wss://` URL to connect to. pub url: Url, @@ -67,6 +68,7 @@ impl From for Request<'static> { } /// Client handshake role. +#[derive(Debug)] pub struct ClientHandshake { verify_data: VerifyData, _marker: PhantomData, @@ -79,14 +81,19 @@ impl ClientHandshake { let machine = { let mut req = Vec::new(); - write!(req, "\ - GET {path} HTTP/1.1\r\n\ - Host: {host}\r\n\ - Connection: upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n", - host = request.get_host(), path = request.get_path(), key = key).unwrap(); + write!( + req, + "\ + GET {path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Connection: upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: {key}\r\n", + host = request.get_host(), + path = request.get_path(), + key = key + ).unwrap(); if let Some(eh) = request.extra_headers { for (k, v) in eh { write!(req, "{}: {}\r\n", k, v).unwrap(); @@ -105,7 +112,10 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { role: client, machine: machine } + MidHandshake { + role: client, + machine: machine, + } } } @@ -113,24 +123,32 @@ impl HandshakeRole for ClientHandshake { type IncomingData = Response; type InternalStream = S; type FinalResult = (WebSocket, Response); - fn stage_finished(&mut self, finish: StageResult) - -> Result> - { + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result> { Ok(match finish { StageResult::DoneWriting(stream) => { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - StageResult::DoneReading { stream, result, tail, } => { + StageResult::DoneReading { + stream, + result, + tail, + } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), - result)) + ProcessingResult::Done(( + WebSocket::from_partially_read(stream, tail, Role::Client), + result, + )) } }) } } /// Information for handshake verification. +#[derive(Debug)] struct VerifyData { /// Accepted server key. accept_key: String, @@ -147,22 +165,37 @@ impl VerifyData { // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) - if !response.headers.header_is_ignore_case("Upgrade", "websocket") { - return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + if !response + .headers + .header_is_ignore_case("Upgrade", "websocket") + { + return Err(Error::Protocol( + "No \"Upgrade: websocket\" in server reply".into(), + )); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if !response.headers.header_is_ignore_case("Connection", "Upgrade") { - return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + if !response + .headers + .header_is_ignore_case("Connection", "Upgrade") + { + return Err(Error::Protocol( + "No \"Connection: upgrade\" in server reply".into(), + )); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) { - return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + if !response + .headers + .header_is("Sec-WebSocket-Accept", &self.accept_key) + { + return Err(Error::Protocol( + "Key mismatch in Sec-WebSocket-Accept".into(), + )); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -183,6 +216,7 @@ impl VerifyData { } /// Server response. +#[derive(Debug)] pub struct Response { /// HTTP response code of the response. pub code: u16, @@ -204,7 +238,9 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); } Ok(Response { code: raw.code.expect("Bug: no HTTP response code"), @@ -223,7 +259,7 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::{Response, generate_key}; + use super::{generate_key, Response}; use super::super::machine::TryParse; #[test] @@ -247,6 +283,9 @@ mod tests { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.code, 200); - assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); + assert_eq!( + resp.headers.find_first("Content-Type"), + Some(&b"text/html"[..]) + ); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index dff0782..b2dfa60 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -19,7 +19,6 @@ pub struct Headers { } impl Headers { - /// Get first header with the given name, if any. pub fn find_first(&self, name: &str) -> Option<&[u8]> { self.find(name).next() @@ -29,7 +28,7 @@ impl Headers { pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { HeadersIter { name: name, - iter: self.data.iter() + iter: self.data.iter(), } } @@ -42,7 +41,8 @@ impl Headers { /// Check if the given header has the given value (case-insensitive). pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name).ok_or(()) + self.find_first(name) + .ok_or(()) .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) .map(|val| val.eq_ignore_ascii_case(value)) .unwrap_or(false) @@ -52,10 +52,10 @@ impl Headers { pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { self.data.iter() } - } /// The iterator over headers. +#[derive(Debug)] pub struct HeadersIter<'name, 'headers> { name: &'name str, iter: slice::Iter<'headers, (String, Box<[u8]>)>, @@ -66,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { fn next(&mut self) -> Option { while let Some(&(ref name, ref value)) = self.iter.next() { if name.eq_ignore_ascii_case(self.name) { - return Some(value) + return Some(value); } } None } } - /// Trait to convert raw objects into HTTP parseables. pub trait FromHttparse: Sized { /// Convert raw object into parsed HTTP headers. @@ -94,8 +93,8 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { Ok(Headers { data: raw.iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), + .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) + .collect(), }) } } @@ -108,8 +107,7 @@ mod tests { #[test] fn headers() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; @@ -125,8 +123,7 @@ mod tests { #[test] fn headers_iter() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ @@ -141,8 +138,7 @@ mod tests { #[test] fn headers_incomplete() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; let hdr = Headers::try_parse(DATA).unwrap(); diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index bacc866..1d74d68 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -6,6 +6,7 @@ use error::{Error, Result}; use util::NonBlockingResult; /// A generic handshake state machine. +#[derive(Debug)] pub struct HandshakeMachine { stream: Stream, state: HandshakeState, @@ -47,11 +48,9 @@ impl HandshakeMachine { .map_err(|_| Error::Capacity("Header too long".into()))? .read_from(&mut self.stream).no_block()?; match read { - Some(0) => { - Err(Error::Protocol("Handshake not finished".into())) - } - Some(_) => { - Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + Some(0) => Err(Error::Protocol("Handshake not finished".into())), + Some(_) => Ok( + if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { result: obj, @@ -63,14 +62,12 @@ impl HandshakeMachine { state: HandshakeState::Reading(buf), ..self }) - }) - } - None => { - Ok(RoundResult::WouldBlock(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - })) - } + }, + ), + None => Ok(RoundResult::WouldBlock(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + })), } } HandshakeState::Writing(mut buf) => { @@ -98,6 +95,7 @@ impl HandshakeMachine { } /// The result of the round. +#[derive(Debug)] pub enum RoundResult { /// Round not done, I/O would block. WouldBlock(HandshakeMachine), @@ -108,9 +106,14 @@ pub enum RoundResult { } /// The result of the stage. +#[derive(Debug)] pub enum StageResult { /// Reading round finished. - DoneReading { result: Obj, stream: Stream, tail: Vec }, + DoneReading { + result: Obj, + stream: Stream, + tail: Vec, + }, /// Writing round finished. DoneWriting(Stream), } @@ -122,6 +125,7 @@ pub trait TryParse: Sized { } /// The handshake state. +#[derive(Debug)] enum HandshakeState { /// Reading data from the peer. Reading(InputBuffer), diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index a6d9192..a4e59b6 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -17,6 +17,7 @@ use error::Error; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. +#[derive(Debug)] pub struct MidHandshake { role: Role, machine: HandshakeMachine, @@ -29,15 +30,16 @@ impl MidHandshake { loop { mach = match mach.single_round()? { RoundResult::WouldBlock(m) => { - return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) + return Err(HandshakeError::Interrupted(MidHandshake { + machine: m, + ..self + })) } RoundResult::Incomplete(m) => m, - RoundResult::StageFinished(s) => { - match self.role.stage_finished(s)? { - ProcessingResult::Continue(m) => m, - ProcessingResult::Done(result) => return Ok(result), - } - } + RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { + ProcessingResult::Continue(m) => m, + ProcessingResult::Done(result) => return Ok(result), + }, } } } @@ -93,12 +95,15 @@ pub trait HandshakeRole { #[doc(hidden)] type FinalResult; #[doc(hidden)] - fn stage_finished(&mut self, finish: StageResult) - -> Result, Error>; + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result, Error>; } /// Stage processing result. #[doc(hidden)] +#[derive(Debug)] pub enum ProcessingResult { Continue(HandshakeMachine), Done(FinalResult), @@ -122,8 +127,10 @@ mod tests { #[test] fn key_conversion() { // example from RFC 6455 - assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), - "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + assert_eq!( + convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + ); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a59eded..e1e1528 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -8,12 +8,13 @@ use httparse; use httparse::Status; use error::{Error, Result}; -use protocol::{WebSocket, Role}; -use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use protocol::{Role, WebSocket}; +use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; /// Request from the client. +#[derive(Debug)] pub struct Request { /// Path part of the URL. pub path: String, @@ -24,14 +25,15 @@ pub struct Request { impl Request { /// Reply to the response. pub fn reply(&self, extra_headers: Option>) -> Result> { - let key = self.headers.find_first("Sec-WebSocket-Key") + let key = self.headers + .find_first("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; let mut reply = format!( "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", convert_key(key)? ); if let Some(eh) = extra_headers { @@ -61,11 +63,13 @@ impl<'h, 'b: 'h> FromHttparse> for Request { return Err(Error::Protocol("Method is not GET".into())); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); } Ok(Request { path: raw.path.expect("Bug: no path in header").into(), - headers: Headers::from_httparse(raw.headers)? + headers: Headers::from_httparse(raw.headers)?, }) } } @@ -83,14 +87,17 @@ pub trait Callback: Sized { fn on_request(self, request: &Request) -> Result>>; } -impl Callback for F where F: FnOnce(&Request) -> Result>> { +impl Callback for F +where + F: FnOnce(&Request) -> Result>>, +{ fn on_request(self, request: &Request) -> Result>> { self(request) } } /// Stub for callback that does nothing. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct NoCallback; impl Callback for NoCallback { @@ -101,6 +108,7 @@ impl Callback for NoCallback { /// Server handshake role. #[allow(missing_copy_implementations)] +#[derive(Debug)] pub struct ServerHandshake { /// Callback which is called whenever the server read the request from the client and is ready /// to reply to it. The callback returns an optional headers which will be added to the reply @@ -119,7 +127,10 @@ impl ServerHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { callback: Some(callback), _marker: PhantomData }, + role: ServerHandshake { + callback: Some(callback), + _marker: PhantomData, + }, } } } @@ -129,13 +140,18 @@ impl HandshakeRole for ServerHandshake { type InternalStream = S; type FinalResult = WebSocket; - fn stage_finished(&mut self, finish: StageResult) - -> Result> - { + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result> { Ok(match finish { - StageResult::DoneReading { stream, result, tail } => { + StageResult::DoneReading { + stream, + result, + tail, + } => { if !tail.is_empty() { - return Err(Error::Protocol("Junk after client request".into())) + return Err(Error::Protocol("Junk after client request".into())); } let extra_headers = { if let Some(callback) = self.callback.take() { @@ -182,13 +198,19 @@ mod tests { let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let _ = req.reply(None).unwrap(); - let extra_headers = Some(vec![(String::from("MyCustomHeader"), - String::from("MyCustomValue")), - (String::from("MyVersion"), - String::from("LOL"))]); + let extra_headers = Some(vec![ + ( + String::from("MyCustomHeader"), + String::from("MyCustomValue"), + ), + (String::from("MyVersion"), String::from("LOL")), + ]); let reply = req.reply(extra_headers).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); - assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); + assert_eq!( + req.headers.find_first("MyCustomHeader"), + Some(b"MyCustomValue".as_ref()) + ); assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); } } diff --git a/src/lib.rs b/src/lib.rs index 6dd9fb5..0a444bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,25 +1,21 @@ //! Lightweight, flexible WebSockets for Rust. -#![deny( - missing_docs, - missing_copy_implementations, - trivial_casts, trivial_numeric_casts, - unstable_features, - unused_must_use, - unused_mut, - unused_imports, - unused_import_braces)] +#![deny(missing_docs, missing_copy_implementations, missing_debug_implementations, trivial_casts, + trivial_numeric_casts, unstable_features, unused_must_use, unused_mut, unused_imports, + unused_import_braces)] -#[macro_use] extern crate log; extern crate base64; extern crate byteorder; extern crate bytes; extern crate httparse; extern crate input_buffer; +#[macro_use] +extern crate log; +#[cfg(feature = "tls")] +extern crate native_tls; extern crate rand; extern crate sha1; extern crate url; extern crate utf8; -#[cfg(feature="tls")] extern crate native_tls; pub mod error; pub mod protocol; @@ -29,10 +25,10 @@ pub mod handshake; pub mod stream; pub mod util; -pub use client::{connect, client}; +pub use client::{client, connect}; pub use server::{accept, accept_hdr}; pub use error::{Error, Result}; -pub use protocol::{WebSocket, Message}; +pub use protocol::{Message, WebSocket}; pub use handshake::HandshakeError; pub use handshake::client::ClientHandshake; pub use handshake::server::ServerHandshake; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index 5fde5fb..f87f184 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -1,7 +1,7 @@ //! Various codes defined in RFC 6455. use std::fmt; -use std::convert::{Into, From}; +use std::convert::{From, Into}; /// WebSocket message opcode as in RFC 6455. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -42,8 +42,8 @@ impl fmt::Display for Data { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Data::Continue => write!(f, "CONTINUE"), - Data::Text => write!(f, "TEXT"), - Data::Binary => write!(f, "BINARY"), + Data::Text => write!(f, "TEXT"), + Data::Binary => write!(f, "BINARY"), Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x), } } @@ -53,8 +53,8 @@ impl fmt::Display for Control { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Control::Close => write!(f, "CLOSE"), - Control::Ping => write!(f, "PING"), - Control::Pong => write!(f, "PONG"), + Control::Ping => write!(f, "PING"), + Control::Pong => write!(f, "PONG"), Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x), } } @@ -71,18 +71,18 @@ impl fmt::Display for OpCode { impl Into for OpCode { fn into(self) -> u8 { - use self::Data::{Continue, Text, Binary}; + use self::Data::{Binary, Continue, Text}; use self::Control::{Close, Ping, Pong}; use self::OpCode::*; match self { Data(Continue) => 0, - Data(Text) => 1, - Data(Binary) => 2, + Data(Text) => 1, + Data(Binary) => 2, Data(self::Data::Reserved(i)) => i, Control(Close) => 8, - Control(Ping) => 9, - Control(Pong) => 10, + Control(Ping) => 9, + Control(Pong) => 10, Control(self::Control::Reserved(i)) => i, } } @@ -90,19 +90,19 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { - use self::Data::{Continue, Text, Binary}; + use self::Data::{Binary, Continue, Text}; use self::Control::{Close, Ping, Pong}; use self::OpCode::*; match byte { - 0 => Data(Continue), - 1 => Data(Text), - 2 => Data(Binary), - i @ 3 ... 7 => Data(self::Data::Reserved(i)), - 8 => Control(Close), - 9 => Control(Ping), - 10 => Control(Pong), - i @ 11 ... 15 => Control(self::Control::Reserved(i)), - _ => panic!("Bug: OpCode out of range"), + 0 => Data(Continue), + 1 => Data(Text), + 2 => Data(Binary), + i @ 3...7 => Data(self::Data::Reserved(i)), + 8 => Control(Close), + 9 => Control(Ping), + 10 => Control(Pong), + i @ 11...15 => Control(self::Control::Reserved(i)), + _ => panic!("Bug: OpCode out of range"), } } } @@ -169,27 +169,22 @@ pub enum CloseCode { /// to a different IP (when multiple targets exist), or reconnect to the same IP /// when a user has performed an action. Again, - #[doc(hidden)] - Tls, - #[doc(hidden)] - Reserved(u16), - #[doc(hidden)] - Iana(u16), - #[doc(hidden)] - Library(u16), - #[doc(hidden)] - Bad(u16), + #[doc(hidden)] Tls, + #[doc(hidden)] Reserved(u16), + #[doc(hidden)] Iana(u16), + #[doc(hidden)] Library(u16), + #[doc(hidden)] Bad(u16), } impl CloseCode { /// Check if this CloseCode is allowed. pub fn is_allowed(&self) -> bool { match *self { - Bad(_) => false, + Bad(_) => false, Reserved(_) => false, - Status => false, - Abnormal => false, - Tls => false, + Status => false, + Abnormal => false, + Tls => false, _ => true, } } @@ -205,24 +200,24 @@ impl fmt::Display for CloseCode { impl<'t> Into for &'t CloseCode { fn into(self) -> u16 { match *self { - Normal => 1000, - Away => 1001, - Protocol => 1002, - Unsupported => 1003, - Status => 1005, - Abnormal => 1006, - Invalid => 1007, - Policy => 1008, - Size => 1009, - Extension => 1010, - Error => 1011, - Restart => 1012, - Again => 1013, - Tls => 1015, - Reserved(code) => code, - Iana(code) => code, - Library(code) => code, - Bad(code) => code, + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Reserved(code) => code, + Iana(code) => code, + Library(code) => code, + Bad(code) => code, } } } @@ -250,11 +245,11 @@ impl From for CloseCode { 1012 => Restart, 1013 => Again, 1015 => Tls, - 1...999 => Bad(code), + 1...999 => Bad(code), 1000...2999 => Reserved(code), 3000...3999 => Iana(code), 4000...4999 => Library(code), - _ => Bad(code) + _ => Bad(code), } } } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 352efab..8ff6396 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,16 +1,16 @@ use std::fmt; use std::borrow::Cow; use std::mem::transmute; -use std::io::{Cursor, Read, Write, ErrorKind}; +use std::io::{Cursor, ErrorKind, Read, Write}; use std::default::Default; -use std::string::{String, FromUtf8Error}; +use std::string::{FromUtf8Error, String}; use std::result::Result as StdResult; -use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian}; +use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt}; use bytes::BufMut; use error::{Error, Result}; -use super::coding::{OpCode, Control, Data, CloseCode}; -use super::mask::{generate_mask, apply_mask}; +use super::coding::{CloseCode, Control, Data, OpCode}; +use super::mask::{apply_mask, generate_mask}; /// A struct representing the close command. #[derive(Debug, Clone)] @@ -52,7 +52,6 @@ pub struct Frame { } impl Frame { - /// Get the length of the frame. /// This is the length of the header + the length of the payload. #[inline] @@ -186,9 +185,8 @@ impl Frame { #[doc(hidden)] #[inline] pub fn remove_mask(&mut self) { - self.mask.and_then(|mask| { - Some(apply_mask(&mut self.payload, &mask)) - }); + self.mask + .and_then(|mask| Some(apply_mask(&mut self.payload, &mask))); self.mask = None; } @@ -204,7 +202,7 @@ impl Frame { String::from_utf8(self.payload) } - /// Consume the frame into a closing frame. + /// Consume the frame into a closing frame. #[inline] pub fn into_close(self) -> Result>> { match self.payload.len() { @@ -215,7 +213,10 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { code: code, reason: text.into() })) + Ok(Some(CloseFrame { + code: code, + reason: text.into(), + })) } } } @@ -223,16 +224,19 @@ impl Frame { /// Create a new data frame. #[inline] pub fn message(data: Vec, code: OpCode, finished: bool) -> Frame { - debug_assert!(match code { - OpCode::Data(_) => true, - _ => false, - }, "Invalid opcode for data frame."); + debug_assert!( + match code { + OpCode::Data(_) => true, + _ => false, + }, + "Invalid opcode for data frame." + ); Frame { finished: finished, opcode: code, payload: data, - .. Frame::default() + ..Frame::default() } } @@ -242,7 +246,7 @@ impl Frame { Frame { opcode: OpCode::Control(Control::Pong), payload: data, - .. Frame::default() + ..Frame::default() } } @@ -252,7 +256,7 @@ impl Frame { Frame { opcode: OpCode::Control(Control::Ping), payload: data, - .. Frame::default() + ..Frame::default() } } @@ -271,7 +275,7 @@ impl Frame { Frame { payload: payload, - .. Frame::default() + ..Frame::default() } } @@ -284,7 +288,7 @@ impl Frame { let mut head = [0u8; 2]; if try!(cursor.read(&mut head)) != 2 { cursor.set_position(initial); - return Ok(None) + return Ok(None); } trace!("Parsed headers {:?}", head); @@ -335,7 +339,7 @@ impl Frame { let mut mask_bytes = [0u8; 4]; if try!(cursor.read(&mut mask_bytes)) != 4 { cursor.set_position(initial); - return Ok(None) + return Ok(None); } else { header_length += 4; Some(mask_bytes) @@ -346,7 +350,7 @@ impl Frame { if size < length + header_length { cursor.set_position(initial); - return Ok(None) + return Ok(None); } let mut data = Vec::with_capacity(length as usize); @@ -360,9 +364,11 @@ impl Frame { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) + return Err(Error::Protocol( + format!("Encountered invalid opcode: {}", first & 0x0F).into(), + )) } - _ => () + _ => (), } let frame = Frame { @@ -375,13 +381,13 @@ impl Frame { payload: data, }; - Ok(Some(frame)) } /// Write a frame out to a buffer pub fn format(mut self, w: &mut W) -> Result<()> - where W: Write + where + W: Write, { let mut one = 0u8; let code: u8 = self.opcode.into(); @@ -461,7 +467,8 @@ impl Default for Frame { impl fmt::Display for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, + write!( + f, " final: {} @@ -479,7 +486,11 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.iter().map(|byte| format!("{:x}", byte)).collect::()) + self.payload + .iter() + .map(|byte| format!("{:x}", byte)) + .collect::() + ) } } @@ -487,16 +498,18 @@ payload: 0x{} mod tests { use super::*; - use super::super::coding::{OpCode, Data}; + use super::super::coding::{Data, OpCode}; use std::io::Cursor; #[test] fn parse() { - let mut raw: Cursor> = Cursor::new(vec![ - 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 - ]); + let mut raw: Cursor> = + Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let frame = Frame::parse(&mut raw).unwrap().unwrap(); - assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); + assert_eq!( + frame.into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); } #[test] diff --git a/src/protocol/frame/mask.rs b/src/protocol/frame/mask.rs index 32ca225..4aa3975 100644 --- a/src/protocol/frame/mask.rs +++ b/src/protocol/frame/mask.rs @@ -71,7 +71,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { // Possible last block. if len > 0 { - unsafe { xor_mem(ptr, mask_u32, len); } + unsafe { + xor_mem(ptr, mask_u32, len); + } } } @@ -94,12 +96,10 @@ mod tests { #[test] fn test_apply_mask() { - let mask = [ - 0x6d, 0xb6, 0xb2, 0x80, - ]; + let mask = [0x6d, 0xb6, 0xb2, 0x80]; let unmasked = vec![ - 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, - 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, + 0x12, 0x03, ]; // Check masking with proper alignment. @@ -126,4 +126,3 @@ mod tests { } } - diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 49f9c8c..8915f92 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -14,6 +14,7 @@ use input_buffer::{InputBuffer, MIN_READ}; use error::{Error, Result}; /// A reader and writer for WebSocket frames. +#[derive(Debug)] pub struct FrameSocket { stream: Stream, in_buffer: InputBuffer, @@ -52,7 +53,8 @@ impl FrameSocket { } impl FrameSocket - where Stream: Read +where + Stream: Read, { /// Read a frame from stream. pub fn read_frame(&mut self) -> Result> { @@ -62,21 +64,22 @@ impl FrameSocket return Ok(Some(frame)); } // No full frames in buffer. - let size = self.in_buffer.prepare_reserve(MIN_READ) + let size = self.in_buffer + .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .read_from(&mut self.stream)?; if size == 0 { trace!("no frame received"); - return Ok(None) + return Ok(None); } } } - } impl FrameSocket - where Stream: Write +where + Stream: Write, { /// Write a frame to stream. /// @@ -86,7 +89,9 @@ impl FrameSocket pub fn write_frame(&mut self, frame: Frame) -> Result<()> { trace!("writing frame {}", frame); self.out_buffer.reserve(frame.len()); - frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); + frame + .format(&mut self.out_buffer) + .expect("Bug: can't write to vector"); self.write_pending() } /// Complete pending write, if any. @@ -100,7 +105,6 @@ impl FrameSocket } } - #[cfg(test)] mod tests { @@ -111,16 +115,19 @@ mod tests { #[test] fn read_frames() { let raw = Cursor::new(vec![ - 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x82, 0x03, 0x03, 0x02, 0x01, + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01, 0x99, ]); let mut sock = FrameSocket::new(raw); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), - vec![0x03, 0x02, 0x01]); + assert_eq!( + sock.read_frame().unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); + assert_eq!( + sock.read_frame().unwrap().unwrap().into_data(), + vec![0x03, 0x02, 0x01] + ); assert!(sock.read_frame().unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -129,12 +136,12 @@ mod tests { #[test] fn from_partially_read() { - let raw = Cursor::new(vec![ - 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - ]); + let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!( + sock.read_frame().unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); } #[test] @@ -148,10 +155,7 @@ mod tests { sock.write_frame(frame).unwrap(); let (buf, _) = sock.into_inner(); - assert_eq!(buf, vec![ - 0x89, 0x02, 0x04, 0x05, - 0x8a, 0x01, 0x01 - ]); + assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 44eda20..8603714 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,4 +1,4 @@ -use std::convert::{From, Into, AsRef}; +use std::convert::{AsRef, From, Into}; use std::fmt; use std::result::Result as StdResult; use std::str; @@ -12,6 +12,7 @@ mod string_collect { use error::{Error, Result}; + #[derive(Debug)] pub struct StringCollector { data: String, incomplete: Option, @@ -34,7 +35,7 @@ mod string_collect { if let Ok(text) = result { self.data.push_str(text); } else { - return Err(Error::Utf8) + return Err(Error::Utf8); } true } else { @@ -52,7 +53,10 @@ mod string_collect { self.data.push_str(text); Ok(()) } - Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { + Err(DecodeError::Incomplete { + valid_prefix, + incomplete_suffix, + }) => { self.data.push_str(valid_prefix); self.incomplete = Some(incomplete_suffix); Ok(()) @@ -81,10 +85,12 @@ mod string_collect { use self::string_collect::StringCollector; /// A struct representing the incomplete message. +#[derive(Debug)] pub struct IncompleteMessage { collector: IncompleteMessageCollector, } +#[derive(Debug)] enum IncompleteMessageCollector { Text(StringCollector), Binary(Vec), @@ -95,11 +101,11 @@ impl IncompleteMessage { pub fn new(message_type: IncompleteMessageType) -> Self { IncompleteMessage { collector: match message_type { - IncompleteMessageType::Binary => - IncompleteMessageCollector::Binary(Vec::new()), - IncompleteMessageType::Text => - IncompleteMessageCollector::Text(StringCollector::new()), - } + IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()), + IncompleteMessageType::Text => { + IncompleteMessageCollector::Text(StringCollector::new()) + } + }, } } /// Add more data to an existing message. @@ -109,17 +115,13 @@ impl IncompleteMessage { v.extend(tail.as_ref()); Ok(()) } - IncompleteMessageCollector::Text(ref mut t) => { - t.extend(tail) - } + IncompleteMessageCollector::Text(ref mut t) => t.extend(tail), } } /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { - IncompleteMessageCollector::Binary(v) => { - Ok(Message::Binary(v)) - } + IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; Ok(Message::Text(text)) @@ -152,17 +154,18 @@ pub enum Message { } impl Message { - /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message - where S: Into + where + S: Into, { Message::Text(string.into()) } /// Create a new binary WebSocket message by converting to Vec. pub fn binary(bin: B) -> Message - where B: Into> + where + B: Into>, { Message::Binary(bin.into()) } @@ -203,9 +206,9 @@ impl Message { pub fn len(&self) -> usize { match *self { Message::Text(ref string) => string.len(), - Message::Binary(ref data) | - Message::Ping(ref data) | - Message::Pong(ref data) => data.len(), + Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { + data.len() + } } } @@ -219,9 +222,7 @@ impl Message { pub fn into_data(self) -> Vec { match self { Message::Text(string) => string.into_bytes(), - Message::Binary(data) | - Message::Ping(data) | - Message::Pong(data) => data, + Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, } } @@ -229,10 +230,9 @@ impl Message { pub fn into_text(self) -> Result { match self { Message::Text(string) => Ok(string), - Message::Binary(data) | - Message::Ping(data) | - Message::Pong(data) => Ok(try!( - String::from_utf8(data).map_err(|err| err.utf8_error()))), + Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => Ok(try!( + String::from_utf8(data).map_err(|err| err.utf8_error()) + )), } } @@ -241,12 +241,11 @@ impl Message { pub fn to_text(&self) -> Result<&str> { match *self { Message::Text(ref string) => Ok(string), - Message::Binary(ref data) | - Message::Ping(ref data) | - Message::Pong(ref data) => Ok(try!(str::from_utf8(data))), + Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { + Ok(try!(str::from_utf8(data))) + } } } - } impl From for Message { @@ -304,7 +303,6 @@ mod tests { assert!(msg.into_text().is_err()); } - #[test] fn binary_convert_vec() { let bin = vec![6u8, 7, 8, 9, 10, 241]; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 841cd38..7824829 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -8,13 +8,13 @@ pub use self::message::Message; pub use self::frame::CloseFrame; use std::collections::VecDeque; -use std::io::{Read, Write, ErrorKind as IoErrorKind}; +use std::io::{ErrorKind as IoErrorKind, Read, Write}; use std::mem::replace; use error::{Error, Result}; use self::message::{IncompleteMessage, IncompleteMessageType}; use self::frame::{Frame, FrameSocket}; -use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode}; +use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use util::NonBlockingResult; /// Indicates a Client or Server role of the websocket @@ -30,6 +30,7 @@ pub enum Role { /// /// This is THE structure you want to create to be able to speak the WebSocket protocol. /// It may be created by calling `connect`, `accept` or `client` functions. +#[derive(Debug)] pub struct WebSocket { /// Server or client? role: Role, @@ -93,7 +94,7 @@ impl WebSocket { let res = self.read_message_frame(); if let Some(message) = self.translate_close(res)? { trace!("Received message {}", message); - return Ok(message) + return Ok(message); } } } @@ -108,16 +109,12 @@ impl WebSocket { /// most recent pong frame is sent if multiple pong frames are queued up. pub fn write_message(&mut self, message: Message) -> Result<()> { let frame = match message { - Message::Text(data) => { - Frame::message(data.into(), OpCode::Data(OpData::Text), true) - } - Message::Binary(data) => { - Frame::message(data, OpCode::Data(OpData::Binary), true) - } + Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), + Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { self.pong = Some(Frame::pong(data)); - return self.write_pending() + return self.write_pending(); } }; self.send_queue.push_back(frame); @@ -182,14 +179,13 @@ impl WebSocket { /// Try to decode one message frame. May return None. fn read_message_frame(&mut self) -> Result> { if let Some(mut frame) = self.socket.read_frame()? { - // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of // the negotiated extensions defines the meaning of such a nonzero // value, the receiving endpoint MUST _Fail the WebSocket // Connection_. if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { - return Err(Error::Protocol("Reserved bits are non-zero".into())) + return Err(Error::Protocol("Reserved bits are non-zero".into())); } match self.role { @@ -201,19 +197,22 @@ impl WebSocket { } else { // The server MUST close the connection upon receiving a // frame that is not masked. (RFC 6455) - return Err(Error::Protocol("Received an unmasked frame from client".into())) + return Err(Error::Protocol( + "Received an unmasked frame from client".into(), + )); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol("Received a masked frame from server".into())) + return Err(Error::Protocol( + "Received a masked frame from server".into(), + )); } } } match frame.opcode() { - OpCode::Control(ctl) => { match ctl { // All control frames MUST have a payload length of 125 bytes or less @@ -224,12 +223,10 @@ impl WebSocket { _ if frame.payload().len() > 125 => { Err(Error::Protocol("Control frame too big".into())) } - OpCtl::Close => { - self.do_close(frame.into_close()?).map(|_| None) - } - OpCtl::Reserved(i) => { - Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) - } + OpCtl::Close => self.do_close(frame.into_close()?).map(|_| None), + OpCtl::Reserved(i) => Err(Error::Protocol( + format!("Unknown control frame type {}", i).into(), + )), OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => { // No ping processing while closing. Ok(None) @@ -239,9 +236,7 @@ impl WebSocket { self.pong = Some(Frame::pong(data.clone())); Ok(Some(Message::Ping(data))) } - OpCtl::Pong => { - Ok(Some(Message::Pong(frame.into_data()))) - } + OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), } } @@ -258,19 +253,21 @@ impl WebSocket { // TODO if msg too big msg.extend(frame.into_data())?; } else { - return Err(Error::Protocol("Continue frame but nothing to continue".into())) + return Err(Error::Protocol( + "Continue frame but nothing to continue".into(), + )); } if fin { - Ok(Some(replace(&mut self.incomplete, None).unwrap().complete()?)) + Ok(Some(replace(&mut self.incomplete, None) + .unwrap() + .complete()?)) } else { Ok(None) } } - c if self.incomplete.is_some() => { - Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into() - )) - } + c if self.incomplete.is_some() => Err(Error::Protocol( + format!("Received {} while waiting for more fragments", c).into(), + )), OpData::Text | OpData::Binary => { let msg = { let message_type = match data { @@ -289,22 +286,20 @@ impl WebSocket { Ok(None) } } - OpData::Reserved(i) => { - Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) - } + OpData::Reserved(i) => Err(Error::Protocol( + format!("Unknown data frame type {}", i).into(), + )), } } - } // match opcode - } else { match replace(&mut self.state, WebSocketState::Terminated) { WebSocketState::CloseAcknowledged(close) | WebSocketState::ClosedByPeer(close) => { Err(Error::ConnectionClosed(close)) } - _ => { - Err(Error::Protocol("Connection reset without closing handshake".into())) - } + _ => Err(Error::Protocol( + "Connection reset without closing handshake".into(), + )), } } } @@ -325,7 +320,7 @@ impl WebSocket { } else { Frame::close(Some(CloseFrame { code: CloseCode::Protocol, - reason: "Protocol violation".into() + reason: "Protocol violation".into(), })) } } else { @@ -361,8 +356,7 @@ impl WebSocket { /// Send a single pending frame. fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> { match self.role { - Role::Server => { - } + Role::Server => {} Role::Client => { // 5. If the data is being sent by the client, the frame(s) MUST be // masked as defined in Section 5.3. (RFC 6455) @@ -379,10 +373,12 @@ impl WebSocket { Err(Error::Io(err)) => Err({ if err.kind() == IoErrorKind::ConnectionReset { match self.state { - WebSocketState::ClosedByPeer(ref mut frame) => - Error::ConnectionClosed(replace(frame, None)), - WebSocketState::CloseAcknowledged(ref mut frame) => - Error::ConnectionClosed(replace(frame, None)), + WebSocketState::ClosedByPeer(ref mut frame) => { + Error::ConnectionClosed(replace(frame, None)) + } + WebSocketState::CloseAcknowledged(ref mut frame) => { + Error::ConnectionClosed(replace(frame, None)) + } _ => Error::Io(err), } } else { @@ -392,10 +388,10 @@ impl WebSocket { x => x, } } - } /// The current connection state. +#[derive(Debug)] enum WebSocketState { /// The connection is active. Active, @@ -421,7 +417,7 @@ impl WebSocketState { #[cfg(test)] mod tests { - use super::{WebSocket, Role, Message}; + use super::{Message, Role, WebSocket}; use std::io; use std::io::Cursor; @@ -443,24 +439,24 @@ mod tests { } } - #[test] fn receive_messages() { let incoming = Cursor::new(vec![ - 0x89, 0x02, 0x01, 0x02, - 0x8a, 0x01, 0x03, - 0x01, 0x07, - 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, - 0x80, 0x06, - 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, - 0x82, 0x03, - 0x01, 0x02, 0x03, + 0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02, + 0x03, ]); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!( + socket.read_message().unwrap(), + Message::Text("Hello, World!".into()) + ); + assert_eq!( + socket.read_message().unwrap(), + Message::Binary(vec![0x01, 0x02, 0x03]) + ); } } diff --git a/src/server.rs b/src/server.rs index 68e026f..7d7a138 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,9 +15,9 @@ use std::io::{Read, Write}; /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept(stream: S) - -> Result, HandshakeError>> -{ +pub fn accept( + stream: S, +) -> Result, HandshakeError>> { accept_hdr(stream, NoCallback) } @@ -26,8 +26,9 @@ pub fn accept(stream: S) /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr(stream: S, callback: C) - -> Result, HandshakeError>> -{ +pub fn accept_hdr( + stream: S, + callback: C, +) -> Result, HandshakeError>> { ServerHandshake::start(stream, callback).handshake() } diff --git a/src/stream.rs b/src/stream.rs index ab2c8e2..96d26d2 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,15 +4,15 @@ //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. -use std::io::{Read, Write, Result as IoResult}; +use std::io::{Read, Result as IoResult, Write}; use std::net::TcpStream; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] use native_tls::TlsStream; /// Stream mode, either plain TCP or TLS. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub enum Mode { /// Plain mode (`ws://` URL). Plain, @@ -32,7 +32,7 @@ impl NoDelay for TcpStream { } } -#[cfg(feature="tls")] +#[cfg(feature = "tls")] impl NoDelay for TlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.get_mut().set_nodelay(nodelay) @@ -40,6 +40,7 @@ impl NoDelay for TlsStream { } /// Stream, either plain TCP or TLS. +#[derive(Debug)] pub enum Stream { /// Unencrypted socket stream. Plain(S), diff --git a/src/util.rs b/src/util.rs index a784f9c..44ed0ec 100644 --- a/src/util.rs +++ b/src/util.rs @@ -40,7 +40,8 @@ pub trait NonBlockingResult { } impl NonBlockingResult for StdResult - where E : NonBlockingError +where + E: NonBlockingError, { type Result = StdResult, E>; fn no_block(self) -> Self::Result { @@ -49,7 +50,7 @@ impl NonBlockingResult for StdResult Err(e) => match e.into_non_blocking() { Some(e) => Err(e), None => Ok(None), - } + }, } } }