diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 31aedb2..8143175 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -14,11 +14,7 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!( - "ws://localhost:9001/updateReports?agent={}", - AGENT - )) - .unwrap(), + Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), )?; socket.close(None)?; Ok(()) @@ -26,11 +22,8 @@ fn update_reports() -> Result<()> { fn run_test(case: u32) -> Result<()> { info!("Running test case {}", case); - let case_url = Url::parse(&format!( - "ws://localhost:9001/runCase?case={}&agent={}", - case, AGENT - )) - .unwrap(); + let case_url = + Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3c99545..3250b2c 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,9 +1,10 @@ -use std::net::{TcpListener, TcpStream}; -use std::thread::spawn; +use std::{ + net::{TcpListener, TcpStream}, + thread::spawn, +}; use log::*; -use tungstenite::handshake::HandshakeRole; -use tungstenite::{accept, Error, HandshakeError, Message, Result}; +use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result}; fn must_not_block(err: HandshakeError) -> Error { match err { @@ -32,12 +33,14 @@ fn main() { for stream in server.incoming() { spawn(move || match stream { - Ok(stream) => if let Err(err) = handle_client(stream) { - match err { - Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), - e => error!("test: {}", e), + Ok(stream) => { + if let Err(err) = handle_client(stream) { + match err { + Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), + e => error!("test: {}", e), + } } - }, + } Err(e) => error!("Error accepting stream: {}", e), }); } diff --git a/examples/callback-error.rs b/examples/callback-error.rs index 357d343..cf78a2e 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -1,9 +1,10 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; -use tungstenite::http::StatusCode; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, + http::StatusCode, +}; fn main() { let server = TcpListener::bind("127.0.0.1:3012").unwrap(); diff --git a/examples/client.rs b/examples/client.rs index 7938cfb..def6a3c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -14,9 +14,7 @@ fn main() { println!("* {}", header); } - socket - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); diff --git a/examples/server.rs b/examples/server.rs index def2f45..420e5db 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,9 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, +}; fn main() { env_logger::init(); diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..db7f39d --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,7 @@ +# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check]. +# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md + +# Break complex but short statements a bit less. +use_small_heuristics = "Max" + +merge_imports = true diff --git a/src/client.rs b/src/client.rs index 1b20980..f9ae3a4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,20 @@ //! Methods to connect to a WebSocket as a client. -use std::io::{Read, Write}; -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -use std::result::Result as StdResult; +use std::{ + io::{Read, Write}, + net::{SocketAddr, TcpStream, ToSocketAddrs}, + result::Result as StdResult, +}; -use http::{Uri, request::Parts}; +use http::{request::Parts, Uri}; use log::*; use url::Url; -use crate::handshake::client::{Request, Response}; -use crate::protocol::WebSocketConfig; +use crate::{ + handshake::client::{Request, Response}, + protocol::WebSocketConfig, +}; #[cfg(feature = "tls")] mod encryption { @@ -22,8 +26,7 @@ mod encryption { /// TCP stream switcher (plain/TLS). pub type AutoStream = StreamSwitcher>; - use crate::error::Result; - use crate::stream::Mode; + use crate::{error::Result, stream::Mode}; pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { match mode { @@ -48,8 +51,10 @@ mod encryption { mod encryption { use std::net::TcpStream; - use crate::error::{Error, Result}; - use crate::stream::Mode; + use crate::{ + error::{Error, Result}, + stream::Mode, + }; /// TLS support is nod compiled in, this is just standard `TcpStream`. pub type AutoStream = TcpStream; @@ -65,11 +70,12 @@ mod encryption { use self::encryption::wrap_stream; pub use self::encryption::AutoStream; -use crate::error::{Error, Result}; -use crate::handshake::client::ClientHandshake; -use crate::handshake::HandshakeError; -use crate::protocol::WebSocket; -use crate::stream::{Mode, NoDelay}; +use crate::{ + error::{Error, Result}, + handshake::{client::ClientHandshake, HandshakeError}, + protocol::WebSocket, + stream::{Mode, NoDelay}, +}; /// Connect to the given WebSocket in blocking mode. /// @@ -91,16 +97,14 @@ pub fn connect_with_config( config: Option, max_redirects: u8, ) -> Result<(WebSocket, Response)> { - - fn try_client_handshake(request: Request, config: Option) - -> Result<(WebSocket, Response)> - { + fn try_client_handshake( + request: Request, + config: Option, + ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let host = + request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -164,9 +168,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { diff --git a/src/error.rs b/src/error.rs index b2657cf..c2becc7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,6 @@ //! Error handling. -use std::borrow::Cow; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io; -use std::result; -use std::str; -use std::string; +use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; use crate::protocol::Message; use http::Response; diff --git a/src/handshake/client.rs b/src/handshake/client.rs index bb159d7..ea011fd 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,17 +1,24 @@ //! Client handshake machine. -use std::io::{Read, Write}; -use std::marker::PhantomData; +use std::{ + io::{Read, Write}, + marker::PhantomData, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; /// Client request type. pub type Request = HttpRequest<()>; @@ -35,15 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol( - "Invalid HTTP method, only GET supported".into(), - )); + return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } // Check the URI scheme: only ws or wss are supported @@ -58,18 +61,11 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); - ClientHandshake { - verify_data: VerifyData { accept_key }, - config, - _marker: PhantomData, - } + ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData } }; trace!("Client handshake initiated."); - Ok(MidHandshake { - role: client, - machine, - }) + Ok(MidHandshake { role: client, machine }) } } @@ -85,11 +81,7 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneWriting(stream) => { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - StageResult::DoneReading { - stream, - result, - tail, - } => { + StageResult::DoneReading { stream, result, tail } => { let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = @@ -105,16 +97,16 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority() - .ok_or_else(|| Error::Url("No host name in the URL".into()))? - .as_str(); - let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ + let authority = + uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); + let host = if let Some(idx) = authority.find('@') { + // handle possible name:password@ authority.split_at(idx + 1).1 } else { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())) + return Err(Error::Url("URL contains empty host name".into())); } write!( @@ -128,17 +120,15 @@ fn generate_request(request: Request, key: &str) -> Result> { Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = uri - .path_and_query() - .ok_or_else(|| Error::Url("No path/query in URL".into()))? - .as_str(), + path = + uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), key = key ) .unwrap(); for (k, v) in request.headers() { let mut k = k.as_str(); - if k == "sec-websocket-protocol" { + if k == "sec-websocket-protocol" { k = "Sec-WebSocket-Protocol"; } writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); @@ -175,9 +165,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -189,22 +177,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !headers - .get("Sec-WebSocket-Accept") - .map(|h| h == &self.accept_key) - .unwrap_or(false) - { - return Err(Error::Protocol( - "Key mismatch in Sec-WebSocket-Accept".into(), - )); + if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { + return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -238,9 +218,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -266,9 +244,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::super::machine::TryParse; + use super::{super::machine::TryParse, generate_key, generate_request, Response}; use crate::client::IntoClientRequest; - use super::{generate_key, generate_request, Response}; #[test] fn random_keys() { @@ -342,9 +319,6 @@ mod tests { const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); - assert_eq!( - resp.headers().get("Content-Type").unwrap(), - &b"text/html"[..], - ); + assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 8386f5a..f336c65 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -41,8 +41,7 @@ impl TryParse for HeaderMap { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::HeaderMap; + use super::{super::machine::TryParse, HeaderMap}; #[test] fn headers() { diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 61090bb..b8416a6 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -2,8 +2,10 @@ use bytes::Buf; use log::*; use std::io::{Cursor, Read, Write}; -use crate::error::{Error, Result}; -use crate::util::NonBlockingResult; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; use input_buffer::{InputBuffer, MIN_READ}; /// A generic handshake state machine. @@ -23,10 +25,7 @@ impl HandshakeMachine { } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { - HandshakeMachine { - stream, - state: HandshakeState::Writing(Cursor::new(data.into())), - } + HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) } } /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &Stream { @@ -52,21 +51,19 @@ impl HandshakeMachine { .no_block()?; match read { Some(0) => Err(Error::Protocol("Handshake not finished".into())), - Some(_) => Ok( - if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { - buf.advance(size); - RoundResult::StageFinished(StageResult::DoneReading { - result: obj, - stream: self.stream, - tail: buf.into_vec(), - }) - } else { - RoundResult::Incomplete(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - }) - }, - ), + Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + buf.advance(size); + RoundResult::StageFinished(StageResult::DoneReading { + result: obj, + stream: self.stream, + tail: buf.into_vec(), + }) + } else { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + }) + }), None => Ok(RoundResult::WouldBlock(HandshakeMachine { state: HandshakeState::Reading(buf), ..self @@ -112,11 +109,7 @@ pub enum RoundResult { #[derive(Debug)] pub enum StageResult { /// Reading round finished. - DoneReading { - result: Obj, - stream: Stream, - tail: Vec, - }, + DoneReading { result: Obj, stream: Stream, tail: Vec }, /// Writing round finished. DoneWriting(Stream), } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 8350dc0..4714ee0 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -6,9 +6,11 @@ pub mod server; mod machine; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io::{Read, Write}; +use std::{ + error::Error as ErrorTrait, + fmt, + io::{Read, Write}, +}; use sha1::{Digest, Sha1}; @@ -39,10 +41,7 @@ impl MidHandshake { loop { mach = match mach.single_round()? { RoundResult::WouldBlock(m) => { - return Err(HandshakeError::Interrupted(MidHandshake { - machine: m, - ..self - })) + return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) } RoundResult::Incomplete(m) => m, RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 15f6b14..1b6eed8 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,18 +1,25 @@ //! Server handshake machine. -use std::io::{self, Read, Write}; -use std::marker::PhantomData; -use std::result::Result as StdResult; +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + result::Result as StdResult, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; /// Server request type. pub type Request = HttpRequest<()>; @@ -30,9 +37,7 @@ pub fn create_response(request: &Request) -> Result { } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } if !request @@ -42,9 +47,7 @@ pub fn create_response(request: &Request) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in client request".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); } if !request @@ -54,20 +57,11 @@ pub fn create_response(request: &Request) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in client request".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); } - if !request - .headers() - .get("Sec-WebSocket-Version") - .map(|h| h == "13") - .unwrap_or(false) - { - return Err(Error::Protocol( - "No \"Sec-WebSocket-Version: 13\" in client request".into(), - )); + if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { + return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); } let key = request @@ -121,9 +115,7 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -229,11 +221,7 @@ impl HandshakeRole for ServerHandshake { finish: StageResult, ) -> Result> { Ok(match finish { - StageResult::DoneReading { - stream, - result, - tail, - } => { + StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())); } @@ -290,9 +278,7 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::create_response; - use super::Request; + use super::{super::machine::TryParse, create_response, Request}; #[test] fn request_parsing() { diff --git a/src/lib.rs b/src/lib.rs index f965478..82f7822 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,10 +22,10 @@ pub mod server; pub mod stream; pub mod util; -pub use crate::client::{client, connect}; -pub use crate::error::{Error, Result}; -pub use crate::handshake::client::ClientHandshake; -pub use crate::handshake::server::ServerHandshake; -pub use crate::handshake::HandshakeError; -pub use crate::protocol::{Message, WebSocket}; -pub use crate::server::{accept, accept_hdr}; +pub use crate::{ + client::{client, connect}, + error::{Error, Result}, + handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, + protocol::{Message, WebSocket}, + server::{accept, accept_hdr}, +}; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index d3fcdeb..c69b90c 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -1,7 +1,9 @@ //! Various codes defined in RFC 6455. -use std::convert::{From, Into}; -use std::fmt; +use std::{ + convert::{From, Into}, + fmt, +}; /// WebSocket message opcode as in RFC 6455. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -71,9 +73,11 @@ impl fmt::Display for OpCode { impl Into for OpCode { fn into(self) -> u8 { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match self { Data(Continue) => 0, Data(Text) => 1, @@ -90,9 +94,11 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match byte { 0 => Data(Continue), 1 => Data(Text), diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 38ce61c..ff64fa2 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,14 +1,18 @@ use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use log::*; -use std::borrow::Cow; -use std::default::Default; -use std::fmt; -use std::io::{Cursor, ErrorKind, Read, Write}; -use std::result::Result as StdResult; -use std::string::{FromUtf8Error, String}; - -use super::coding::{CloseCode, Control, Data, OpCode}; -use super::mask::{apply_mask, generate_mask}; +use std::{ + borrow::Cow, + default::Default, + fmt, + io::{Cursor, ErrorKind, Read, Write}, + result::Result as StdResult, + string::{FromUtf8Error, String}, +}; + +use super::{ + coding::{CloseCode, Control, Data, OpCode}, + mask::{apply_mask, generate_mask}, +}; use crate::error::{Error, Result}; /// A struct representing the close command. @@ -23,10 +27,7 @@ pub struct CloseFrame<'t> { impl<'t> CloseFrame<'t> { /// Convert into a owned string. pub fn into_owned(self) -> CloseFrame<'static> { - CloseFrame { - code: self.code, - reason: self.reason.into_owned().into(), - } + CloseFrame { code: self.code, reason: self.reason.into_owned().into() } } } @@ -192,14 +193,7 @@ impl FrameHeader { _ => (), } - let hdr = FrameHeader { - is_final, - rsv1, - rsv2, - rsv3, - opcode, - mask, - }; + let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask }; Ok(Some((hdr, length))) } @@ -298,10 +292,7 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { - code, - reason: text.into(), - })) + Ok(Some(CloseFrame { code, reason: text.into() })) } } } @@ -309,19 +300,9 @@ impl Frame { /// Create a new data frame. #[inline] pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { - debug_assert!( - matches!(opcode, OpCode::Data(_)), - "Invalid opcode for data frame." - ); + debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - Frame { - header: FrameHeader { - is_final, - opcode, - ..FrameHeader::default() - }, - payload: data, - } + Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } } /// Create a new Pong control frame. @@ -360,10 +341,7 @@ impl Frame { Vec::new() }; - Frame { - header: FrameHeader::default(), - payload, - } + Frame { header: FrameHeader::default(), payload } } /// Create a frame from given header and data. @@ -401,10 +379,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload - .iter() - .map(|byte| format!("{:x}", byte)) - .collect::() + self.payload.iter().map(|byte| format!("{:x}", byte)).collect::() ) } } @@ -476,10 +451,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload); - assert_eq!( - frame.into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] - ); + assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } #[test] diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 6756f0a..dfd0bd5 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,8 +6,7 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::CloseFrame; -pub use self::frame::{Frame, FrameHeader}; +pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; @@ -26,18 +25,12 @@ pub struct FrameSocket { impl FrameSocket { /// Create a new frame socket. pub fn new(stream: Stream) -> Self { - FrameSocket { - stream, - codec: FrameCodec::new(), - } + FrameSocket { stream, codec: FrameCodec::new() } } /// Create a new frame socket from partially read data. pub fn from_partially_read(stream: Stream, part: Vec) -> Self { - FrameSocket { - stream, - codec: FrameCodec::from_partially_read(part), - } + FrameSocket { stream, codec: FrameCodec::from_partially_read(part) } } /// Extract a stream from the socket. @@ -184,9 +177,7 @@ impl FrameCodec { { trace!("writing frame {}", frame); self.out_buffer.reserve(frame.len()); - frame - .format(&mut self.out_buffer) - .expect("Bug: can't write to vector"); + frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); self.write_pending(stream) } @@ -231,10 +222,7 @@ mod tests { sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); - assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x03, 0x02, 0x01] - ); + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); diff --git a/src/protocol/message.rs b/src/protocol/message.rs index d1778f1..f799dbf 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,7 +1,9 @@ -use std::convert::{AsRef, From, Into}; -use std::fmt; -use std::result::Result as StdResult; -use std::str; +use std::{ + convert::{AsRef, From, Into}, + fmt, + result::Result as StdResult, + str, +}; use super::frame::CloseFrame; use crate::error::{Error, Result}; @@ -19,10 +21,7 @@ mod string_collect { impl StringCollector { pub fn new() -> Self { - StringCollector { - data: String::new(), - incomplete: None, - } + StringCollector { data: String::new(), incomplete: None } } pub fn len(&self) -> usize { @@ -54,10 +53,7 @@ mod string_collect { self.data.push_str(text); Ok(()) } - Err(DecodeError::Incomplete { - valid_prefix, - incomplete_suffix, - }) => { + Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { self.data.push_str(valid_prefix); self.incomplete = Some(incomplete_suffix); Ok(()) @@ -127,11 +123,7 @@ impl IncompleteMessage { // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { return Err(Error::Capacity( - format!( - "Message too big: {} + {} > {}", - my_size, portion_size, max_size - ) - .into(), + format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), )); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8137393..72485e9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -4,19 +4,26 @@ pub mod frame; mod message; -pub use self::frame::CloseFrame; -pub use self::message::Message; +pub use self::{frame::CloseFrame, message::Message}; use log::*; -use std::collections::VecDeque; -use std::io::{ErrorKind as IoErrorKind, Read, Write}; -use std::mem::replace; - -use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; -use self::frame::{Frame, FrameCodec}; -use self::message::{IncompleteMessage, IncompleteMessageType}; -use crate::error::{Error, Result}; -use crate::util::NonBlockingResult; +use std::{ + collections::VecDeque, + io::{ErrorKind as IoErrorKind, Read, Write}, + mem::replace, +}; + +use self::{ + frame::{ + coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, + Frame, FrameCodec, + }, + message::{IncompleteMessage, IncompleteMessageType}, +}; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -74,10 +81,7 @@ impl WebSocket { /// or together with an existing one. If you need an initial handshake, use /// `connect()` or `accept()` functions of the crate to construct a websocket. pub fn from_raw_socket(stream: Stream, role: Role, config: Option) -> Self { - WebSocket { - socket: stream, - context: WebSocketContext::new(role, config), - } + WebSocket { socket: stream, context: WebSocketContext::new(role, config) } } /// Convert a raw socket into a WebSocket without performing a handshake. @@ -320,9 +324,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol( - "Sending after closing is not allowed".into(), - )); + return Err(Error::Protocol("Sending after closing is not allowed".into())); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -455,9 +457,7 @@ impl WebSocketContext { Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol( - "Received a masked frame from server".into(), - )); + return Err(Error::Protocol("Received a masked frame from server".into())); } } } @@ -474,9 +474,9 @@ impl WebSocketContext { Err(Error::Protocol("Control frame too big".into())) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), - OpCtl::Reserved(i) => Err(Error::Protocol( - format!("Unknown control frame type {}", i).into(), - )), + OpCtl::Reserved(i) => { + Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + } OpCtl::Ping => { let data = frame.into_data(); // No ping processing after we sent a close frame. @@ -527,9 +527,9 @@ impl WebSocketContext { Ok(None) } } - OpData::Reserved(i) => Err(Error::Protocol( - format!("Unknown data frame type {}", i).into(), - )), + OpData::Reserved(i) => { + Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + } } } } // match opcode @@ -539,9 +539,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol( - "Connection reset without closing handshake".into(), - )), + _ => Err(Error::Protocol("Connection reset without closing handshake".into())), } } } @@ -602,9 +600,7 @@ impl WebSocketContext { } trace!("Sending frame: {:?}", frame); - self.frame - .write_frame(stream, frame) - .check_connection_reset(self.state) + self.frame.write_frame(stream, frame).check_connection_reset(self.state) } } @@ -669,8 +665,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use std::io; - use std::io::Cursor; + use std::{io, io::Cursor}; struct WriteMoc(Stream); @@ -699,14 +694,8 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!( - socket.read_message().unwrap(), - Message::Text("Hello, World!".into()) - ); - assert_eq!( - socket.read_message().unwrap(), - Message::Binary(vec![0x01, 0x02, 0x03]) - ); + assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); + assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); } #[test] @@ -715,10 +704,7 @@ mod tests { 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, ]); - let limit = WebSocketConfig { - max_message_size: Some(10), - ..WebSocketConfig::default() - }; + let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), @@ -729,10 +715,7 @@ mod tests { #[test] fn size_limiting_binary() { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); - let limit = WebSocketConfig { - max_message_size: Some(2), - ..WebSocketConfig::default() - }; + let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), diff --git a/src/server.rs b/src/server.rs index 725d892..53303ee 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,10 @@ pub use crate::handshake::server::ServerHandshake; -use crate::handshake::server::{Callback, NoCallback}; -use crate::handshake::HandshakeError; +use crate::handshake::{ + server::{Callback, NoCallback}, + HandshakeError, +}; use crate::protocol::{WebSocket, WebSocketConfig}; diff --git a/src/util.rs b/src/util.rs index cd03035..f40ca43 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,9 @@ //! Helper traits to ease non-blocking handling. -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::result::Result as StdResult; +use std::{ + io::{Error as IoError, ErrorKind as IoErrorKind}, + result::Result as StdResult, +}; use crate::error::Error; diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index d95ee81..7e3e33f 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,15 +1,17 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. -use std::net::{TcpStream, TcpListener}; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::{TcpListener, TcpStream}, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; -use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream}; use native_tls::TlsStream; -use url::Url; use net2::TcpStreamExt; +use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; +use url::Url; type Sock = WebSocket>>; @@ -26,8 +28,8 @@ where exit(1); }); - let server = TcpListener::bind(("127.0.0.1", port)) - .expect("Can't listen, is port already in use?"); + let server = + TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?"); let client_thread = spawn(move || { let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap()) @@ -46,11 +48,10 @@ where #[test] fn test_server_close() { - do_test(3012, + do_test( + 3012, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive close from server assert!(message.is_close()); @@ -75,16 +76,16 @@ fn test_server_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); + }, + ); } #[test] fn test_evil_server_close() { - do_test(3013, + do_test( + 3013, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); sleep(Duration::from_secs(1)); @@ -108,16 +109,16 @@ fn test_evil_server_close() { // and now just drop the connection without waiting for `ConnectionClosed` srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); drop(srv_sock); - }); + }, + ); } #[test] fn test_client_close() { - do_test(3014, + do_test( + 3014, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive answer from server assert_eq!(message.into_data(), b"From Server"); @@ -147,6 +148,6 @@ fn test_client_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); - + }, + ); } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index d8e20e5..f348eca 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 352020e..87f8dda 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; @@ -24,9 +26,7 @@ fn test_receive_after_init_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); - client - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = client.read_message().unwrap(); // receive close from server assert!(message.is_close());