Add `permessage-deflate` support

permessage-deflate
kazk 3 years ago committed by Daniel Abramov
parent e1033afd95
commit edb2377540
  1. 2
      .gitignore
  2. 1
      .travis.yml
  3. 14
      Cargo.toml
  4. 2
      README.md
  5. 576
      autobahn/expected-results.json
  6. 14
      examples/autobahn-client.rs
  7. 14
      examples/autobahn-server.rs
  8. 2
      examples/srv_accept_unmasked_frames.rs
  9. 2
      scripts/autobahn-client.sh
  10. 2
      scripts/autobahn-server.sh
  11. 19
      src/error.rs
  12. 442
      src/extensions/compression/deflate.rs
  13. 4
      src/extensions/compression/mod.rs
  14. 18
      src/extensions/mod.rs
  15. 123
      src/handshake/client.rs
  16. 26
      src/handshake/server.rs
  17. 1
      src/lib.rs
  18. 12
      src/protocol/frame/frame.rs
  19. 22
      src/protocol/message.rs
  20. 204
      src/protocol/mod.rs

2
.gitignore vendored

@ -1,2 +1,4 @@
target target
Cargo.lock Cargo.lock
autobahn/client/
autobahn/server/

@ -10,5 +10,6 @@ before_script:
script: script:
- cargo test --release - cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh - echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh - echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -25,6 +25,15 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"] __rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies] [dependencies]
data-encoding = { version = "2", optional = true } data-encoding = { version = "2", optional = true }
@ -38,6 +47,11 @@ sha1 = { version = "0.10", optional = true }
thiserror = "1.0.23" thiserror = "1.0.23"
url = { version = "2.1.0", optional = true } url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5" utf-8 = "0.7.5"
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate] [dependencies.native-tls-crate]
optional = true optional = true

@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features, By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints. otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing Testing
------- -------

File diff suppressed because it is too large Load Diff

@ -1,7 +1,10 @@
use log::*; use log::*;
use url::Url; use url::Url;
use tungstenite::{connect, Error, Message, Result}; use tungstenite::{
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite"; const AGENT: &str = "Tungstenite";
@ -24,7 +27,14 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {

@ -4,7 +4,10 @@ use std::{
}; };
use log::*; use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result}; use tungstenite::{
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
@ -14,7 +17,14 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test"); info!("Running test");
loop { loop {
match socket.read_message()? { match socket.read_message()? {

@ -35,6 +35,8 @@ fn main() {
// rare cases where it is necessary to integrate with existing/legacy // rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames // clients which are sending unmasked frames
accept_unmasked_frames: true, accept_unmasked_frames: true,
#[cfg(feature = "deflate")]
compression: None,
}); });
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3 sleep 3
cargo run --release --example autobahn-client cargo run --release --example autobahn-client --features=deflate
test_diff test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi fi
} }
cargo run --release --example autobahn-server & WSSERVER_PID=$! cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3 sleep 3
docker run --rm \ docker run --rm \

@ -70,6 +70,10 @@ pub enum Error {
#[error("HTTP format error: {0}")] #[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error), HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] crate::extensions::DeflateError),
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -206,6 +210,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented. /// Control frames must not be fragmented.
#[error("Fragmented control frame")] #[error("Fragmented control frame")]
FragmentedControlFrame, FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less. /// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")] #[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig, ControlFrameTooBig,
@ -218,6 +225,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue. /// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")] #[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame, UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments. /// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")] #[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data), ExpectedFragment(Data),
@ -230,6 +240,15 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid. /// The payload for the closing frame is invalid.
#[error("Invalid close sequence")] #[error("Invalid close sequence")]
InvalidCloseSequence, InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
InvalidExtensionsHeader,
} }
/// Indicates the specific type/cause of URL error. /// Indicates the specific type/cause of URL error.

@ -0,0 +1,442 @@
use std::convert::TryFrom;
use bytes::BytesMut;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use headers::WebsocketExtension;
use http::HeaderValue;
use thiserror::Error;
use crate::protocol::Role;
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Errors from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress(#[source] std::io::Error),
/// Decompress failed
#[error("Failed to decompress")]
Decompress(#[source] std::io::Error),
/// Extension negotiation failed.
#[error("Extension negotiation failed")]
Negotiation(#[source] NegotiationError),
}
/// Errors from `permessage-deflate` extension negotiation.
#[derive(Debug, Error)]
pub enum NegotiationError {
/// Unknown parameter in a negotiation response.
#[error("Unknown parameter in a negotiation response: {0}")]
UnknownParameter(String),
/// Duplicate parameter in a negotiation response.
#[error("Duplicate parameter in a negotiation response: {0}")]
DuplicateParameter(String),
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
UnexpectedClientMaxWindowBits,
/// Received unsupported `server_max_window_bits` in a negotiation response.
#[error("Received unsupported server_max_window_bits in a negotiation response")]
ServerMaxWindowBitsNotSupported,
/// Invalid `client_max_window_bits` value in a negotiation response.
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
InvalidClientMaxWindowBitsValue(String),
/// Invalid `server_max_window_bits` value in a negotiation response.
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
InvalidServerMaxWindowBitsValue(String),
/// Missing `server_max_window_bits` value in a negotiation response.
#[error("Missing server_max_window_bits value in a negotiation response")]
MissingServerMaxWindowBitsValue,
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug, Default)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
// > a client informs the peer server of a hint that even if the server doesn't include the
// > "client_no_context_takeover" extension parameter in the corresponding
// > extension negotiation response to the offer, the client is not going
// > to use context takeover.
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
if self.client_no_context_takeover {
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub(crate) fn accept_offer(
&self,
offers: &headers::SecWebsocketExtensions,
) -> Option<(WebsocketExtension, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
// 2. The negotiation offer contains an extension parameter with an invalid value.
// 3. The negotiation offer contains multiple extension parameters with the same name.
// 4. The server doesn't support the offered configuration.
offers.iter().find_map(|extension| {
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for (key, val) in params {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
return None;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
return None;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `server_max_window_bits` requires a value in range [8, 15].
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
} else {
return None;
}
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
return None;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `client_max_window_bits` requires a value in range [8, 15] or no value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
}
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
return None;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
return None;
}
}
}
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
} else {
None
}
})
}
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = DeflateConfig {
compression: self.compression,
// If this was hinted in the offer, the client won't use context takeover
// even if the response doesn't include it.
// See `generate_offer`.
client_no_context_takeover: self.client_no_context_takeover,
..DeflateConfig::default()
};
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
// A client MUST _Fail the WebSocket Connection_ if the peer server
// accepted an extension negotiation offer for this extension with an
// extension negotiation response meeting any of the following
// conditions:
// 1. The negotiation response contains an extension parameter not defined for use in a response.
// 2. The negotiation response contains an extension parameter with an invalid value.
// 3. The negotiation response contains multiple extension parameters with the same name.
// 4. The client does not support the configuration that the response represents.
for (key, val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_server_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_server_no_context_takeover = true;
// A server MAY include the "server_no_context_takeover" extension
// parameter in an extension negotiation response even if the extension
// negotiation offer being accepted by the extension negotiation
// response didn't include the "server_no_context_takeover" extension
// parameter.
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_client_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_client_no_context_takeover = true;
// The server may include this parameter in the response and the client MUST support it.
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
));
}
} else {
return Err(DeflateError::Negotiation(
NegotiationError::MissingServerMaxWindowBitsValue,
));
}
// A server may include the "server_max_window_bits" extension parameter
// in an extension negotiation response even if the extension
// negotiation offer being accepted by the response didn't include the
// "server_max_window_bits" extension parameter.
//
// However, but we need to fail the connection because we don't support it (condition 4).
return Err(DeflateError::Negotiation(
NegotiationError::ServerMaxWindowBitsNotSupported,
));
}
CLIENT_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
));
}
}
// Fail the connection because the parameter is invalid when the client didn't offer.
//
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
return Err(DeflateError::Negotiation(
NegotiationError::UnexpectedClientMaxWindowBits,
));
}
// Response with unknown parameter MUST fail the WebSocket connection.
_ => {
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
key.to_owned(),
)));
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
fn is_valid_max_window_bits(bits: &str) -> bool {
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
}
#[cfg(test)]
mod tests {
use super::is_valid_max_window_bits;
#[test]
fn valid_max_window_bits() {
for bits in 8..=15 {
assert!(is_valid_max_window_bits(&bits.to_string()));
}
}
#[test]
fn invalid_max_window_bits() {
assert!(!is_valid_max_window_bits(""));
assert!(!is_valid_max_window_bits("0"));
assert!(!is_valid_max_window_bits("08"));
assert!(!is_valid_max_window_bits("+8"));
assert!(!is_valid_max_window_bits("-8"));
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
for param in params {
buf.extend_from_slice(b"; ");
buf.extend_from_slice(param.as_bytes());
}
let header = HeaderValue::from_maybe_shared(buf.freeze())
.expect("semicolon separated HeaderValue is valid");
WebsocketExtension::try_from(header).expect("valid extension")
}

@ -0,0 +1,4 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -0,0 +1,18 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -5,6 +5,7 @@ use std::{
marker::PhantomData, marker::PhantomData,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -19,6 +20,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result, UrlError}, error::{Error, ProtocolError, Result, UrlError},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -56,7 +58,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC. // Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request). // Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?; let (request, key) = generate_request(request, &config)?;
let machine = HandshakeMachine::start_write(stream, request); let machine = HandshakeMachine::start_write(stream, request);
@ -83,18 +85,24 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) { let (result, extensions) =
match self.verify_data.verify_response(result, &self.config) {
Ok(r) => r, Ok(r) => r,
Err(Error::Http(mut e)) => { Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail); *e.body_mut() = Some(tail);
return Err(Error::Http(e)) return Err(Error::Http(e));
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket = WebSocket::from_partially_read_with_extensions(
WebSocket::from_partially_read(stream, tail, Role::Client, self.config); stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
}) })
@ -102,7 +110,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
} }
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it. /// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> { pub fn generate_request(
mut request: Request,
config: &Option<WebSocketConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new(); let mut req = Vec::new();
write!( write!(
req, req,
@ -173,6 +184,9 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap(); writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
} }
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
}
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req)); trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok((req, key)) Ok((req, key))
@ -186,7 +200,11 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> { pub fn verify_response(
&self,
response: Response,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -231,7 +249,14 @@ impl VerifyData {
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO let extensions = if let Some(agreed) = headers
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
{
verify_extensions(&agreed, _config)?
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field // 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was // and this header field indicates the use of a subprotocol that was
@ -240,8 +265,47 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok(response) Ok((response, extensions))
}
}
fn verify_extensions(
agreed_extensions: &headers::SecWebsocketExtensions,
_config: &Option<WebSocketConfig>,
) -> Result<Option<Extensions>> {
#[cfg(feature = "deflate")]
{
if let Some(compression) = _config.and_then(|c| c.compression) {
let mut extensions = None;
for extension in agreed_extensions.iter() {
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
// > the client MUST _Fail the WebSocket Connection_.
if extension.name() != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
extension.name().to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
extension.name().to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(extension.params())?),
});
}
return Ok(extensions);
}
}
if let Some(extension) = agreed_extensions.iter().next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
} }
Ok(None)
} }
impl TryParse for Response { impl TryParse for Response {
@ -286,6 +350,8 @@ pub fn generate_key() -> String {
mod tests { mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response}; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest; use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -322,7 +388,7 @@ mod tests {
#[test] #[test]
fn request_formatting() { fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost", &key); let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -330,7 +396,7 @@ mod tests {
#[test] #[test]
fn request_formatting_with_host() { fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -338,11 +404,40 @@ mod tests {
#[test] #[test]
fn request_formatting_with_at() { fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n",
host = "localhost",
key = key
)
.into_bytes();
assert_eq!(&request[..], &correct[..]);
}
#[test] #[test]
fn response_parsing() { fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -354,6 +449,6 @@ mod tests {
#[test] #[test]
fn invalid_custom_request() { fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap(); let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err()); assert!(generate_request(request, &None).is_err());
} }
} }

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult, result::Result as StdResult,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -20,6 +21,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>, error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback), callback: Some(callback),
config, config,
error_response: None, error_response: None,
extensions: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -240,7 +245,19 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?; let mut response = create_response(&result)?;
if let Some(config) = &self.config {
if let Some((agreed, extensions)) = result
.headers()
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
.and_then(|values| config.accept_offers(&values))
{
response.headers_mut().typed_insert(agreed);
self.extensions = Some(extensions);
}
}
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result, response)
} else { } else {
@ -283,7 +300,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body))); return Err(Error::Http(http::Response::from_parts(parts, body)));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
self.extensions.take(),
);
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }
} }

@ -19,6 +19,7 @@ pub mod buffer;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod client; pub mod client;
pub mod error; pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod handshake; pub mod handshake;
pub mod protocol; pub mod protocol;

@ -311,6 +311,18 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
} }
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame. /// Create a new Pong control frame.
#[inline] #[inline]
pub fn pong(data: Vec<u8>) -> Frame { pub fn pong(data: Vec<u8>) -> Frame {

@ -84,6 +84,8 @@ use self::string_collect::StringCollector;
#[derive(Debug)] #[derive(Debug)]
pub struct IncompleteMessage { pub struct IncompleteMessage {
collector: IncompleteMessageCollector, collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -94,6 +96,7 @@ enum IncompleteMessageCollector {
impl IncompleteMessage { impl IncompleteMessage {
/// Create new. /// Create new.
#[cfg(not(feature = "deflate"))]
pub fn new(message_type: IncompleteMessageType) -> Self { pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage { IncompleteMessage {
collector: match message_type { collector: match message_type {
@ -105,6 +108,25 @@ impl IncompleteMessage {
} }
} }
/// Create new.
#[cfg(feature = "deflate")]
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer. /// Get the current filled size of the buffer.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self.collector { match self.collector {

@ -22,6 +22,7 @@ use self::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
util::NonBlockingResult, util::NonBlockingResult,
}; };
@ -56,6 +57,9 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC. /// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455. /// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool, pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
@ -65,6 +69,64 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
accept_unmasked_frames: false, accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
}
impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
#[cfg(feature = "deflate")]
{
let mut offers = Vec::new();
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
offers.push(compression);
}
if offers.is_empty() {
None
} else {
Some(headers::SecWebsocketExtensions::new(offers))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers(
&self,
_offers: &headers::SecWebsocketExtensions,
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
#[cfg(feature = "deflate")]
{
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
let mut agreed_extensions = Vec::new();
let mut extensions = Extensions::default();
if let Some(compression) = &self.compression {
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
agreed_extensions.push(agreed);
extensions.compression = Some(compression);
}
}
if agreed_extensions.is_empty() {
None
} else {
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
}
}
#[cfg(not(feature = "deflate"))]
{
None
} }
} }
} }
@ -91,6 +153,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) } WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
} }
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
/// ///
/// Call this function if you're using Tungstenite as a part of a web framework /// Call this function if you're using Tungstenite as a part of a web framework
@ -108,6 +182,21 @@ impl<Stream> WebSocket<Stream> {
} }
} }
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
&self.socket &self.socket
@ -241,6 +330,8 @@ pub struct WebSocketContext {
pong: Option<Frame>, pong: Option<Frame>,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
} }
impl WebSocketContext { impl WebSocketContext {
@ -254,6 +345,7 @@ impl WebSocketContext {
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
pong: None, pong: None,
config: config.unwrap_or_default(), config: config.unwrap_or_default(),
extensions: None,
} }
} }
@ -265,6 +357,19 @@ impl WebSocketContext {
} }
} }
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
extensions,
..WebSocketContext::new(role, config)
}
}
/// Change the configuration. /// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config) set_func(&mut self.config)
@ -348,8 +453,8 @@ impl WebSocketContext {
} }
let frame = match message { let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
self.pong = Some(Frame::pong(data)); self.pong = Some(Frame::pong(data));
@ -363,6 +468,17 @@ impl WebSocketContext {
self.write_pending(stream) self.write_pending(stream)
} }
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
#[cfg(feature = "deflate")]
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
}
Ok(Frame::message(data, opcode, is_final))
}
/// Flush the pending send queue. /// Flush the pending send queue.
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()> pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where where
@ -439,12 +555,14 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero // the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket // value, the receiving endpoint MUST _Fail the WebSocket
// Connection_. // Connection_.
{ let is_compressed = {
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
}
hdr.rsv1
};
match self.role { match self.role {
Role::Server => { Role::Server => {
@ -479,6 +597,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
} }
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => { OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -499,39 +621,34 @@ impl WebSocketContext {
let fin = frame.header().is_final; let fin = frame.header().is_final;
match data { match data {
OpData::Continue => { OpData::Continue => {
if let Some(ref mut msg) = self.incomplete { if self.incomplete.is_some() && is_compressed {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol( return Err(Error::Protocol(
ProtocolError::UnexpectedContinueFrame, ProtocolError::CompressedContinueFrame,
)); ));
} }
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?)) let msg = self
} else { .incomplete
Ok(None) .take()
} .ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
self.extend_incomplete(msg, frame.into_data(), fin)
} }
c if self.incomplete.is_some() => { c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
} }
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data { let message_type = match data {
OpData::Text => IncompleteMessageType::Text, OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary, OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"), _ => panic!("Bug: message is not text nor binary"),
}; };
let mut m = IncompleteMessage::new(message_type); #[cfg(feature = "deflate")]
m.extend(frame.into_data(), self.config.max_message_size)?; let msg = IncompleteMessage::new(message_type, is_compressed);
m #[cfg(not(feature = "deflate"))]
}; let msg = IncompleteMessage::new(message_type);
if fin { self.extend_incomplete(msg, frame.into_data(), fin)
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
} }
OpData::Reserved(i) => { OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
@ -550,6 +667,32 @@ impl WebSocketContext {
} }
} }
fn extend_incomplete(
&mut self,
mut msg: IncompleteMessage,
data: Vec<u8>,
is_final: bool,
) -> Result<Option<Message>> {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(data, is_final)?
} else {
data
};
msg.extend(data, self.config.max_message_size)?;
if is_final {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
/// Received a close frame. Tells if we need to return a close frame to the user. /// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)] #[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> { fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
@ -605,6 +748,17 @@ impl WebSocketContext {
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state) self.frame.write_frame(stream, frame).check_connection_reset(self.state)
} }
fn has_compression(&self) -> bool {
#[cfg(feature = "deflate")]
{
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
#[cfg(not(feature = "deflate"))]
{
false
}
}
} }
/// The current connection state. /// The current connection state.

Loading…
Cancel
Save