From 34c6e63d876372b50b12b96796f7c8973fa7503a Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sun, 3 Jan 2021 22:15:06 +0000 Subject: [PATCH 01/15] Add specific URL error types --- src/client.rs | 14 +++++++------- src/error.rs | 34 ++++++++++++++++++++++++++++++++-- src/handshake/client.rs | 8 ++++---- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1741fa2..074c7ff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ mod encryption { use std::net::TcpStream; use crate::{ - error::{Error, Result}, + error::{Error, UrlErrorType, Result}, stream::Mode, }; @@ -62,7 +62,7 @@ mod encryption { pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { match mode { Mode::Plain => Ok(stream), - Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + Mode::Tls => Err(Error::Url(UrlErrorType::TlsFeatureNotEnabled)), } } } @@ -71,7 +71,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::{ - error::{Error, Result}, + error::{Error, UrlErrorType, Result}, handshake::{client::ClientHandshake, HandshakeError}, protocol::WebSocket, stream::{Mode, NoDelay}, @@ -104,7 +104,7 @@ pub fn connect_with_config( let uri = request.uri(); let mode = uri_mode(uri)?; let host = - request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -166,7 +166,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = uri.host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { @@ -175,7 +175,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result Result { match uri.scheme_str() { Some("ws") => Ok(Mode::Plain), Some("wss") => Ok(Mode::Tls), - _ => Err(Error::Url("URL scheme not supported".into())), + _ => Err(Error::Url(UrlErrorType::UnsupportedUrlScheme)), } } diff --git a/src/error.rs b/src/error.rs index c2becc7..f9891f5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,7 +14,7 @@ pub mod tls { /// Result type of all Tungstenite library calls. pub type Result = result::Result; -/// Possible WebSocket errors +/// Possible WebSocket errors. #[derive(Debug)] pub enum Error { /// WebSocket connection closed normally. This informs you of the close. @@ -54,7 +54,7 @@ pub enum Error { /// UTF coding error Utf8, /// Invalid URL. - Url(Cow<'static, str>), + Url(UrlErrorType), /// HTTP error. Http(Response>), /// HTTP format error. @@ -151,3 +151,33 @@ impl From for Error { } } } + +/// Indicates the specific type/cause of URL error. +#[derive(Debug)] +pub enum UrlErrorType { + /// TLS is used despite not being compiled with the TLS feature enabled. + TlsFeatureNotEnabled, + /// The URL does not include a host name. + NoHostName, + /// Failed to connect with this URL. + UnableToConnect(String), + /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). + UnsupportedUrlScheme, + /// The URL host name, though included, is empty. + EmptyHostName, + /// The URL does not include a path/query. + NoPathOrQuery +} + +impl fmt::Display for UrlErrorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + UrlErrorType::TlsFeatureNotEnabled => write!(f, "TLS support not compiled in"), + UrlErrorType::NoHostName => write!(f, "No host name in the URL"), + UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), + UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), + UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), + UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL") + } + } +} \ No newline at end of file diff --git a/src/handshake/client.rs b/src/handshake/client.rs index ea011fd..558eb85 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, Result}, + error::{Error, UrlErrorType, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -98,7 +98,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let uri = request.uri(); let authority = - uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); + uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -106,7 +106,7 @@ fn generate_request(request: Request, key: &str) -> Result> { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())); + return Err(Error::Url(UrlErrorType::EmptyHostName)); } write!( @@ -121,7 +121,7 @@ fn generate_request(request: Request, key: &str) -> Result> { version = request.version(), host = host, path = - uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), + uri.path_and_query().ok_or_else(|| Error::Url(UrlErrorType::NoPathOrQuery))?.as_str(), key = key ) .unwrap(); From 6f846da0e327275420418ff51f951e29c4055695 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 4 Jan 2021 12:17:22 +0000 Subject: [PATCH 02/15] Add protocol error types --- src/client.rs | 7 +- src/error.rs | 137 +++++++++++++++++++++++++++++++++-- src/handshake/client.rs | 17 ++--- src/handshake/machine.rs | 4 +- src/handshake/server.rs | 22 +++--- src/protocol/frame/frame.rs | 8 +- src/protocol/mod.rs | 34 ++++----- tests/no_send_after_close.rs | 4 +- 8 files changed, 173 insertions(+), 60 deletions(-) diff --git a/src/client.rs b/src/client.rs index 074c7ff..f4c5f8d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ mod encryption { use std::net::TcpStream; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, Result, UrlErrorType}, stream::Mode, }; @@ -71,7 +71,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, Result, UrlErrorType}, handshake::{client::ClientHandshake, HandshakeError}, protocol::WebSocket, stream::{Mode, NoDelay}, @@ -103,8 +103,7 @@ pub fn connect_with_config( ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = - request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; + let host = request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, diff --git a/src/error.rs b/src/error.rs index f9891f5..e11aef3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; -use crate::protocol::Message; +use crate::protocol::{frame::coding::Data, Message}; use http::Response; #[cfg(feature = "tls")] @@ -41,14 +41,14 @@ pub enum Error { /// underlying connection and you should probably consider them fatal. Io(io::Error), #[cfg(feature = "tls")] - /// TLS error + /// TLS error. Tls(tls::Error), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). Capacity(Cow<'static, str>), /// Protocol violation. - Protocol(Cow<'static, str>), + Protocol(ProtocolErrorType), /// Message send queue full. SendQueueFull(Message), /// UTF coding error @@ -147,13 +147,13 @@ impl From for Error { fn from(err: httparse::Error) -> Self { match err { httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), - e => Error::Protocol(e.to_string().into()), + e => Error::Protocol(ProtocolErrorType::HttparseError(e)), } } } /// Indicates the specific type/cause of URL error. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum UrlErrorType { /// TLS is used despite not being compiled with the TLS feature enabled. TlsFeatureNotEnabled, @@ -166,7 +166,7 @@ pub enum UrlErrorType { /// The URL host name, though included, is empty. EmptyHostName, /// The URL does not include a path/query. - NoPathOrQuery + NoPathOrQuery, } impl fmt::Display for UrlErrorType { @@ -177,7 +177,128 @@ impl fmt::Display for UrlErrorType { UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), - UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL") + UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"), + } + } +} + +/// Indicates the specific type/cause of a protocol error. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolErrorType { + /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + WrongHttpMethod, + /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + WrongHttpVersion, + /// Missing `Connection: upgrade` HTTP header. + MissingConnectionUpgradeHeader, + /// Missing `Upgrade: websocket` HTTP header. + MissingUpgradeWebSocketHeader, + /// Missing `Sec-WebSocket-Version: 13` HTTP header. + MissingSecWebSocketVersionHeader, + /// Missing `Sec-WebSocket-Key` HTTP header. + MissingSecWebSocketKey, + /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + SecWebSocketAcceptKeyMismatch, + /// Garbage data encountered after client request. + JunkAfterRequest, + /// Custom responses must be unsuccessful. + CustomResponseSuccessful, + /// No more data while still performing handshake. + HandshakeIncomplete, + /// Wrapper around a [`httparse::Error`] value. + HttparseError(httparse::Error), + /// Not allowed to send after having sent a closing frame. + SendAfterClosing, + /// Remote sent data after sending a closing frame. + ReceivedAfterClosing, + /// Reserved bits in frame header are non-zero. + NonZeroReservedBits, + /// The server must close the connection when an unmasked frame is received. + UnmaskedFrameFromClient, + /// The client must close the connection when a masked frame is received. + MaskedFrameFromServer, + /// Control frames must not be fragmented. + FragmentedControlFrame, + /// Control frames must have a payload of 125 bytes or less. + ControlFrameTooBig, + /// Type of control frame not recognised. + UnknownControlFrameType(u8), + /// Type of data frame not recognised. + UnknownDataFrameType(u8), + /// Received a continue frame despite there being nothing to continue. + UnexpectedContinueFrame, + /// Received data while waiting for more fragments. + ExpectedFragment(Data), + /// Connection closed without performing the closing handshake. + ResetWithoutClosingHandshake, + /// Encountered an invalid opcode. + InvalidOpcode(u8), + /// The payload for the closing frame is invalid. + InvalidCloseSequence, +} + +impl fmt::Display for ProtocolErrorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ProtocolErrorType::WrongHttpMethod => { + write!(f, "Unsupported HTTP method used, only GET is allowed") + } + ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"), + ProtocolErrorType::MissingConnectionUpgradeHeader => { + write!(f, "No \"Connection: upgrade\" header") + } + ProtocolErrorType::MissingUpgradeWebSocketHeader => { + write!(f, "No \"Upgrade: websocket\" header") + } + ProtocolErrorType::MissingSecWebSocketVersionHeader => { + write!(f, "No \"Sec-WebSocket-Version: 13\" header") + } + ProtocolErrorType::MissingSecWebSocketKey => { + write!(f, "No \"Sec-WebSocket-Key\" header") + } + ProtocolErrorType::SecWebSocketAcceptKeyMismatch => { + write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header") + } + ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"), + ProtocolErrorType::CustomResponseSuccessful => { + write!(f, "Custom response must not be successful") + } + ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"), + ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e), + ProtocolErrorType::SendAfterClosing => { + write!(f, "Sending after closing is not allowed") + } + ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"), + ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"), + ProtocolErrorType::UnmaskedFrameFromClient => { + write!(f, "Received an unmasked frame from client") + } + ProtocolErrorType::MaskedFrameFromServer => { + write!(f, "Received a masked frame from server") + } + ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"), + ProtocolErrorType::ControlFrameTooBig => { + write!(f, "Control frame too big (payload must be 125 bytes or less)") + } + ProtocolErrorType::UnknownControlFrameType(i) => { + write!(f, "Unknown control frame type: {}", i) + } + ProtocolErrorType::UnknownDataFrameType(i) => { + write!(f, "Unknown data frame type: {}", i) + } + ProtocolErrorType::UnexpectedContinueFrame => { + write!(f, "Continue frame but nothing to continue") + } + ProtocolErrorType::ExpectedFragment(c) => { + write!(f, "While waiting for more fragments received: {}", c) + } + ProtocolErrorType::ResetWithoutClosingHandshake => { + write!(f, "Connection reset without closing handshake") + } + ProtocolErrorType::InvalidOpcode(opcode) => { + write!(f, "Encountered invalid opcode: {}", opcode) + } + ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"), } } -} \ No newline at end of file +} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 558eb85..e7247d9 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, ProtocolErrorType, Result, UrlErrorType}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -42,11 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } // Check the URI scheme: only ws or wss are supported @@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = - uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); + let authority = uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -165,7 +164,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -177,14 +176,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { - return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch)); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -218,7 +217,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } let headers = HeaderMap::from_httparse(raw.headers)?; diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 5c7e000..e23fb15 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -50,7 +50,7 @@ impl HandshakeMachine { .read_from(&mut self.stream) .no_block()?; match read { - Some(0) => Err(Error::Protocol("Handshake not finished".into())), + Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 53227ab..c7877bc 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -19,7 +19,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse>; fn create_parts(request: &HttpRequest) -> Result { if request.method() != http::Method::GET { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } if !request @@ -48,7 +48,7 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); } if !request @@ -58,17 +58,17 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); } if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { - return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader)); } let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + .ok_or_else(|| Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) @@ -125,11 +125,11 @@ impl TryParse for Request { impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -237,7 +237,7 @@ impl HandshakeRole for ServerHandshake { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { - return Err(Error::Protocol("Junk after client request".into())); + return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest)); } let response = create_response(&result)?; @@ -257,7 +257,7 @@ impl HandshakeRole for ServerHandshake { Err(resp) => { if resp.status().is_success() { return Err(Error::Protocol( - "Custom response must not be successful".into(), + ProtocolErrorType::CustomResponseSuccessful, )); } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index ff64fa2..27281b6 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -13,7 +13,7 @@ use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, }; -use crate::error::{Error, Result}; +use crate::error::{Error, ProtocolErrorType, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,9 +186,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol( - format!("Encountered invalid opcode: {}", first & 0x0F).into(), - )) + return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F))) } _ => (), } @@ -286,7 +284,7 @@ impl Frame { pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), - 1 => Err(Error::Protocol("Invalid close sequence".into())), + 1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)), _ => { let mut data = self.payload; let code = NetworkEndian::read_u16(&data[0..2]).into(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 63763f0..be002e8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,7 +21,7 @@ use self::{ message::{IncompleteMessage, IncompleteMessageType}, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; @@ -331,7 +331,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol("Sending after closing is not allowed".into())); + return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing)); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -431,9 +431,7 @@ impl WebSocketContext { .check_connection_reset(self.state)? { if !self.state.can_read() { - return Err(Error::Protocol( - "Remote sent frame after having sent a Close Frame".into(), - )); + return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing)); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -443,7 +441,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol("Reserved bits are non-zero".into())); + return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits)); } } @@ -458,15 +456,13 @@ impl WebSocketContext { // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol( - "Received an unmasked frame from client".into(), - )); + return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient)); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol("Received a masked frame from server".into())); + return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer)); } } } @@ -477,14 +473,14 @@ impl WebSocketContext { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) _ if !frame.header().is_final => { - Err(Error::Protocol("Fragmented control frame".into())) + Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame)) } _ if frame.payload().len() > 125 => { - Err(Error::Protocol("Control frame too big".into())) + Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig)) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Reserved(i) => { - Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i))) } OpCtl::Ping => { let data = frame.into_data(); @@ -506,7 +502,7 @@ impl WebSocketContext { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol( - "Continue frame but nothing to continue".into(), + ProtocolErrorType::UnexpectedContinueFrame, )); } if fin { @@ -515,9 +511,9 @@ impl WebSocketContext { Ok(None) } } - c if self.incomplete.is_some() => Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into(), - )), + c if self.incomplete.is_some() => { + Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c))) + } OpData::Text | OpData::Binary => { let msg = { let message_type = match data { @@ -537,7 +533,7 @@ impl WebSocketContext { } } OpData::Reserved(i) => { - Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i))) } } } @@ -548,7 +544,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol("Connection reset without closing handshake".into())), + _ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)), } } } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index f348eca..d3b6943 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -8,7 +8,7 @@ use std::{ time::Duration, }; -use tungstenite::{accept, connect, Error, Message}; +use tungstenite::{accept, connect, error::ProtocolErrorType, Error, Message}; use url::Url; #[test] @@ -46,7 +46,7 @@ fn test_no_send_after_close() { assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), + Error::Protocol(s) => assert_eq!(s, ProtocolErrorType::SendAfterClosing), e => panic!("unexpected error: {:?}", e), } From aaebb432f01405b3006c1a16a293bd3f6f999957 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 4 Jan 2021 12:20:36 +0000 Subject: [PATCH 03/15] Fix clippy warnings --- src/client.rs | 4 ++-- src/handshake/client.rs | 5 ++--- src/handshake/server.rs | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/client.rs b/src/client.rs index f4c5f8d..41a813d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -103,7 +103,7 @@ pub fn connect_with_config( ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; + let host = request.uri().host().ok_or(Error::Url(UrlErrorType::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -165,7 +165,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri.host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; + let domain = uri.host().ok_or(Error::Url(UrlErrorType::NoHostName))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e7247d9..5b9f934 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -97,7 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); + let authority = uri.authority().ok_or(Error::Url(UrlErrorType::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -119,8 +119,7 @@ fn generate_request(request: Request, key: &str) -> Result> { Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = - uri.path_and_query().ok_or_else(|| Error::Url(UrlErrorType::NoPathOrQuery))?.as_str(), + path = uri.path_and_query().ok_or(Error::Url(UrlErrorType::NoPathOrQuery))?.as_str(), key = key ) .unwrap(); diff --git a/src/handshake/server.rs b/src/handshake/server.rs index c7877bc..372bff7 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -68,7 +68,7 @@ fn create_parts(request: &HttpRequest) -> Result { let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; + .ok_or(Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) From 0b34bee94f035bc5aea9545c8679eee008bb1212 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 4 Jan 2021 15:16:29 +0000 Subject: [PATCH 04/15] Add capacity error types --- src/error.rs | 81 +++++++++++++++++++++++++++------------ src/handshake/machine.rs | 4 +- src/protocol/frame/mod.rs | 21 +++++----- src/protocol/message.rs | 9 +++-- src/protocol/mod.rs | 19 +++++---- 5 files changed, 86 insertions(+), 48 deletions(-) diff --git a/src/error.rs b/src/error.rs index e11aef3..0f00064 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ //! Error handling. -use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; +use std::{error::Error as ErrorTrait, fmt, io, result, str, string}; use crate::protocol::{frame::coding::Data, Message}; use http::Response; @@ -46,7 +46,7 @@ pub enum Error { /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). - Capacity(Cow<'static, str>), + Capacity(CapacityErrorType), /// Protocol violation. Protocol(ProtocolErrorType), /// Message send queue full. @@ -146,38 +146,39 @@ impl From for Error { impl From for Error { fn from(err: httparse::Error) -> Self { match err { - httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), + httparse::Error::TooManyHeaders => Error::Capacity(CapacityErrorType::TooManyHeaders), e => Error::Protocol(ProtocolErrorType::HttparseError(e)), } } } -/// Indicates the specific type/cause of URL error. -#[derive(Debug, PartialEq, Eq)] -pub enum UrlErrorType { - /// TLS is used despite not being compiled with the TLS feature enabled. - TlsFeatureNotEnabled, - /// The URL does not include a host name. - NoHostName, - /// Failed to connect with this URL. - UnableToConnect(String), - /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). - UnsupportedUrlScheme, - /// The URL host name, though included, is empty. - EmptyHostName, - /// The URL does not include a path/query. - NoPathOrQuery, +/// Indicates the specific type/cause of a capacity error. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum CapacityErrorType { + /// Too many headers provided (see [`httparse::Error::TooManyHeaders`]). + TooManyHeaders, + /// Received header is too long. + HeaderTooLong, + /// Message is bigger than the maximum allowed size. + MessageTooLong { + /// The size of the message. + size: usize, + /// The maximum allowed message size. + max_size: usize, + }, + /// TCP buffer is full. + TcpBufferFull, } -impl fmt::Display for UrlErrorType { +impl fmt::Display for CapacityErrorType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - UrlErrorType::TlsFeatureNotEnabled => write!(f, "TLS support not compiled in"), - UrlErrorType::NoHostName => write!(f, "No host name in the URL"), - UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), - UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), - UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), - UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"), + CapacityErrorType::TooManyHeaders => write!(f, "Too many headers"), + CapacityErrorType::HeaderTooLong => write!(f, "Header too long"), + CapacityErrorType::MessageTooLong { size, max_size } => { + write!(f, "Message too long: {} > {}", size, max_size) + } + CapacityErrorType::TcpBufferFull => write!(f, "Incoming TCP buffer is full"), } } } @@ -302,3 +303,33 @@ impl fmt::Display for ProtocolErrorType { } } } + +/// Indicates the specific type/cause of URL error. +#[derive(Debug, PartialEq, Eq)] +pub enum UrlErrorType { + /// TLS is used despite not being compiled with the TLS feature enabled. + TlsFeatureNotEnabled, + /// The URL does not include a host name. + NoHostName, + /// Failed to connect with this URL. + UnableToConnect(String), + /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). + UnsupportedUrlScheme, + /// The URL host name, though included, is empty. + EmptyHostName, + /// The URL does not include a path/query. + NoPathOrQuery, +} + +impl fmt::Display for UrlErrorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + UrlErrorType::TlsFeatureNotEnabled => write!(f, "TLS support not compiled in"), + UrlErrorType::NoHostName => write!(f, "No host name in the URL"), + UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), + UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), + UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), + UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"), + } + } +} diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index e23fb15..05d3a6d 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{Error, ProtocolErrorType, Result}, + error::{CapacityErrorType, Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -46,7 +46,7 @@ impl HandshakeMachine { let read = buf .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) // TODO limit size - .map_err(|_| Error::Capacity("Header too long".into()))? + .map_err(|_| Error::Capacity(CapacityErrorType::HeaderTooLong))? .read_from(&mut self.stream) .no_block()?; match read { diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index dfd0bd5..c435762 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -8,7 +8,7 @@ mod mask; pub use self::frame::{CloseFrame, Frame, FrameHeader}; -use crate::error::{Error, Result}; +use crate::error::{CapacityErrorType, Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; use log::*; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; @@ -133,9 +133,10 @@ impl FrameCodec { // Enforce frame size limit early and make sure `length` // is not too big (fits into `usize`). if length > max_size as u64 { - return Err(Error::Capacity( - format!("Message length too big: {} > {}", length, max_size).into(), - )); + return Err(Error::Capacity(CapacityErrorType::MessageTooLong { + size: length as usize, + max_size, + })); } let input_size = cursor.get_ref().len() as u64 - cursor.position(); @@ -155,7 +156,7 @@ impl FrameCodec { .in_buffer .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) - .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? + .map_err(|_| Error::Capacity(CapacityErrorType::TcpBufferFull))? .read_from(stream)?; if size == 0 { trace!("no frame received"); @@ -206,6 +207,8 @@ impl FrameCodec { #[cfg(test)] mod tests { + use crate::error::{CapacityErrorType, Error}; + use super::{Frame, FrameSocket}; use std::io::Cursor; @@ -266,9 +269,9 @@ mod tests { fn size_limit_hit() { let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); - assert_eq!( - sock.read_frame(Some(5)).unwrap_err().to_string(), - "Space limit exceeded: Message length too big: 7 > 5" - ); + match sock.read_frame(Some(5)) { + Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 7, max_size: 5 })) => {} + _ => panic!(), + } } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index f799dbf..1e9ce42 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -6,7 +6,7 @@ use std::{ }; use super::frame::CloseFrame; -use crate::error::{Error, Result}; +use crate::error::{CapacityErrorType, Error, Result}; mod string_collect { use utf8::DecodeError; @@ -122,9 +122,10 @@ impl IncompleteMessage { let portion_size = tail.as_ref().len(); // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { - return Err(Error::Capacity( - format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), - )); + return Err(Error::Capacity(CapacityErrorType::MessageTooLong { + size: my_size + portion_size, + max_size, + })); } match self.collector { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index be002e8..21465b5 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -669,6 +669,7 @@ impl CheckConnectionReset for Result { #[cfg(test)] mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; + use crate::error::{CapacityErrorType, Error}; use std::{io, io::Cursor}; @@ -711,10 +712,11 @@ mod tests { ]); let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!( - socket.read_message().unwrap_err().to_string(), - "Space limit exceeded: Message too big: 7 + 6 > 10" - ); + + match socket.read_message() { + Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 13, max_size: 10 })) => {} + _ => panic!(), + } } #[test] @@ -722,9 +724,10 @@ mod tests { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!( - socket.read_message().unwrap_err().to_string(), - "Space limit exceeded: Message too big: 0 + 3 > 2" - ); + + match socket.read_message() { + Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 3, max_size: 2 })) => {} + _ => panic!(), + } } } From b85a66196721a1e093111ffb597d8c594a31105e Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 4 Jan 2021 15:20:35 +0000 Subject: [PATCH 05/15] Minor change to wording of README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a0351db..091eac2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ fn main () { let mut websocket = accept(stream.unwrap()).unwrap(); loop { let msg = websocket.read_message().unwrap(); - + // We do not want to send back ping/pong messages. if msg.is_binary() || msg.is_text() { websocket.write_message(msg).unwrap(); @@ -64,7 +64,7 @@ Testing ------- Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for -WebSockets. It is also covered by internal unit tests as good as possible. +WebSockets. It is also covered by internal unit tests as well as possible. Contributing ------------ From 78d59f92661d80ccccdbf72308e68de7f8be058d Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 4 Jan 2021 15:29:56 +0000 Subject: [PATCH 06/15] Escape square brackets in doc comment --- src/protocol/frame/coding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index e726161..a37dcd2 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -143,7 +143,7 @@ pub enum CloseCode { Abnormal, /// Indicates that an endpoint is terminating the connection /// because it has received data within a message that was not - /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] /// data within a text message). Invalid, /// Indicates that an endpoint is terminating the connection From a1b4b2de61e09baf844e4d19687ed0b9263c2e5f Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 20:42:05 +0000 Subject: [PATCH 07/15] Bump version to 0.13.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9ec4595..3e6d205 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.12.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.12.0" +version = "0.13.0" edition = "2018" [features] From e6d66698a325cb97ea2a1c01c7ccb8baee549bfd Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:04:41 +0000 Subject: [PATCH 08/15] Use thiserror to streamline the implementation of the main Error type --- Cargo.toml | 1 + src/error.rs | 62 ++++++++++++++-------------------------------------- 2 files changed, 18 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3e6d205..cbab804 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rand = "0.8.0" sha-1 = "0.9" url = "2.1.0" utf-8 = "0.7.5" +thiserror = "1.0.23" [dependencies.native-tls] optional = true diff --git a/src/error.rs b/src/error.rs index 0f00064..220aed4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,10 @@ //! Error handling. -use std::{error::Error as ErrorTrait, fmt, io, result, str, string}; +use std::{fmt, io, result, str, string}; use crate::protocol::{frame::coding::Data, Message}; use http::Response; +use thiserror::Error; #[cfg(feature = "tls")] pub mod tls { @@ -15,7 +16,7 @@ pub mod tls { pub type Result = result::Result; /// Possible WebSocket errors. -#[derive(Debug)] +#[derive(Error, Debug)] pub enum Error { /// WebSocket connection closed normally. This informs you of the close. /// It's not an error as such and nothing wrong happened. @@ -28,6 +29,7 @@ pub enum Error { /// /// Receiving this error means that the WebSocket object is not usable anymore and the /// only meaningful action with it is dropping it. + #[error("Connection closed normally")] ConnectionClosed, /// Trying to work with already closed connection. /// @@ -36,56 +38,39 @@ pub enum Error { /// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the /// connection when it really shouldn't anymore, so this really indicates a programmer /// error on your part. + #[error("Trying to work with closed connection")] AlreadyClosed, /// Input-output error. Apart from WouldBlock, these are generally errors with the /// underlying connection and you should probably consider them fatal. - Io(io::Error), + #[error("IO error: {0}")] + Io(#[from] io::Error), #[cfg(feature = "tls")] /// TLS error. - Tls(tls::Error), + #[error("TLS error: {0}")] + Tls(#[from] tls::Error), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). + #[error("Space limit exceeded: {0}")] Capacity(CapacityErrorType), /// Protocol violation. + #[error("WebSocket protocol error: {0}")] Protocol(ProtocolErrorType), /// Message send queue full. + #[error("Send queue is full")] SendQueueFull(Message), /// UTF coding error + #[error("UTF-8 encoding error")] Utf8, /// Invalid URL. + #[error("URL error: {0}")] Url(UrlErrorType), /// HTTP error. + #[error("HTTP error: {}", .0.status())] Http(Response>), /// HTTP format error. - HttpFormat(http::Error), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::ConnectionClosed => write!(f, "Connection closed normally"), - Error::AlreadyClosed => write!(f, "Trying to work with closed connection"), - Error::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(feature = "tls")] - Error::Tls(ref err) => write!(f, "TLS error: {}", err), - Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), - Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), - Error::SendQueueFull(_) => write!(f, "Send queue is full"), - Error::Utf8 => write!(f, "UTF-8 encoding error"), - Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), - Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), - } - } -} - -impl ErrorTrait for Error {} - -impl From for Error { - fn from(err: io::Error) -> Self { - Error::Io(err) - } + #[error("HTTP format error: {0}")] + HttpFormat(#[from] http::Error), } impl From for Error { @@ -130,19 +115,6 @@ impl From for Error { } } -impl From for Error { - fn from(err: http::Error) -> Self { - Error::HttpFormat(err) - } -} - -#[cfg(feature = "tls")] -impl From for Error { - fn from(err: tls::Error) -> Self { - Error::Tls(err) - } -} - impl From for Error { fn from(err: httparse::Error) -> Self { match err { From 652a6b776eb80f8c9b9a98f4429db502c8310bf2 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:12:33 +0000 Subject: [PATCH 09/15] Rename CapacityErrorType to just CapacityError, implement using thiserror --- src/error.rs | 25 ++++++++----------------- src/handshake/machine.rs | 4 ++-- src/protocol/frame/mod.rs | 10 +++++----- src/protocol/message.rs | 4 ++-- src/protocol/mod.rs | 6 +++--- 5 files changed, 20 insertions(+), 29 deletions(-) diff --git a/src/error.rs b/src/error.rs index 220aed4..8ddff52 100644 --- a/src/error.rs +++ b/src/error.rs @@ -52,7 +52,7 @@ pub enum Error { /// - When writing: your message is bigger than the configured max message size /// (64MB by default). #[error("Space limit exceeded: {0}")] - Capacity(CapacityErrorType), + Capacity(CapacityError), /// Protocol violation. #[error("WebSocket protocol error: {0}")] Protocol(ProtocolErrorType), @@ -118,20 +118,23 @@ impl From for Error { impl From for Error { fn from(err: httparse::Error) -> Self { match err { - httparse::Error::TooManyHeaders => Error::Capacity(CapacityErrorType::TooManyHeaders), + httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders), e => Error::Protocol(ProtocolErrorType::HttparseError(e)), } } } /// Indicates the specific type/cause of a capacity error. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum CapacityErrorType { +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum CapacityError { /// Too many headers provided (see [`httparse::Error::TooManyHeaders`]). + #[error("Too many headers")] TooManyHeaders, /// Received header is too long. + #[error("Header too long")] HeaderTooLong, /// Message is bigger than the maximum allowed size. + #[error("Message too long: {size} > {max_size}")] MessageTooLong { /// The size of the message. size: usize, @@ -139,22 +142,10 @@ pub enum CapacityErrorType { max_size: usize, }, /// TCP buffer is full. + #[error("Incoming TCP buffer is full")] TcpBufferFull, } -impl fmt::Display for CapacityErrorType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - CapacityErrorType::TooManyHeaders => write!(f, "Too many headers"), - CapacityErrorType::HeaderTooLong => write!(f, "Header too long"), - CapacityErrorType::MessageTooLong { size, max_size } => { - write!(f, "Message too long: {} > {}", size, max_size) - } - CapacityErrorType::TcpBufferFull => write!(f, "Incoming TCP buffer is full"), - } - } -} - /// Indicates the specific type/cause of a protocol error. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ProtocolErrorType { diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 05d3a6d..78521a0 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{CapacityErrorType, Error, ProtocolErrorType, Result}, + error::{CapacityError, Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -46,7 +46,7 @@ impl HandshakeMachine { let read = buf .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) // TODO limit size - .map_err(|_| Error::Capacity(CapacityErrorType::HeaderTooLong))? + .map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))? .read_from(&mut self.stream) .no_block()?; match read { diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index c435762..ba190e5 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -8,7 +8,7 @@ mod mask; pub use self::frame::{CloseFrame, Frame, FrameHeader}; -use crate::error::{CapacityErrorType, Error, Result}; +use crate::error::{CapacityError, Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; use log::*; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; @@ -133,7 +133,7 @@ impl FrameCodec { // Enforce frame size limit early and make sure `length` // is not too big (fits into `usize`). if length > max_size as u64 { - return Err(Error::Capacity(CapacityErrorType::MessageTooLong { + return Err(Error::Capacity(CapacityError::MessageTooLong { size: length as usize, max_size, })); @@ -156,7 +156,7 @@ impl FrameCodec { .in_buffer .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) - .map_err(|_| Error::Capacity(CapacityErrorType::TcpBufferFull))? + .map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))? .read_from(stream)?; if size == 0 { trace!("no frame received"); @@ -207,7 +207,7 @@ impl FrameCodec { #[cfg(test)] mod tests { - use crate::error::{CapacityErrorType, Error}; + use crate::error::{CapacityError, Error}; use super::{Frame, FrameSocket}; @@ -270,7 +270,7 @@ mod tests { let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); match sock.read_frame(Some(5)) { - Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 7, max_size: 5 })) => {} + Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) => {} _ => panic!(), } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 1e9ce42..6720c3c 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -6,7 +6,7 @@ use std::{ }; use super::frame::CloseFrame; -use crate::error::{CapacityErrorType, Error, Result}; +use crate::error::{CapacityError, Error, Result}; mod string_collect { use utf8::DecodeError; @@ -122,7 +122,7 @@ impl IncompleteMessage { let portion_size = tail.as_ref().len(); // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { - return Err(Error::Capacity(CapacityErrorType::MessageTooLong { + return Err(Error::Capacity(CapacityError::MessageTooLong { size: my_size + portion_size, max_size, })); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 21465b5..a166a11 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -669,7 +669,7 @@ impl CheckConnectionReset for Result { #[cfg(test)] mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use crate::error::{CapacityErrorType, Error}; + use crate::error::{CapacityError, Error}; use std::{io, io::Cursor}; @@ -714,7 +714,7 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); match socket.read_message() { - Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 13, max_size: 10 })) => {} + Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) => {} _ => panic!(), } } @@ -726,7 +726,7 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); match socket.read_message() { - Err(Error::Capacity(CapacityErrorType::MessageTooLong { size: 3, max_size: 2 })) => {} + Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) => {} _ => panic!(), } } From 98377cf3dddc6f7a8e5d2fa5aa3d3018e68a95b4 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:23:03 +0000 Subject: [PATCH 10/15] Rename ProtocolErrorType to just ProtocolError, implement using thiserror --- src/error.rs | 101 +++++++++++------------------------- src/handshake/client.rs | 14 ++--- src/handshake/machine.rs | 4 +- src/handshake/server.rs | 24 ++++----- src/protocol/frame/frame.rs | 6 +-- src/protocol/mod.rs | 26 +++++----- 6 files changed, 66 insertions(+), 109 deletions(-) diff --git a/src/error.rs b/src/error.rs index 8ddff52..29a944c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,7 +55,7 @@ pub enum Error { Capacity(CapacityError), /// Protocol violation. #[error("WebSocket protocol error: {0}")] - Protocol(ProtocolErrorType), + Protocol(ProtocolError), /// Message send queue full. #[error("Send queue is full")] SendQueueFull(Message), @@ -119,7 +119,7 @@ impl From for Error { fn from(err: httparse::Error) -> Self { match err { httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders), - e => Error::Protocol(ProtocolErrorType::HttparseError(e)), + e => Error::Protocol(ProtocolError::HttparseError(e)), } } } @@ -147,126 +147,85 @@ pub enum CapacityError { } /// Indicates the specific type/cause of a protocol error. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ProtocolErrorType { +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolError { /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + #[error("Unsupported HTTP method used - only GET is allowed")] WrongHttpMethod, /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + #[error("HTTP version must be 1.1 or higher")] WrongHttpVersion, /// Missing `Connection: upgrade` HTTP header. + #[error("No \"Connection: upgrade\" header")] MissingConnectionUpgradeHeader, /// Missing `Upgrade: websocket` HTTP header. + #[error("No \"Upgrade: websocket\" header")] MissingUpgradeWebSocketHeader, /// Missing `Sec-WebSocket-Version: 13` HTTP header. + #[error("No \"Sec-WebSocket-Version: 13\" header")] MissingSecWebSocketVersionHeader, /// Missing `Sec-WebSocket-Key` HTTP header. + #[error("No \"Sec-WebSocket-Key\" header")] MissingSecWebSocketKey, /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + #[error("Key mismatch in \"Sec-WebSocket-Accept\" header")] SecWebSocketAcceptKeyMismatch, /// Garbage data encountered after client request. + #[error("Junk after client request")] JunkAfterRequest, /// Custom responses must be unsuccessful. + #[error("Custom response must not be successful")] CustomResponseSuccessful, /// No more data while still performing handshake. + #[error("Handshake not finished")] HandshakeIncomplete, /// Wrapper around a [`httparse::Error`] value. - HttparseError(httparse::Error), + #[error("httparse error: {0}")] + HttparseError(#[from] httparse::Error), /// Not allowed to send after having sent a closing frame. + #[error("Sending after closing is not allowed")] SendAfterClosing, /// Remote sent data after sending a closing frame. + #[error("Remote sent after having closed")] ReceivedAfterClosing, /// Reserved bits in frame header are non-zero. + #[error("Reserved bits are non-zero")] NonZeroReservedBits, /// The server must close the connection when an unmasked frame is received. + #[error("Received an unmasked frame from client")] UnmaskedFrameFromClient, /// The client must close the connection when a masked frame is received. + #[error("Received a masked frame from server")] MaskedFrameFromServer, /// Control frames must not be fragmented. + #[error("Fragmented control frame")] FragmentedControlFrame, /// Control frames must have a payload of 125 bytes or less. + #[error("Control frame too big (payload must be 125 bytes or less)")] ControlFrameTooBig, /// Type of control frame not recognised. + #[error("Unknown control frame type: {0}")] UnknownControlFrameType(u8), /// Type of data frame not recognised. + #[error("Unknown data frame type: {0}")] UnknownDataFrameType(u8), /// Received a continue frame despite there being nothing to continue. + #[error("Continue frame but nothing to continue")] UnexpectedContinueFrame, /// Received data while waiting for more fragments. + #[error("While waiting for more fragments received: {0}")] ExpectedFragment(Data), /// Connection closed without performing the closing handshake. + #[error("Connection reset without closing handshake")] ResetWithoutClosingHandshake, /// Encountered an invalid opcode. + #[error("Encountered invalid opcode: {0}")] InvalidOpcode(u8), /// The payload for the closing frame is invalid. + #[error("Invalid close sequence")] InvalidCloseSequence, } -impl fmt::Display for ProtocolErrorType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ProtocolErrorType::WrongHttpMethod => { - write!(f, "Unsupported HTTP method used, only GET is allowed") - } - ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"), - ProtocolErrorType::MissingConnectionUpgradeHeader => { - write!(f, "No \"Connection: upgrade\" header") - } - ProtocolErrorType::MissingUpgradeWebSocketHeader => { - write!(f, "No \"Upgrade: websocket\" header") - } - ProtocolErrorType::MissingSecWebSocketVersionHeader => { - write!(f, "No \"Sec-WebSocket-Version: 13\" header") - } - ProtocolErrorType::MissingSecWebSocketKey => { - write!(f, "No \"Sec-WebSocket-Key\" header") - } - ProtocolErrorType::SecWebSocketAcceptKeyMismatch => { - write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header") - } - ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"), - ProtocolErrorType::CustomResponseSuccessful => { - write!(f, "Custom response must not be successful") - } - ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"), - ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e), - ProtocolErrorType::SendAfterClosing => { - write!(f, "Sending after closing is not allowed") - } - ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"), - ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"), - ProtocolErrorType::UnmaskedFrameFromClient => { - write!(f, "Received an unmasked frame from client") - } - ProtocolErrorType::MaskedFrameFromServer => { - write!(f, "Received a masked frame from server") - } - ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"), - ProtocolErrorType::ControlFrameTooBig => { - write!(f, "Control frame too big (payload must be 125 bytes or less)") - } - ProtocolErrorType::UnknownControlFrameType(i) => { - write!(f, "Unknown control frame type: {}", i) - } - ProtocolErrorType::UnknownDataFrameType(i) => { - write!(f, "Unknown data frame type: {}", i) - } - ProtocolErrorType::UnexpectedContinueFrame => { - write!(f, "Continue frame but nothing to continue") - } - ProtocolErrorType::ExpectedFragment(c) => { - write!(f, "While waiting for more fragments received: {}", c) - } - ProtocolErrorType::ResetWithoutClosingHandshake => { - write!(f, "Connection reset without closing handshake") - } - ProtocolErrorType::InvalidOpcode(opcode) => { - write!(f, "Encountered invalid opcode: {}", opcode) - } - ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"), - } - } -} - /// Indicates the specific type/cause of URL error. #[derive(Debug, PartialEq, Eq)] pub enum UrlErrorType { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 5b9f934..fc4c579 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolErrorType, Result, UrlErrorType}, + error::{Error, ProtocolError, Result, UrlErrorType}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -42,11 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } // Check the URI scheme: only ws or wss are supported @@ -163,7 +163,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -175,14 +175,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch)); + return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -216,7 +216,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } let headers = HeaderMap::from_httparse(raw.headers)?; diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 78521a0..ced0153 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{CapacityError, Error, ProtocolErrorType, Result}, + error::{CapacityError, Error, ProtocolError, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -50,7 +50,7 @@ impl HandshakeMachine { .read_from(&mut self.stream) .no_block()?; match read { - Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)), + Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 372bff7..f80c11b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -19,7 +19,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolErrorType, Result}, + error::{Error, ProtocolError, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse>; fn create_parts(request: &HttpRequest) -> Result { if request.method() != http::Method::GET { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } if !request @@ -48,7 +48,7 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } if !request @@ -58,17 +58,17 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader)); + return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader)); } let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or(Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; + .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) @@ -125,11 +125,11 @@ impl TryParse for Request { impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -237,7 +237,7 @@ impl HandshakeRole for ServerHandshake { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { - return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest)); + return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); } let response = create_response(&result)?; @@ -256,9 +256,7 @@ impl HandshakeRole for ServerHandshake { Err(resp) => { if resp.status().is_success() { - return Err(Error::Protocol( - ProtocolErrorType::CustomResponseSuccessful, - )); + return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); } self.error_response = Some(resp); diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 27281b6..986bba0 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -13,7 +13,7 @@ use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, }; -use crate::error::{Error, ProtocolErrorType, Result}; +use crate::error::{Error, ProtocolError, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,7 +186,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F))) + return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) } _ => (), } @@ -284,7 +284,7 @@ impl Frame { pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), - 1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)), + 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { let mut data = self.payload; let code = NetworkEndian::read_u16(&data[0..2]).into(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index a166a11..b7dc177 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,7 +21,7 @@ use self::{ message::{IncompleteMessage, IncompleteMessageType}, }; use crate::{ - error::{Error, ProtocolErrorType, Result}, + error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; @@ -331,7 +331,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing)); + return Err(Error::Protocol(ProtocolError::SendAfterClosing)); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -431,7 +431,7 @@ impl WebSocketContext { .check_connection_reset(self.state)? { if !self.state.can_read() { - return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing)); + return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing)); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -441,7 +441,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits)); + return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); } } @@ -456,13 +456,13 @@ impl WebSocketContext { // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient)); + return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer)); + return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); } } } @@ -473,14 +473,14 @@ impl WebSocketContext { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) _ if !frame.header().is_final => { - Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame)) + Err(Error::Protocol(ProtocolError::FragmentedControlFrame)) } _ if frame.payload().len() > 125 => { - Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig)) + Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Reserved(i) => { - Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i))) + Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { let data = frame.into_data(); @@ -502,7 +502,7 @@ impl WebSocketContext { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol( - ProtocolErrorType::UnexpectedContinueFrame, + ProtocolError::UnexpectedContinueFrame, )); } if fin { @@ -512,7 +512,7 @@ impl WebSocketContext { } } c if self.incomplete.is_some() => { - Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c))) + Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) } OpData::Text | OpData::Binary => { let msg = { @@ -533,7 +533,7 @@ impl WebSocketContext { } } OpData::Reserved(i) => { - Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i))) + Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) } } } @@ -544,7 +544,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)), + _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)), } } } From 3e485ddb96de9b9ed678cc784ef57b091a63a57d Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:31:03 +0000 Subject: [PATCH 11/15] Rename UrlErrorType to just UrlError, implement using thiserror --- src/client.rs | 14 +++++++------- src/error.rs | 31 ++++++++++++------------------- src/handshake/client.rs | 8 ++++---- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/client.rs b/src/client.rs index 41a813d..5ed89cf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ mod encryption { use std::net::TcpStream; use crate::{ - error::{Error, Result, UrlErrorType}, + error::{Error, Result, UrlError}, stream::Mode, }; @@ -62,7 +62,7 @@ mod encryption { pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { match mode { Mode::Plain => Ok(stream), - Mode::Tls => Err(Error::Url(UrlErrorType::TlsFeatureNotEnabled)), + Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), } } } @@ -71,7 +71,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::{ - error::{Error, Result, UrlErrorType}, + error::{Error, Result, UrlError}, handshake::{client::ClientHandshake, HandshakeError}, protocol::WebSocket, stream::{Mode, NoDelay}, @@ -103,7 +103,7 @@ pub fn connect_with_config( ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = request.uri().host().ok_or(Error::Url(UrlErrorType::NoHostName))?; + let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -165,7 +165,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri.host().ok_or(Error::Url(UrlErrorType::NoHostName))?; + let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { @@ -174,7 +174,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result Result { match uri.scheme_str() { Some("ws") => Ok(Mode::Plain), Some("wss") => Ok(Mode::Tls), - _ => Err(Error::Url(UrlErrorType::UnsupportedUrlScheme)), + _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)), } } diff --git a/src/error.rs b/src/error.rs index 29a944c..f4dfdf1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ //! Error handling. -use std::{fmt, io, result, str, string}; +use std::{io, result, str, string}; use crate::protocol::{frame::coding::Data, Message}; use http::Response; @@ -44,8 +44,8 @@ pub enum Error { /// underlying connection and you should probably consider them fatal. #[error("IO error: {0}")] Io(#[from] io::Error), - #[cfg(feature = "tls")] /// TLS error. + #[cfg(feature = "tls")] #[error("TLS error: {0}")] Tls(#[from] tls::Error), /// - When reading: buffer capacity exhausted. @@ -59,12 +59,12 @@ pub enum Error { /// Message send queue full. #[error("Send queue is full")] SendQueueFull(Message), - /// UTF coding error + /// UTF coding error. #[error("UTF-8 encoding error")] Utf8, /// Invalid URL. #[error("URL error: {0}")] - Url(UrlErrorType), + Url(UrlError), /// HTTP error. #[error("HTTP error: {}", .0.status())] Http(Response>), @@ -227,31 +227,24 @@ pub enum ProtocolError { } /// Indicates the specific type/cause of URL error. -#[derive(Debug, PartialEq, Eq)] -pub enum UrlErrorType { +#[derive(Error, Debug, PartialEq, Eq)] +pub enum UrlError { /// TLS is used despite not being compiled with the TLS feature enabled. + #[error("TLS support not compiled in")] TlsFeatureNotEnabled, /// The URL does not include a host name. + #[error("No host name in the URL")] NoHostName, /// Failed to connect with this URL. + #[error("Unable to connect to {0}")] UnableToConnect(String), /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). + #[error("URL scheme not supported")] UnsupportedUrlScheme, /// The URL host name, though included, is empty. + #[error("URL contains empty host name")] EmptyHostName, /// The URL does not include a path/query. + #[error("No path/query in URL")] NoPathOrQuery, } - -impl fmt::Display for UrlErrorType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - UrlErrorType::TlsFeatureNotEnabled => write!(f, "TLS support not compiled in"), - UrlErrorType::NoHostName => write!(f, "No host name in the URL"), - UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), - UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), - UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), - UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"), - } - } -} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index fc4c579..92e5477 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolError, Result, UrlErrorType}, + error::{Error, ProtocolError, Result, UrlError}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -97,7 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority().ok_or(Error::Url(UrlErrorType::NoHostName))?.as_str(); + let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -105,7 +105,7 @@ fn generate_request(request: Request, key: &str) -> Result> { authority }; if authority.is_empty() { - return Err(Error::Url(UrlErrorType::EmptyHostName)); + return Err(Error::Url(UrlError::EmptyHostName)); } write!( @@ -119,7 +119,7 @@ fn generate_request(request: Request, key: &str) -> Result> { Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = uri.path_and_query().ok_or(Error::Url(UrlErrorType::NoPathOrQuery))?.as_str(), + path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), key = key ) .unwrap(); From f4aa926092291a5ee824c41a15d48a2087e44810 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:35:06 +0000 Subject: [PATCH 12/15] Change ProtocolErrorType to ProtocolError in test --- tests/no_send_after_close.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index d3b6943..a9dcab2 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -8,7 +8,7 @@ use std::{ time::Duration, }; -use tungstenite::{accept, connect, error::ProtocolErrorType, Error, Message}; +use tungstenite::{accept, connect, error::ProtocolError, Error, Message}; use url::Url; #[test] @@ -46,7 +46,7 @@ fn test_no_send_after_close() { assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => assert_eq!(s, ProtocolErrorType::SendAfterClosing), + Error::Protocol(s) => assert_eq!(s, ProtocolError::SendAfterClosing), e => panic!("unexpected error: {:?}", e), } From 7c9e684dedd42e27b109545f6d54f27e65b88720 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:46:00 +0000 Subject: [PATCH 13/15] Add entry to CHANGELOG.md for version 0.13.0 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b78c43b..1bb2516 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 0.13.0 + +- Add `thiserror` to dependencies. +- Modify `Error` to be implemented using the `thiserror::Error` derive macro. +- Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively. +- Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages. + # 0.12.0 - Add facilities to allow clients to follow HTTP 3XX redirects. From 8b88fb2444851d30d10af3b23f637d9b74af59c5 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 11 Jan 2021 12:09:07 +0000 Subject: [PATCH 14/15] Remove unnecessary changelog lines --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bb2516..8939190 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,5 @@ # 0.13.0 -- Add `thiserror` to dependencies. -- Modify `Error` to be implemented using the `thiserror::Error` derive macro. - Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively. - Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages. From 79dcf9f77c04133608d1e07ba1d555d93d11de4a Mon Sep 17 00:00:00 2001 From: WiredSound Date: Mon, 11 Jan 2021 12:21:40 +0000 Subject: [PATCH 15/15] Use matches! macro in tests --- src/protocol/frame/mod.rs | 8 ++++---- src/protocol/mod.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index ba190e5..1e41853 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -269,9 +269,9 @@ mod tests { fn size_limit_hit() { let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); - match sock.read_frame(Some(5)) { - Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) => {} - _ => panic!(), - } + assert!(matches!( + sock.read_frame(Some(5)), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) + )); } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index b7dc177..215b061 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -713,10 +713,10 @@ mod tests { let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - match socket.read_message() { - Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) => {} - _ => panic!(), - } + assert!(matches!( + socket.read_message(), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) + )); } #[test] @@ -725,9 +725,9 @@ mod tests { let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - match socket.read_message() { - Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) => {} - _ => panic!(), - } + assert!(matches!( + socket.read_message(), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) + )); } }