From cbf80ecc761f4990e39ac4b6066601e40fc9c131 Mon Sep 17 00:00:00 2001 From: Artem Vorotnikov Date: Mon, 26 Aug 2019 20:00:41 +0300 Subject: [PATCH] Edition 2018, formatting, clippy fixes --- Cargo.toml | 1 + examples/autobahn-client.rs | 40 +++++----- examples/autobahn-server.rs | 28 +++---- examples/callback-error.rs | 6 +- examples/client.rs | 15 ++-- examples/server.rs | 10 +-- fuzz/Cargo.toml | 1 - src/client.rs | 84 +++++++++++--------- src/error.rs | 12 +-- src/handshake/client.rs | 104 ++++++++++++++++--------- src/handshake/headers.rs | 33 ++++---- src/handshake/machine.rs | 41 +++++----- src/handshake/mod.rs | 34 ++++---- src/handshake/server.rs | 91 +++++++++++++--------- src/lib.rs | 38 ++++----- src/protocol/frame/coding.rs | 98 +++++++++++------------ src/protocol/frame/frame.rs | 109 ++++++++++++++------------ src/protocol/frame/mask.rs | 26 +++---- src/protocol/frame/mod.rs | 77 ++++++++++--------- src/protocol/message.rs | 75 +++++++++--------- src/protocol/mod.rs | 145 ++++++++++++++++------------------- src/server.rs | 30 ++++---- src/stream.rs | 6 +- src/util.rs | 7 +- tests/connection_reset.rs | 14 ++-- 25 files changed, 575 insertions(+), 550 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 92d87c1..83d669b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.9.1" repository = "https://github.com/snapview/tungstenite-rs" version = "0.9.1" +edition = "2018" [features] default = ["tls"] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 3e7b732..5ef0c24 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -1,18 +1,12 @@ -#[macro_use] extern crate log; -extern crate env_logger; -extern crate tungstenite; -extern crate url; - +use log::*; use url::Url; -use tungstenite::{connect, Error, Result, Message}; +use tungstenite::{connect, Error, Message, Result}; const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { - let (mut socket, _) = connect( - Url::parse("ws://localhost:9001/getCaseCount").unwrap(), - )?; + let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let msg = socket.read_message()?; socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) @@ -20,7 +14,11 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), + Url::parse(&format!( + "ws://localhost:9001/updateReports?agent={}", + AGENT + )) + .unwrap(), )?; socket.close(None)?; Ok(()) @@ -28,19 +26,18 @@ fn update_reports() -> Result<()> { fn run_test(case: u32) -> Result<()> { info!("Running test case {}", case); - let case_url = Url::parse( - &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) - ).unwrap(); + let case_url = Url::parse(&format!( + "ws://localhost:9001/runCase?case={}&agent={}", + case, AGENT + )) + .unwrap(); let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { - msg @ Message::Text(_) | - msg @ Message::Binary(_) => { + msg @ Message::Text(_) | msg @ Message::Binary(_) => { socket.write_message(msg)?; } - Message::Ping(_) | - Message::Pong(_) | - Message::Close(_) => {} + Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {} } } } @@ -53,12 +50,13 @@ fn main() { for case in 1..(total + 1) { if let Err(e) = run_test(case) { match e { - Error::Protocol(_) => { } - err => { warn!("test: {}", err); } + Error::Protocol(_) => {} + err => { + warn!("test: {}", err); + } } } } update_reports().unwrap(); } - diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 0a49f1f..31b842c 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,12 +1,9 @@ -#[macro_use] extern crate log; -extern crate env_logger; -extern crate tungstenite; - use std::net::{TcpListener, TcpStream}; use std::thread::spawn; -use tungstenite::{accept, HandshakeError, Error, Result, Message}; +use log::*; use tungstenite::handshake::HandshakeRole; +use tungstenite::{accept, Error, HandshakeError, Message, Result}; fn must_not_block(err: HandshakeError) -> Error { match err { @@ -19,13 +16,10 @@ fn handle_client(stream: TcpStream) -> Result<()> { let mut socket = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { - msg @ Message::Text(_) | - msg @ Message::Binary(_) => { + msg @ Message::Text(_) | msg @ Message::Binary(_) => { socket.write_message(msg)?; } - Message::Ping(_) | - Message::Pong(_) | - Message::Close(_) => {} + Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {} } } } @@ -36,14 +30,12 @@ fn main() { let server = TcpListener::bind("127.0.0.1:9002").unwrap(); for stream in server.incoming() { - spawn(move || { - match stream { - Ok(stream) => match handle_client(stream) { - Ok(_) => (), - Err(e) => warn!("Error in client: {}", e), - }, - Err(e) => warn!("Error accepting stream: {}", e), - } + spawn(move || match stream { + Ok(stream) => match handle_client(stream) { + Ok(_) => (), + Err(e) => warn!("Error in client: {}", e), + }, + Err(e) => warn!("Error accepting stream: {}", e), }); } } diff --git a/examples/callback-error.rs b/examples/callback-error.rs index 3072460..fdb5b8a 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -1,10 +1,8 @@ -extern crate tungstenite; - -use std::thread::spawn; use std::net::TcpListener; +use std::thread::spawn; use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, ErrorResponse}; +use tungstenite::handshake::server::{ErrorResponse, Request}; use tungstenite::http::StatusCode; fn main() { diff --git a/examples/client.rs b/examples/client.rs index 13b7d59..e3200d2 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,15 +1,11 @@ -extern crate tungstenite; -extern crate url; -extern crate env_logger; - +use tungstenite::{connect, Message}; use url::Url; -use tungstenite::{Message, connect}; fn main() { env_logger::init(); - let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) - .expect("Can't connect"); + let (mut socket, response) = + connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect"); println!("Connected to the server"); println!("Response HTTP code: {}", response.code); @@ -18,11 +14,12 @@ fn main() { println!("* {}", header); } - socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + socket + .write_message(Message::Text("Hello WebSocket".into())) + .unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); } // socket.close(None); - } diff --git a/examples/server.rs b/examples/server.rs index d1b5d95..70ba186 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,5 @@ -extern crate tungstenite; -extern crate env_logger; - -use std::thread::spawn; use std::net::TcpListener; +use std::thread::spawn; use tungstenite::accept_hdr; use tungstenite::handshake::server::Request; @@ -23,7 +20,10 @@ fn main() { // Let's add an additional header to our response to the client. let extra_headers = vec![ (String::from("MyCustomHeader"), String::from(":)")), - (String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")), + ( + String::from("SOME_TUNGSTENITE_HEADER"), + String::from("header_value"), + ), ]; Ok(Some(extra_headers)) }; diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 983fa62..6bf180e 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -1,4 +1,3 @@ - [package] name = "tungstenite-fuzz" version = "0.0.1" diff --git a/src/client.rs b/src/client.rs index 7e71130..cea3086 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,36 +1,40 @@ //! Methods to connect to an WebSocket as a client. -use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; -use std::result::Result as StdResult; use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::result::Result as StdResult; +use log::*; use url::Url; -use handshake::client::Response; -use protocol::WebSocketConfig; +use crate::handshake::client::Response; +use crate::protocol::WebSocketConfig; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] mod encryption { - use std::net::TcpStream; - use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError}; pub use native_tls::TlsStream; + use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector}; + use std::net::TcpStream; - pub use stream::Stream as StreamSwitcher; + pub use crate::stream::Stream as StreamSwitcher; /// TCP stream switcher (plain/TLS). pub type AutoStream = StreamSwitcher>; - use stream::Mode; - use error::Result; + use crate::error::Result; + use crate::stream::Mode; pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { match mode { Mode::Plain => Ok(StreamSwitcher::Plain(stream)), Mode::Tls => { let connector = TlsConnector::builder().build()?; - connector.connect(domain, stream) + connector + .connect(domain, stream) .map_err(|e| match e { TlsHandshakeError::Failure(f) => f.into(), - TlsHandshakeError::WouldBlock(_) => panic!("Bug: TLS handshake not blocked"), + TlsHandshakeError::WouldBlock(_) => { + panic!("Bug: TLS handshake not blocked") + } }) .map(StreamSwitcher::Tls) } @@ -38,12 +42,12 @@ mod encryption { } } -#[cfg(not(feature="tls"))] +#[cfg(not(feature = "tls"))] mod encryption { use std::net::TcpStream; - use stream::Mode; use error::{Error, Result}; + use stream::Mode; /// TLS support is nod compiled in, this is just standard `TcpStream`. pub type AutoStream = TcpStream; @@ -56,15 +60,14 @@ mod encryption { } } -pub use self::encryption::AutoStream; use self::encryption::wrap_stream; +pub use self::encryption::AutoStream; -use protocol::WebSocket; -use handshake::HandshakeError; -use handshake::client::{ClientHandshake, Request}; -use stream::{NoDelay, Mode}; -use error::{Error, Result}; - +use crate::error::{Error, Result}; +use crate::handshake::client::{ClientHandshake, Request}; +use crate::handshake::HandshakeError; +use crate::protocol::WebSocket; +use crate::stream::{Mode, NoDelay}; /// Connect to the given WebSocket in blocking mode. /// @@ -83,13 +86,17 @@ use error::{Error, Result}; /// `connect` since it's the only function that uses native_tls. pub fn connect_with_config<'t, Req: Into>>( request: Req, - config: Option + config: Option, ) -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; - let host = request.url.host() + let host = request + .url + .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = request.url.port_or_known_default() + let port = request + .url + .port_or_known_default() .ok_or_else(|| Error::Url("No port number in the URL".into()))?; let addrs; let addr; @@ -109,11 +116,10 @@ pub fn connect_with_config<'t, Req: Into>>( }; let mut stream = connect_to_some(addrs, &request.url, mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config) - .map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + client_with_config(request, stream, config).map_err(|e| match e { + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) } /// Connect to the given WebSocket in blocking mode. @@ -128,19 +134,21 @@ pub fn connect_with_config<'t, Req: Into>>( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) - -> Result<(WebSocket, Response)> -{ +pub fn connect<'t, Req: Into>>( + request: Req, +) -> Result<(WebSocket, Response)> { connect_with_config(request, None) } fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result { - let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = url + .host_str() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", url, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { - return Ok(stream) + return Ok(stream); } } } @@ -155,7 +163,7 @@ pub fn url_mode(url: &Url) -> Result { match url.scheme() { "ws" => Ok(Mode::Plain), "wss" => Ok(Mode::Tls), - _ => Err(Error::Url("URL scheme not supported".into())) + _ => Err(Error::Url("URL scheme not supported".into())), } } @@ -182,8 +190,10 @@ where /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult<(WebSocket, Response), HandshakeError>> +pub fn client<'t, Stream, Req>( + request: Req, + stream: Stream, +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: Into>, diff --git a/src/error.rs b/src/error.rs index 0483a82..ec86b0d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,9 +11,9 @@ use std::string; use httparse; -use protocol::Message; +use crate::protocol::Message; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] pub mod tls { //! TLS error wrapper module, feature-gated. pub use native_tls::Error; @@ -41,7 +41,7 @@ pub enum Error { AlreadyClosed, /// Input-output error Io(io::Error), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] /// TLS error Tls(tls::Error), /// Buffer capacity exhausted @@ -64,7 +64,7 @@ impl fmt::Display for Error { Error::ConnectionClosed => write!(f, "Connection closed normally"), Error::AlreadyClosed => write!(f, "Trying to work with closed connection"), Error::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] Error::Tls(ref err) => write!(f, "TLS error: {}", err), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), @@ -82,7 +82,7 @@ impl ErrorTrait for Error { Error::ConnectionClosed => "A close handshake is performed", Error::AlreadyClosed => "Trying to read or write after getting close notification", Error::Io(ref err) => err.description(), - #[cfg(feature="tls")] + #[cfg(feature = "tls")] Error::Tls(ref err) => err.description(), Error::Capacity(ref msg) => msg.borrow(), Error::Protocol(ref msg) => msg.borrow(), @@ -112,7 +112,7 @@ impl From for Error { } } -#[cfg(feature="tls")] +#[cfg(feature = "tls")] impl From for Error { fn from(err: tls::Error) -> Self { Error::Tls(err) diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 014af64..8203ee5 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -4,17 +4,15 @@ use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; -use base64; use httparse::Status; -use httparse; -use rand; +use log::*; use url::Url; -use error::{Error, Result}; -use protocol::{WebSocket, WebSocketConfig, Role}; -use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; +use crate::error::{Error, Result}; +use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Client request. #[derive(Debug)] @@ -80,26 +78,32 @@ impl ClientHandshake { pub fn start( stream: S, request: Request, - config: Option + config: Option, ) -> MidHandshake { let key = generate_key(); let machine = { let mut req = Vec::new(); - write!(req, "\ - GET {path} HTTP/1.1\r\n\ - Host: {host}\r\n\ - Connection: upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n", - host = request.get_host(), path = request.get_path(), key = key).unwrap(); + write!( + req, + "\ + GET {path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Connection: upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: {key}\r\n", + host = request.get_host(), + path = request.get_path(), + key = key + ) + .unwrap(); if let Some(eh) = request.extra_headers { for (k, v) in eh { - write!(req, "{}: {}\r\n", k, v).unwrap(); + writeln!(req, "{}: {}\r", k, v).unwrap(); } } - write!(req, "\r\n").unwrap(); + writeln!(req, "\r").unwrap(); HandshakeMachine::start_write(stream, req) }; @@ -113,7 +117,10 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { role: client, machine } + MidHandshake { + role: client, + machine, + } } } @@ -121,22 +128,23 @@ impl HandshakeRole for ClientHandshake { type IncomingData = Response; type InternalStream = S; type FinalResult = (WebSocket, Response); - fn stage_finished(&mut self, finish: StageResult) - -> Result> - { + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result> { Ok(match finish { StageResult::DoneWriting(stream) => { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - StageResult::DoneReading { stream, result, tail, } => { + StageResult::DoneReading { + stream, + result, + tail, + } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - let websocket = WebSocket::from_partially_read( - stream, - tail, - Role::Client, - self.config.clone(), - ); + let websocket = + WebSocket::from_partially_read(stream, tail, Role::Client, self.config); ProcessingResult::Done((websocket, result)) } }) @@ -161,22 +169,37 @@ impl VerifyData { // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) - if !response.headers.header_is_ignore_case("Upgrade", "websocket") { - return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + if !response + .headers + .header_is_ignore_case("Upgrade", "websocket") + { + return Err(Error::Protocol( + "No \"Upgrade: websocket\" in server reply".into(), + )); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if !response.headers.header_is_ignore_case("Connection", "Upgrade") { - return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + if !response + .headers + .header_is_ignore_case("Connection", "Upgrade") + { + return Err(Error::Protocol( + "No \"Connection: upgrade\" in server reply".into(), + )); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) { - return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + if !response + .headers + .header_is("Sec-WebSocket-Accept", &self.accept_key) + { + return Err(Error::Protocol( + "Key mismatch in Sec-WebSocket-Accept".into(), + )); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -219,7 +242,9 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); } Ok(Response { code: raw.code.expect("Bug: no HTTP response code"), @@ -238,8 +263,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::{Response, generate_key}; use super::super::machine::TryParse; + use super::{generate_key, Response}; #[test] fn random_keys() { @@ -262,6 +287,9 @@ mod tests { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.code, 200); - assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); + assert_eq!( + resp.headers.find_first("Content-Type"), + Some(&b"text/html"[..]) + ); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 23f0d77..097b22b 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,13 +1,13 @@ //! HTTP Request and response header handling. -use std::str::from_utf8; use std::slice; +use std::str::from_utf8; use httparse; use httparse::Status; -use error::Result; use super::machine::TryParse; +use crate::error::Result; /// Limit for the number of header lines. pub const MAX_HEADERS: usize = 124; @@ -19,7 +19,6 @@ pub struct Headers { } impl Headers { - /// Get first header with the given name, if any. pub fn find_first(&self, name: &str) -> Option<&[u8]> { self.find(name).next() @@ -29,7 +28,7 @@ impl Headers { pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { HeadersIter { name, - iter: self.data.iter() + iter: self.data.iter(), } } @@ -42,7 +41,8 @@ impl Headers { /// Check if the given header has the given value (case-insensitive). pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name).ok_or(()) + self.find_first(name) + .ok_or(()) .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) .map(|val| val.eq_ignore_ascii_case(value)) .unwrap_or(false) @@ -52,7 +52,6 @@ impl Headers { pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { self.data.iter() } - } /// The iterator over headers. @@ -67,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { fn next(&mut self) -> Option { while let Some(&(ref name, ref value)) = self.iter.next() { if name.eq_ignore_ascii_case(self.name) { - return Some(value) + return Some(value); } } None } } - /// Trait to convert raw objects into HTTP parseables. pub trait FromHttparse: Sized { /// Convert raw object into parsed HTTP headers. @@ -94,9 +92,10 @@ impl TryParse for Headers { impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { Ok(Headers { - data: raw.iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), + data: raw + .iter() + .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) + .collect(), }) } } @@ -104,13 +103,12 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { #[cfg(test)] mod tests { - use super::Headers; use super::super::machine::TryParse; + use super::Headers; #[test] fn headers() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; @@ -126,8 +124,7 @@ mod tests { #[test] fn headers_iter() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ @@ -142,12 +139,10 @@ mod tests { #[test] fn headers_incomplete() { - const DATA: &'static [u8] = - b"Host: foo.com\r\n\ + const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; let hdr = Headers::try_parse(DATA).unwrap(); assert!(hdr.is_none()); } - } diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 39f37a5..61090bb 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -1,9 +1,10 @@ -use std::io::{Cursor, Read, Write}; use bytes::Buf; +use log::*; +use std::io::{Cursor, Read, Write}; +use crate::error::{Error, Result}; +use crate::util::NonBlockingResult; use input_buffer::{InputBuffer, MIN_READ}; -use error::{Error, Result}; -use util::NonBlockingResult; /// A generic handshake state machine. #[derive(Debug)] @@ -43,16 +44,16 @@ impl HandshakeMachine { trace!("Doing handshake round."); match self.state { HandshakeState::Reading(mut buf) => { - let read = buf.prepare_reserve(MIN_READ) + let read = buf + .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) // TODO limit size .map_err(|_| Error::Capacity("Header too long".into()))? - .read_from(&mut self.stream).no_block()?; + .read_from(&mut self.stream) + .no_block()?; match read { - Some(0) => { - Err(Error::Protocol("Handshake not finished".into())) - } - Some(_) => { - Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + Some(0) => Err(Error::Protocol("Handshake not finished".into())), + Some(_) => Ok( + if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { result: obj, @@ -64,14 +65,12 @@ impl HandshakeMachine { state: HandshakeState::Reading(buf), ..self }) - }) - } - None => { - Ok(RoundResult::WouldBlock(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - })) - } + }, + ), + None => Ok(RoundResult::WouldBlock(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + })), } } HandshakeState::Writing(mut buf) => { @@ -113,7 +112,11 @@ pub enum RoundResult { #[derive(Debug)] pub enum StageResult { /// Reading round finished. - DoneReading { result: Obj, stream: Stream, tail: Vec }, + DoneReading { + result: Obj, + stream: Stream, + tail: Vec, + }, /// Writing round finished. DoneWriting(Stream), } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index f28445d..ce862d9 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -1,7 +1,7 @@ //! WebSocket handshake control. -pub mod headers; pub mod client; +pub mod headers; pub mod server; mod machine; @@ -11,10 +11,10 @@ use std::fmt; use std::io::{Read, Write}; use base64; -use sha1::{Sha1, Digest}; +use sha1::{Digest, Sha1}; -use error::Error; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; +use crate::error::Error; /// A WebSocket handshake. #[derive(Debug)] @@ -30,15 +30,16 @@ impl MidHandshake { loop { mach = match mach.single_round()? { RoundResult::WouldBlock(m) => { - return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) + return Err(HandshakeError::Interrupted(MidHandshake { + machine: m, + ..self + })) } RoundResult::Incomplete(m) => m, - RoundResult::StageFinished(s) => { - match self.role.stage_finished(s)? { - ProcessingResult::Continue(m) => m, - ProcessingResult::Done(result) => return Ok(result), - } - } + RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { + ProcessingResult::Continue(m) => m, + ProcessingResult::Done(result) => return Ok(result), + }, } } } @@ -94,8 +95,10 @@ pub trait HandshakeRole { #[doc(hidden)] type FinalResult; #[doc(hidden)] - fn stage_finished(&mut self, finish: StageResult) - -> Result, Error>; + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result, Error>; } /// Stage processing result. @@ -124,8 +127,9 @@ mod tests { #[test] fn key_conversion() { // example from RFC 6455 - assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), - "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + assert_eq!( + convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + ); } - } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 94649e1..615a755 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,15 +5,15 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use httparse; -use httparse::Status; use http::StatusCode; +use httparse::Status; +use log::*; -use error::{Error, Result}; -use protocol::{WebSocket, WebSocketConfig, Role}; -use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; +use crate::error::{Error, Result}; +use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Request from the client. #[derive(Debug)] @@ -27,14 +27,16 @@ pub struct Request { impl Request { /// Reply to the response. pub fn reply(&self, extra_headers: Option>) -> Result> { - let key = self.headers.find_first("Sec-WebSocket-Key") + let key = self + .headers + .find_first("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; let mut reply = format!( "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", convert_key(key)? ); add_headers(&mut reply, extra_headers); @@ -45,13 +47,12 @@ impl Request { fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) { if let Some(eh) = extra_headers { for (k, v) in eh { - write!(reply, "{}: {}\r\n", k, v).unwrap(); + writeln!(reply, "{}: {}\r", k, v).unwrap(); } } - write!(reply, "\r\n").unwrap(); + writeln!(reply, "\r").unwrap(); } - impl TryParse for Request { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; @@ -69,11 +70,13 @@ impl<'h, 'b: 'h> FromHttparse> for Request { return Err(Error::Protocol("Method is not GET".into())); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); } Ok(Request { path: raw.path.expect("Bug: no path in header").into(), - headers: Headers::from_httparse(raw.headers)? + headers: Headers::from_httparse(raw.headers)?, }) } } @@ -115,7 +118,10 @@ pub trait Callback: Sized { fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; } -impl Callback for F where F: FnOnce(&Request) -> StdResult, ErrorResponse> { +impl Callback for F +where + F: FnOnce(&Request) -> StdResult, ErrorResponse>, +{ fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { self(request) } @@ -160,7 +166,7 @@ impl ServerHandshake { callback: Some(callback), config, error_code: None, - _marker: PhantomData + _marker: PhantomData, }, } } @@ -171,13 +177,18 @@ impl HandshakeRole for ServerHandshake { type InternalStream = S; type FinalResult = WebSocket; - fn stage_finished(&mut self, finish: StageResult) - -> Result> - { + fn stage_finished( + &mut self, + finish: StageResult, + ) -> Result> { Ok(match finish { - StageResult::DoneReading { stream, result, tail } => { + StageResult::DoneReading { + stream, + result, + tail, + } => { if !tail.is_empty() { - return Err(Error::Protocol("Junk after client request".into())) + return Err(Error::Protocol("Junk after client request".into())); } let callback_result = if let Some(callback) = self.callback.take() { @@ -192,8 +203,12 @@ impl HandshakeRole for ServerHandshake { ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } - Err(ErrorResponse { error_code, headers, body }) => { - self.error_code= Some(error_code.as_u16()); + Err(ErrorResponse { + error_code, + headers, + body, + }) => { + self.error_code = Some(error_code.as_u16()); let mut response = format!( "HTTP/1.1 {} {}\r\n", error_code.as_str(), @@ -214,11 +229,7 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Http(err)); } else { debug!("Server handshake done."); - let websocket = WebSocket::from_raw_socket( - stream, - Role::Server, - self.config.clone(), - ); + let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); ProcessingResult::Done(websocket) } } @@ -228,9 +239,9 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::Request; - use super::super::machine::TryParse; use super::super::client::Response; + use super::super::machine::TryParse; + use super::Request; #[test] fn request_parsing() { @@ -253,13 +264,19 @@ mod tests { let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let _ = req.reply(None).unwrap(); - let extra_headers = Some(vec![(String::from("MyCustomHeader"), - String::from("MyCustomValue")), - (String::from("MyVersion"), - String::from("LOL"))]); + let extra_headers = Some(vec![ + ( + String::from("MyCustomHeader"), + String::from("MyCustomValue"), + ), + (String::from("MyVersion"), String::from("LOL")), + ]); let reply = req.reply(extra_headers).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); - assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); + assert_eq!( + req.headers.find_first("MyCustomHeader"), + Some(b"MyCustomValue".as_ref()) + ); assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); } } diff --git a/src/lib.rs b/src/lib.rs index 547c454..f965478 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,39 +3,29 @@ missing_docs, missing_copy_implementations, missing_debug_implementations, - trivial_casts, trivial_numeric_casts, + trivial_casts, + trivial_numeric_casts, unstable_features, unused_must_use, unused_mut, unused_imports, - unused_import_braces)] + unused_import_braces +)] -#[macro_use] extern crate log; -extern crate base64; -extern crate byteorder; -extern crate bytes; -extern crate httparse; -extern crate input_buffer; -extern crate rand; -extern crate sha1; -extern crate url; -extern crate utf8; -#[cfg(feature="tls")] extern crate native_tls; - -pub extern crate http; +pub use http; +pub mod client; pub mod error; +pub mod handshake; pub mod protocol; -pub mod client; pub mod server; -pub mod handshake; pub mod stream; pub mod util; -pub use client::{connect, client}; -pub use server::{accept, accept_hdr}; -pub use error::{Error, Result}; -pub use protocol::{WebSocket, Message}; -pub use handshake::HandshakeError; -pub use handshake::client::ClientHandshake; -pub use handshake::server::ServerHandshake; +pub use crate::client::{client, connect}; +pub use crate::error::{Error, Result}; +pub use crate::handshake::client::ClientHandshake; +pub use crate::handshake::server::ServerHandshake; +pub use crate::handshake::HandshakeError; +pub use crate::protocol::{Message, WebSocket}; +pub use crate::server::{accept, accept_hdr}; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index 5fde5fb..8380d9f 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -1,7 +1,7 @@ //! Various codes defined in RFC 6455. +use std::convert::{From, Into}; use std::fmt; -use std::convert::{Into, From}; /// WebSocket message opcode as in RFC 6455. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -42,8 +42,8 @@ impl fmt::Display for Data { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Data::Continue => write!(f, "CONTINUE"), - Data::Text => write!(f, "TEXT"), - Data::Binary => write!(f, "BINARY"), + Data::Text => write!(f, "TEXT"), + Data::Binary => write!(f, "BINARY"), Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x), } } @@ -53,8 +53,8 @@ impl fmt::Display for Control { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Control::Close => write!(f, "CLOSE"), - Control::Ping => write!(f, "PING"), - Control::Pong => write!(f, "PONG"), + Control::Ping => write!(f, "PING"), + Control::Pong => write!(f, "PONG"), Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x), } } @@ -71,18 +71,18 @@ impl fmt::Display for OpCode { impl Into for OpCode { fn into(self) -> u8 { - use self::Data::{Continue, Text, Binary}; use self::Control::{Close, Ping, Pong}; + use self::Data::{Binary, Continue, Text}; use self::OpCode::*; match self { Data(Continue) => 0, - Data(Text) => 1, - Data(Binary) => 2, + Data(Text) => 1, + Data(Binary) => 2, Data(self::Data::Reserved(i)) => i, Control(Close) => 8, - Control(Ping) => 9, - Control(Pong) => 10, + Control(Ping) => 9, + Control(Pong) => 10, Control(self::Control::Reserved(i)) => i, } } @@ -90,19 +90,19 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { - use self::Data::{Continue, Text, Binary}; use self::Control::{Close, Ping, Pong}; + use self::Data::{Binary, Continue, Text}; use self::OpCode::*; match byte { - 0 => Data(Continue), - 1 => Data(Text), - 2 => Data(Binary), - i @ 3 ... 7 => Data(self::Data::Reserved(i)), - 8 => Control(Close), - 9 => Control(Ping), - 10 => Control(Pong), - i @ 11 ... 15 => Control(self::Control::Reserved(i)), - _ => panic!("Bug: OpCode out of range"), + 0 => Data(Continue), + 1 => Data(Text), + 2 => Data(Binary), + i @ 3..=7 => Data(self::Data::Reserved(i)), + 8 => Control(Close), + 9 => Control(Ping), + 10 => Control(Pong), + i @ 11..=15 => Control(self::Control::Reserved(i)), + _ => panic!("Bug: OpCode out of range"), } } } @@ -183,13 +183,13 @@ pub enum CloseCode { impl CloseCode { /// Check if this CloseCode is allowed. - pub fn is_allowed(&self) -> bool { - match *self { - Bad(_) => false, + pub fn is_allowed(self) -> bool { + match self { + Bad(_) => false, Reserved(_) => false, - Status => false, - Abnormal => false, - Tls => false, + Status => false, + Abnormal => false, + Tls => false, _ => true, } } @@ -205,24 +205,24 @@ impl fmt::Display for CloseCode { impl<'t> Into for &'t CloseCode { fn into(self) -> u16 { match *self { - Normal => 1000, - Away => 1001, - Protocol => 1002, - Unsupported => 1003, - Status => 1005, - Abnormal => 1006, - Invalid => 1007, - Policy => 1008, - Size => 1009, - Extension => 1010, - Error => 1011, - Restart => 1012, - Again => 1013, - Tls => 1015, - Reserved(code) => code, - Iana(code) => code, - Library(code) => code, - Bad(code) => code, + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Reserved(code) => code, + Iana(code) => code, + Library(code) => code, + Bad(code) => code, } } } @@ -250,11 +250,11 @@ impl From for CloseCode { 1012 => Restart, 1013 => Again, 1015 => Tls, - 1...999 => Bad(code), - 1000...2999 => Reserved(code), - 3000...3999 => Iana(code), - 4000...4999 => Library(code), - _ => Bad(code) + 1..=999 => Bad(code), + 1016..=2999 => Reserved(code), + 3000..=3999 => Iana(code), + 4000..=4999 => Library(code), + _ => Bad(code), } } } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index d016c38..5992932 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,14 +1,15 @@ -use std::fmt; +use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; +use log::*; use std::borrow::Cow; -use std::io::{Cursor, Read, Write, ErrorKind}; use std::default::Default; -use std::string::{String, FromUtf8Error}; +use std::fmt; +use std::io::{Cursor, ErrorKind, Read, Write}; use std::result::Result as StdResult; -use byteorder::{ByteOrder, ReadBytesExt, WriteBytesExt, NetworkEndian}; +use std::string::{FromUtf8Error, String}; -use error::{Error, Result}; -use super::coding::{OpCode, Control, Data, CloseCode}; -use super::mask::{generate_mask, apply_mask}; +use super::coding::{CloseCode, Control, Data, OpCode}; +use super::mask::{apply_mask, generate_mask}; +use crate::error::{Error, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -77,15 +78,13 @@ impl FrameHeader { cursor.set_position(initial); ret } - ret => ret + ret => ret, } } /// Get the size of the header formatted with given payload length. pub fn len(&self, length: u64) -> usize { - 2 - + LengthFormat::for_length(length).extra_bytes() - + if self.mask.is_some() { 4 } else { 0 } + 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 } } /// Format a header for given payload size. @@ -93,19 +92,15 @@ impl FrameHeader { let code: u8 = self.opcode.into(); let one = { - code - | if self.is_final { 0x80 } else { 0 } - | if self.rsv1 { 0x40 } else { 0 } - | if self.rsv2 { 0x20 } else { 0 } - | if self.rsv3 { 0x10 } else { 0 } + code | if self.is_final { 0x80 } else { 0 } + | if self.rsv1 { 0x40 } else { 0 } + | if self.rsv2 { 0x20 } else { 0 } + | if self.rsv3 { 0x10 } else { 0 } }; let lenfmt = LengthFormat::for_length(length); - let two = { - lenfmt.length_byte() - | if self.mask.is_some() { 0x80 } else { 0 } - }; + let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } }; output.write_all(&[one, two])?; match lenfmt { @@ -137,7 +132,7 @@ impl FrameHeader { let (first, second) = { let mut head = [0u8; 2]; if cursor.read(&mut head)? != 2 { - return Ok(None) + return Ok(None); } trace!("Parsed headers {:?}", head); (head[0], head[1]) @@ -169,17 +164,17 @@ impl FrameHeader { Err(err) => { return Err(err.into()); } - Ok(read) => read + Ok(read) => read, } } else { - length_byte as u64 + u64::from(length_byte) } }; let mask = if masked { let mut mask_bytes = [0u8; 4]; if cursor.read(&mut mask_bytes)? != 4 { - return Ok(None) + return Ok(None); } else { Some(mask_bytes) } @@ -190,9 +185,11 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) + return Err(Error::Protocol( + format!("Encountered invalid opcode: {}", first & 0x0F).into(), + )) } - _ => () + _ => (), } let hdr = FrameHeader { @@ -216,7 +213,6 @@ pub struct Frame { } impl Frame { - /// Get the length of the frame. /// This is the length of the header + the length of the payload. #[inline] @@ -225,6 +221,12 @@ impl Frame { self.header.len(length as u64) + length } + /// Check if the frame is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Get a reference to the frame's header. #[inline] pub fn header(&self) -> &FrameHeader { @@ -285,7 +287,7 @@ impl Frame { String::from_utf8(self.payload) } - /// Consume the frame into a closing frame. + /// Consume the frame into a closing frame. #[inline] pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { @@ -296,7 +298,10 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { code, reason: text.into() })) + Ok(Some(CloseFrame { + code, + reason: text.into(), + })) } } } @@ -304,16 +309,19 @@ impl Frame { /// Create a new data frame. #[inline] pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { - debug_assert!(match opcode { - OpCode::Data(_) => true, - _ => false, - }, "Invalid opcode for data frame."); + debug_assert!( + match opcode { + OpCode::Data(_) => true, + _ => false, + }, + "Invalid opcode for data frame." + ); Frame { header: FrameHeader { is_final, opcode, - .. FrameHeader::default() + ..FrameHeader::default() }, payload: data, } @@ -325,7 +333,7 @@ impl Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Pong), - .. FrameHeader::default() + ..FrameHeader::default() }, payload: data, } @@ -337,7 +345,7 @@ impl Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Ping), - .. FrameHeader::default() + ..FrameHeader::default() }, payload: data, } @@ -363,10 +371,7 @@ impl Frame { /// Create a frame from given header and data. pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { - Frame { - header, - payload, - } + Frame { header, payload } } /// Write a frame out to a buffer @@ -380,7 +385,8 @@ impl Frame { impl fmt::Display for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, + write!( + f, " final: {} @@ -398,7 +404,11 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.iter().map(|byte| format!("{:x}", byte)).collect::()) + self.payload + .iter() + .map(|byte| format!("{:x}", byte)) + .collect::() + ) } } @@ -448,7 +458,7 @@ impl LengthFormat { match byte & 0x7F { 126 => LengthFormat::U16, 127 => LengthFormat::U64, - b => LengthFormat::U8(b) + b => LengthFormat::U8(b), } } } @@ -457,20 +467,22 @@ impl LengthFormat { mod tests { use super::*; - use super::super::coding::{OpCode, Data}; + use super::super::coding::{Data, OpCode}; use std::io::Cursor; #[test] fn parse() { - let mut raw: Cursor> = Cursor::new(vec![ - 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 - ]); + let mut raw: Cursor> = + Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); assert_eq!(length, 7); let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload); - assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); + assert_eq!( + frame.into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); } #[test] @@ -487,5 +499,4 @@ mod tests { let view = format!("{}", f); assert!(view.contains("payload:")); } - } diff --git a/src/protocol/frame/mask.rs b/src/protocol/frame/mask.rs index 8b39d76..b357795 100644 --- a/src/protocol/frame/mask.rs +++ b/src/protocol/frame/mask.rs @@ -1,7 +1,8 @@ +use rand; use std::cmp::min; +#[allow(deprecated)] use std::mem::uninitialized; use std::ptr::{copy_nonoverlapping, read_unaligned}; -use rand; /// Generate a random frame mask. #[inline] @@ -26,11 +27,9 @@ fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { /// Faster version of `apply_mask()` which operates on 4-byte blocks. #[inline] -#[allow(dead_code)] +#[allow(dead_code, clippy::cast_ptr_alignment)] fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { - let mask_u32: u32 = unsafe { - read_unaligned(mask.as_ptr() as *const u32) - }; + let mask_u32: u32 = unsafe { read_unaligned(mask.as_ptr() as *const u32) }; let mut ptr = buf.as_mut_ptr(); let mut len = buf.len(); @@ -40,7 +39,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { let mask_u32 = if head > 0 { unsafe { xor_mem(ptr, mask_u32, head); - ptr = ptr.offset(head as isize); + ptr = ptr.add(head); } len -= head; if cfg!(target_endian = "big") { @@ -67,7 +66,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { // Possible last block. if len > 0 { - unsafe { xor_mem(ptr, mask_u32, len); } + unsafe { + xor_mem(ptr, mask_u32, len); + } } } @@ -75,6 +76,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { // TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient, // it could be done better. The compiler does not see that len is limited to 3. unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { + #[allow(deprecated)] let mut b: u32 = uninitialized(); #[allow(trivial_casts)] copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); @@ -90,12 +92,10 @@ mod tests { #[test] fn test_apply_mask() { - let mask = [ - 0x6d, 0xb6, 0xb2, 0x80, - ]; + let mask = [0x6d, 0xb6, 0xb2, 0x80]; let unmasked = vec![ - 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, - 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, + 0x12, 0x03, ]; // Check masking with proper alignment. @@ -120,6 +120,4 @@ mod tests { assert_eq!(masked, masked_fast); } } - } - diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 830146f..44f3ec3 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -2,16 +2,17 @@ pub mod coding; +#[allow(clippy::module_inception)] mod frame; mod mask; -pub use self::frame::{Frame, FrameHeader}; pub use self::frame::CloseFrame; +pub use self::frame::{Frame, FrameHeader}; -use std::io::{Read, Write}; - +use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; -use error::{Error, Result}; +use log::*; +use std::io::{Read, Write}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -56,7 +57,8 @@ impl FrameSocket { } impl FrameSocket - where Stream: Read +where + Stream: Read, { /// Read a frame from stream. pub fn read_frame(&mut self, max_size: Option) -> Result> { @@ -65,7 +67,8 @@ impl FrameSocket } impl FrameSocket - where Stream: Write +where + Stream: Write, { /// Write a frame to stream. /// @@ -138,8 +141,8 @@ impl FrameCodec { // is not too big (fits into `usize`). if length > max_size as u64 { return Err(Error::Capacity( - format!("Message length too big: {} > {}", length, max_size).into() - )) + format!("Message length too big: {} > {}", length, max_size).into(), + )); } let input_size = cursor.get_ref().len() as u64 - cursor.position(); @@ -149,19 +152,21 @@ impl FrameCodec { if length > 0 { cursor.take(length).read_to_end(&mut payload)?; } - break payload + break payload; } } } // Not enough data in buffer. - let size = self.in_buffer.prepare_reserve(MIN_READ) + let size = self + .in_buffer + .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .read_from(stream)?; if size == 0 { trace!("no frame received"); - return Ok(None) + return Ok(None); } }; @@ -173,17 +178,15 @@ impl FrameCodec { } /// Write a frame to the provided stream. - pub(super) fn write_frame( - &mut self, - stream: &mut Stream, - frame: Frame, - ) -> Result<()> + pub(super) fn write_frame(&mut self, stream: &mut Stream, frame: Frame) -> Result<()> where Stream: Write, { trace!("writing frame {}", frame); self.out_buffer.reserve(frame.len()); - frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); + frame + .format(&mut self.out_buffer) + .expect("Bug: can't write to vector"); self.write_pending(stream) } @@ -211,16 +214,19 @@ mod tests { #[test] fn read_frames() { let raw = Cursor::new(vec![ - 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x82, 0x03, 0x03, 0x02, 0x01, + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01, 0x99, ]); let mut sock = FrameSocket::new(raw); - assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); - assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x03, 0x02, 0x01]); + assert_eq!( + sock.read_frame(None).unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); + assert_eq!( + sock.read_frame(None).unwrap().unwrap().into_data(), + vec![0x03, 0x02, 0x01] + ); assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -229,12 +235,12 @@ mod tests { #[test] fn from_partially_read() { - let raw = Cursor::new(vec![ - 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - ]); + let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); - assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!( + sock.read_frame(None).unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); } #[test] @@ -248,17 +254,13 @@ mod tests { sock.write_frame(frame).unwrap(); let (buf, _) = sock.into_inner(); - assert_eq!(buf, vec![ - 0x89, 0x02, 0x04, 0x05, - 0x8a, 0x01, 0x01 - ]); + assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); } #[test] fn parse_overflow() { let raw = Cursor::new(vec![ - 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, ]); let mut sock = FrameSocket::new(raw); let _ = sock.read_frame(None); // should not crash @@ -266,11 +268,10 @@ mod tests { #[test] fn size_limit_hit() { - let raw = Cursor::new(vec![ - 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - ]); + let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); - assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(), + assert_eq!( + sock.read_frame(Some(5)).unwrap_err().to_string(), "Space limit exceeded: Message length too big: 7 > 5" ); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index c98d22f..ba00765 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,17 +1,17 @@ -use std::convert::{From, Into, AsRef}; +use std::convert::{AsRef, From, Into}; use std::fmt; use std::result::Result as StdResult; use std::str; -use error::{Result, Error}; use super::frame::CloseFrame; +use crate::error::{Error, Result}; mod string_collect { use utf8; use utf8::DecodeError; - use error::{Error, Result}; + use crate::error::{Error, Result}; #[derive(Debug)] pub struct StringCollector { @@ -28,7 +28,8 @@ mod string_collect { } pub fn len(&self) -> usize { - self.data.len() + self.data + .len() .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0)) } @@ -41,7 +42,7 @@ mod string_collect { if let Ok(text) = result { self.data.push_str(text); } else { - return Err(Error::Utf8) + return Err(Error::Utf8); } true } else { @@ -59,7 +60,10 @@ mod string_collect { self.data.push_str(text); Ok(()) } - Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { + Err(DecodeError::Incomplete { + valid_prefix, + incomplete_suffix, + }) => { self.data.push_str(valid_prefix); self.incomplete = Some(incomplete_suffix); Ok(()) @@ -82,7 +86,6 @@ mod string_collect { } } } - } use self::string_collect::StringCollector; @@ -104,11 +107,11 @@ impl IncompleteMessage { pub fn new(message_type: IncompleteMessageType) -> Self { IncompleteMessage { collector: match message_type { - IncompleteMessageType::Binary => - IncompleteMessageCollector::Binary(Vec::new()), - IncompleteMessageType::Text => - IncompleteMessageCollector::Text(StringCollector::new()), - } + IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()), + IncompleteMessageType::Text => { + IncompleteMessageCollector::Text(StringCollector::new()) + } + }, } } @@ -130,8 +133,12 @@ impl IncompleteMessage { // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { return Err(Error::Capacity( - format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into() - )) + format!( + "Message too big: {} + {} > {}", + my_size, portion_size, max_size + ) + .into(), + )); } match self.collector { @@ -139,18 +146,14 @@ impl IncompleteMessage { v.extend(tail.as_ref()); Ok(()) } - IncompleteMessageCollector::Text(ref mut t) => { - t.extend(tail) - } + IncompleteMessageCollector::Text(ref mut t) => t.extend(tail), } } /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { - IncompleteMessageCollector::Binary(v) => { - Ok(Message::Binary(v)) - } + IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; Ok(Message::Text(text)) @@ -185,17 +188,18 @@ pub enum Message { } impl Message { - /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message - where S: Into + where + S: Into, { Message::Text(string.into()) } /// Create a new binary WebSocket message by converting to Vec. pub fn binary(bin: B) -> Message - where B: Into> + where + B: Into>, { Message::Binary(bin.into()) } @@ -244,9 +248,9 @@ impl Message { pub fn len(&self) -> usize { match *self { Message::Text(ref string) => string.len(), - Message::Binary(ref data) | - Message::Ping(ref data) | - Message::Pong(ref data) => data.len(), + Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { + data.len() + } Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), } } @@ -261,9 +265,7 @@ impl Message { pub fn into_data(self) -> Vec { match self { Message::Text(string) => string.into_bytes(), - Message::Binary(data) | - Message::Ping(data) | - Message::Pong(data) => data, + Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => Vec::new(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } @@ -273,10 +275,9 @@ impl Message { pub fn into_text(self) -> Result { match self { Message::Text(string) => Ok(string), - Message::Binary(data) | - Message::Ping(data) | - Message::Pong(data) => Ok(try!( - String::from_utf8(data).map_err(|err| err.utf8_error()))), + Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { + Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) + } Message::Close(None) => Ok(String::new()), Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), } @@ -287,14 +288,13 @@ impl Message { pub fn to_text(&self) -> Result<&str> { match *self { Message::Text(ref string) => Ok(string), - Message::Binary(ref data) | - Message::Ping(ref data) | - Message::Pong(ref data) => Ok(try!(str::from_utf8(data))), + Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { + Ok(str::from_utf8(data)?) + } Message::Close(None) => Ok(""), Message::Close(Some(ref frame)) => Ok(&frame.reason), } } - } impl From for Message { @@ -358,7 +358,6 @@ mod tests { assert!(msg.into_text().is_err()); } - #[test] fn binary_convert_vec() { let bin = vec![6u8, 7, 8, 9, 10, 241]; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 0d1e5f7..a853f94 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -4,18 +4,19 @@ pub mod frame; mod message; -pub use self::message::Message; pub use self::frame::CloseFrame; +pub use self::message::Message; +use log::*; use std::collections::VecDeque; -use std::io::{Read, Write, ErrorKind as IoErrorKind}; +use std::io::{ErrorKind as IoErrorKind, Read, Write}; use std::mem::replace; -use error::{Error, Result}; -use self::message::{IncompleteMessage, IncompleteMessageType}; +use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use self::frame::{Frame, FrameCodec}; -use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode}; -use util::NonBlockingResult; +use self::message::{IncompleteMessage, IncompleteMessageType}; +use crate::error::{Error, Result}; +use crate::util::NonBlockingResult; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -147,7 +148,6 @@ impl WebSocket { } } - /// A context for managing WebSocket stream. #[derive(Debug)] pub struct WebSocketContext { @@ -182,11 +182,7 @@ impl WebSocketContext { } /// Create a WebSocket context that manages an post-handshake stream. - pub fn from_partially_read( - part: Vec, - role: Role, - config: Option, - ) -> Self { + pub fn from_partially_read(part: Vec, role: Role, config: Option) -> Self { WebSocketContext { frame: FrameCodec::from_partially_read(part), ..WebSocketContext::new(role, config) @@ -217,7 +213,7 @@ impl WebSocketContext { // Thus if read blocks, just let it return WouldBlock. if let Some(message) = self.read_message_frame(stream)? { trace!("Received message {}", message); - return Ok(message) + return Ok(message); } } } @@ -251,20 +247,14 @@ impl WebSocketContext { } let frame = match message { - Message::Text(data) => { - Frame::message(data.into(), OpCode::Data(OpData::Text), true) - } - Message::Binary(data) => { - Frame::message(data, OpCode::Data(OpData::Binary), true) - } + Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), + Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { self.pong = Some(Frame::pong(data)); - return self.write_pending(stream) - } - Message::Close(code) => { - return self.close(stream, code) + return self.write_pending(stream); } + Message::Close(code) => return self.close(stream, code), }; self.send_queue.push_back(frame); @@ -342,7 +332,6 @@ impl WebSocketContext { Stream: Read + Write, { if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { - // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of // the negotiated extensions defines the meaning of such a nonzero @@ -351,7 +340,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol("Reserved bits are non-zero".into())) + return Err(Error::Protocol("Reserved bits are non-zero".into())); } } @@ -364,19 +353,22 @@ impl WebSocketContext { } else { // The server MUST close the connection upon receiving a // frame that is not masked. (RFC 6455) - return Err(Error::Protocol("Received an unmasked frame from client".into())) + return Err(Error::Protocol( + "Received an unmasked frame from client".into(), + )); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol("Received a masked frame from server".into())) + return Err(Error::Protocol( + "Received a masked frame from server".into(), + )); } } } match frame.header().opcode { - OpCode::Control(ctl) => { match ctl { // All control frames MUST have a payload length of 125 bytes or less @@ -387,12 +379,10 @@ impl WebSocketContext { _ if frame.payload().len() > 125 => { Err(Error::Protocol("Control frame too big".into())) } - OpCtl::Close => { - Ok(self.do_close(frame.into_close()?).map(Message::Close)) - } - OpCtl::Reserved(i) => { - Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) - } + OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), + OpCtl::Reserved(i) => Err(Error::Protocol( + format!("Unknown control frame type {}", i).into(), + )), OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => { // No ping processing while closing. Ok(None) @@ -402,9 +392,7 @@ impl WebSocketContext { self.pong = Some(Frame::pong(data.clone())); Ok(Some(Message::Ping(data))) } - OpCtl::Pong => { - Ok(Some(Message::Pong(frame.into_data()))) - } + OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), } } @@ -420,7 +408,9 @@ impl WebSocketContext { if let Some(ref mut msg) = self.incomplete { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { - return Err(Error::Protocol("Continue frame but nothing to continue".into())) + return Err(Error::Protocol( + "Continue frame but nothing to continue".into(), + )); } if fin { Ok(Some(self.incomplete.take().unwrap().complete()?)) @@ -428,11 +418,9 @@ impl WebSocketContext { Ok(None) } } - c if self.incomplete.is_some() => { - Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into() - )) - } + c if self.incomplete.is_some() => Err(Error::Protocol( + format!("Received {} while waiting for more fragments", c).into(), + )), OpData::Text | OpData::Binary => { let msg = { let message_type = match data { @@ -451,28 +439,27 @@ impl WebSocketContext { Ok(None) } } - OpData::Reserved(i) => { - Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) - } + OpData::Reserved(i) => Err(Error::Protocol( + format!("Unknown data frame type {}", i).into(), + )), } } - } // match opcode - } else { // Connection closed by peer match replace(&mut self.state, WebSocketState::Terminated) { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => { - Err(Error::Protocol("Connection reset without closing handshake".into())) - } + _ => Err(Error::Protocol( + "Connection reset without closing handshake".into(), + )), } } } /// Received a close frame. Tells if we need to return a close frame to the user. + #[allow(clippy::option_option)] fn do_close<'t>(&mut self, close: Option>) -> Option>> { debug!("Received close frame: {:?}", close); match self.state { @@ -488,7 +475,7 @@ impl WebSocketContext { } else { Frame::close(Some(CloseFrame { code: CloseCode::Protocol, - reason: "Protocol violation".into() + reason: "Protocol violation".into(), })) } } else { @@ -518,8 +505,7 @@ impl WebSocketContext { Stream: Read + Write, { match self.role { - Role::Server => { - } + Role::Server => {} Role::Client => { // 5. If the data is being sent by the client, the frame(s) MUST be // masked as defined in Section 5.3. (RFC 6455) @@ -535,7 +521,9 @@ impl WebSocketContext { match self.state { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged if err.kind() == IoErrorKind::ConnectionReset => - Error::ConnectionClosed, + { + Error::ConnectionClosed + } _ => Error::Io(err), } }), @@ -544,7 +532,6 @@ impl WebSocketContext { } } - /// The current connection state. #[derive(Debug)] enum WebSocketState { @@ -580,7 +567,7 @@ impl WebSocketState { #[cfg(test)] mod tests { - use super::{WebSocket, Role, Message, WebSocketConfig}; + use super::{Message, Role, WebSocket, WebSocketConfig}; use std::io; use std::io::Cursor; @@ -602,57 +589,53 @@ mod tests { } } - #[test] fn receive_messages() { let incoming = Cursor::new(vec![ - 0x89, 0x02, 0x01, 0x02, - 0x8a, 0x01, 0x03, - 0x01, 0x07, - 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, - 0x80, 0x06, - 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, - 0x82, 0x03, - 0x01, 0x02, 0x03, + 0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02, + 0x03, ]); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!( + socket.read_message().unwrap(), + Message::Text("Hello, World!".into()) + ); + assert_eq!( + socket.read_message().unwrap(), + Message::Binary(vec![0x01, 0x02, 0x03]) + ); } - #[test] fn size_limiting_text_fragmented() { let incoming = Cursor::new(vec![ - 0x01, 0x07, - 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, - 0x80, 0x06, - 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, + 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, + 0x6c, 0x64, 0x21, ]); let limit = WebSocketConfig { max_message_size: Some(10), - .. WebSocketConfig::default() + ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!(socket.read_message().unwrap_err().to_string(), + assert_eq!( + socket.read_message().unwrap_err().to_string(), "Space limit exceeded: Message too big: 7 + 6 > 10" ); } #[test] fn size_limiting_binary() { - let incoming = Cursor::new(vec![ - 0x82, 0x03, - 0x01, 0x02, 0x03, - ]); + let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let limit = WebSocketConfig { max_message_size: Some(2), - .. WebSocketConfig::default() + ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!(socket.read_message().unwrap_err().to_string(), + assert_eq!( + socket.read_message().unwrap_err().to_string(), "Space limit exceeded: Message too big: 0 + 3 > 2" ); } diff --git a/src/server.rs b/src/server.rs index ba93508..725d892 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,11 +1,11 @@ //! Methods to accept an incoming WebSocket connection on a server. -pub use handshake::server::ServerHandshake; +pub use crate::handshake::server::ServerHandshake; -use handshake::HandshakeError; -use handshake::server::{Callback, NoCallback}; +use crate::handshake::server::{Callback, NoCallback}; +use crate::handshake::HandshakeError; -use protocol::{WebSocket, WebSocketConfig}; +use crate::protocol::{WebSocket, WebSocketConfig}; use std::io::{Read, Write}; @@ -18,9 +18,10 @@ use std::io::{Read, Write}; /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept_with_config(stream: S, config: Option) - -> Result, HandshakeError>> -{ +pub fn accept_with_config( + stream: S, + config: Option, +) -> Result, HandshakeError>> { accept_hdr_with_config(stream, NoCallback, config) } @@ -30,9 +31,9 @@ pub fn accept_with_config(stream: S, config: Option(stream: S) - -> Result, HandshakeError>> -{ +pub fn accept( + stream: S, +) -> Result, HandshakeError>> { accept_with_config(stream, None) } @@ -47,7 +48,7 @@ pub fn accept(stream: S) pub fn accept_hdr_with_config( stream: S, callback: C, - config: Option + config: Option, ) -> Result, HandshakeError>> { ServerHandshake::start(stream, callback, config).handshake() } @@ -57,8 +58,9 @@ pub fn accept_hdr_with_config( /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr(stream: S, callback: C) - -> Result, HandshakeError>> -{ +pub fn accept_hdr( + stream: S, + callback: C, +) -> Result, HandshakeError>> { accept_hdr_with_config(stream, callback, None) } diff --git a/src/stream.rs b/src/stream.rs index 24324ca..96d26d2 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,11 +4,11 @@ //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. -use std::io::{Read, Write, Result as IoResult}; +use std::io::{Read, Result as IoResult, Write}; use std::net::TcpStream; -#[cfg(feature="tls")] +#[cfg(feature = "tls")] use native_tls::TlsStream; /// Stream mode, either plain TCP or TLS. @@ -32,7 +32,7 @@ impl NoDelay for TcpStream { } } -#[cfg(feature="tls")] +#[cfg(feature = "tls")] impl NoDelay for TlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.get_mut().set_nodelay(nodelay) diff --git a/src/util.rs b/src/util.rs index a784f9c..cd03035 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,7 +3,7 @@ use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::result::Result as StdResult; -use error::Error; +use crate::error::Error; /// Non-blocking IO handling. pub trait NonBlockingError: Sized { @@ -40,7 +40,8 @@ pub trait NonBlockingResult { } impl NonBlockingResult for StdResult - where E : NonBlockingError +where + E: NonBlockingError, { type Result = StdResult, E>; fn no_block(self) -> Self::Result { @@ -49,7 +50,7 @@ impl NonBlockingResult for StdResult Err(e) => match e.into_non_blocking() { Some(e) => Err(e), None => Ok(None), - } + }, } } } diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 47a0740..b94e8d2 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,13 +1,9 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. -extern crate env_logger; -extern crate tungstenite; -extern crate url; - use std::net::TcpListener; use std::process::exit; -use std::thread::{spawn, sleep}; +use std::thread::{sleep, spawn}; use std::time::Duration; use tungstenite::{accept, connect, Error, Message}; @@ -28,14 +24,16 @@ fn test_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()).unwrap(); - client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + client + .write_message(Message::Text("Hello WebSocket".into())) + .unwrap(); let message = client.read_message().unwrap(); // receive close from server assert!(message.is_close()); let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed match err { - Error::ConnectionClosed => { }, + Error::ConnectionClosed => {} _ => panic!("unexpected error"), } }); @@ -52,7 +50,7 @@ fn test_close() { let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed match err { - Error::ConnectionClosed => { }, + Error::ConnectionClosed => {} _ => panic!("unexpected error"), }