From 33f84642116b89471ba44dbf11cb69cc51cf6dc7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Oct 2020 06:16:40 +0000 Subject: [PATCH 01/10] Update base64 requirement from 0.12.0 to 0.13.0 Updates the requirements on [base64](https://github.com/marshallpierce/rust-base64) to permit the latest version. - [Release notes](https://github.com/marshallpierce/rust-base64/releases) - [Changelog](https://github.com/marshallpierce/rust-base64/blob/master/RELEASE-NOTES.md) - [Commits](https://github.com/marshallpierce/rust-base64/compare/v0.12.0...v0.13.0) Signed-off-by: dependabot[bot] --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 286dad3..0716a70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ tls = ["native-tls"] tls-vendored = ["native-tls", "native-tls/vendored"] [dependencies] -base64 = "0.12.0" +base64 = "0.13.0" byteorder = "1.3.2" bytes = "0.5" http = "0.2" From f62bfcba0e4e3fcbd48fdc44369ebb7d0f03681a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Oct 2020 06:46:10 +0000 Subject: [PATCH 02/10] Update env_logger requirement from 0.7.1 to 0.8.1 Updates the requirements on [env_logger](https://github.com/env-logger-rs/env_logger) to permit the latest version. - [Release notes](https://github.com/env-logger-rs/env_logger/releases) - [Changelog](https://github.com/env-logger-rs/env_logger/blob/master/CHANGELOG.md) - [Commits](https://github.com/env-logger-rs/env_logger/compare/v0.7.1...v0.8.1) Signed-off-by: dependabot[bot] --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 286dad3..9e56f74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,5 +35,5 @@ optional = true version = "0.2.3" [dev-dependencies] -env_logger = "0.7.1" +env_logger = "0.8.1" net2 = "0.2.33" From 88ff5f371fea8b1e829442acfdfa2a0329777456 Mon Sep 17 00:00:00 2001 From: Horki Date: Fri, 23 Oct 2020 00:51:48 +0200 Subject: [PATCH 03/10] matches!: use macros; remove unused imports --- src/error.rs | 3 --- src/handshake/headers.rs | 2 -- src/handshake/mod.rs | 1 - src/protocol/frame/frame.rs | 5 +---- src/protocol/message.rs | 27 +++++---------------------- src/protocol/mod.rs | 10 ++-------- 6 files changed, 8 insertions(+), 40 deletions(-) diff --git a/src/error.rs b/src/error.rs index ab24753..5c96dc7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,9 +8,6 @@ use std::result; use std::str; use std::string; -use http; -use httparse; - use crate::protocol::Message; #[cfg(feature = "tls")] diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 0c51008..8386f5a 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,8 +1,6 @@ //! HTTP Request and response header handling. -use http; use http::header::{HeaderMap, HeaderName, HeaderValue}; -use httparse; use httparse::Status; use super::machine::TryParse; diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index ce6b4dd..8350dc0 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -10,7 +10,6 @@ use std::error::Error as ErrorTrait; use std::fmt; use std::io::{Read, Write}; -use base64; use sha1::{Digest, Sha1}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index e6a0009..38ce61c 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -310,10 +310,7 @@ impl Frame { #[inline] pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { debug_assert!( - match opcode { - OpCode::Data(_) => true, - _ => false, - }, + matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame." ); diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 7019494..d1778f1 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -7,8 +7,6 @@ use super::frame::CloseFrame; use crate::error::{Error, Result}; mod string_collect { - - use utf8; use utf8::DecodeError; use crate::error::{Error, Result}; @@ -202,42 +200,27 @@ impl Message { /// Indicates whether a message is a text message. pub fn is_text(&self) -> bool { - match *self { - Message::Text(_) => true, - _ => false, - } + matches!(*self, Message::Text(_)) } /// Indicates whether a message is a binary message. pub fn is_binary(&self) -> bool { - match *self { - Message::Binary(_) => true, - _ => false, - } + matches!(*self, Message::Binary(_)) } /// Indicates whether a message is a ping message. pub fn is_ping(&self) -> bool { - match *self { - Message::Ping(_) => true, - _ => false, - } + matches!(*self, Message::Ping(_)) } /// Indicates whether a message is a pong message. pub fn is_pong(&self) -> bool { - match *self { - Message::Pong(_) => true, - _ => false, - } + matches!(*self, Message::Pong(_)) } /// Indicates whether a message ia s close message. pub fn is_close(&self) -> bool { - match *self { - Message::Close(_) => true, - _ => false, - } + matches!(*self, Message::Close(_)) } /// Get the length of the WebSocket message. diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 97b376e..8137393 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -626,20 +626,14 @@ enum WebSocketState { impl WebSocketState { /// Tell if we're allowed to process normal messages. fn is_active(self) -> bool { - match self { - WebSocketState::Active => true, - _ => false, - } + matches!(self, WebSocketState::Active) } /// Tell if we should process incoming data. Note that if we send a close frame /// but the remote hasn't confirmed, they might have sent data before they receive our /// close frame, so we should still pass those to client code, hence ClosedByUs is valid. fn can_read(self) -> bool { - match self { - WebSocketState::Active | WebSocketState::ClosedByUs => true, - _ => false, - } + matches!(self, WebSocketState::Active | WebSocketState::ClosedByUs) } /// Check if the state is active, return error if not. From 6bce14fa26f8b73363a2baafe55ce28614a66469 Mon Sep 17 00:00:00 2001 From: Redrield Date: Thu, 1 Oct 2020 18:27:51 -0400 Subject: [PATCH 04/10] Add facilities to allow clients to follow HTTP 3xx redirects * The connect() function defined in this crate will automatically follow redirecting responses. * Adds Error::Redirection, which is a special case of Error::Http that extracts the redirection target from the response headers, and stores it in the error object. Client implementations that build upon tungstenite can use this to implement redirecting. * A catch-all solution for redirects is not possible due to the abstraction transforming socket types to Read + Write, implementations that use the client_* methods need to handle redirections themselves. --- src/client.rs | 20 ++++++++++++++++++-- src/error.rs | 4 ++++ src/handshake/client.rs | 7 ++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/client.rs b/src/client.rs index cba6109..9f8516c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -91,6 +91,12 @@ pub fn connect_with_config( config: Option, ) -> Result<(WebSocket, Response)> { let request: Request = request.into_client_request()?; + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -104,10 +110,20 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { + match client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + }) { + Ok(r) => Ok(r), + Err(e) => match e { + Error::Redirection(uri) => { + debug!("Redirecting to {}", uri); + let request = request2.uri(uri).body(()).unwrap(); + connect_with_config(request, config) + } + _ => Err(e), + } + } } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 5c96dc7..9df1a45 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; +use http::Uri; #[cfg(feature = "tls")] pub mod tls { @@ -62,6 +63,8 @@ pub enum Error { Url(Cow<'static, str>), /// HTTP error. Http(http::StatusCode), + /// HTTP 3xx redirection response + Redirection(Uri), /// HTTP format error. HttpFormat(http::Error), } @@ -80,6 +83,7 @@ impl fmt::Display for Error { Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..8b00338 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -160,7 +160,12 @@ impl VerifyData { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); + if response.status().is_redirection() { + let value = response.headers().get("Location").unwrap(); + return Err(Error::Redirection(value.to_str()?.parse()?)) + } else { + return Err(Error::Http(response.status())); + } } let headers = response.headers(); From 60f7b0f02453435605007e03ce07ef18e438a827 Mon Sep 17 00:00:00 2001 From: Redrield Date: Mon, 26 Oct 2020 13:15:21 -0400 Subject: [PATCH 05/10] Fix some code-review issues * Replace Redirection error with a general Http error that owns the response * Make the default client connect function iterative instead of recursive * Add a limit to the amount of redirects a client will attempt to perform --- src/client.rs | 82 +++++++++++++++++++++++++---------------- src/error.rs | 15 +++++--- src/handshake/client.rs | 15 +++----- src/handshake/server.rs | 2 +- src/protocol/mod.rs | 4 ++ 5 files changed, 70 insertions(+), 48 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9f8516c..49ed656 100644 --- a/src/client.rs +++ b/src/client.rs @@ -90,40 +90,60 @@ pub fn connect_with_config( request: Req, config: Option, ) -> Result<(WebSocket, Response)> { - let request: Request = request.into_client_request()?; - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - let uri = request.uri(); - let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = uri.port_u16().unwrap_or(match mode { - Mode::Plain => 80, - Mode::Tls => 443, - }); - let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; - NoDelay::set_nodelay(&mut stream, true)?; - match client_with_config(request, stream, config).map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) { - Ok(r) => Ok(r), - Err(e) => match e { - Error::Redirection(uri) => { - debug!("Redirecting to {}", uri); - let request = request2.uri(uri).body(()).unwrap(); - connect_with_config(request, config) + let mut request: Request = request.into_client_request()?; + + fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + let uri = request.uri(); + let mode = uri_mode(uri)?; + let host = request + .uri() + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + NoDelay::set_nodelay(&mut stream, true)?; + client_with_config(request, stream, config).map_err(|e| match e{ + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) + } + + let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); + let mut redirects = 0; + + loop { + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); + + match inner(request, config) { + Ok(r) => return Ok(r), + Err(e) => match e { + Error::Http(res) => { + if res.status().is_redirection() { + let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; + debug!("Redirecting to {:?}", uri); + request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); + redirects += 1; + if redirects > max_redirects { + return Err(Error::Http(res)); + } + } else { + return Err(Error::Http(res)); + } + } + _ => return Err(e), } - _ => Err(e), } } + } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 9df1a45..01edcb0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::Uri; +use http::{Response, StatusCode}; #[cfg(feature = "tls")] pub mod tls { @@ -61,10 +61,12 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), + /// HTTP error (status only). + HttpStatus(StatusCode), /// HTTP error. - Http(http::StatusCode), - /// HTTP 3xx redirection response - Redirection(Uri), + Http(Response<()>), + /// No Location header in 3xx response + NoLocation, /// HTTP format error. HttpFormat(http::Error), } @@ -82,8 +84,9 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP error: {}", code), - Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), + Error::NoLocation => write!(f, "No Location header specified"), + Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), + Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8b00338..e2ca308 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,6 +90,11 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { + // If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if result.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(result)); + } self.verify_data.verify_response(&result)?; debug!("Client handshake done."); let websocket = @@ -157,16 +162,6 @@ struct VerifyData { impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { - // 1. If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - if response.status().is_redirection() { - let value = response.headers().get("Location").unwrap(); - return Err(Error::Redirection(value.to_str()?.parse()?)) - } else { - return Err(Error::Http(response.status())); - } - } let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 9412a6f..406dc24 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -274,7 +274,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(StatusCode::from_u16(err)?)); + return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8137393..505dddd 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,6 +43,9 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, + /// The max number of redirects the client should follow before aborting the connection. + /// The default value is 3. `None` here means that the client will not attempt to follow redirects. + pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -51,6 +54,7 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), + max_redirects: Some(3) } } } From 521f1a0767339a7e455cdc92580aa477900a50d7 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 16 Nov 2020 19:25:08 +0100 Subject: [PATCH 06/10] clean up the redirect logic a bit --- src/client.rs | 63 +++++++++++++++++++++------------------------ src/protocol/mod.rs | 4 --- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/src/client.rs b/src/client.rs index 49ed656..b28cf3c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; -use http::Uri; +use http::{Uri, request::Parts}; use log::*; use url::Url; @@ -89,10 +89,12 @@ use crate::stream::{Mode, NoDelay}; pub fn connect_with_config( request: Req, config: Option, + max_redirects: u8, ) -> Result<(WebSocket, Response)> { - let mut request: Request = request.into_client_request()?; - fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + fn try_client_handshake(request: Request, config: Option) + -> Result<(WebSocket, Response)> + { let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -106,44 +108,39 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e{ + client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), }) } - let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); - let mut redirects = 0; - - loop { - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - - match inner(request, config) { - Ok(r) => return Ok(r), - Err(e) => match e { - Error::Http(res) => { - if res.status().is_redirection() { - let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; - debug!("Redirecting to {:?}", uri); - request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); - redirects += 1; - if redirects > max_redirects { - return Err(Error::Http(res)); - } - } else { - return Err(Error::Http(res)); - } - } - _ => return Err(e), + fn create_request(parts: &Parts, uri: &Uri) -> Request { + let mut builder = Request::builder() + .uri(uri.clone()) + .method(parts.method.clone()) + .version(parts.version.clone()); + *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); + builder.body(()).expect("Failed to create `Request`") + } + + let (parts, _) = request.into_client_request()?.into_parts(); + let mut uri = parts.uri.clone(); + + for attempt in 0..(max_redirects + 1) { + let request = create_request(&parts, &uri); + + match try_client_handshake(request, config) { + Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { + let location = res.headers().get("Location").ok_or(Error::NoLocation)?; + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; } + other => return other, } } + unreachable!("Bug in a redirect handling logic") } /// Connect to the given WebSocket in blocking mode. @@ -159,7 +156,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) + connect_with_config(request, None, 0) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 505dddd..8137393 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,9 +43,6 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, - /// The max number of redirects the client should follow before aborting the connection. - /// The default value is 3. `None` here means that the client will not attempt to follow redirects. - pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -54,7 +51,6 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), - max_redirects: Some(3) } } } From a8e06d2b39c6c1f9091ce7af2adf2510575ab9d3 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 16 Nov 2020 19:49:17 +0100 Subject: [PATCH 07/10] clean up http error handling --- src/client.rs | 12 ++++++++---- src/error.rs | 10 ++-------- src/handshake/client.rs | 17 +++++++++-------- src/handshake/server.rs | 13 ++++++++----- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/client.rs b/src/client.rs index b28cf3c..0b70af3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -131,10 +131,14 @@ pub fn connect_with_config( match try_client_handshake(request, config) { Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { - let location = res.headers().get("Location").ok_or(Error::NoLocation)?; - uri = location.to_str()?.parse::()?; - debug!("Redirecting to {:?}", uri); - continue; + if let Some(location) = res.headers().get("Location") { + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; + } else { + warn!("No `Location` found in redirect"); + return Err(Error::Http(res)); + } } other => return other, } diff --git a/src/error.rs b/src/error.rs index 01edcb0..b2657cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::{Response, StatusCode}; +use http::Response; #[cfg(feature = "tls")] pub mod tls { @@ -61,12 +61,8 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), - /// HTTP error (status only). - HttpStatus(StatusCode), /// HTTP error. - Http(Response<()>), - /// No Location header in 3xx response - NoLocation, + Http(Response>), /// HTTP format error. HttpFormat(http::Error), } @@ -84,8 +80,6 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::NoLocation => write!(f, "No Location header specified"), - Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e2ca308..bb159d7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,12 +90,7 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { - // If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if result.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(result)); - } - self.verify_data.verify_response(&result)?; + let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, self.config); @@ -161,7 +156,13 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: Response) -> Result { + // 1. If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.map(|_| None))); + } + let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| @@ -219,7 +220,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok(()) + Ok(response) } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 406dc24..15f6b14 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -195,7 +195,7 @@ pub struct ServerHandshake { /// WebSocket configuration. config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. - error_code: Option, + error_response: Option, /// Internal stream type. _marker: PhantomData, } @@ -212,7 +212,7 @@ impl ServerHandshake { role: ServerHandshake { callback: Some(callback), config, - error_code: None, + error_response: None, _marker: PhantomData, }, } @@ -259,22 +259,25 @@ impl HandshakeRole for ServerHandshake { )); } - self.error_code = Some(resp.status().as_u16()); + self.error_response = Some(resp); + let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } StageResult::DoneWriting(stream) => { - if let Some(err) = self.error_code.take() { + if let Some(err) = self.error_response.take() { debug!("Server handshake failed."); - return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); + return Err(Error::Http(err)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); From 09f5d34899416d950f9f51611c1757b189cd1ebd Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 17 Nov 2020 11:17:56 +0100 Subject: [PATCH 08/10] use 3 redirects as default for `connect` --- src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 0b70af3..1b20980 100644 --- a/src/client.rs +++ b/src/client.rs @@ -160,7 +160,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None, 0) + connect_with_config(request, None, 3) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { From 96d9eb75e5d8a33dcb5d49a77088f48b7961d8ce Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 17 Nov 2020 11:50:03 +0100 Subject: [PATCH 09/10] chore: apply `fmt` to the whole project --- examples/autobahn-client.rs | 13 +---- examples/autobahn-server.rs | 21 ++++--- examples/callback-error.rs | 11 ++-- examples/client.rs | 4 +- examples/server.rs | 9 +-- rustfmt.toml | 7 +++ src/client.rs | 54 +++++++++--------- src/error.rs | 8 +-- src/handshake/client.rs | 94 +++++++++++-------------------- src/handshake/headers.rs | 3 +- src/handshake/machine.rs | 45 +++++++-------- src/handshake/mod.rs | 13 ++--- src/handshake/server.rs | 60 ++++++++------------ src/lib.rs | 14 ++--- src/protocol/frame/coding.rs | 22 +++++--- src/protocol/frame/frame.rs | 70 +++++++---------------- src/protocol/frame/mod.rs | 22 ++------ src/protocol/message.rs | 26 +++------ src/protocol/mod.rs | 85 +++++++++++----------------- src/server.rs | 6 +- src/util.rs | 6 +- tests/connection_reset.rs | 49 ++++++++-------- tests/no_send_after_close.rs | 10 ++-- tests/receive_after_init_close.rs | 14 ++--- 24 files changed, 282 insertions(+), 384 deletions(-) create mode 100644 rustfmt.toml diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 31aedb2..8143175 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -14,11 +14,7 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!( - "ws://localhost:9001/updateReports?agent={}", - AGENT - )) - .unwrap(), + Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), )?; socket.close(None)?; Ok(()) @@ -26,11 +22,8 @@ fn update_reports() -> Result<()> { fn run_test(case: u32) -> Result<()> { info!("Running test case {}", case); - let case_url = Url::parse(&format!( - "ws://localhost:9001/runCase?case={}&agent={}", - case, AGENT - )) - .unwrap(); + let case_url = + Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3c99545..3250b2c 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -1,9 +1,10 @@ -use std::net::{TcpListener, TcpStream}; -use std::thread::spawn; +use std::{ + net::{TcpListener, TcpStream}, + thread::spawn, +}; use log::*; -use tungstenite::handshake::HandshakeRole; -use tungstenite::{accept, Error, HandshakeError, Message, Result}; +use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result}; fn must_not_block(err: HandshakeError) -> Error { match err { @@ -32,12 +33,14 @@ fn main() { for stream in server.incoming() { spawn(move || match stream { - Ok(stream) => if let Err(err) = handle_client(stream) { - match err { - Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), - e => error!("test: {}", e), + Ok(stream) => { + if let Err(err) = handle_client(stream) { + match err { + Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), + e => error!("test: {}", e), + } } - }, + } Err(e) => error!("Error accepting stream: {}", e), }); } diff --git a/examples/callback-error.rs b/examples/callback-error.rs index 357d343..cf78a2e 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -1,9 +1,10 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; -use tungstenite::http::StatusCode; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, + http::StatusCode, +}; fn main() { let server = TcpListener::bind("127.0.0.1:3012").unwrap(); diff --git a/examples/client.rs b/examples/client.rs index 7938cfb..def6a3c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -14,9 +14,7 @@ fn main() { println!("* {}", header); } - socket - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); diff --git a/examples/server.rs b/examples/server.rs index def2f45..420e5db 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,9 @@ -use std::net::TcpListener; -use std::thread::spawn; +use std::{net::TcpListener, thread::spawn}; -use tungstenite::accept_hdr; -use tungstenite::handshake::server::{Request, Response}; +use tungstenite::{ + accept_hdr, + handshake::server::{Request, Response}, +}; fn main() { env_logger::init(); diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..db7f39d --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,7 @@ +# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check]. +# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md + +# Break complex but short statements a bit less. +use_small_heuristics = "Max" + +merge_imports = true diff --git a/src/client.rs b/src/client.rs index 1b20980..f9ae3a4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,20 @@ //! Methods to connect to a WebSocket as a client. -use std::io::{Read, Write}; -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -use std::result::Result as StdResult; +use std::{ + io::{Read, Write}, + net::{SocketAddr, TcpStream, ToSocketAddrs}, + result::Result as StdResult, +}; -use http::{Uri, request::Parts}; +use http::{request::Parts, Uri}; use log::*; use url::Url; -use crate::handshake::client::{Request, Response}; -use crate::protocol::WebSocketConfig; +use crate::{ + handshake::client::{Request, Response}, + protocol::WebSocketConfig, +}; #[cfg(feature = "tls")] mod encryption { @@ -22,8 +26,7 @@ mod encryption { /// TCP stream switcher (plain/TLS). pub type AutoStream = StreamSwitcher>; - use crate::error::Result; - use crate::stream::Mode; + use crate::{error::Result, stream::Mode}; pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { match mode { @@ -48,8 +51,10 @@ mod encryption { mod encryption { use std::net::TcpStream; - use crate::error::{Error, Result}; - use crate::stream::Mode; + use crate::{ + error::{Error, Result}, + stream::Mode, + }; /// TLS support is nod compiled in, this is just standard `TcpStream`. pub type AutoStream = TcpStream; @@ -65,11 +70,12 @@ mod encryption { use self::encryption::wrap_stream; pub use self::encryption::AutoStream; -use crate::error::{Error, Result}; -use crate::handshake::client::ClientHandshake; -use crate::handshake::HandshakeError; -use crate::protocol::WebSocket; -use crate::stream::{Mode, NoDelay}; +use crate::{ + error::{Error, Result}, + handshake::{client::ClientHandshake, HandshakeError}, + protocol::WebSocket, + stream::{Mode, NoDelay}, +}; /// Connect to the given WebSocket in blocking mode. /// @@ -91,16 +97,14 @@ pub fn connect_with_config( config: Option, max_redirects: u8, ) -> Result<(WebSocket, Response)> { - - fn try_client_handshake(request: Request, config: Option) - -> Result<(WebSocket, Response)> - { + fn try_client_handshake( + request: Request, + config: Option, + ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let host = + request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -164,9 +168,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { diff --git a/src/error.rs b/src/error.rs index b2657cf..c2becc7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,6 @@ //! Error handling. -use std::borrow::Cow; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io; -use std::result; -use std::str; -use std::string; +use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; use crate::protocol::Message; use http::Response; diff --git a/src/handshake/client.rs b/src/handshake/client.rs index bb159d7..ea011fd 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,17 +1,24 @@ //! Client handshake machine. -use std::io::{Read, Write}; -use std::marker::PhantomData; +use std::{ + io::{Read, Write}, + marker::PhantomData, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; /// Client request type. pub type Request = HttpRequest<()>; @@ -35,15 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol( - "Invalid HTTP method, only GET supported".into(), - )); + return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } // Check the URI scheme: only ws or wss are supported @@ -58,18 +61,11 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); - ClientHandshake { - verify_data: VerifyData { accept_key }, - config, - _marker: PhantomData, - } + ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData } }; trace!("Client handshake initiated."); - Ok(MidHandshake { - role: client, - machine, - }) + Ok(MidHandshake { role: client, machine }) } } @@ -85,11 +81,7 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneWriting(stream) => { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - StageResult::DoneReading { - stream, - result, - tail, - } => { + StageResult::DoneReading { stream, result, tail } => { let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = @@ -105,16 +97,16 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority() - .ok_or_else(|| Error::Url("No host name in the URL".into()))? - .as_str(); - let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ + let authority = + uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); + let host = if let Some(idx) = authority.find('@') { + // handle possible name:password@ authority.split_at(idx + 1).1 } else { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())) + return Err(Error::Url("URL contains empty host name".into())); } write!( @@ -128,17 +120,15 @@ fn generate_request(request: Request, key: &str) -> Result> { Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = uri - .path_and_query() - .ok_or_else(|| Error::Url("No path/query in URL".into()))? - .as_str(), + path = + uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), key = key ) .unwrap(); for (k, v) in request.headers() { let mut k = k.as_str(); - if k == "sec-websocket-protocol" { + if k == "sec-websocket-protocol" { k = "Sec-WebSocket-Protocol"; } writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); @@ -175,9 +165,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -189,22 +177,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in server reply".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !headers - .get("Sec-WebSocket-Accept") - .map(|h| h == &self.accept_key) - .unwrap_or(false) - { - return Err(Error::Protocol( - "Key mismatch in Sec-WebSocket-Accept".into(), - )); + if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { + return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -238,9 +218,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -266,9 +244,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { - use super::super::machine::TryParse; + use super::{super::machine::TryParse, generate_key, generate_request, Response}; use crate::client::IntoClientRequest; - use super::{generate_key, generate_request, Response}; #[test] fn random_keys() { @@ -342,9 +319,6 @@ mod tests { const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); - assert_eq!( - resp.headers().get("Content-Type").unwrap(), - &b"text/html"[..], - ); + assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 8386f5a..f336c65 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -41,8 +41,7 @@ impl TryParse for HeaderMap { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::HeaderMap; + use super::{super::machine::TryParse, HeaderMap}; #[test] fn headers() { diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 61090bb..b8416a6 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -2,8 +2,10 @@ use bytes::Buf; use log::*; use std::io::{Cursor, Read, Write}; -use crate::error::{Error, Result}; -use crate::util::NonBlockingResult; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; use input_buffer::{InputBuffer, MIN_READ}; /// A generic handshake state machine. @@ -23,10 +25,7 @@ impl HandshakeMachine { } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { - HandshakeMachine { - stream, - state: HandshakeState::Writing(Cursor::new(data.into())), - } + HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) } } /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &Stream { @@ -52,21 +51,19 @@ impl HandshakeMachine { .no_block()?; match read { Some(0) => Err(Error::Protocol("Handshake not finished".into())), - Some(_) => Ok( - if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { - buf.advance(size); - RoundResult::StageFinished(StageResult::DoneReading { - result: obj, - stream: self.stream, - tail: buf.into_vec(), - }) - } else { - RoundResult::Incomplete(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - }) - }, - ), + Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + buf.advance(size); + RoundResult::StageFinished(StageResult::DoneReading { + result: obj, + stream: self.stream, + tail: buf.into_vec(), + }) + } else { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + }) + }), None => Ok(RoundResult::WouldBlock(HandshakeMachine { state: HandshakeState::Reading(buf), ..self @@ -112,11 +109,7 @@ pub enum RoundResult { #[derive(Debug)] pub enum StageResult { /// Reading round finished. - DoneReading { - result: Obj, - stream: Stream, - tail: Vec, - }, + DoneReading { result: Obj, stream: Stream, tail: Vec }, /// Writing round finished. DoneWriting(Stream), } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 8350dc0..4714ee0 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -6,9 +6,11 @@ pub mod server; mod machine; -use std::error::Error as ErrorTrait; -use std::fmt; -use std::io::{Read, Write}; +use std::{ + error::Error as ErrorTrait, + fmt, + io::{Read, Write}, +}; use sha1::{Digest, Sha1}; @@ -39,10 +41,7 @@ impl MidHandshake { loop { mach = match mach.single_round()? { RoundResult::WouldBlock(m) => { - return Err(HandshakeError::Interrupted(MidHandshake { - machine: m, - ..self - })) + return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) } RoundResult::Incomplete(m) => m, RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 15f6b14..1b6eed8 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,18 +1,25 @@ //! Server handshake machine. -use std::io::{self, Read, Write}; -use std::marker::PhantomData; -use std::result::Result as StdResult; +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + result::Result as StdResult, +}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, MAX_HEADERS}; -use super::machine::{HandshakeMachine, StageResult, TryParse}; -use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; -use crate::error::{Error, Result}; -use crate::protocol::{Role, WebSocket, WebSocketConfig}; +use super::{ + convert_key, + headers::{FromHttparse, MAX_HEADERS}, + machine::{HandshakeMachine, StageResult, TryParse}, + HandshakeRole, MidHandshake, ProcessingResult, +}; +use crate::{ + error::{Error, Result}, + protocol::{Role, WebSocket, WebSocketConfig}, +}; /// Server request type. pub type Request = HttpRequest<()>; @@ -30,9 +37,7 @@ pub fn create_response(request: &Request) -> Result { } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } if !request @@ -42,9 +47,7 @@ pub fn create_response(request: &Request) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Connection: upgrade\" in client request".into(), - )); + return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); } if !request @@ -54,20 +57,11 @@ pub fn create_response(request: &Request) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol( - "No \"Upgrade: websocket\" in client request".into(), - )); + return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); } - if !request - .headers() - .get("Sec-WebSocket-Version") - .map(|h| h == "13") - .unwrap_or(false) - { - return Err(Error::Protocol( - "No \"Sec-WebSocket-Version: 13\" in client request".into(), - )); + if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { + return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); } let key = request @@ -121,9 +115,7 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol( - "HTTP version should be 1.1 or higher".into(), - )); + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -229,11 +221,7 @@ impl HandshakeRole for ServerHandshake { finish: StageResult, ) -> Result> { Ok(match finish { - StageResult::DoneReading { - stream, - result, - tail, - } => { + StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())); } @@ -290,9 +278,7 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::super::machine::TryParse; - use super::create_response; - use super::Request; + use super::{super::machine::TryParse, create_response, Request}; #[test] fn request_parsing() { diff --git a/src/lib.rs b/src/lib.rs index f965478..82f7822 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,10 +22,10 @@ pub mod server; pub mod stream; pub mod util; -pub use crate::client::{client, connect}; -pub use crate::error::{Error, Result}; -pub use crate::handshake::client::ClientHandshake; -pub use crate::handshake::server::ServerHandshake; -pub use crate::handshake::HandshakeError; -pub use crate::protocol::{Message, WebSocket}; -pub use crate::server::{accept, accept_hdr}; +pub use crate::{ + client::{client, connect}, + error::{Error, Result}, + handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, + protocol::{Message, WebSocket}, + server::{accept, accept_hdr}, +}; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index d3fcdeb..c69b90c 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -1,7 +1,9 @@ //! Various codes defined in RFC 6455. -use std::convert::{From, Into}; -use std::fmt; +use std::{ + convert::{From, Into}, + fmt, +}; /// WebSocket message opcode as in RFC 6455. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -71,9 +73,11 @@ impl fmt::Display for OpCode { impl Into for OpCode { fn into(self) -> u8 { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match self { Data(Continue) => 0, Data(Text) => 1, @@ -90,9 +94,11 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { - use self::Control::{Close, Ping, Pong}; - use self::Data::{Binary, Continue, Text}; - use self::OpCode::*; + use self::{ + Control::{Close, Ping, Pong}, + Data::{Binary, Continue, Text}, + OpCode::*, + }; match byte { 0 => Data(Continue), 1 => Data(Text), diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 38ce61c..ff64fa2 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,14 +1,18 @@ use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use log::*; -use std::borrow::Cow; -use std::default::Default; -use std::fmt; -use std::io::{Cursor, ErrorKind, Read, Write}; -use std::result::Result as StdResult; -use std::string::{FromUtf8Error, String}; - -use super::coding::{CloseCode, Control, Data, OpCode}; -use super::mask::{apply_mask, generate_mask}; +use std::{ + borrow::Cow, + default::Default, + fmt, + io::{Cursor, ErrorKind, Read, Write}, + result::Result as StdResult, + string::{FromUtf8Error, String}, +}; + +use super::{ + coding::{CloseCode, Control, Data, OpCode}, + mask::{apply_mask, generate_mask}, +}; use crate::error::{Error, Result}; /// A struct representing the close command. @@ -23,10 +27,7 @@ pub struct CloseFrame<'t> { impl<'t> CloseFrame<'t> { /// Convert into a owned string. pub fn into_owned(self) -> CloseFrame<'static> { - CloseFrame { - code: self.code, - reason: self.reason.into_owned().into(), - } + CloseFrame { code: self.code, reason: self.reason.into_owned().into() } } } @@ -192,14 +193,7 @@ impl FrameHeader { _ => (), } - let hdr = FrameHeader { - is_final, - rsv1, - rsv2, - rsv3, - opcode, - mask, - }; + let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask }; Ok(Some((hdr, length))) } @@ -298,10 +292,7 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { - code, - reason: text.into(), - })) + Ok(Some(CloseFrame { code, reason: text.into() })) } } } @@ -309,19 +300,9 @@ impl Frame { /// Create a new data frame. #[inline] pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { - debug_assert!( - matches!(opcode, OpCode::Data(_)), - "Invalid opcode for data frame." - ); + debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - Frame { - header: FrameHeader { - is_final, - opcode, - ..FrameHeader::default() - }, - payload: data, - } + Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } } /// Create a new Pong control frame. @@ -360,10 +341,7 @@ impl Frame { Vec::new() }; - Frame { - header: FrameHeader::default(), - payload, - } + Frame { header: FrameHeader::default(), payload } } /// Create a frame from given header and data. @@ -401,10 +379,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload - .iter() - .map(|byte| format!("{:x}", byte)) - .collect::() + self.payload.iter().map(|byte| format!("{:x}", byte)).collect::() ) } } @@ -476,10 +451,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload); - assert_eq!( - frame.into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] - ); + assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } #[test] diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 6756f0a..dfd0bd5 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,8 +6,7 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::CloseFrame; -pub use self::frame::{Frame, FrameHeader}; +pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; @@ -26,18 +25,12 @@ pub struct FrameSocket { impl FrameSocket { /// Create a new frame socket. pub fn new(stream: Stream) -> Self { - FrameSocket { - stream, - codec: FrameCodec::new(), - } + FrameSocket { stream, codec: FrameCodec::new() } } /// Create a new frame socket from partially read data. pub fn from_partially_read(stream: Stream, part: Vec) -> Self { - FrameSocket { - stream, - codec: FrameCodec::from_partially_read(part), - } + FrameSocket { stream, codec: FrameCodec::from_partially_read(part) } } /// Extract a stream from the socket. @@ -184,9 +177,7 @@ impl FrameCodec { { trace!("writing frame {}", frame); self.out_buffer.reserve(frame.len()); - frame - .format(&mut self.out_buffer) - .expect("Bug: can't write to vector"); + frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); self.write_pending(stream) } @@ -231,10 +222,7 @@ mod tests { sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); - assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), - vec![0x03, 0x02, 0x01] - ); + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); diff --git a/src/protocol/message.rs b/src/protocol/message.rs index d1778f1..f799dbf 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,7 +1,9 @@ -use std::convert::{AsRef, From, Into}; -use std::fmt; -use std::result::Result as StdResult; -use std::str; +use std::{ + convert::{AsRef, From, Into}, + fmt, + result::Result as StdResult, + str, +}; use super::frame::CloseFrame; use crate::error::{Error, Result}; @@ -19,10 +21,7 @@ mod string_collect { impl StringCollector { pub fn new() -> Self { - StringCollector { - data: String::new(), - incomplete: None, - } + StringCollector { data: String::new(), incomplete: None } } pub fn len(&self) -> usize { @@ -54,10 +53,7 @@ mod string_collect { self.data.push_str(text); Ok(()) } - Err(DecodeError::Incomplete { - valid_prefix, - incomplete_suffix, - }) => { + Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { self.data.push_str(valid_prefix); self.incomplete = Some(incomplete_suffix); Ok(()) @@ -127,11 +123,7 @@ impl IncompleteMessage { // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { return Err(Error::Capacity( - format!( - "Message too big: {} + {} > {}", - my_size, portion_size, max_size - ) - .into(), + format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), )); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8137393..72485e9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -4,19 +4,26 @@ pub mod frame; mod message; -pub use self::frame::CloseFrame; -pub use self::message::Message; +pub use self::{frame::CloseFrame, message::Message}; use log::*; -use std::collections::VecDeque; -use std::io::{ErrorKind as IoErrorKind, Read, Write}; -use std::mem::replace; - -use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; -use self::frame::{Frame, FrameCodec}; -use self::message::{IncompleteMessage, IncompleteMessageType}; -use crate::error::{Error, Result}; -use crate::util::NonBlockingResult; +use std::{ + collections::VecDeque, + io::{ErrorKind as IoErrorKind, Read, Write}, + mem::replace, +}; + +use self::{ + frame::{ + coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, + Frame, FrameCodec, + }, + message::{IncompleteMessage, IncompleteMessageType}, +}; +use crate::{ + error::{Error, Result}, + util::NonBlockingResult, +}; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -74,10 +81,7 @@ impl WebSocket { /// or together with an existing one. If you need an initial handshake, use /// `connect()` or `accept()` functions of the crate to construct a websocket. pub fn from_raw_socket(stream: Stream, role: Role, config: Option) -> Self { - WebSocket { - socket: stream, - context: WebSocketContext::new(role, config), - } + WebSocket { socket: stream, context: WebSocketContext::new(role, config) } } /// Convert a raw socket into a WebSocket without performing a handshake. @@ -320,9 +324,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol( - "Sending after closing is not allowed".into(), - )); + return Err(Error::Protocol("Sending after closing is not allowed".into())); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -455,9 +457,7 @@ impl WebSocketContext { Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol( - "Received a masked frame from server".into(), - )); + return Err(Error::Protocol("Received a masked frame from server".into())); } } } @@ -474,9 +474,9 @@ impl WebSocketContext { Err(Error::Protocol("Control frame too big".into())) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), - OpCtl::Reserved(i) => Err(Error::Protocol( - format!("Unknown control frame type {}", i).into(), - )), + OpCtl::Reserved(i) => { + Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + } OpCtl::Ping => { let data = frame.into_data(); // No ping processing after we sent a close frame. @@ -527,9 +527,9 @@ impl WebSocketContext { Ok(None) } } - OpData::Reserved(i) => Err(Error::Protocol( - format!("Unknown data frame type {}", i).into(), - )), + OpData::Reserved(i) => { + Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + } } } } // match opcode @@ -539,9 +539,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol( - "Connection reset without closing handshake".into(), - )), + _ => Err(Error::Protocol("Connection reset without closing handshake".into())), } } } @@ -602,9 +600,7 @@ impl WebSocketContext { } trace!("Sending frame: {:?}", frame); - self.frame - .write_frame(stream, frame) - .check_connection_reset(self.state) + self.frame.write_frame(stream, frame).check_connection_reset(self.state) } } @@ -669,8 +665,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use std::io; - use std::io::Cursor; + use std::{io, io::Cursor}; struct WriteMoc(Stream); @@ -699,14 +694,8 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!( - socket.read_message().unwrap(), - Message::Text("Hello, World!".into()) - ); - assert_eq!( - socket.read_message().unwrap(), - Message::Binary(vec![0x01, 0x02, 0x03]) - ); + assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); + assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); } #[test] @@ -715,10 +704,7 @@ mod tests { 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, ]); - let limit = WebSocketConfig { - max_message_size: Some(10), - ..WebSocketConfig::default() - }; + let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), @@ -729,10 +715,7 @@ mod tests { #[test] fn size_limiting_binary() { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); - let limit = WebSocketConfig { - max_message_size: Some(2), - ..WebSocketConfig::default() - }; + let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), diff --git a/src/server.rs b/src/server.rs index 725d892..53303ee 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,10 @@ pub use crate::handshake::server::ServerHandshake; -use crate::handshake::server::{Callback, NoCallback}; -use crate::handshake::HandshakeError; +use crate::handshake::{ + server::{Callback, NoCallback}, + HandshakeError, +}; use crate::protocol::{WebSocket, WebSocketConfig}; diff --git a/src/util.rs b/src/util.rs index cd03035..f40ca43 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,9 @@ //! Helper traits to ease non-blocking handling. -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::result::Result as StdResult; +use std::{ + io::{Error as IoError, ErrorKind as IoErrorKind}, + result::Result as StdResult, +}; use crate::error::Error; diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index d95ee81..7e3e33f 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,15 +1,17 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. -use std::net::{TcpStream, TcpListener}; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::{TcpListener, TcpStream}, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; -use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream}; use native_tls::TlsStream; -use url::Url; use net2::TcpStreamExt; +use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; +use url::Url; type Sock = WebSocket>>; @@ -26,8 +28,8 @@ where exit(1); }); - let server = TcpListener::bind(("127.0.0.1", port)) - .expect("Can't listen, is port already in use?"); + let server = + TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?"); let client_thread = spawn(move || { let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap()) @@ -46,11 +48,10 @@ where #[test] fn test_server_close() { - do_test(3012, + do_test( + 3012, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive close from server assert!(message.is_close()); @@ -75,16 +76,16 @@ fn test_server_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); + }, + ); } #[test] fn test_evil_server_close() { - do_test(3013, + do_test( + 3013, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); sleep(Duration::from_secs(1)); @@ -108,16 +109,16 @@ fn test_evil_server_close() { // and now just drop the connection without waiting for `ConnectionClosed` srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); drop(srv_sock); - }); + }, + ); } #[test] fn test_client_close() { - do_test(3014, + do_test( + 3014, |mut cli_sock| { - cli_sock - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read_message().unwrap(); // receive answer from server assert_eq!(message.into_data(), b"From Server"); @@ -147,6 +148,6 @@ fn test_client_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); - + }, + ); } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index d8e20e5..f348eca 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 352020e..87f8dda 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -1,10 +1,12 @@ //! Verifies that we can read data messages even if we have initiated a close handshake, //! but before we got confirmation. -use std::net::TcpListener; -use std::process::exit; -use std::thread::{sleep, spawn}; -use std::time::Duration; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; use tungstenite::{accept, connect, Error, Message}; use url::Url; @@ -24,9 +26,7 @@ fn test_receive_after_init_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); - client - .write_message(Message::Text("Hello WebSocket".into())) - .unwrap(); + client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); let message = client.read_message().unwrap(); // receive close from server assert!(message.is_close()); From fcacea7c9fbcdd9edb7c71a13c7fd494bd65bb6c Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 17 Nov 2020 11:56:15 +0100 Subject: [PATCH 10/10] chore: apply `clippy` --- src/client.rs | 2 +- src/protocol/frame/coding.rs | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/client.rs b/src/client.rs index f9ae3a4..9b13be9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -122,7 +122,7 @@ pub fn connect_with_config( let mut builder = Request::builder() .uri(uri.clone()) .method(parts.method.clone()) - .version(parts.version.clone()); + .version(parts.version); *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); builder.body(()).expect("Failed to create `Request`") } diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index c69b90c..e726161 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -190,14 +190,7 @@ pub enum CloseCode { impl CloseCode { /// Check if this CloseCode is allowed. pub fn is_allowed(self) -> bool { - match self { - Bad(_) => false, - Reserved(_) => false, - Status => false, - Abnormal => false, - Tls => false, - _ => true, - } + !matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls) } }