Merge branch 'master' into rustls

# Conflicts:
#	src/error.rs
pull/166/head
Dominik Nakamura 4 years ago
commit 5a3dd8acfd
No known key found for this signature in database
GPG Key ID: E4C6A749B2491910
  1. 5
      CHANGELOG.md
  2. 3
      Cargo.toml
  3. 6
      README.md
  4. 15
      src/client.rs
  5. 227
      src/error.rs
  6. 22
      src/handshake/client.rs
  7. 6
      src/handshake/machine.rs
  8. 24
      src/handshake/server.rs
  9. 2
      src/protocol/frame/coding.rs
  10. 8
      src/protocol/frame/frame.rs
  11. 21
      src/protocol/frame/mod.rs
  12. 9
      src/protocol/message.rs
  13. 53
      src/protocol/mod.rs
  14. 4
      tests/no_send_after_close.rs

@ -1,3 +1,8 @@
# 0.13.0
- Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively.
- Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages.
# 0.12.0 # 0.12.0
- Add facilities to allow clients to follow HTTP 3XX redirects. - Add facilities to allow clients to follow HTTP 3XX redirects.

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.12.0" documentation = "https://docs.rs/tungstenite/0.12.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.12.0" version = "0.13.0"
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
@ -31,6 +31,7 @@ input_buffer = "0.4.0"
log = "0.4.8" log = "0.4.8"
rand = "0.8.0" rand = "0.8.0"
sha-1 = "0.9" sha-1 = "0.9"
thiserror = "1.0.23"
url = "2.1.0" url = "2.1.0"
utf-8 = "0.7.5" utf-8 = "0.7.5"

@ -15,7 +15,7 @@ fn main () {
let mut websocket = accept(stream.unwrap()).unwrap(); let mut websocket = accept(stream.unwrap()).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages. // We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap(); websocket.write_message(msg).unwrap();
@ -44,8 +44,6 @@ and asynchronous usage and is easy to integrate into any third-party event loops
WebSocket protocol but still makes them accessible for those who wants full control over the WebSocket protocol but still makes them accessible for those who wants full control over the
network. network.
This library is a work in progress. Feel free to ask questions and send us pull requests.
Why Tungstenite? Why Tungstenite?
---------------- ----------------
@ -64,7 +62,7 @@ Testing
------- -------
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for
WebSockets. It is also covered by internal unit tests as good as possible. WebSockets. It is also covered by internal unit tests as well as possible.
Contributing Contributing
------------ ------------

@ -85,7 +85,7 @@ mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result, UrlError},
stream::Mode, stream::Mode,
}; };
@ -95,7 +95,7 @@ mod encryption {
pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> { pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode { match mode {
Mode::Plain => Ok(stream), Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
} }
} }
} }
@ -104,7 +104,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::{ use crate::{
error::{Error, Result}, error::{Error, Result, UrlError},
handshake::{client::ClientHandshake, HandshakeError}, handshake::{client::ClientHandshake, HandshakeError},
protocol::WebSocket, protocol::WebSocket,
stream::{Mode, NoDelay}, stream::{Mode, NoDelay},
@ -136,8 +136,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,
@ -199,7 +198,7 @@ pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoSt
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {
@ -208,7 +207,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoSt
} }
} }
} }
Err(Error::Url(format!("Unable to connect to {}", uri).into())) Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
} }
/// Get the mode of the given URL. /// Get the mode of the given URL.
@ -219,7 +218,7 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() { match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain), Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls), Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())), _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
} }
} }

@ -1,15 +1,16 @@
//! Error handling. //! Error handling.
use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; use std::{io, result, str, string};
use crate::protocol::Message; use crate::protocol::{frame::coding::Data, Message};
use http::Response; use http::Response;
use thiserror::Error;
/// Result type of all Tungstenite library calls. /// Result type of all Tungstenite library calls.
pub type Result<T> = result::Result<T, Error>; pub type Result<T> = result::Result<T, Error>;
/// Possible WebSocket errors /// Possible WebSocket errors.
#[derive(Debug)] #[derive(Error, Debug)]
#[non_exhaustive] #[non_exhaustive]
pub enum Error { pub enum Error {
/// WebSocket connection closed normally. This informs you of the close. /// WebSocket connection closed normally. This informs you of the close.
@ -23,6 +24,7 @@ pub enum Error {
/// ///
/// Receiving this error means that the WebSocket object is not usable anymore and the /// Receiving this error means that the WebSocket object is not usable anymore and the
/// only meaningful action with it is dropping it. /// only meaningful action with it is dropping it.
#[error("Connection closed normally")]
ConnectionClosed, ConnectionClosed,
/// Trying to work with already closed connection. /// Trying to work with already closed connection.
/// ///
@ -31,66 +33,47 @@ pub enum Error {
/// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the /// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the
/// connection when it really shouldn't anymore, so this really indicates a programmer /// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part. /// error on your part.
#[error("Trying to work with closed connection")]
AlreadyClosed, AlreadyClosed,
/// Input-output error. Apart from WouldBlock, these are generally errors with the /// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal. /// underlying connection and you should probably consider them fatal.
Io(io::Error), #[error("IO error: {0}")]
#[cfg(feature = "use-native-tls")] Io(#[from] io::Error),
/// TLS error /// TLS error
TlsNative(native_tls::Error), #[cfg(feature = "use-native-tls")]
#[cfg(feature = "use-rustls")] #[error("TLS (native-tls) error: {0}")]
TlsNative(#[from] native_tls::Error),
/// TLS error /// TLS error
TlsRustls(rustls::TLSError),
#[cfg(feature = "use-rustls")] #[cfg(feature = "use-rustls")]
#[error("TLS (rustls) error: {0}")]
TlsRustls(#[from] rustls::TLSError),
/// DNS name resolution error. /// DNS name resolution error.
Dns(webpki::InvalidDNSNameError), #[cfg(feature = "use-rustls")]
#[error("Invalid DNS name: {0}")]
Dns(#[from] webpki::InvalidDNSNameError),
/// - When reading: buffer capacity exhausted. /// - When reading: buffer capacity exhausted.
/// - When writing: your message is bigger than the configured max message size /// - When writing: your message is bigger than the configured max message size
/// (64MB by default). /// (64MB by default).
Capacity(Cow<'static, str>), #[error("Space limit exceeded: {0}")]
Capacity(CapacityError),
/// Protocol violation. /// Protocol violation.
Protocol(Cow<'static, str>), #[error("WebSocket protocol error: {0}")]
Protocol(ProtocolError),
/// Message send queue full. /// Message send queue full.
#[error("Send queue is full")]
SendQueueFull(Message), SendQueueFull(Message),
/// UTF coding error /// UTF coding error.
#[error("UTF-8 encoding error")]
Utf8, Utf8,
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), #[error("URL error: {0}")]
Url(UrlError),
/// HTTP error. /// HTTP error.
#[error("HTTP error: {}", .0.status())]
Http(Response<Option<String>>), Http(Response<Option<String>>),
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), #[error("HTTP format error: {0}")]
} HttpFormat(#[from] http::Error),
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::ConnectionClosed => write!(f, "Connection closed normally"),
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature = "use-native-tls")]
Error::TlsNative(ref err) => write!(f, "TLS (native-tls) error: {}", err),
#[cfg(feature = "use-rustls")]
Error::TlsRustls(ref err) => write!(f, "TLS (rustls) error: {}", err),
#[cfg(feature = "use-rustls")]
Error::Dns(ref err) => write!(f, "Invalid DNS name: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
}
impl ErrorTrait for Error {}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::Io(err)
}
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -135,38 +118,136 @@ impl From<http::status::InvalidStatusCode> for Error {
} }
} }
impl From<http::Error> for Error { impl From<httparse::Error> for Error {
fn from(err: http::Error) -> Self { fn from(err: httparse::Error) -> Self {
Error::HttpFormat(err) match err {
} httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders),
} e => Error::Protocol(ProtocolError::HttparseError(e)),
}
#[cfg(feature = "use-native-tls")]
impl From<native_tls::Error> for Error {
fn from(err: native_tls::Error) -> Self {
Error::TlsNative(err)
} }
} }
#[cfg(feature = "use-rustls")] /// Indicates the specific type/cause of a capacity error.
impl From<rustls::TLSError> for Error { #[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
fn from(err: rustls::TLSError) -> Self { pub enum CapacityError {
Error::TlsRustls(err) /// Too many headers provided (see [`httparse::Error::TooManyHeaders`]).
} #[error("Too many headers")]
TooManyHeaders,
/// Received header is too long.
#[error("Header too long")]
HeaderTooLong,
/// Message is bigger than the maximum allowed size.
#[error("Message too long: {size} > {max_size}")]
MessageTooLong {
/// The size of the message.
size: usize,
/// The maximum allowed message size.
max_size: usize,
},
/// TCP buffer is full.
#[error("Incoming TCP buffer is full")]
TcpBufferFull,
} }
#[cfg(feature = "use-rustls")] /// Indicates the specific type/cause of a protocol error.
impl From<webpki::InvalidDNSNameError> for Error { #[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
fn from(err: webpki::InvalidDNSNameError) -> Self { pub enum ProtocolError {
Error::Dns(err) /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
} #[error("Unsupported HTTP method used - only GET is allowed")]
WrongHttpMethod,
/// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher).
#[error("HTTP version must be 1.1 or higher")]
WrongHttpVersion,
/// Missing `Connection: upgrade` HTTP header.
#[error("No \"Connection: upgrade\" header")]
MissingConnectionUpgradeHeader,
/// Missing `Upgrade: websocket` HTTP header.
#[error("No \"Upgrade: websocket\" header")]
MissingUpgradeWebSocketHeader,
/// Missing `Sec-WebSocket-Version: 13` HTTP header.
#[error("No \"Sec-WebSocket-Version: 13\" header")]
MissingSecWebSocketVersionHeader,
/// Missing `Sec-WebSocket-Key` HTTP header.
#[error("No \"Sec-WebSocket-Key\" header")]
MissingSecWebSocketKey,
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
#[error("Key mismatch in \"Sec-WebSocket-Accept\" header")]
SecWebSocketAcceptKeyMismatch,
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,
/// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")]
CustomResponseSuccessful,
/// No more data while still performing handshake.
#[error("Handshake not finished")]
HandshakeIncomplete,
/// Wrapper around a [`httparse::Error`] value.
#[error("httparse error: {0}")]
HttparseError(#[from] httparse::Error),
/// Not allowed to send after having sent a closing frame.
#[error("Sending after closing is not allowed")]
SendAfterClosing,
/// Remote sent data after sending a closing frame.
#[error("Remote sent after having closed")]
ReceivedAfterClosing,
/// Reserved bits in frame header are non-zero.
#[error("Reserved bits are non-zero")]
NonZeroReservedBits,
/// The server must close the connection when an unmasked frame is received.
#[error("Received an unmasked frame from client")]
UnmaskedFrameFromClient,
/// The client must close the connection when a masked frame is received.
#[error("Received a masked frame from server")]
MaskedFrameFromServer,
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
/// Type of control frame not recognised.
#[error("Unknown control frame type: {0}")]
UnknownControlFrameType(u8),
/// Type of data frame not recognised.
#[error("Unknown data frame type: {0}")]
UnknownDataFrameType(u8),
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
/// Connection closed without performing the closing handshake.
#[error("Connection reset without closing handshake")]
ResetWithoutClosingHandshake,
/// Encountered an invalid opcode.
#[error("Encountered invalid opcode: {0}")]
InvalidOpcode(u8),
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
} }
impl From<httparse::Error> for Error { /// Indicates the specific type/cause of URL error.
fn from(err: httparse::Error) -> Self { #[derive(Error, Debug, PartialEq, Eq)]
match err { pub enum UrlError {
httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), /// TLS is used despite not being compiled with the TLS feature enabled.
e => Error::Protocol(e.to_string().into()), #[error("TLS support not compiled in")]
} TlsFeatureNotEnabled,
} /// The URL does not include a host name.
#[error("No host name in the URL")]
NoHostName,
/// Failed to connect with this URL.
#[error("Unable to connect to {0}")]
UnableToConnect(String),
/// Unsupported URL scheme used (only `ws://` or `wss://` may be used).
#[error("URL scheme not supported")]
UnsupportedUrlScheme,
/// The URL host name, though included, is empty.
#[error("URL contains empty host name")]
EmptyHostName,
/// The URL does not include a path/query.
#[error("No path/query in URL")]
NoPathOrQuery,
} }

@ -16,7 +16,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult, HandshakeRole, MidHandshake, ProcessingResult,
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result, UrlError},
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -42,11 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
// Check the URI scheme: only ws or wss are supported // Check the URI scheme: only ws or wss are supported
@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri(); let uri = request.uri();
let authority = let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
let host = if let Some(idx) = authority.find('@') { let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@ // handle possible name:password@
authority.split_at(idx + 1).1 authority.split_at(idx + 1).1
@ -106,7 +105,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
authority authority
}; };
if authority.is_empty() { if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into())); return Err(Error::Url(UrlError::EmptyHostName));
} }
write!( write!(
@ -120,8 +119,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
version = request.version(), version = request.version(),
host = host, host = host,
path = path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
@ -165,7 +163,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
} }
// 3. If the response lacks a |Connection| header field or the // 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
@ -177,14 +175,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
} }
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or // 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
} }
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
@ -218,7 +216,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;

@ -3,7 +3,7 @@ use log::*;
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write};
use crate::{ use crate::{
error::{Error, Result}, error::{CapacityError, Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
@ -46,11 +46,11 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
let read = buf let read = buf
.prepare_reserve(MIN_READ) .prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) // TODO limit size .with_limit(usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))? .map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))?
.read_from(&mut self.stream) .read_from(&mut self.stream)
.no_block()?; .no_block()?;
match read { match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())), Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size); buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading { RoundResult::StageFinished(StageResult::DoneReading {

@ -19,7 +19,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult, HandshakeRole, MidHandshake, ProcessingResult,
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result},
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse<Option<String>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
if !request if !request
@ -48,7 +48,7 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
} }
if !request if !request
@ -58,17 +58,17 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
} }
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
} }
let key = request let key = request
.headers() .headers()
.get("Sec-WebSocket-Key") .get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
let builder = Response::builder() let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS) .status(StatusCode::SWITCHING_PROTOCOLS)
@ -125,11 +125,11 @@ impl TryParse for Request {
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" { if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -237,7 +237,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?; let response = create_response(&result)?;
@ -256,9 +256,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Err(resp) => { Err(resp) => {
if resp.status().is_success() { if resp.status().is_success() {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
"Custom response must not be successful".into(),
));
} }
self.error_response = Some(resp); self.error_response = Some(resp);

@ -143,7 +143,7 @@ pub enum CloseCode {
Abnormal, Abnormal,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection
/// because it has received data within a message that was not /// because it has received data within a message that was not
/// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
/// data within a text message). /// data within a text message).
Invalid, Invalid,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection

@ -13,7 +13,7 @@ use super::{
coding::{CloseCode, Control, Data, OpCode}, coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask}, mask::{apply_mask, generate_mask},
}; };
use crate::error::{Error, Result}; use crate::error::{Error, ProtocolError, Result};
/// A struct representing the close command. /// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
@ -186,9 +186,7 @@ impl FrameHeader {
// Disallow bad opcode // Disallow bad opcode
match opcode { match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
} }
_ => (), _ => (),
} }
@ -286,7 +284,7 @@ impl Frame {
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> { pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() { match self.payload.len() {
0 => Ok(None), 0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => { _ => {
let mut data = self.payload; let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();

@ -8,7 +8,7 @@ mod mask;
pub use self::frame::{CloseFrame, Frame, FrameHeader}; pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::error::{Error, Result}; use crate::error::{CapacityError, Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use log::*; use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
@ -133,9 +133,10 @@ impl FrameCodec {
// Enforce frame size limit early and make sure `length` // Enforce frame size limit early and make sure `length`
// is not too big (fits into `usize`). // is not too big (fits into `usize`).
if length > max_size as u64 { if length > max_size as u64 {
return Err(Error::Capacity( return Err(Error::Capacity(CapacityError::MessageTooLong {
format!("Message length too big: {} > {}", length, max_size).into(), size: length as usize,
)); max_size,
}));
} }
let input_size = cursor.get_ref().len() as u64 - cursor.position(); let input_size = cursor.get_ref().len() as u64 - cursor.position();
@ -155,7 +156,7 @@ impl FrameCodec {
.in_buffer .in_buffer
.prepare_reserve(MIN_READ) .prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) .with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))?
.read_from(stream)?; .read_from(stream)?;
if size == 0 { if size == 0 {
trace!("no frame received"); trace!("no frame received");
@ -206,6 +207,8 @@ impl FrameCodec {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::error::{CapacityError, Error};
use super::{Frame, FrameSocket}; use super::{Frame, FrameSocket};
use std::io::Cursor; use std::io::Cursor;
@ -266,9 +269,9 @@ mod tests {
fn size_limit_hit() { fn size_limit_hit() {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!( assert!(matches!(
sock.read_frame(Some(5)).unwrap_err().to_string(), sock.read_frame(Some(5)),
"Space limit exceeded: Message length too big: 7 > 5" Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
); ));
} }
} }

@ -6,7 +6,7 @@ use std::{
}; };
use super::frame::CloseFrame; use super::frame::CloseFrame;
use crate::error::{Error, Result}; use crate::error::{CapacityError, Error, Result};
mod string_collect { mod string_collect {
use utf8::DecodeError; use utf8::DecodeError;
@ -122,9 +122,10 @@ impl IncompleteMessage {
let portion_size = tail.as_ref().len(); let portion_size = tail.as_ref().len();
// Be careful about integer overflows here. // Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size { if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity( return Err(Error::Capacity(CapacityError::MessageTooLong {
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), size: my_size + portion_size,
)); max_size,
}));
} }
match self.collector { match self.collector {

@ -21,7 +21,7 @@ use self::{
message::{IncompleteMessage, IncompleteMessageType}, message::{IncompleteMessage, IncompleteMessageType},
}; };
use crate::{ use crate::{
error::{Error, Result}, error::{Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
@ -331,7 +331,7 @@ impl WebSocketContext {
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol("Sending after closing is not allowed".into())); return Err(Error::Protocol(ProtocolError::SendAfterClosing));
} }
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
@ -431,9 +431,7 @@ impl WebSocketContext {
.check_connection_reset(self.state)? .check_connection_reset(self.state)?
{ {
if !self.state.can_read() { if !self.state.can_read() {
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing));
"Remote sent frame after having sent a Close Frame".into(),
));
} }
// MUST be 0 unless an extension is negotiated that defines meanings // MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of // for non-zero values. If a nonzero value is received and none of
@ -443,7 +441,7 @@ impl WebSocketContext {
{ {
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into())); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
} }
@ -458,15 +456,13 @@ impl WebSocketContext {
// frame that is not masked. (RFC 6455) // frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given // The only exception here is if the user explicitly accepts given
// stream by setting WebSocketConfig.accept_unmasked_frames to true // stream by setting WebSocketConfig.accept_unmasked_frames to true
return Err(Error::Protocol( return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
"Received an unmasked frame from client".into(),
));
} }
} }
Role::Client => { Role::Client => {
if frame.is_masked() { if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455) // A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into())); return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer));
} }
} }
} }
@ -477,14 +473,14 @@ impl WebSocketContext {
// All control frames MUST have a payload length of 125 bytes or less // All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455) // and MUST NOT be fragmented. (RFC 6455)
_ if !frame.header().is_final => { _ if !frame.header().is_final => {
Err(Error::Protocol("Fragmented control frame".into())) Err(Error::Protocol(ProtocolError::FragmentedControlFrame))
} }
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into())) Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
} }
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => { OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
} }
OpCtl::Ping => { OpCtl::Ping => {
let data = frame.into_data(); let data = frame.into_data();
@ -506,7 +502,7 @@ impl WebSocketContext {
msg.extend(frame.into_data(), self.config.max_message_size)?; msg.extend(frame.into_data(), self.config.max_message_size)?;
} else { } else {
return Err(Error::Protocol( return Err(Error::Protocol(
"Continue frame but nothing to continue".into(), ProtocolError::UnexpectedContinueFrame,
)); ));
} }
if fin { if fin {
@ -515,9 +511,9 @@ impl WebSocketContext {
Ok(None) Ok(None)
} }
} }
c if self.incomplete.is_some() => Err(Error::Protocol( c if self.incomplete.is_some() => {
format!("Received {} while waiting for more fragments", c).into(), Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
)), }
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = { let msg = {
let message_type = match data { let message_type = match data {
@ -537,7 +533,7 @@ impl WebSocketContext {
} }
} }
OpData::Reserved(i) => { OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
} }
} }
} }
@ -548,7 +544,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} }
_ => Err(Error::Protocol("Connection reset without closing handshake".into())), _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)),
} }
} }
} }
@ -673,6 +669,7 @@ impl<T> CheckConnectionReset for Result<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::error::{CapacityError, Error};
use std::{io, io::Cursor}; use std::{io, io::Cursor};
@ -715,10 +712,11 @@ mod tests {
]); ]);
let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(), assert!(matches!(
"Space limit exceeded: Message too big: 7 + 6 > 10" socket.read_message(),
); Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
));
} }
#[test] #[test]
@ -726,9 +724,10 @@ mod tests {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(), assert!(matches!(
"Space limit exceeded: Message too big: 0 + 3 > 2" socket.read_message(),
); Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
));
} }
} }

@ -8,7 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url; use url::Url;
#[test] #[test]
@ -46,7 +46,7 @@ fn test_no_send_after_close() {
assert!(err.is_err()); assert!(err.is_err());
match err.unwrap_err() { match err.unwrap_err() {
Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), Error::Protocol(s) => assert_eq!(s, ProtocolError::SendAfterClosing),
e => panic!("unexpected error: {:?}", e), e => panic!("unexpected error: {:?}", e),
} }

Loading…
Cancel
Save