Rename ProtocolErrorType to just ProtocolError, implement using thiserror

pull/168/head
WiredSound 4 years ago
parent 652a6b776e
commit 98377cf3dd
  1. 101
      src/error.rs
  2. 14
      src/handshake/client.rs
  3. 4
      src/handshake/machine.rs
  4. 24
      src/handshake/server.rs
  5. 6
      src/protocol/frame/frame.rs
  6. 26
      src/protocol/mod.rs

@ -55,7 +55,7 @@ pub enum Error {
Capacity(CapacityError),
/// Protocol violation.
#[error("WebSocket protocol error: {0}")]
Protocol(ProtocolErrorType),
Protocol(ProtocolError),
/// Message send queue full.
#[error("Send queue is full")]
SendQueueFull(Message),
@ -119,7 +119,7 @@ impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self {
match err {
httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders),
e => Error::Protocol(ProtocolErrorType::HttparseError(e)),
e => Error::Protocol(ProtocolError::HttparseError(e)),
}
}
}
@ -147,126 +147,85 @@ pub enum CapacityError {
}
/// Indicates the specific type/cause of a protocol error.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ProtocolErrorType {
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
pub enum ProtocolError {
/// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
#[error("Unsupported HTTP method used - only GET is allowed")]
WrongHttpMethod,
/// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher).
#[error("HTTP version must be 1.1 or higher")]
WrongHttpVersion,
/// Missing `Connection: upgrade` HTTP header.
#[error("No \"Connection: upgrade\" header")]
MissingConnectionUpgradeHeader,
/// Missing `Upgrade: websocket` HTTP header.
#[error("No \"Upgrade: websocket\" header")]
MissingUpgradeWebSocketHeader,
/// Missing `Sec-WebSocket-Version: 13` HTTP header.
#[error("No \"Sec-WebSocket-Version: 13\" header")]
MissingSecWebSocketVersionHeader,
/// Missing `Sec-WebSocket-Key` HTTP header.
#[error("No \"Sec-WebSocket-Key\" header")]
MissingSecWebSocketKey,
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
#[error("Key mismatch in \"Sec-WebSocket-Accept\" header")]
SecWebSocketAcceptKeyMismatch,
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,
/// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")]
CustomResponseSuccessful,
/// No more data while still performing handshake.
#[error("Handshake not finished")]
HandshakeIncomplete,
/// Wrapper around a [`httparse::Error`] value.
HttparseError(httparse::Error),
#[error("httparse error: {0}")]
HttparseError(#[from] httparse::Error),
/// Not allowed to send after having sent a closing frame.
#[error("Sending after closing is not allowed")]
SendAfterClosing,
/// Remote sent data after sending a closing frame.
#[error("Remote sent after having closed")]
ReceivedAfterClosing,
/// Reserved bits in frame header are non-zero.
#[error("Reserved bits are non-zero")]
NonZeroReservedBits,
/// The server must close the connection when an unmasked frame is received.
#[error("Received an unmasked frame from client")]
UnmaskedFrameFromClient,
/// The client must close the connection when a masked frame is received.
#[error("Received a masked frame from server")]
MaskedFrameFromServer,
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
/// Type of control frame not recognised.
#[error("Unknown control frame type: {0}")]
UnknownControlFrameType(u8),
/// Type of data frame not recognised.
#[error("Unknown data frame type: {0}")]
UnknownDataFrameType(u8),
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
/// Connection closed without performing the closing handshake.
#[error("Connection reset without closing handshake")]
ResetWithoutClosingHandshake,
/// Encountered an invalid opcode.
#[error("Encountered invalid opcode: {0}")]
InvalidOpcode(u8),
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
}
impl fmt::Display for ProtocolErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ProtocolErrorType::WrongHttpMethod => {
write!(f, "Unsupported HTTP method used, only GET is allowed")
}
ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"),
ProtocolErrorType::MissingConnectionUpgradeHeader => {
write!(f, "No \"Connection: upgrade\" header")
}
ProtocolErrorType::MissingUpgradeWebSocketHeader => {
write!(f, "No \"Upgrade: websocket\" header")
}
ProtocolErrorType::MissingSecWebSocketVersionHeader => {
write!(f, "No \"Sec-WebSocket-Version: 13\" header")
}
ProtocolErrorType::MissingSecWebSocketKey => {
write!(f, "No \"Sec-WebSocket-Key\" header")
}
ProtocolErrorType::SecWebSocketAcceptKeyMismatch => {
write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header")
}
ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"),
ProtocolErrorType::CustomResponseSuccessful => {
write!(f, "Custom response must not be successful")
}
ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"),
ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e),
ProtocolErrorType::SendAfterClosing => {
write!(f, "Sending after closing is not allowed")
}
ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"),
ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"),
ProtocolErrorType::UnmaskedFrameFromClient => {
write!(f, "Received an unmasked frame from client")
}
ProtocolErrorType::MaskedFrameFromServer => {
write!(f, "Received a masked frame from server")
}
ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"),
ProtocolErrorType::ControlFrameTooBig => {
write!(f, "Control frame too big (payload must be 125 bytes or less)")
}
ProtocolErrorType::UnknownControlFrameType(i) => {
write!(f, "Unknown control frame type: {}", i)
}
ProtocolErrorType::UnknownDataFrameType(i) => {
write!(f, "Unknown data frame type: {}", i)
}
ProtocolErrorType::UnexpectedContinueFrame => {
write!(f, "Continue frame but nothing to continue")
}
ProtocolErrorType::ExpectedFragment(c) => {
write!(f, "While waiting for more fragments received: {}", c)
}
ProtocolErrorType::ResetWithoutClosingHandshake => {
write!(f, "Connection reset without closing handshake")
}
ProtocolErrorType::InvalidOpcode(opcode) => {
write!(f, "Encountered invalid opcode: {}", opcode)
}
ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"),
}
}
}
/// Indicates the specific type/cause of URL error.
#[derive(Debug, PartialEq, Eq)]
pub enum UrlErrorType {

@ -16,7 +16,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, ProtocolErrorType, Result, UrlErrorType},
error::{Error, ProtocolError, Result, UrlErrorType},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -42,11 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
}
// Check the URI scheme: only ws or wss are supported
@ -163,7 +163,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader));
return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
@ -175,14 +175,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader));
return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch));
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
@ -216,7 +216,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
}
let headers = HeaderMap::from_httparse(raw.headers)?;

@ -3,7 +3,7 @@ use log::*;
use std::io::{Cursor, Read, Write};
use crate::{
error::{CapacityError, Error, ProtocolErrorType, Result},
error::{CapacityError, Error, ProtocolError, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ};
@ -50,7 +50,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.read_from(&mut self.stream)
.no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)),
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {

@ -19,7 +19,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, ProtocolErrorType, Result},
error::{Error, ProtocolError, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse<Option<String>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
}
if !request
@ -48,7 +48,7 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader));
return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
}
if !request
@ -58,17 +58,17 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader));
return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
}
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader));
return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
}
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or(Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?;
.ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
@ -125,11 +125,11 @@ impl TryParse for Request {
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -237,7 +237,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() {
return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest));
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
@ -256,9 +256,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Err(resp) => {
if resp.status().is_success() {
return Err(Error::Protocol(
ProtocolErrorType::CustomResponseSuccessful,
));
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
}
self.error_response = Some(resp);

@ -13,7 +13,7 @@ use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, ProtocolErrorType, Result};
use crate::error::{Error, ProtocolError, Result};
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
@ -186,7 +186,7 @@ impl FrameHeader {
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F)))
return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
}
_ => (),
}
@ -284,7 +284,7 @@ impl Frame {
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() {
0 => Ok(None),
1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)),
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();

@ -21,7 +21,7 @@ use self::{
message::{IncompleteMessage, IncompleteMessageType},
};
use crate::{
error::{Error, ProtocolErrorType, Result},
error::{Error, ProtocolError, Result},
util::NonBlockingResult,
};
@ -331,7 +331,7 @@ impl WebSocketContext {
// Do not write after sending a close frame.
if !self.state.is_active() {
return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing));
return Err(Error::Protocol(ProtocolError::SendAfterClosing));
}
if let Some(max_send_queue) = self.config.max_send_queue {
@ -431,7 +431,7 @@ impl WebSocketContext {
.check_connection_reset(self.state)?
{
if !self.state.can_read() {
return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing));
return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing));
}
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
@ -441,7 +441,7 @@ impl WebSocketContext {
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits));
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
}
@ -456,13 +456,13 @@ impl WebSocketContext {
// frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given
// stream by setting WebSocketConfig.accept_unmasked_frames to true
return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient));
return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
}
}
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer));
return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer));
}
}
}
@ -473,14 +473,14 @@ impl WebSocketContext {
// All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455)
_ if !frame.header().is_final => {
Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame))
Err(Error::Protocol(ProtocolError::FragmentedControlFrame))
}
_ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig))
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i)))
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
}
OpCtl::Ping => {
let data = frame.into_data();
@ -502,7 +502,7 @@ impl WebSocketContext {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol(
ProtocolErrorType::UnexpectedContinueFrame,
ProtocolError::UnexpectedContinueFrame,
));
}
if fin {
@ -512,7 +512,7 @@ impl WebSocketContext {
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c)))
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
}
OpData::Text | OpData::Binary => {
let msg = {
@ -533,7 +533,7 @@ impl WebSocketContext {
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i)))
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
}
}
}
@ -544,7 +544,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed)
}
_ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)),
_ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)),
}
}
}

Loading…
Cancel
Save