Add protocol error types

pull/168/head
WiredSound 4 years ago
parent 34c6e63d87
commit 6f846da0e3
  1. 7
      src/client.rs
  2. 137
      src/error.rs
  3. 17
      src/handshake/client.rs
  4. 4
      src/handshake/machine.rs
  5. 22
      src/handshake/server.rs
  6. 8
      src/protocol/frame/frame.rs
  7. 34
      src/protocol/mod.rs
  8. 4
      tests/no_send_after_close.rs

@ -52,7 +52,7 @@ mod encryption {
use std::net::TcpStream;
use crate::{
error::{Error, UrlErrorType, Result},
error::{Error, Result, UrlErrorType},
stream::Mode,
};
@ -71,7 +71,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::{
error::{Error, UrlErrorType, Result},
error::{Error, Result, UrlErrorType},
handshake::{client::ClientHandshake, HandshakeError},
protocol::WebSocket,
stream::{Mode, NoDelay},
@ -103,8 +103,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri();
let mode = uri_mode(uri)?;
let host =
request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?;
let host = request.uri().host().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,

@ -2,7 +2,7 @@
use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string};
use crate::protocol::Message;
use crate::protocol::{frame::coding::Data, Message};
use http::Response;
#[cfg(feature = "tls")]
@ -41,14 +41,14 @@ pub enum Error {
/// underlying connection and you should probably consider them fatal.
Io(io::Error),
#[cfg(feature = "tls")]
/// TLS error
/// TLS error.
Tls(tls::Error),
/// - When reading: buffer capacity exhausted.
/// - When writing: your message is bigger than the configured max message size
/// (64MB by default).
Capacity(Cow<'static, str>),
/// Protocol violation.
Protocol(Cow<'static, str>),
Protocol(ProtocolErrorType),
/// Message send queue full.
SendQueueFull(Message),
/// UTF coding error
@ -147,13 +147,13 @@ impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self {
match err {
httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()),
e => Error::Protocol(e.to_string().into()),
e => Error::Protocol(ProtocolErrorType::HttparseError(e)),
}
}
}
/// Indicates the specific type/cause of URL error.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum UrlErrorType {
/// TLS is used despite not being compiled with the TLS feature enabled.
TlsFeatureNotEnabled,
@ -166,7 +166,7 @@ pub enum UrlErrorType {
/// The URL host name, though included, is empty.
EmptyHostName,
/// The URL does not include a path/query.
NoPathOrQuery
NoPathOrQuery,
}
impl fmt::Display for UrlErrorType {
@ -177,7 +177,128 @@ impl fmt::Display for UrlErrorType {
UrlErrorType::UnableToConnect(uri) => write!(f, "Unable to connect to {}", uri),
UrlErrorType::UnsupportedUrlScheme => write!(f, "URL scheme not supported"),
UrlErrorType::EmptyHostName => write!(f, "URL contains empty host name"),
UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL")
UrlErrorType::NoPathOrQuery => write!(f, "No path/query in URL"),
}
}
}
/// Indicates the specific type/cause of a protocol error.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ProtocolErrorType {
/// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
WrongHttpMethod,
/// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher).
WrongHttpVersion,
/// Missing `Connection: upgrade` HTTP header.
MissingConnectionUpgradeHeader,
/// Missing `Upgrade: websocket` HTTP header.
MissingUpgradeWebSocketHeader,
/// Missing `Sec-WebSocket-Version: 13` HTTP header.
MissingSecWebSocketVersionHeader,
/// Missing `Sec-WebSocket-Key` HTTP header.
MissingSecWebSocketKey,
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
SecWebSocketAcceptKeyMismatch,
/// Garbage data encountered after client request.
JunkAfterRequest,
/// Custom responses must be unsuccessful.
CustomResponseSuccessful,
/// No more data while still performing handshake.
HandshakeIncomplete,
/// Wrapper around a [`httparse::Error`] value.
HttparseError(httparse::Error),
/// Not allowed to send after having sent a closing frame.
SendAfterClosing,
/// Remote sent data after sending a closing frame.
ReceivedAfterClosing,
/// Reserved bits in frame header are non-zero.
NonZeroReservedBits,
/// The server must close the connection when an unmasked frame is received.
UnmaskedFrameFromClient,
/// The client must close the connection when a masked frame is received.
MaskedFrameFromServer,
/// Control frames must not be fragmented.
FragmentedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
ControlFrameTooBig,
/// Type of control frame not recognised.
UnknownControlFrameType(u8),
/// Type of data frame not recognised.
UnknownDataFrameType(u8),
/// Received a continue frame despite there being nothing to continue.
UnexpectedContinueFrame,
/// Received data while waiting for more fragments.
ExpectedFragment(Data),
/// Connection closed without performing the closing handshake.
ResetWithoutClosingHandshake,
/// Encountered an invalid opcode.
InvalidOpcode(u8),
/// The payload for the closing frame is invalid.
InvalidCloseSequence,
}
impl fmt::Display for ProtocolErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ProtocolErrorType::WrongHttpMethod => {
write!(f, "Unsupported HTTP method used, only GET is allowed")
}
ProtocolErrorType::WrongHttpVersion => write!(f, "HTTP version must be 1.1 or higher"),
ProtocolErrorType::MissingConnectionUpgradeHeader => {
write!(f, "No \"Connection: upgrade\" header")
}
ProtocolErrorType::MissingUpgradeWebSocketHeader => {
write!(f, "No \"Upgrade: websocket\" header")
}
ProtocolErrorType::MissingSecWebSocketVersionHeader => {
write!(f, "No \"Sec-WebSocket-Version: 13\" header")
}
ProtocolErrorType::MissingSecWebSocketKey => {
write!(f, "No \"Sec-WebSocket-Key\" header")
}
ProtocolErrorType::SecWebSocketAcceptKeyMismatch => {
write!(f, "Key mismatch in \"Sec-WebSocket-Accept\" header")
}
ProtocolErrorType::JunkAfterRequest => write!(f, "Junk after client request"),
ProtocolErrorType::CustomResponseSuccessful => {
write!(f, "Custom response must not be successful")
}
ProtocolErrorType::HandshakeIncomplete => write!(f, "Handshake not finished"),
ProtocolErrorType::HttparseError(e) => write!(f, "httparse error: {}", e),
ProtocolErrorType::SendAfterClosing => {
write!(f, "Sending after closing is not allowed")
}
ProtocolErrorType::ReceivedAfterClosing => write!(f, "Remote sent after having closed"),
ProtocolErrorType::NonZeroReservedBits => write!(f, "Reserved bits are non-zero"),
ProtocolErrorType::UnmaskedFrameFromClient => {
write!(f, "Received an unmasked frame from client")
}
ProtocolErrorType::MaskedFrameFromServer => {
write!(f, "Received a masked frame from server")
}
ProtocolErrorType::FragmentedControlFrame => write!(f, "Fragmented control frame"),
ProtocolErrorType::ControlFrameTooBig => {
write!(f, "Control frame too big (payload must be 125 bytes or less)")
}
ProtocolErrorType::UnknownControlFrameType(i) => {
write!(f, "Unknown control frame type: {}", i)
}
ProtocolErrorType::UnknownDataFrameType(i) => {
write!(f, "Unknown data frame type: {}", i)
}
ProtocolErrorType::UnexpectedContinueFrame => {
write!(f, "Continue frame but nothing to continue")
}
ProtocolErrorType::ExpectedFragment(c) => {
write!(f, "While waiting for more fragments received: {}", c)
}
ProtocolErrorType::ResetWithoutClosingHandshake => {
write!(f, "Connection reset without closing handshake")
}
ProtocolErrorType::InvalidOpcode(opcode) => {
write!(f, "Encountered invalid opcode: {}", opcode)
}
ProtocolErrorType::InvalidCloseSequence => write!(f, "Invalid close sequence"),
}
}
}
}

@ -16,7 +16,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, UrlErrorType, Result},
error::{Error, ProtocolErrorType, Result, UrlErrorType},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -42,11 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
}
// Check the URI scheme: only ws or wss are supported
@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();
let authority =
uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str();
let authority = uri.authority().ok_or_else(|| Error::Url(UrlErrorType::NoHostName))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
@ -165,7 +164,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
@ -177,14 +176,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
return Err(Error::Protocol(ProtocolErrorType::SecWebSocketAcceptKeyMismatch));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
@ -218,7 +217,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
}
let headers = HeaderMap::from_httparse(raw.headers)?;

@ -3,7 +3,7 @@ use log::*;
use std::io::{Cursor, Read, Write};
use crate::{
error::{Error, Result},
error::{Error, ProtocolErrorType, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ};
@ -50,7 +50,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.read_from(&mut self.stream)
.no_block()?;
match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(0) => Err(Error::Protocol(ProtocolErrorType::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {

@ -19,7 +19,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
error::{Error, ProtocolErrorType, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse<Option<String>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET {
return Err(Error::Protocol("Method is not GET".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
}
if !request
@ -48,7 +48,7 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false)
{
return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into()));
return Err(Error::Protocol(ProtocolErrorType::MissingConnectionUpgradeHeader));
}
if !request
@ -58,17 +58,17 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into()));
return Err(Error::Protocol(ProtocolErrorType::MissingUpgradeWebSocketHeader));
}
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into()));
return Err(Error::Protocol(ProtocolErrorType::MissingSecWebSocketVersionHeader));
}
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
.ok_or_else(|| Error::Protocol(ProtocolErrorType::MissingSecWebSocketKey))?;
let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
@ -125,11 +125,11 @@ impl TryParse for Request {
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpMethod));
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(ProtocolErrorType::WrongHttpVersion));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -237,7 +237,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()));
return Err(Error::Protocol(ProtocolErrorType::JunkAfterRequest));
}
let response = create_response(&result)?;
@ -257,7 +257,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Err(resp) => {
if resp.status().is_success() {
return Err(Error::Protocol(
"Custom response must not be successful".into(),
ProtocolErrorType::CustomResponseSuccessful,
));
}

@ -13,7 +13,7 @@ use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, Result};
use crate::error::{Error, ProtocolErrorType, Result};
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
@ -186,9 +186,7 @@ impl FrameHeader {
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
return Err(Error::Protocol(ProtocolErrorType::InvalidOpcode(first & 0x0F)))
}
_ => (),
}
@ -286,7 +284,7 @@ impl Frame {
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() {
0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())),
1 => Err(Error::Protocol(ProtocolErrorType::InvalidCloseSequence)),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();

@ -21,7 +21,7 @@ use self::{
message::{IncompleteMessage, IncompleteMessageType},
};
use crate::{
error::{Error, Result},
error::{Error, ProtocolErrorType, Result},
util::NonBlockingResult,
};
@ -331,7 +331,7 @@ impl WebSocketContext {
// Do not write after sending a close frame.
if !self.state.is_active() {
return Err(Error::Protocol("Sending after closing is not allowed".into()));
return Err(Error::Protocol(ProtocolErrorType::SendAfterClosing));
}
if let Some(max_send_queue) = self.config.max_send_queue {
@ -431,9 +431,7 @@ impl WebSocketContext {
.check_connection_reset(self.state)?
{
if !self.state.can_read() {
return Err(Error::Protocol(
"Remote sent frame after having sent a Close Frame".into(),
));
return Err(Error::Protocol(ProtocolErrorType::ReceivedAfterClosing));
}
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
@ -443,7 +441,7 @@ impl WebSocketContext {
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()));
return Err(Error::Protocol(ProtocolErrorType::NonZeroReservedBits));
}
}
@ -458,15 +456,13 @@ impl WebSocketContext {
// frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given
// stream by setting WebSocketConfig.accept_unmasked_frames to true
return Err(Error::Protocol(
"Received an unmasked frame from client".into(),
));
return Err(Error::Protocol(ProtocolErrorType::UnmaskedFrameFromClient));
}
}
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into()));
return Err(Error::Protocol(ProtocolErrorType::MaskedFrameFromServer));
}
}
}
@ -477,14 +473,14 @@ impl WebSocketContext {
// All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455)
_ if !frame.header().is_final => {
Err(Error::Protocol("Fragmented control frame".into()))
Err(Error::Protocol(ProtocolErrorType::FragmentedControlFrame))
}
_ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into()))
Err(Error::Protocol(ProtocolErrorType::ControlFrameTooBig))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
Err(Error::Protocol(ProtocolErrorType::UnknownControlFrameType(i)))
}
OpCtl::Ping => {
let data = frame.into_data();
@ -506,7 +502,7 @@ impl WebSocketContext {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
ProtocolErrorType::UnexpectedContinueFrame,
));
}
if fin {
@ -515,9 +511,9 @@ impl WebSocketContext {
Ok(None)
}
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolErrorType::ExpectedFragment(c)))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
@ -537,7 +533,7 @@ impl WebSocketContext {
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
Err(Error::Protocol(ProtocolErrorType::UnknownDataFrameType(i)))
}
}
}
@ -548,7 +544,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed)
}
_ => Err(Error::Protocol("Connection reset without closing handshake".into())),
_ => Err(Error::Protocol(ProtocolErrorType::ResetWithoutClosingHandshake)),
}
}
}

@ -8,7 +8,7 @@ use std::{
time::Duration,
};
use tungstenite::{accept, connect, Error, Message};
use tungstenite::{accept, connect, error::ProtocolErrorType, Error, Message};
use url::Url;
#[test]
@ -46,7 +46,7 @@ fn test_no_send_after_close() {
assert!(err.is_err());
match err.unwrap_err() {
Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s),
Error::Protocol(s) => assert_eq!(s, ProtocolErrorType::SendAfterClosing),
e => panic!("unexpected error: {:?}", e),
}

Loading…
Cancel
Save