From 98377cf3dddc6f7a8e5d2fa5aa3d3018e68a95b4 Mon Sep 17 00:00:00 2001 From: WiredSound Date: Sat, 9 Jan 2021 21:23:03 +0000 Subject: [PATCH] Rename ProtocolErrorType to just ProtocolError, implement using thiserror --- src/error.rs | 101 +++++++++++------------------------- src/handshake/client.rs | 14 ++--- src/handshake/machine.rs | 4 +- src/handshake/server.rs | 24 ++++----- src/protocol/frame/frame.rs | 6 +-- src/protocol/mod.rs | 26 +++++----- 6 files changed, 66 insertions(+), 109 deletions(-) diff --git a/src/error.rs b/src/error.rs index 8ddff52..29a944c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,7 +55,7 @@ pub enum Error { Capacity(CapacityError), /// Protocol violation. #[error("WebSocket protocol error: {0}")] - Protocol(ProtocolErrorType), + Protocol(ProtocolError), /// Message send queue full. #[error("Send queue is full")] SendQueueFull(Message), @@ -119,7 +119,7 @@ impl From for Error { fn from(err: httparse::Error) -> Self { match err { httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders), - e => Error::Protocol(ProtocolErrorType::HttparseError(e)), + e => Error::Protocol(ProtocolError::HttparseError(e)), } } } @@ -147,126 +147,85 @@ pub enum CapacityError { } /// Indicates the specific type/cause of a protocol error. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ProtocolErrorType { +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolError { /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + #[error("Unsupported HTTP method used - only GET is allowed")] WrongHttpMethod, /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + #[error("HTTP version must be 1.1 or higher")] WrongHttpVersion, /// Missing `Connection: upgrade` HTTP header. + #[error("No \"Connection: upgrade\" header")] MissingConnectionUpgradeHeader, /// Missing `Upgrade: websocket` HTTP header. + #[error("No \"Upgrade: websocket\" header")] MissingUpgradeWebSocketHeader, /// Missing `Sec-WebSocket-Version: 13` HTTP header. + #[error("No \"Sec-WebSocket-Version: 13\" header")] MissingSecWebSocketVersionHeader, /// Missing `Sec-WebSocket-Key` HTTP header. + #[error("No \"Sec-WebSocket-Key\" header")] MissingSecWebSocketKey, /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + #[error("Key mismatch in \"Sec-WebSocket-Accept\" header")] SecWebSocketAcceptKeyMismatch, /// Garbage data encountered after client request. + #[error("Junk after client request")] JunkAfterRequest, /// Custom responses must be unsuccessful. + #[error("Custom response must not be successful")] CustomResponseSuccessful, /// No more data while still performing handshake. + #[error("Handshake not finished")] HandshakeIncomplete, /// Wrapper around a [`httparse::Error`] value. - HttparseError(httparse::Error), + #[error("httparse error: {0}")] + HttparseError(#[from] httparse::Error), /// Not allowed to send after having sent a closing frame. + #[error("Sending after closing is not allowed")] SendAfterClosing, /// Remote sent data after sending a closing frame. + #[error("Remote sent after having closed")] ReceivedAfterClosing, /// Reserved bits in frame header are non-zero. + #[error("Reserved bits are non-zero")] NonZeroReservedBits, /// The server must close the connection when an unmasked frame is received. + #[error("Received an unmasked frame from client")] UnmaskedFrameFromClient, /// The client must close the connection when a masked frame is received. + #[error("Received a masked frame from server")] MaskedFrameFromServer, /// Control frames must not be fragmented. + #[error("Fragmented control frame")] FragmentedControlFrame, /// Control frames must have a payload of 125 bytes or less. + #[error("Control frame too big (payload must be 125 bytes or less)")] ControlFrameTooBig, /// Type of control frame not recognised. + #[error("Unknown control frame type: {0}")] UnknownControlFrameType(u8), /// Type of data frame not recognised. + #[error("Unknown data frame type: {0}")] UnknownDataFrameType(u8), /// Received a continue frame despite there being nothing to continue. + #[error("Continue frame but nothing to continue")] UnexpectedContinueFrame, /// Received data while waiting for more fragments. + #[error("While waiting for more fragments received: {0}")] ExpectedFragment(Data), /// Connection closed without performing the closing handshake. + #[error("Connection reset without closing handshake")] ResetWithoutClosingHandshake, /// Encountered an invalid opcode. + #[error("Encountered invalid opcode: {0}")] InvalidOpcode(u8), /// The payload for the closing frame is invalid. + #[error("Invalid close sequence")] InvalidCloseSequence, } -impl fmt::Display for ProtocolErrorType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ProtocolErrorType::WrongHttpMethod => { - write!(f, "Unsupported HTTP method used, only GET is allowed") - } - ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"), - ProtocolErrorType::MissingConnectionUpgradeHeader => { - write!(f, "No \"Connection: upgrade\" header") - } - ProtocolErrorType::MissingUpgradeWebSocketHeader => { - write!(f, "No \"Upgrade: websocket\" header") - } - ProtocolErrorType::MissingSecWebSocketVersionHeader => { - write!(f, "No \"Sec-WebSocket-Version: 13\" header") - } - ProtocolErrorType::MissingSecWebSocketKey => { - write!(f, "No \"Sec-WebSocket-Key\" header") - } - ProtocolErrorType::SecWebSocketAcceptKeyMismatch => { - write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header") - } - ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"), - ProtocolErrorType::CustomResponseSuccessful => { - write!(f, "Custom response must not be successful") - } - ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"), - ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e), - ProtocolErrorType::SendAfterClosing => { - write!(f, "Sending after closing is not allowed") - } - ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"), - ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"), - ProtocolErrorType::UnmaskedFrameFromClient => { - write!(f, "Received an unmasked frame from client") - } - ProtocolErrorType::MaskedFrameFromServer => { - write!(f, "Received a masked frame from server") - } - ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"), - ProtocolErrorType::ControlFrameTooBig => { - write!(f, "Control frame too big (payload must be 125 bytes or less)") - } - ProtocolErrorType::UnknownControlFrameType(i) => { - write!(f, "Unknown control frame type: {}", i) - } - ProtocolErrorType::UnknownDataFrameType(i) => { - write!(f, "Unknown data frame type: {}", i) - } - ProtocolErrorType::UnexpectedContinueFrame => { - write!(f, "Continue frame but nothing to continue") - } - ProtocolErrorType::ExpectedFragment(c) => { - write!(f, "While waiting for more fragments received: {}", c) - } - ProtocolErrorType::ResetWithoutClosingHandshake => { - write!(f, "Connection reset without closing handshake") - } - ProtocolErrorType::InvalidOpcode(opcode) => { - write!(f, "Encountered invalid opcode: {}", opcode) - } - ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"), - } - } -} - /// Indicates the specific type/cause of URL error. #[derive(Debug, PartialEq, Eq)] pub enum UrlErrorType { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 5b9f934..fc4c579 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolErrorType, Result, UrlErrorType}, + error::{Error, ProtocolError, Result, UrlErrorType}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -42,11 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } // Check the URI scheme: only ws or wss are supported @@ -163,7 +163,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -175,14 +175,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch)); + return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -216,7 +216,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } let headers = HeaderMap::from_httparse(raw.headers)?; diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 78521a0..ced0153 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{CapacityError, Error, ProtocolErrorType, Result}, + error::{CapacityError, Error, ProtocolError, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -50,7 +50,7 @@ impl HandshakeMachine { .read_from(&mut self.stream) .no_block()?; match read { - Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)), + Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 372bff7..f80c11b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -19,7 +19,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolErrorType, Result}, + error::{Error, ProtocolError, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse>; fn create_parts(request: &HttpRequest) -> Result { if request.method() != http::Method::GET { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } if !request @@ -48,7 +48,7 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } if !request @@ -58,17 +58,17 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { - return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader)); + return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader)); } let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or(Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; + .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) @@ -125,11 +125,11 @@ impl TryParse for Request { impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -237,7 +237,7 @@ impl HandshakeRole for ServerHandshake { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { - return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest)); + return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); } let response = create_response(&result)?; @@ -256,9 +256,7 @@ impl HandshakeRole for ServerHandshake { Err(resp) => { if resp.status().is_success() { - return Err(Error::Protocol( - ProtocolErrorType::CustomResponseSuccessful, - )); + return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); } self.error_response = Some(resp); diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 27281b6..986bba0 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -13,7 +13,7 @@ use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, }; -use crate::error::{Error, ProtocolErrorType, Result}; +use crate::error::{Error, ProtocolError, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,7 +186,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F))) + return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) } _ => (), } @@ -284,7 +284,7 @@ impl Frame { pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), - 1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)), + 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { let mut data = self.payload; let code = NetworkEndian::read_u16(&data[0..2]).into(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index a166a11..b7dc177 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,7 +21,7 @@ use self::{ message::{IncompleteMessage, IncompleteMessageType}, }; use crate::{ - error::{Error, ProtocolErrorType, Result}, + error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; @@ -331,7 +331,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing)); + return Err(Error::Protocol(ProtocolError::SendAfterClosing)); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -431,7 +431,7 @@ impl WebSocketContext { .check_connection_reset(self.state)? { if !self.state.can_read() { - return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing)); + return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing)); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -441,7 +441,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits)); + return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); } } @@ -456,13 +456,13 @@ impl WebSocketContext { // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient)); + return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer)); + return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); } } } @@ -473,14 +473,14 @@ impl WebSocketContext { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) _ if !frame.header().is_final => { - Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame)) + Err(Error::Protocol(ProtocolError::FragmentedControlFrame)) } _ if frame.payload().len() > 125 => { - Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig)) + Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Reserved(i) => { - Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i))) + Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { let data = frame.into_data(); @@ -502,7 +502,7 @@ impl WebSocketContext { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol( - ProtocolErrorType::UnexpectedContinueFrame, + ProtocolError::UnexpectedContinueFrame, )); } if fin { @@ -512,7 +512,7 @@ impl WebSocketContext { } } c if self.incomplete.is_some() => { - Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c))) + Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) } OpData::Text | OpData::Binary => { let msg = { @@ -533,7 +533,7 @@ impl WebSocketContext { } } OpData::Reserved(i) => { - Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i))) + Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) } } } @@ -544,7 +544,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)), + _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)), } } }