diff --git a/Cargo.toml b/Cargo.toml index 4dc875c..2059e15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ tls-vendored = ["native-tls", "native-tls/vendored"] deflate = ["flate2"] [dependencies] -base64 = "0.12.0" +base64 = "0.13.0" byteorder = "1.3.2" bytes = "0.5" http = "0.2" @@ -42,7 +42,7 @@ optional = true version = "0.2.3" [dev-dependencies] -env_logger = "0.7.1" +env_logger = "0.8.1" net2 = "0.2.33" [[example]] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index fb0a839..44e26d7 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -18,11 +18,7 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!( - "ws://localhost:9001/updateReports?agent={}", - AGENT - )) - .unwrap(), + Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), )?; socket.close(None)?; Ok(()) diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 5e9fd2b..0417be3 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,7 +1,10 @@ -use std::net::{TcpListener, TcpStream}; -use std::thread::spawn; +use std::{ + net::{TcpListener, TcpStream}, + thread::spawn, +}; use log::*; +use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result}; use tungstenite::extensions::compression::deflate::DeflateConfigBuilder; use tungstenite::extensions::compression::WsCompression; use tungstenite::handshake::HandshakeRole; diff --git a/examples/callback-error.rs b/examples/callback-error.rs index 357d343..cf78a2e 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -1,9 +1,10 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; -use tungstenite::http::StatusCode; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, + http::StatusCode, +}; fn main() { let server = TcpListener::bind("127.0.0.1:3012").unwrap(); diff --git a/examples/client.rs b/examples/client.rs index 7938cfb..def6a3c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -14,9 +14,7 @@ fn main() { println!("* {}", header); } - socket - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); diff --git a/examples/server.rs b/examples/server.rs index def2f45..420e5db 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,9 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, +}; fn main() { env_logger::init(); diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..db7f39d --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,7 @@ +# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check]. +# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md + +# Break complex but short statements a bit less. +use_small_heuristics = "Max" + +merge_imports = true diff --git a/src/client.rs b/src/client.rs index 31213c1..9b13be9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,20 @@ //! Methods to connect to a WebSocket as a client. -use std::io::{Read, Write}; -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -use std::result::Result as StdResult; +use std::{ + io::{Read, Write}, + net::{SocketAddr, TcpStream, ToSocketAddrs}, + result::Result as StdResult, +}; -use http::Uri; +use http::{request::Parts, Uri}; use log::*; use url::Url; -use crate::handshake::client::{Request, Response}; -use crate::protocol::WebSocketConfig; +use crate::{ + handshake::client::{Request, Response}, + protocol::WebSocketConfig, +}; #[cfg(feature = "tls")] mod encryption { @@ -22,8 +26,7 @@ mod encryption { /// TCP stream switcher (plain/TLS). pub type AutoStream = StreamSwitcher>; - use crate::error::Result; - use crate::stream::Mode; + use crate::{error::Result, stream::Mode}; pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { match mode { @@ -48,8 +51,10 @@ mod encryption { mod encryption { use std::net::TcpStream; - use crate::error::{Error, Result}; - use crate::stream::Mode; + use crate::{ + error::{Error, Result}, + stream::Mode, + }; /// TLS support is nod compiled in, this is just standard `TcpStream`. pub type AutoStream = TcpStream; @@ -65,11 +70,12 @@ mod encryption { use self::encryption::wrap_stream; pub use self::encryption::AutoStream; -use crate::error::{Error, Result}; -use crate::handshake::client::ClientHandshake; -use crate::handshake::HandshakeError; -use crate::protocol::WebSocket; -use crate::stream::{Mode, NoDelay}; +use crate::{ + error::{Error, Result}, + handshake::{client::ClientHandshake, HandshakeError}, + protocol::WebSocket, + stream::{Mode, NoDelay}, +}; /// Connect to the given WebSocket in blocking mode. /// @@ -86,31 +92,63 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config( +pub fn connect_with_config( request: Req, config: Option, -) -> Result<(WebSocket, Response)> -where - Req: IntoClientRequest, -{ - let request: Request = request.into_client_request()?; - let uri = request.uri(); - let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = uri.port_u16().unwrap_or(match mode { - Mode::Plain => 80, - Mode::Tls => 443, - }); - let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; - NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + max_redirects: u8, +) -> Result<(WebSocket, Response)> { + fn try_client_handshake( + request: Request, + config: Option, + ) -> Result<(WebSocket, Response)> { + let uri = request.uri(); + let mode = uri_mode(uri)?; + let host = + request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + NoDelay::set_nodelay(&mut stream, true)?; + client_with_config(request, stream, config).map_err(|e| match e { + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) + } + + fn create_request(parts: &Parts, uri: &Uri) -> Request { + let mut builder = Request::builder() + .uri(uri.clone()) + .method(parts.method.clone()) + .version(parts.version); + *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); + builder.body(()).expect("Failed to create `Request`") + } + + let (parts, _) = request.into_client_request()?.into_parts(); + let mut uri = parts.uri.clone(); + + for attempt in 0..(max_redirects + 1) { + let request = create_request(&parts, &uri); + + match try_client_handshake(request, config) { + Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { + if let Some(location) = res.headers().get("Location") { + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; + } else { + warn!("No `Location` found in redirect"); + return Err(Error::Http(res)); + } + } + other => return other, + } + } + + unreachable!("Bug in a redirect handling logic") } /// Connect to the given WebSocket in blocking mode. @@ -126,13 +164,11 @@ where /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) + connect_with_config(request, None, 3) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { - let domain = uri - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { diff --git a/src/error.rs b/src/error.rs index 547a931..acee6f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,17 +1,9 @@ //! Error handling. -use std::borrow::Cow; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io; -use std::result; -use std::str; -use std::string; - -use http; -use httparse; +use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; use crate::protocol::Message; +use http::Response; #[cfg(feature = "tls")] pub mod tls { @@ -64,7 +56,7 @@ pub enum Error { /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. - Http(http::StatusCode), + Http(Response>), /// HTTP format error. HttpFormat(http::Error), /// An error from a WebSocket extension. @@ -84,7 +76,7 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), Error::ExtensionError(ref e) => write!(f, "Extension error: {}", e), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 350347c..f3c92f7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,18 +1,25 @@ //! Client handshake machine. -use std::io::{Read, Write}; -use std::marker::PhantomData; +use std::{ + io::{Read, Write}, + marker::PhantomData, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers}; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Client request type. pub type Request = HttpRequest<()>; @@ -28,26 +35,19 @@ pub struct ClientHandshake { _marker: PhantomData, } -impl ClientHandshake -where - Stream: Read + Write, -{ +impl ClientHandshake { /// Initiate a client handshake. pub fn start( - stream: Stream, + stream: S, request: Request, mut config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol( - "Invalid HTTP method, only GET supported".into(), - )); + return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } // Check the URI scheme: only ws or wss are supported @@ -62,29 +62,18 @@ where let client = { let accept_key = convert_key(key.as_ref()).unwrap(); - ClientHandshake { - verify_data: VerifyData { accept_key }, - config: Some(config), - _marker: PhantomData, - } + ClientHandshake { verify_data: VerifyData { accept_key }, config: Some(config), _marker: PhantomData } }; trace!("Client handshake initiated."); - Ok(MidHandshake { - role: client, - machine, - }) + Ok(MidHandshake { role: client, machine }) } } -impl HandshakeRole for ClientHandshake -where - Stream: Read + Write, -{ +impl HandshakeRole for ClientHandshake { type IncomingData = Response; - type InternalStream = Stream; - type FinalResult = (WebSocket, Response); - + type InternalStream = S; + type FinalResult = (WebSocket, Response); fn stage_finished( &mut self, finish: StageResult, @@ -93,16 +82,11 @@ where StageResult::DoneWriting(stream) => { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - StageResult::DoneReading { - stream, - result, - tail, - } => { - let mut config = self.config.take().unwrap(); - - self.verify_data.verify_response(&result, &mut config)?; + StageResult::DoneReading { stream, result, tail } => { + let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); - let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config); + let websocket = + WebSocket::from_partially_read(stream, tail, Role::Client, self.config); ProcessingResult::Done((websocket, result)) } }) @@ -119,10 +103,8 @@ fn generate_request( let mut req = Vec::new(); let uri = request.uri(); - let authority = uri - .authority() - .ok_or_else(|| Error::Url("No host name in the URL".into()))? - .as_str(); + let authority = + uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -144,10 +126,8 @@ fn generate_request( Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = uri - .path_and_query() - .ok_or_else(|| Error::Url("No path/query in URL".into()))? - .as_str(), + path = + uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), key = key ) .unwrap(); @@ -176,12 +156,13 @@ impl VerifyData { &self, response: &Response, config: &mut Option, - ) -> Result<()> { + ) -> Result { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); + return Err(Error::Http(response.map(|_| None))); } + let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| @@ -194,9 +175,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -208,22 +187,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !headers - .get("Sec-WebSocket-Accept") - .map(|h| h == &self.accept_key) - .unwrap_or(false) - { - return Err(Error::Protocol( - "Key mismatch in Sec-WebSocket-Accept".into(), - )); + if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { + return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); } // 5. If the response includes a |Sec-WebSocket-Extensions| header @@ -231,7 +202,6 @@ impl VerifyData { // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - verify_compression_resp_headers(response, config)?; // 6. If the response includes a |Sec-WebSocket-Protocol| header field @@ -241,7 +211,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok(()) + Ok(response) } } @@ -259,9 +229,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -287,8 +255,7 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::{generate_key, generate_request, Response}; + use super::{super::machine::TryParse, generate_key, generate_request, Response}; use crate::client::IntoClientRequest; #[test] @@ -367,9 +334,6 @@ mod tests { const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); - assert_eq!( - resp.headers().get("Content-Type").unwrap(), - &b"text/html"[..], - ); + assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 0c51008..f336c65 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,8 +1,6 @@ //! HTTP Request and response header handling. -use http; use http::header::{HeaderMap, HeaderName, HeaderValue}; -use httparse; use httparse::Status; use super::machine::TryParse; @@ -43,8 +41,7 @@ impl TryParse for HeaderMap { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::HeaderMap; + use super::{super::machine::TryParse, HeaderMap}; #[test] fn headers() { diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 61090bb..b8416a6 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -2,8 +2,10 @@ use bytes::Buf; use log::*; use std::io::{Cursor, Read, Write}; -use crate::error::{Error, Result}; -use crate::util::NonBlockingResult; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; use input_buffer::{InputBuffer, MIN_READ}; /// A generic handshake state machine. @@ -23,10 +25,7 @@ impl HandshakeMachine { } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { - HandshakeMachine { - stream, - state: HandshakeState::Writing(Cursor::new(data.into())), - } + HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) } } /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &Stream { @@ -52,21 +51,19 @@ impl HandshakeMachine { .no_block()?; match read { Some(0) => Err(Error::Protocol("Handshake not finished".into())), - Some(_) => Ok( - if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { - buf.advance(size); - RoundResult::StageFinished(StageResult::DoneReading { - result: obj, - stream: self.stream, - tail: buf.into_vec(), - }) - } else { - RoundResult::Incomplete(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - }) - }, - ), + Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + buf.advance(size); + RoundResult::StageFinished(StageResult::DoneReading { + result: obj, + stream: self.stream, + tail: buf.into_vec(), + }) + } else { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + }) + }), None => Ok(RoundResult::WouldBlock(HandshakeMachine { state: HandshakeState::Reading(buf), ..self @@ -112,11 +109,7 @@ pub enum RoundResult { #[derive(Debug)] pub enum StageResult { /// Reading round finished. - DoneReading { - result: Obj, - stream: Stream, - tail: Vec, - }, + DoneReading { result: Obj, stream: Stream, tail: Vec }, /// Writing round finished. DoneWriting(Stream), } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index ce6b4dd..4714ee0 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -6,11 +6,12 @@ pub mod server; mod machine; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io::{Read, Write}; +use std::{ + error::Error as ErrorTrait, + fmt, + io::{Read, Write}, +}; -use base64; use sha1::{Digest, Sha1}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; @@ -40,10 +41,7 @@ impl MidHandshake { loop { mach = match mach.single_round()? { RoundResult::WouldBlock(m) => { - return Err(HandshakeError::Interrupted(MidHandshake { - machine: m, - ..self - })) + return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) } RoundResult::Incomplete(m) => m, RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 110a574..35fb73d 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,19 +1,26 @@ //! Server handshake machine. -use std::io::{self, Read, Write}; -use std::marker::PhantomData; -use std::result::Result as StdResult; +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + result::Result as StdResult, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; -use crate::extensions::compression::verify_compression_req_headers; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, + extensions::verify_compression_req_headers +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; /// Server request type. pub type Request = HttpRequest<()>; @@ -31,24 +38,17 @@ pub fn create_response(request: &Request) -> Result { } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } if !request .headers() .get("Connection") .and_then(|h| h.to_str().ok()) - .map(|h| { - h.split(|c| c == ' ' || c == ',') - .any(|p| p.eq_ignore_ascii_case("Upgrade")) - }) + .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in client request".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); } if !request @@ -58,20 +58,11 @@ pub fn create_response(request: &Request) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in client request".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); } - if !request - .headers() - .get("Sec-WebSocket-Version") - .map(|h| h == "13") - .unwrap_or(false) - { - return Err(Error::Protocol( - "No \"Sec-WebSocket-Version: 13\" in client request".into(), - )); + if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { + return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); } let key = request @@ -125,9 +116,7 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -199,16 +188,12 @@ pub struct ServerHandshake { /// WebSocket configuration. config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. - error_code: Option, + error_response: Option, /// Internal stream type. _marker: PhantomData, } -impl ServerHandshake -where - S: Read + Write, - C: Callback, -{ +impl ServerHandshake { /// Start server handshake. `callback` specifies a custom callback which the user can pass to /// the handshake, this callback will be called when the a websocket client connnects to the /// server, you can specify the callback if you want to add additional header to the client @@ -220,18 +205,14 @@ where role: ServerHandshake { callback: Some(callback), config, - error_code: None, + error_response: None, _marker: PhantomData, }, } } } -impl HandshakeRole for ServerHandshake -where - S: Read + Write, - C: Callback, -{ +impl HandshakeRole for ServerHandshake { type IncomingData = Request; type InternalStream = S; type FinalResult = WebSocket; @@ -241,20 +222,16 @@ where finish: StageResult, ) -> Result> { Ok(match finish { - StageResult::DoneReading { - stream, - result: request, - tail, - } => { + StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())); } - let mut response = create_response(&request)?; + let mut response = create_response(&result)?; verify_compression_req_headers(&request, &mut response, &mut self.config)?; let callback_result = if let Some(callback) = self.callback.take() { - callback.on_request(&request, response) + callback.on_request(&result, response) } else { Ok(response) }; @@ -273,22 +250,25 @@ where )); } - self.error_code = Some(resp.status().as_u16()); + self.error_response = Some(resp); + let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } StageResult::DoneWriting(stream) => { - if let Some(err) = self.error_code.take() { + if let Some(err) = self.error_response.take() { debug!("Server handshake failed."); - return Err(Error::Http(StatusCode::from_u16(err)?)); + return Err(Error::Http(err)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); @@ -301,9 +281,7 @@ where #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::create_response; - use super::Request; + use super::{super::machine::TryParse, create_response, Request}; #[test] fn request_parsing() { diff --git a/src/lib.rs b/src/lib.rs index 36b947d..0a466c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,18 +16,17 @@ pub use http; pub mod client; pub mod error; +pub mod extensions; pub mod handshake; pub mod protocol; pub mod server; pub mod stream; pub mod util; -pub mod extensions; - -pub use crate::client::{client, connect}; -pub use crate::error::{Error, Result}; -pub use crate::handshake::client::ClientHandshake; -pub use crate::handshake::server::ServerHandshake; -pub use crate::handshake::HandshakeError; -pub use crate::protocol::{Message, WebSocket}; -pub use crate::server::{accept, accept_hdr}; +pub use crate::{ + client::{client, connect}, + error::{Error, Result}, + handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, + protocol::{Message, WebSocket}, + server::{accept, accept_hdr}, +}; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index d3fcdeb..e726161 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -1,7 +1,9 @@ //! Various codes defined in RFC 6455. -use std::convert::{From, Into}; -use std::fmt; +use std::{ + convert::{From, Into}, + fmt, +}; /// WebSocket message opcode as in RFC 6455. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -71,9 +73,11 @@ impl fmt::Display for OpCode { impl Into for OpCode { fn into(self) -> u8 { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match self { Data(Continue) => 0, Data(Text) => 1, @@ -90,9 +94,11 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match byte { 0 => Data(Continue), 1 => Data(Text), @@ -184,14 +190,7 @@ pub enum CloseCode { impl CloseCode { /// Check if this CloseCode is allowed. pub fn is_allowed(self) -> bool { - match self { - Bad(_) => false, - Reserved(_) => false, - Status => false, - Abnormal => false, - Tls => false, - _ => true, - } + !matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls) } } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index e067d51..3aca767 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,14 +1,18 @@ use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use log::*; -use std::borrow::Cow; -use std::default::Default; -use std::fmt; -use std::io::{Cursor, ErrorKind, Read, Write}; -use std::result::Result as StdResult; -use std::string::{FromUtf8Error, String}; - -use super::coding::{CloseCode, Control, Data, OpCode}; -use super::mask::{apply_mask, generate_mask}; +use std::{ + borrow::Cow, + default::Default, + fmt, + io::{Cursor, ErrorKind, Read, Write}, + result::Result as StdResult, + string::{FromUtf8Error, String}, +}; + +use super::{ + coding::{CloseCode, Control, Data, OpCode}, + mask::{apply_mask, generate_mask}, +}; use crate::error::{Error, Result}; /// A struct representing the close command. @@ -23,10 +27,7 @@ pub struct CloseFrame<'t> { impl<'t> CloseFrame<'t> { /// Convert into a owned string. pub fn into_owned(self) -> CloseFrame<'static> { - CloseFrame { - code: self.code, - reason: self.reason.into_owned().into(), - } + CloseFrame { code: self.code, reason: self.reason.into_owned().into() } } } @@ -313,10 +314,7 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { - code, - reason: text.into(), - })) + Ok(Some(CloseFrame { code, reason: text.into() })) } } } @@ -324,22 +322,9 @@ impl Frame { /// Create a new data frame. #[inline] pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { - debug_assert!( - match opcode { - OpCode::Data(_) => true, - _ => false, - }, - "Invalid opcode for data frame." - ); + debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - Frame { - header: FrameHeader { - is_final, - opcode, - ..FrameHeader::default() - }, - payload: data, - } + Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } } /// Create a new Pong control frame. @@ -378,10 +363,7 @@ impl Frame { Vec::new() }; - Frame { - header: FrameHeader::default(), - payload, - } + Frame { header: FrameHeader::default(), payload } } /// Create a frame from given header and data. @@ -425,10 +407,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload - .iter() - .map(|byte| format!("{:x}", byte)) - .collect::() + self.payload.iter().map(|byte| format!("{:x}", byte)).collect::() ) } } @@ -500,10 +479,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload); - assert_eq!( - frame.into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] - ); + assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } #[test] diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 7145d4e..cd44624 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,8 +6,7 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::CloseFrame; -pub use self::frame::{ExtensionHeaders, Frame, FrameHeader}; +pub use self::frame::{CloseFrame, ExtensionHeaders, Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; @@ -26,18 +25,12 @@ pub struct FrameSocket { impl FrameSocket { /// Create a new frame socket. pub fn new(stream: Stream) -> Self { - FrameSocket { - stream, - codec: FrameCodec::new(), - } + FrameSocket { stream, codec: FrameCodec::new() } } /// Create a new frame socket from partially read data. pub fn from_partially_read(stream: Stream, part: Vec) -> Self { - FrameSocket { - stream, - codec: FrameCodec::from_partially_read(part), - } + FrameSocket { stream, codec: FrameCodec::from_partially_read(part) } } /// Extract a stream from the socket. @@ -184,9 +177,7 @@ impl FrameCodec { { trace!("writing frame {}", frame); self.out_buffer.reserve(frame.len()); - frame - .format(&mut self.out_buffer) - .expect("Bug: can't write to vector"); + frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); self.write_pending(stream) } @@ -231,10 +222,7 @@ mod tests { sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); - assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x03, 0x02, 0x01] - ); + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); diff --git a/src/protocol/message.rs b/src/protocol/message.rs index c948d19..b6451b9 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,14 +1,14 @@ -use std::convert::{AsRef, From, Into}; -use std::fmt; -use std::result::Result as StdResult; -use std::str; +use std::{ + convert::{AsRef, From, Into}, + fmt, + result::Result as StdResult, + str, +}; use super::frame::CloseFrame; use crate::error::{Error, Result}; mod string_collect { - - use utf8; use utf8::DecodeError; use crate::error::{Error, Result}; @@ -21,10 +21,7 @@ mod string_collect { impl StringCollector { pub fn new() -> Self { - StringCollector { - data: String::new(), - incomplete: None, - } + StringCollector { data: String::new(), incomplete: None } } pub fn len(&self) -> usize { @@ -56,10 +53,7 @@ mod string_collect { self.data.push_str(text); Ok(()) } - Err(DecodeError::Incomplete { - valid_prefix, - incomplete_suffix, - }) => { + Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { self.data.push_str(valid_prefix); self.incomplete = Some(incomplete_suffix); Ok(()) @@ -129,11 +123,7 @@ impl IncompleteMessage { // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { return Err(Error::Capacity( - format!( - "Message too big: {} + {} > {}", - my_size, portion_size, max_size - ) - .into(), + format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), )); } @@ -203,42 +193,27 @@ impl Message { /// Indicates whether a message is a text message. pub fn is_text(&self) -> bool { - match *self { - Message::Text(_) => true, - _ => false, - } + matches!(*self, Message::Text(_)) } /// Indicates whether a message is a binary message. pub fn is_binary(&self) -> bool { - match *self { - Message::Binary(_) => true, - _ => false, - } + matches!(*self, Message::Binary(_)) } /// Indicates whether a message is a ping message. pub fn is_ping(&self) -> bool { - match *self { - Message::Ping(_) => true, - _ => false, - } + matches!(*self, Message::Ping(_)) } /// Indicates whether a message is a pong message. pub fn is_pong(&self) -> bool { - match *self { - Message::Pong(_) => true, - _ => false, - } + matches!(*self, Message::Pong(_)) } /// Indicates whether a message ia s close message. pub fn is_close(&self) -> bool { - match *self { - Message::Close(_) => true, - _ => false, - } + matches!(*self, Message::Close(_)) } /// Get the length of the WebSocket message. diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 67cf18a..0b4b5ac 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -2,24 +2,29 @@ pub mod frame; -pub(crate) mod message; +mod message; -pub use self::frame::CloseFrame; -pub use self::message::Message; +pub use self::{frame::CloseFrame, message::Message}; use log::*; -use std::collections::VecDeque; -use std::io::{ErrorKind as IoErrorKind, Read, Write}; -use std::mem::replace; - -use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; -use self::frame::{Frame, FrameCodec}; -use self::message::IncompleteMessage; -use crate::error::{Error, Result}; -use crate::extensions::compression::{CompressionSwitcher, WsCompression}; -use crate::extensions::WebSocketExtension; -use crate::protocol::frame::coding::Data; -use crate::util::NonBlockingResult; +use std::{ + collections::VecDeque, + io::{ErrorKind as IoErrorKind, Read, Write}, + mem::replace, +}; + +use self::{ + frame::{ + coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, + Frame, FrameCodec, + }, + message::{IncompleteMessage, IncompleteMessageType}, + extensions::{WebSocketExtension, compression::{CompressionSwitcher, WsCompression}}; +}; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; @@ -33,7 +38,7 @@ pub enum Role { } /// The configuration for WebSocket connection. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone, Copy)] pub struct WebSocketConfig { /// The size of the send queue. You can use it to turn on/off the backpressure features. `None` /// means here that the size of the queue is unlimited. The default value is the unlimited @@ -77,10 +82,7 @@ impl WebSocket { /// or together with an existing one. If you need an initial handshake, use /// `connect()` or `accept()` functions of the crate to construct a websocket. pub fn from_raw_socket(stream: Stream, role: Role, config: Option) -> Self { - WebSocket { - socket: stream, - context: WebSocketContext::new(role, config), - } + WebSocket { socket: stream, context: WebSocketContext::new(role, config) } } /// Convert a raw socket into a WebSocket without performing a handshake. @@ -136,10 +138,7 @@ impl WebSocket { } } -impl WebSocket -where - Stream: Read + Write, -{ +impl WebSocket { /// Read a message from stream, if possible. /// /// This will queue responses to ping and close messages to be sent. It will call @@ -333,9 +332,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol( - "Sending after closing is not allowed".into(), - )); + return Err(Error::Protocol("Sending after closing is not allowed".into())); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -457,9 +454,7 @@ impl WebSocketContext { Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol( - "Received a masked frame from server".into(), - )); + return Err(Error::Protocol("Received a masked frame from server".into())); } } } @@ -476,9 +471,9 @@ impl WebSocketContext { Err(Error::Protocol("Control frame too big".into())) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), - OpCtl::Reserved(i) => Err(Error::Protocol( - format!("Unknown control frame type {}", i).into(), - )), + OpCtl::Reserved(i) => { + Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + } OpCtl::Ping => { let data = frame.into_data(); // No ping processing after we sent a close frame. @@ -568,43 +563,8 @@ impl WebSocketContext { } } - if frame.header().is_final { - frame = self.decoder.on_send_frame(frame)?; - } - - let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value); - if frame.payload().len() > max_frame_size { - let mut chunks = frame.payload().chunks(max_frame_size).peekable(); - let data_frame = Frame::message( - Vec::from(chunks.next().unwrap()), - frame.header().opcode, - false, - ); - self.frame - .write_frame(stream, data_frame) - .check_connection_reset(self.state)?; - - while let Some(chunk) = chunks.next() { - let frame = Frame::message( - Vec::from(chunk), - OpCode::Data(Data::Continue), - chunks.peek().is_none(), - ); - - trace!("Sending frame: {:?}", frame); - - self.frame - .write_frame(stream, frame) - .check_connection_reset(self.state)?; - } - - Ok(()) - } else { - trace!("Sending frame: {:?}", frame); - self.frame - .write_frame(stream, frame) - .check_connection_reset(self.state) - } + trace!("Sending frame: {:?}", frame); + self.frame.write_frame(stream, frame).check_connection_reset(self.state) } } @@ -626,20 +586,14 @@ enum WebSocketState { impl WebSocketState { /// Tell if we're allowed to process normal messages. fn is_active(self) -> bool { - match self { - WebSocketState::Active => true, - _ => false, - } + matches!(self, WebSocketState::Active) } /// Tell if we should process incoming data. Note that if we send a close frame /// but the remote hasn't confirmed, they might have sent data before they receive our /// close frame, so we should still pass those to client code, hence ClosedByUs is valid. fn can_read(self) -> bool { - match self { - WebSocketState::Active | WebSocketState::ClosedByUs => true, - _ => false, - } + matches!(self, WebSocketState::Active | WebSocketState::ClosedByUs) } /// Check if the state is active, return error if not. @@ -675,11 +629,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use crate::extensions::compression::WsCompression; - use crate::protocol::frame::coding::{Data, OpCode}; - use crate::protocol::frame::Frame; - use std::io; - use std::io::Cursor; + use std::{io, io::Cursor}; struct WriteMoc(Stream); @@ -708,14 +658,8 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!( - socket.read_message().unwrap(), - Message::Text("Hello, World!".into()) - ); - assert_eq!( - socket.read_message().unwrap(), - Message::Binary(vec![0x01, 0x02, 0x03]) - ); + assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); + assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); } #[test] @@ -724,11 +668,7 @@ mod tests { 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, ]); - let limit = WebSocketConfig { - max_send_queue: None, - max_frame_size: Some(16 << 20), - compression: WsCompression::None(Some(10)), - }; + let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), @@ -739,80 +679,11 @@ mod tests { #[test] fn size_limiting_binary() { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); - let limit = WebSocketConfig { - max_send_queue: None, - max_frame_size: Some(16 << 20), - compression: WsCompression::None(Some(2)), - }; + let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), "Space limit exceeded: Message too big: 0 + 3 > 2" ); } - - #[test] - fn fragmented_tx() { - let max_message_size = 2; - let input_str = "hello unit test"; - - let limit = WebSocketConfig { - max_send_queue: None, - max_frame_size: Some(2), - compression: WsCompression::None(Some(max_message_size)), - }; - - let mut socket = - WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit)); - - socket.write_message(Message::text(input_str)).unwrap(); - socket.socket.set_position(0); - - let WebSocket { - mut socket, - mut context, - } = socket; - - let vec = input_str.chars().collect::>(); - let mut iter = vec - .chunks(max_message_size) - .map(|c| c.iter().collect::()) - .into_iter() - .peekable(); - - let frame_eq = |expected: Frame, actual: Frame| { - assert_eq!(expected.payload(), actual.payload()); - assert_eq!(expected.header().opcode, actual.header().opcode); - assert_eq!( - expected.header().ext_headers.rsv1, - actual.header().ext_headers.rsv1 - ); - }; - - let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false); - frame_eq( - expected, - context - .frame - .read_frame(&mut socket, Some(max_message_size)) - .unwrap() - .unwrap(), - ); - - while let Some(chars) = iter.next() { - let expected = Frame::message( - chars.into(), - OpCode::Data(Data::Continue), - iter.peek().is_none(), - ); - frame_eq( - expected, - context - .frame - .read_frame(&mut socket, Some(max_message_size)) - .unwrap() - .unwrap(), - ); - } - } } diff --git a/src/server.rs b/src/server.rs index 2415254..34cd83b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,10 @@ pub use crate::handshake::server::ServerHandshake; -use crate::handshake::server::{Callback, NoCallback}; -use crate::handshake::HandshakeError; +use crate::handshake::{ + server::{Callback, NoCallback}, + HandshakeError, +}; use crate::protocol::{WebSocket, WebSocketConfig}; diff --git a/src/util.rs b/src/util.rs index cd03035..f40ca43 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,9 @@ //! Helper traits to ease non-blocking handling. -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::result::Result as StdResult; +use std::{ + io::{Error as IoError, ErrorKind as IoErrorKind}, + result::Result as StdResult, +}; use crate::error::Error; diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index b86bfe0..7e3e33f 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,10 +1,12 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. -use std::net::{TcpListener, TcpStream}; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::{TcpListener, TcpStream}, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use native_tls::TlsStream; use net2::TcpStreamExt; @@ -49,9 +51,7 @@ fn test_server_close() { do_test( 3012, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive close from server assert!(message.is_close()); @@ -85,9 +85,7 @@ fn test_evil_server_close() { do_test( 3013, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); sleep(Duration::from_secs(1)); @@ -109,10 +107,7 @@ fn test_evil_server_close() { let message = srv_sock.read_message().unwrap(); // receive acknowledgement assert!(message.is_close()); // and now just drop the connection without waiting for `ConnectionClosed` - srv_sock - .get_mut() - .set_linger(Some(Duration::from_secs(0))) - .unwrap(); + srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); drop(srv_sock); }, ); @@ -123,9 +118,7 @@ fn test_client_close() { do_test( 3014, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive answer from server assert_eq!(message.into_data(), b"From Server"); @@ -145,9 +138,7 @@ fn test_client_close() { let message = srv_sock.read_message().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); - srv_sock - .write_message(Message::Text("From Server".into())) - .unwrap(); + srv_sock.write_message(Message::Text("From Server".into())).unwrap(); let message = srv_sock.read_message().unwrap(); // receive close from client assert!(message.is_close()); diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index d8e20e5..f348eca 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 352020e..87f8dda 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; @@ -24,9 +26,7 @@ fn test_receive_after_init_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); - client - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = client.read_message().unwrap(); // receive close from server assert!(message.is_close());