diff --git a/src/client.rs b/src/client.rs index 074c7ff..f4c5f8d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ mod encryption { use std::net::TcpStream; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, Result, UrlErrorType}, stream::Mode, }; @@ -71,7 +71,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, Result, UrlErrorType}, handshake::{client::ClientHandshake, HandshakeError}, protocol::WebSocket, stream::{Mode, NoDelay}, @@ -103,8 +103,7 @@ pub fn connect_with_config( ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = - request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; + let host = request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, diff --git a/src/error.rs b/src/error.rs index f9891f5..e11aef3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; -use crate::protocol::Message; +use crate::protocol::{frame::coding::Data, Message}; use http::Response; #[cfg(feature = "tls")] @@ -41,14 +41,14 @@ pub enum Error { /// underlying connection and you should probably consider them fatal. Io(io::Error), #[cfg(feature = "tls")] - /// TLS error + /// TLS error. Tls(tls::Error), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). Capacity(Cow<'static, str>), /// Protocol violation. - Protocol(Cow<'static, str>), + Protocol(ProtocolErrorType), /// Message send queue full. SendQueueFull(Message), /// UTF coding error @@ -147,13 +147,13 @@ impl From for Error { fn from(err: httparse::Error) -> Self { match err { httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), - e => Error::Protocol(e.to_string().into()), + e => Error::Protocol(ProtocolErrorType::HttparseError(e)), } } } /// Indicates the specific type/cause of URL error. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum UrlErrorType { /// TLS is used despite not being compiled with the TLS feature enabled. TlsFeatureNotEnabled, @@ -166,7 +166,7 @@ pub enum UrlErrorType { /// The URL host name, though included, is empty. EmptyHostName, /// The URL does not include a path/query. - NoPathOrQuery + NoPathOrQuery, } impl fmt::Display for UrlErrorType { @@ -177,7 +177,128 @@ impl fmt::Display for UrlErrorType { UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri), UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"), UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"), - UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL") + UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"), + } + } +} + +/// Indicates the specific type/cause of a protocol error. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolErrorType { + /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + WrongHttpMethod, + /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + WrongHttpVersion, + /// Missing `Connection: upgrade` HTTP header. + MissingConnectionUpgradeHeader, + /// Missing `Upgrade: websocket` HTTP header. + MissingUpgradeWebSocketHeader, + /// Missing `Sec-WebSocket-Version: 13` HTTP header. + MissingSecWebSocketVersionHeader, + /// Missing `Sec-WebSocket-Key` HTTP header. + MissingSecWebSocketKey, + /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + SecWebSocketAcceptKeyMismatch, + /// Garbage data encountered after client request. + JunkAfterRequest, + /// Custom responses must be unsuccessful. + CustomResponseSuccessful, + /// No more data while still performing handshake. + HandshakeIncomplete, + /// Wrapper around a [`httparse::Error`] value. + HttparseError(httparse::Error), + /// Not allowed to send after having sent a closing frame. + SendAfterClosing, + /// Remote sent data after sending a closing frame. + ReceivedAfterClosing, + /// Reserved bits in frame header are non-zero. + NonZeroReservedBits, + /// The server must close the connection when an unmasked frame is received. + UnmaskedFrameFromClient, + /// The client must close the connection when a masked frame is received. + MaskedFrameFromServer, + /// Control frames must not be fragmented. + FragmentedControlFrame, + /// Control frames must have a payload of 125 bytes or less. + ControlFrameTooBig, + /// Type of control frame not recognised. + UnknownControlFrameType(u8), + /// Type of data frame not recognised. + UnknownDataFrameType(u8), + /// Received a continue frame despite there being nothing to continue. + UnexpectedContinueFrame, + /// Received data while waiting for more fragments. + ExpectedFragment(Data), + /// Connection closed without performing the closing handshake. + ResetWithoutClosingHandshake, + /// Encountered an invalid opcode. + InvalidOpcode(u8), + /// The payload for the closing frame is invalid. + InvalidCloseSequence, +} + +impl fmt::Display for ProtocolErrorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ProtocolErrorType::WrongHttpMethod => { + write!(f, "Unsupported HTTP method used, only GET is allowed") + } + ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"), + ProtocolErrorType::MissingConnectionUpgradeHeader => { + write!(f, "No \"Connection: upgrade\" header") + } + ProtocolErrorType::MissingUpgradeWebSocketHeader => { + write!(f, "No \"Upgrade: websocket\" header") + } + ProtocolErrorType::MissingSecWebSocketVersionHeader => { + write!(f, "No \"Sec-WebSocket-Version: 13\" header") + } + ProtocolErrorType::MissingSecWebSocketKey => { + write!(f, "No \"Sec-WebSocket-Key\" header") + } + ProtocolErrorType::SecWebSocketAcceptKeyMismatch => { + write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header") + } + ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"), + ProtocolErrorType::CustomResponseSuccessful => { + write!(f, "Custom response must not be successful") + } + ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"), + ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e), + ProtocolErrorType::SendAfterClosing => { + write!(f, "Sending after closing is not allowed") + } + ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"), + ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"), + ProtocolErrorType::UnmaskedFrameFromClient => { + write!(f, "Received an unmasked frame from client") + } + ProtocolErrorType::MaskedFrameFromServer => { + write!(f, "Received a masked frame from server") + } + ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"), + ProtocolErrorType::ControlFrameTooBig => { + write!(f, "Control frame too big (payload must be 125 bytes or less)") + } + ProtocolErrorType::UnknownControlFrameType(i) => { + write!(f, "Unknown control frame type: {}", i) + } + ProtocolErrorType::UnknownDataFrameType(i) => { + write!(f, "Unknown data frame type: {}", i) + } + ProtocolErrorType::UnexpectedContinueFrame => { + write!(f, "Continue frame but nothing to continue") + } + ProtocolErrorType::ExpectedFragment(c) => { + write!(f, "While waiting for more fragments received: {}", c) + } + ProtocolErrorType::ResetWithoutClosingHandshake => { + write!(f, "Connection reset without closing handshake") + } + ProtocolErrorType::InvalidOpcode(opcode) => { + write!(f, "Encountered invalid opcode: {}", opcode) + } + ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"), } } -} \ No newline at end of file +} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 558eb85..e7247d9 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, UrlErrorType, Result}, + error::{Error, ProtocolErrorType, Result, UrlErrorType}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -42,11 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } // Check the URI scheme: only ws or wss are supported @@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = - uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); + let authority = uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -165,7 +164,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -177,14 +176,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { - return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch)); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -218,7 +217,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } let headers = HeaderMap::from_httparse(raw.headers)?; diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 5c7e000..e23fb15 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -50,7 +50,7 @@ impl HandshakeMachine { .read_from(&mut self.stream) .no_block()?; match read { - Some(0) => Err(Error::Protocol("Handshake not finished".into())), + Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 53227ab..c7877bc 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -19,7 +19,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse>; fn create_parts(request: &HttpRequest) -> Result { if request.method() != http::Method::GET { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } if !request @@ -48,7 +48,7 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader)); } if !request @@ -58,17 +58,17 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader)); } if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { - return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); + return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader)); } let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + .ok_or_else(|| Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) @@ -125,11 +125,11 @@ impl TryParse for Request { impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod)); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion)); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -237,7 +237,7 @@ impl HandshakeRole for ServerHandshake { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { - return Err(Error::Protocol("Junk after client request".into())); + return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest)); } let response = create_response(&result)?; @@ -257,7 +257,7 @@ impl HandshakeRole for ServerHandshake { Err(resp) => { if resp.status().is_success() { return Err(Error::Protocol( - "Custom response must not be successful".into(), + ProtocolErrorType::CustomResponseSuccessful, )); } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index ff64fa2..27281b6 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -13,7 +13,7 @@ use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, }; -use crate::error::{Error, Result}; +use crate::error::{Error, ProtocolErrorType, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,9 +186,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol( - format!("Encountered invalid opcode: {}", first & 0x0F).into(), - )) + return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F))) } _ => (), } @@ -286,7 +284,7 @@ impl Frame { pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), - 1 => Err(Error::Protocol("Invalid close sequence".into())), + 1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)), _ => { let mut data = self.payload; let code = NetworkEndian::read_u16(&data[0..2]).into(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 63763f0..be002e8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,7 +21,7 @@ use self::{ message::{IncompleteMessage, IncompleteMessageType}, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolErrorType, Result}, util::NonBlockingResult, }; @@ -331,7 +331,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol("Sending after closing is not allowed".into())); + return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing)); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -431,9 +431,7 @@ impl WebSocketContext { .check_connection_reset(self.state)? { if !self.state.can_read() { - return Err(Error::Protocol( - "Remote sent frame after having sent a Close Frame".into(), - )); + return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing)); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -443,7 +441,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol("Reserved bits are non-zero".into())); + return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits)); } } @@ -458,15 +456,13 @@ impl WebSocketContext { // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol( - "Received an unmasked frame from client".into(), - )); + return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient)); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol("Received a masked frame from server".into())); + return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer)); } } } @@ -477,14 +473,14 @@ impl WebSocketContext { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) _ if !frame.header().is_final => { - Err(Error::Protocol("Fragmented control frame".into())) + Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame)) } _ if frame.payload().len() > 125 => { - Err(Error::Protocol("Control frame too big".into())) + Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig)) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Reserved(i) => { - Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i))) } OpCtl::Ping => { let data = frame.into_data(); @@ -506,7 +502,7 @@ impl WebSocketContext { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol( - "Continue frame but nothing to continue".into(), + ProtocolErrorType::UnexpectedContinueFrame, )); } if fin { @@ -515,9 +511,9 @@ impl WebSocketContext { Ok(None) } } - c if self.incomplete.is_some() => Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into(), - )), + c if self.incomplete.is_some() => { + Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c))) + } OpData::Text | OpData::Binary => { let msg = { let message_type = match data { @@ -537,7 +533,7 @@ impl WebSocketContext { } } OpData::Reserved(i) => { - Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i))) } } } @@ -548,7 +544,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol("Connection reset without closing handshake".into())), + _ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)), } } } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index f348eca..d3b6943 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -8,7 +8,7 @@ use std::{ time::Duration, }; -use tungstenite::{accept, connect, Error, Message}; +use tungstenite::{accept, connect, error::ProtocolErrorType, Error, Message}; use url::Url; #[test] @@ -46,7 +46,7 @@ fn test_no_send_after_close() { assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), + Error::Protocol(s) => assert_eq!(s, ProtocolErrorType::SendAfterClosing), e => panic!("unexpected error: {:?}", e), }