Merge pull request #168 from WiredSound/master

Create specific error types for protocol, URL, and capacity errors
pull/174/head
Matěj Laitl 4 years ago committed by GitHub
commit 5586d0af51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      CHANGELOG.md
  2. 3
      Cargo.toml
  3. 2
      README.md
  4. 15
      src/client.rs
  5. 205
      src/error.rs
  6. 22
      src/handshake/client.rs
  7. 6
      src/handshake/machine.rs
  8. 24
      src/handshake/server.rs
  9. 2
      src/protocol/frame/coding.rs
  10. 8
      src/protocol/frame/frame.rs
  11. 21
      src/protocol/frame/mod.rs
  12. 9
      src/protocol/message.rs
  13. 53
      src/protocol/mod.rs
  14. 4
      tests/no_send_after_close.rs

@ -1,3 +1,8 @@
# 0.13.0
- Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively.
- Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages.
# 0.12.0 # 0.12.0
- Add facilities to allow clients to follow HTTP 3XX redirects. - Add facilities to allow clients to follow HTTP 3XX redirects.

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.12.0" documentation = "https://docs.rs/tungstenite/0.12.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.12.0" version = "0.13.0"
edition = "2018" edition = "2018"
[features] [features]
@ -29,6 +29,7 @@ rand = "0.8.0"
sha-1 = "0.9" sha-1 = "0.9"
url = "2.1.0" url = "2.1.0"
utf-8 = "0.7.5" utf-8 = "0.7.5"
thiserror = "1.0.23"
[dependencies.native-tls] [dependencies.native-tls]
optional = true optional = true

@ -62,7 +62,7 @@ Testing
------- -------
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for
WebSockets. It is also covered by internal unit tests as good as possible. WebSockets. It is also covered by internal unit tests as well as possible.
Contributing Contributing
------------ ------------

@ -52,7 +52,7 @@ mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result, UrlError},
stream::Mode, stream::Mode,
}; };
@ -62,7 +62,7 @@ mod encryption {
pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> { pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode { match mode {
Mode::Plain => Ok(stream), Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
} }
} }
} }
@ -71,7 +71,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result, UrlError},
handshake::{client::ClientHandshake, HandshakeError}, handshake::{client::ClientHandshake, HandshakeError},
protocol::WebSocket, protocol::WebSocket,
stream::{Mode, NoDelay}, stream::{Mode, NoDelay},
@ -103,8 +103,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,
@ -166,7 +165,7 @@ pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoSt
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {
@ -175,7 +174,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoSt
} }
} }
} }
Err(Error::Url(format!("Unable to connect to {}", uri).into())) Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
} }
/// Get the mode of the given URL. /// Get the mode of the given URL.
@ -186,7 +185,7 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() { match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain), Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls), Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())), _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
} }
} }

@ -1,9 +1,10 @@
//! Error handling. //! Error handling.
use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; use std::{io, result, str, string};
use crate::protocol::Message; use crate::protocol::{frame::coding::Data, Message};
use http::Response; use http::Response;
use thiserror::Error;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
@ -14,8 +15,8 @@ pub mod tls {
/// Result type of all Tungstenite library calls. /// Result type of all Tungstenite library calls.
pub type Result<T> = result::Result<T, Error>; pub type Result<T> = result::Result<T, Error>;
/// Possible WebSocket errors /// Possible WebSocket errors.
#[derive(Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
/// WebSocket connection closed normally. This informs you of the close. /// WebSocket connection closed normally. This informs you of the close.
/// It's not an error as such and nothing wrong happened. /// It's not an error as such and nothing wrong happened.
@ -28,6 +29,7 @@ pub enum Error {
/// ///
/// Receiving this error means that the WebSocket object is not usable anymore and the /// Receiving this error means that the WebSocket object is not usable anymore and the
/// only meaningful action with it is dropping it. /// only meaningful action with it is dropping it.
#[error("Connection closed normally")]
ConnectionClosed, ConnectionClosed,
/// Trying to work with already closed connection. /// Trying to work with already closed connection.
/// ///
@ -36,56 +38,39 @@ pub enum Error {
/// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the /// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the
/// connection when it really shouldn't anymore, so this really indicates a programmer /// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part. /// error on your part.
#[error("Trying to work with closed connection")]
AlreadyClosed, AlreadyClosed,
/// Input-output error. Apart from WouldBlock, these are generally errors with the /// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal. /// underlying connection and you should probably consider them fatal.
Io(io::Error), #[error("IO error: {0}")]
Io(#[from] io::Error),
/// TLS error.
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
/// TLS error #[error("TLS error: {0}")]
Tls(tls::Error), Tls(#[from] tls::Error),
/// - When reading: buffer capacity exhausted. /// - When reading: buffer capacity exhausted.
/// - When writing: your message is bigger than the configured max message size /// - When writing: your message is bigger than the configured max message size
/// (64MB by default). /// (64MB by default).
Capacity(Cow<'static, str>), #[error("Space limit exceeded: {0}")]
Capacity(CapacityError),
/// Protocol violation. /// Protocol violation.
Protocol(Cow<'static, str>), #[error("WebSocket protocol error: {0}")]
Protocol(ProtocolError),
/// Message send queue full. /// Message send queue full.
#[error("Send queue is full")]
SendQueueFull(Message), SendQueueFull(Message),
/// UTF coding error /// UTF coding error.
#[error("UTF-8 encoding error")]
Utf8, Utf8,
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), #[error("URL error: {0}")]
Url(UrlError),
/// HTTP error. /// HTTP error.
#[error("HTTP error: {}", .0.status())]
Http(Response<Option<String>>), Http(Response<Option<String>>),
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), #[error("HTTP format error: {0}")]
} HttpFormat(#[from] http::Error),
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::ConnectionClosed => write!(f, "Connection closed normally"),
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature = "tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
}
impl ErrorTrait for Error {}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::Io(err)
}
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -130,24 +115,136 @@ impl From<http::status::InvalidStatusCode> for Error {
} }
} }
impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}
#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)
}
}
impl From<httparse::Error> for Error { impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self { fn from(err: httparse::Error) -> Self {
match err { match err {
httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders),
e => Error::Protocol(e.to_string().into()), e => Error::Protocol(ProtocolError::HttparseError(e)),
} }
} }
} }
/// Indicates the specific type/cause of a capacity error.
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
pub enum CapacityError {
/// Too many headers provided (see [`httparse::Error::TooManyHeaders`]).
#[error("Too many headers")]
TooManyHeaders,
/// Received header is too long.
#[error("Header too long")]
HeaderTooLong,
/// Message is bigger than the maximum allowed size.
#[error("Message too long: {size} > {max_size}")]
MessageTooLong {
/// The size of the message.
size: usize,
/// The maximum allowed message size.
max_size: usize,
},
/// TCP buffer is full.
#[error("Incoming TCP buffer is full")]
TcpBufferFull,
}
/// Indicates the specific type/cause of a protocol error.
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
pub enum ProtocolError {
/// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
#[error("Unsupported HTTP method used - only GET is allowed")]
WrongHttpMethod,
/// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher).
#[error("HTTP version must be 1.1 or higher")]
WrongHttpVersion,
/// Missing `Connection: upgrade` HTTP header.
#[error("No \"Connection: upgrade\" header")]
MissingConnectionUpgradeHeader,
/// Missing `Upgrade: websocket` HTTP header.
#[error("No \"Upgrade: websocket\" header")]
MissingUpgradeWebSocketHeader,
/// Missing `Sec-WebSocket-Version: 13` HTTP header.
#[error("No \"Sec-WebSocket-Version: 13\" header")]
MissingSecWebSocketVersionHeader,
/// Missing `Sec-WebSocket-Key` HTTP header.
#[error("No \"Sec-WebSocket-Key\" header")]
MissingSecWebSocketKey,
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
#[error("Key mismatch in \"Sec-WebSocket-Accept\" header")]
SecWebSocketAcceptKeyMismatch,
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,
/// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")]
CustomResponseSuccessful,
/// No more data while still performing handshake.
#[error("Handshake not finished")]
HandshakeIncomplete,
/// Wrapper around a [`httparse::Error`] value.
#[error("httparse error: {0}")]
HttparseError(#[from] httparse::Error),
/// Not allowed to send after having sent a closing frame.
#[error("Sending after closing is not allowed")]
SendAfterClosing,
/// Remote sent data after sending a closing frame.
#[error("Remote sent after having closed")]
ReceivedAfterClosing,
/// Reserved bits in frame header are non-zero.
#[error("Reserved bits are non-zero")]
NonZeroReservedBits,
/// The server must close the connection when an unmasked frame is received.
#[error("Received an unmasked frame from client")]
UnmaskedFrameFromClient,
/// The client must close the connection when a masked frame is received.
#[error("Received a masked frame from server")]
MaskedFrameFromServer,
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
/// Type of control frame not recognised.
#[error("Unknown control frame type: {0}")]
UnknownControlFrameType(u8),
/// Type of data frame not recognised.
#[error("Unknown data frame type: {0}")]
UnknownDataFrameType(u8),
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
/// Connection closed without performing the closing handshake.
#[error("Connection reset without closing handshake")]
ResetWithoutClosingHandshake,
/// Encountered an invalid opcode.
#[error("Encountered invalid opcode: {0}")]
InvalidOpcode(u8),
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
}
/// Indicates the specific type/cause of URL error.
#[derive(Error, Debug, PartialEq, Eq)]
pub enum UrlError {
/// TLS is used despite not being compiled with the TLS feature enabled.
#[error("TLS support not compiled in")]
TlsFeatureNotEnabled,
/// The URL does not include a host name.
#[error("No host name in the URL")]
NoHostName,
/// Failed to connect with this URL.
#[error("Unable to connect to {0}")]
UnableToConnect(String),
/// Unsupported URL scheme used (only `ws://` or `wss://` may be used).
#[error("URL scheme not supported")]
UnsupportedUrlScheme,
/// The URL host name, though included, is empty.
#[error("URL contains empty host name")]
EmptyHostName,
/// The URL does not include a path/query.
#[error("No path/query in URL")]
NoPathOrQuery,
}

@ -16,7 +16,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult, HandshakeRole, MidHandshake, ProcessingResult,
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result, UrlError},
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -42,11 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
// Check the URI scheme: only ws or wss are supported // Check the URI scheme: only ws or wss are supported
@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri(); let uri = request.uri();
let authority = let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
let host = if let Some(idx) = authority.find('@') { let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@ // handle possible name:password@
authority.split_at(idx + 1).1 authority.split_at(idx + 1).1
@ -106,7 +105,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
authority authority
}; };
if authority.is_empty() { if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into())); return Err(Error::Url(UrlError::EmptyHostName));
} }
write!( write!(
@ -120,8 +119,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
version = request.version(), version = request.version(),
host = host, host = host,
path = path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
@ -165,7 +163,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
} }
// 3. If the response lacks a |Connection| header field or the // 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
@ -177,14 +175,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
} }
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or // 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
} }
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
@ -218,7 +216,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;

@ -3,7 +3,7 @@ use log::*;
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write};
use crate::{ use crate::{
error::{Error, Result}, error::{CapacityError, Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
@ -46,11 +46,11 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
let read = buf let read = buf
.prepare_reserve(MIN_READ) .prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) // TODO limit size .with_limit(usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))? .map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))?
.read_from(&mut self.stream) .read_from(&mut self.stream)
.no_block()?; .no_block()?;
match read { match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())), Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size); buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading { RoundResult::StageFinished(StageResult::DoneReading {

@ -19,7 +19,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult, HandshakeRole, MidHandshake, ProcessingResult,
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result},
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse<Option<String>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
if !request if !request
@ -48,7 +48,7 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
} }
if !request if !request
@ -58,17 +58,17 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
} }
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
} }
let key = request let key = request
.headers() .headers()
.get("Sec-WebSocket-Key") .get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
let builder = Response::builder() let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS) .status(StatusCode::SWITCHING_PROTOCOLS)
@ -125,11 +125,11 @@ impl TryParse for Request {
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" { if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -237,7 +237,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?; let response = create_response(&result)?;
@ -256,9 +256,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Err(resp) => { Err(resp) => {
if resp.status().is_success() { if resp.status().is_success() {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
"Custom response must not be successful".into(),
));
} }
self.error_response = Some(resp); self.error_response = Some(resp);

@ -143,7 +143,7 @@ pub enum CloseCode {
Abnormal, Abnormal,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection
/// because it has received data within a message that was not /// because it has received data within a message that was not
/// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
/// data within a text message). /// data within a text message).
Invalid, Invalid,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection

@ -13,7 +13,7 @@ use super::{
coding::{CloseCode, Control, Data, OpCode}, coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask}, mask::{apply_mask, generate_mask},
}; };
use crate::error::{Error, Result}; use crate::error::{Error, ProtocolError, Result};
/// A struct representing the close command. /// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
@ -186,9 +186,7 @@ impl FrameHeader {
// Disallow bad opcode // Disallow bad opcode
match opcode { match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
} }
_ => (), _ => (),
} }
@ -286,7 +284,7 @@ impl Frame {
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> { pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() { match self.payload.len() {
0 => Ok(None), 0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => { _ => {
let mut data = self.payload; let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();

@ -8,7 +8,7 @@ mod mask;
pub use self::frame::{CloseFrame, Frame, FrameHeader}; pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::error::{Error, Result}; use crate::error::{CapacityError, Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use log::*; use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
@ -133,9 +133,10 @@ impl FrameCodec {
// Enforce frame size limit early and make sure `length` // Enforce frame size limit early and make sure `length`
// is not too big (fits into `usize`). // is not too big (fits into `usize`).
if length > max_size as u64 { if length > max_size as u64 {
return Err(Error::Capacity( return Err(Error::Capacity(CapacityError::MessageTooLong {
format!("Message length too big: {} > {}", length, max_size).into(), size: length as usize,
)); max_size,
}));
} }
let input_size = cursor.get_ref().len() as u64 - cursor.position(); let input_size = cursor.get_ref().len() as u64 - cursor.position();
@ -155,7 +156,7 @@ impl FrameCodec {
.in_buffer .in_buffer
.prepare_reserve(MIN_READ) .prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) .with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))?
.read_from(stream)?; .read_from(stream)?;
if size == 0 { if size == 0 {
trace!("no frame received"); trace!("no frame received");
@ -206,6 +207,8 @@ impl FrameCodec {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::error::{CapacityError, Error};
use super::{Frame, FrameSocket}; use super::{Frame, FrameSocket};
use std::io::Cursor; use std::io::Cursor;
@ -266,9 +269,9 @@ mod tests {
fn size_limit_hit() { fn size_limit_hit() {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!( assert!(matches!(
sock.read_frame(Some(5)).unwrap_err().to_string(), sock.read_frame(Some(5)),
"Space limit exceeded: Message length too big: 7 > 5" Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
); ));
} }
} }

@ -6,7 +6,7 @@ use std::{
}; };
use super::frame::CloseFrame; use super::frame::CloseFrame;
use crate::error::{Error, Result}; use crate::error::{CapacityError, Error, Result};
mod string_collect { mod string_collect {
use utf8::DecodeError; use utf8::DecodeError;
@ -122,9 +122,10 @@ impl IncompleteMessage {
let portion_size = tail.as_ref().len(); let portion_size = tail.as_ref().len();
// Be careful about integer overflows here. // Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size { if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity( return Err(Error::Capacity(CapacityError::MessageTooLong {
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), size: my_size + portion_size,
)); max_size,
}));
} }
match self.collector { match self.collector {

@ -21,7 +21,7 @@ use self::{
message::{IncompleteMessage, IncompleteMessageType}, message::{IncompleteMessage, IncompleteMessageType},
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
@ -331,7 +331,7 @@ impl WebSocketContext {
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol("Sending after closing is not allowed".into())); return Err(Error::Protocol(ProtocolError::SendAfterClosing));
} }
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
@ -431,9 +431,7 @@ impl WebSocketContext {
.check_connection_reset(self.state)? .check_connection_reset(self.state)?
{ {
if !self.state.can_read() { if !self.state.can_read() {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing));
"Remote sent frame after having sent a Close Frame".into(),
));
} }
// MUST be 0 unless an extension is negotiated that defines meanings // MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of // for non-zero values. If a nonzero value is received and none of
@ -443,7 +441,7 @@ impl WebSocketContext {
{ {
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into())); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
} }
@ -458,15 +456,13 @@ impl WebSocketContext {
// frame that is not masked. (RFC 6455) // frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given // The only exception here is if the user explicitly accepts given
// stream by setting WebSocketConfig.accept_unmasked_frames to true // stream by setting WebSocketConfig.accept_unmasked_frames to true
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
"Received an unmasked frame from client".into(),
));
} }
} }
Role::Client => { Role::Client => {
if frame.is_masked() { if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455) // A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into())); return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer));
} }
} }
} }
@ -477,14 +473,14 @@ impl WebSocketContext {
// All control frames MUST have a payload length of 125 bytes or less // All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455) // and MUST NOT be fragmented. (RFC 6455)
_ if !frame.header().is_final => { _ if !frame.header().is_final => {
Err(Error::Protocol("Fragmented control frame".into())) Err(Error::Protocol(ProtocolError::FragmentedControlFrame))
} }
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into())) Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
} }
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => { OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
} }
OpCtl::Ping => { OpCtl::Ping => {
let data = frame.into_data(); let data = frame.into_data();
@ -506,7 +502,7 @@ impl WebSocketContext {
msg.extend(frame.into_data(), self.config.max_message_size)?; msg.extend(frame.into_data(), self.config.max_message_size)?;
} else { } else {
return Err(Error::Protocol( return Err(Error::Protocol(
"Continue frame but nothing to continue".into(), ProtocolError::UnexpectedContinueFrame,
)); ));
} }
if fin { if fin {
@ -515,9 +511,9 @@ impl WebSocketContext {
Ok(None) Ok(None)
} }
} }
c if self.incomplete.is_some() => Err(Error::Protocol( c if self.incomplete.is_some() => {
format!("Received {} while waiting for more fragments", c).into(), Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
)), }
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = { let msg = {
let message_type = match data { let message_type = match data {
@ -537,7 +533,7 @@ impl WebSocketContext {
} }
} }
OpData::Reserved(i) => { OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
} }
} }
} }
@ -548,7 +544,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} }
_ => Err(Error::Protocol("Connection reset without closing handshake".into())), _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)),
} }
} }
} }
@ -673,6 +669,7 @@ impl<T> CheckConnectionReset for Result<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::error::{CapacityError, Error};
use std::{io, io::Cursor}; use std::{io, io::Cursor};
@ -715,10 +712,11 @@ mod tests {
]); ]);
let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(), assert!(matches!(
"Space limit exceeded: Message too big: 7 + 6 > 10" socket.read_message(),
); Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
));
} }
#[test] #[test]
@ -726,9 +724,10 @@ mod tests {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(), assert!(matches!(
"Space limit exceeded: Message too big: 0 + 3 > 2" socket.read_message(),
); Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
));
} }
} }

@ -8,7 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url; use url::Url;
#[test] #[test]
@ -46,7 +46,7 @@ fn test_no_send_after_close() {
assert!(err.is_err()); assert!(err.is_err());
match err.unwrap_err() { match err.unwrap_err() {
Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), Error::Protocol(s) => assert_eq!(s, ProtocolError::SendAfterClosing),
e => panic!("unexpected error: {:?}", e), e => panic!("unexpected error: {:?}", e),
} }

Loading…
Cancel
Save