Add `permessage-deflate` support

permessage-deflate
kazk 3 years ago committed by Daniel Abramov
parent e1033afd95
commit edb2377540
  1. 2
      .gitignore
  2. 1
      .travis.yml
  3. 14
      Cargo.toml
  4. 2
      README.md
  5. 576
      autobahn/expected-results.json
  6. 14
      examples/autobahn-client.rs
  7. 14
      examples/autobahn-server.rs
  8. 2
      examples/srv_accept_unmasked_frames.rs
  9. 2
      scripts/autobahn-client.sh
  10. 2
      scripts/autobahn-server.sh
  11. 19
      src/error.rs
  12. 442
      src/extensions/compression/deflate.rs
  13. 4
      src/extensions/compression/mod.rs
  14. 18
      src/extensions/mod.rs
  15. 123
      src/handshake/client.rs
  16. 26
      src/handshake/server.rs
  17. 1
      src/lib.rs
  18. 12
      src/protocol/frame/frame.rs
  19. 22
      src/protocol/message.rs
  20. 204
      src/protocol/mod.rs

2
.gitignore vendored

@ -1,2 +1,4 @@
target
Cargo.lock
autobahn/client/
autobahn/server/

@ -10,5 +10,6 @@ before_script:
script:
- cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -25,6 +25,15 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies]
data-encoding = { version = "2", optional = true }
@ -38,6 +47,11 @@ sha1 = { version = "0.10", optional = true }
thiserror = "1.0.23"
url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5"
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate]
optional = true

@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing
-------

File diff suppressed because it is too large Load Diff

@ -1,7 +1,10 @@
use log::*;
use url::Url;
use tungstenite::{connect, Error, Message, Result};
use tungstenite::{
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite";
@ -24,7 +27,14 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?;
let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {

@ -4,7 +4,10 @@ use std::{
};
use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::{
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -14,7 +17,14 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test");
loop {
match socket.read_message()? {

@ -35,6 +35,8 @@ fn main() {
// rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames
accept_unmasked_frames: true,
#[cfg(feature = "deflate")]
compression: None,
});
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3
cargo run --release --example autobahn-client
cargo run --release --example autobahn-client --features=deflate
test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi
}
cargo run --release --example autobahn-server & WSSERVER_PID=$!
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3
docker run --rm \

@ -70,6 +70,10 @@ pub enum Error {
#[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] crate::extensions::DeflateError),
}
impl From<str::Utf8Error> for Error {
@ -206,6 +210,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
@ -218,6 +225,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
@ -230,6 +240,15 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
InvalidExtensionsHeader,
}
/// Indicates the specific type/cause of URL error.

@ -0,0 +1,442 @@
use std::convert::TryFrom;
use bytes::BytesMut;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use headers::WebsocketExtension;
use http::HeaderValue;
use thiserror::Error;
use crate::protocol::Role;
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Errors from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress(#[source] std::io::Error),
/// Decompress failed
#[error("Failed to decompress")]
Decompress(#[source] std::io::Error),
/// Extension negotiation failed.
#[error("Extension negotiation failed")]
Negotiation(#[source] NegotiationError),
}
/// Errors from `permessage-deflate` extension negotiation.
#[derive(Debug, Error)]
pub enum NegotiationError {
/// Unknown parameter in a negotiation response.
#[error("Unknown parameter in a negotiation response: {0}")]
UnknownParameter(String),
/// Duplicate parameter in a negotiation response.
#[error("Duplicate parameter in a negotiation response: {0}")]
DuplicateParameter(String),
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
UnexpectedClientMaxWindowBits,
/// Received unsupported `server_max_window_bits` in a negotiation response.
#[error("Received unsupported server_max_window_bits in a negotiation response")]
ServerMaxWindowBitsNotSupported,
/// Invalid `client_max_window_bits` value in a negotiation response.
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
InvalidClientMaxWindowBitsValue(String),
/// Invalid `server_max_window_bits` value in a negotiation response.
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
InvalidServerMaxWindowBitsValue(String),
/// Missing `server_max_window_bits` value in a negotiation response.
#[error("Missing server_max_window_bits value in a negotiation response")]
MissingServerMaxWindowBitsValue,
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug, Default)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
// > a client informs the peer server of a hint that even if the server doesn't include the
// > "client_no_context_takeover" extension parameter in the corresponding
// > extension negotiation response to the offer, the client is not going
// > to use context takeover.
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
if self.client_no_context_takeover {
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub(crate) fn accept_offer(
&self,
offers: &headers::SecWebsocketExtensions,
) -> Option<(WebsocketExtension, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
// 2. The negotiation offer contains an extension parameter with an invalid value.
// 3. The negotiation offer contains multiple extension parameters with the same name.
// 4. The server doesn't support the offered configuration.
offers.iter().find_map(|extension| {
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for (key, val) in params {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
return None;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
return None;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `server_max_window_bits` requires a value in range [8, 15].
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
} else {
return None;
}
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
return None;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `client_max_window_bits` requires a value in range [8, 15] or no value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
}
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
return None;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
return None;
}
}
}
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
} else {
None
}
})
}
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = DeflateConfig {
compression: self.compression,
// If this was hinted in the offer, the client won't use context takeover
// even if the response doesn't include it.
// See `generate_offer`.
client_no_context_takeover: self.client_no_context_takeover,
..DeflateConfig::default()
};
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
// A client MUST _Fail the WebSocket Connection_ if the peer server
// accepted an extension negotiation offer for this extension with an
// extension negotiation response meeting any of the following
// conditions:
// 1. The negotiation response contains an extension parameter not defined for use in a response.
// 2. The negotiation response contains an extension parameter with an invalid value.
// 3. The negotiation response contains multiple extension parameters with the same name.
// 4. The client does not support the configuration that the response represents.
for (key, val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_server_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_server_no_context_takeover = true;
// A server MAY include the "server_no_context_takeover" extension
// parameter in an extension negotiation response even if the extension
// negotiation offer being accepted by the extension negotiation
// response didn't include the "server_no_context_takeover" extension
// parameter.
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_client_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_client_no_context_takeover = true;
// The server may include this parameter in the response and the client MUST support it.
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
));
}
} else {
return Err(DeflateError::Negotiation(
NegotiationError::MissingServerMaxWindowBitsValue,
));
}
// A server may include the "server_max_window_bits" extension parameter
// in an extension negotiation response even if the extension
// negotiation offer being accepted by the response didn't include the
// "server_max_window_bits" extension parameter.
//
// However, but we need to fail the connection because we don't support it (condition 4).
return Err(DeflateError::Negotiation(
NegotiationError::ServerMaxWindowBitsNotSupported,
));
}
CLIENT_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
));
}
}
// Fail the connection because the parameter is invalid when the client didn't offer.
//
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
return Err(DeflateError::Negotiation(
NegotiationError::UnexpectedClientMaxWindowBits,
));
}
// Response with unknown parameter MUST fail the WebSocket connection.
_ => {
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
key.to_owned(),
)));
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
fn is_valid_max_window_bits(bits: &str) -> bool {
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
}
#[cfg(test)]
mod tests {
use super::is_valid_max_window_bits;
#[test]
fn valid_max_window_bits() {
for bits in 8..=15 {
assert!(is_valid_max_window_bits(&bits.to_string()));
}
}
#[test]
fn invalid_max_window_bits() {
assert!(!is_valid_max_window_bits(""));
assert!(!is_valid_max_window_bits("0"));
assert!(!is_valid_max_window_bits("08"));
assert!(!is_valid_max_window_bits("+8"));
assert!(!is_valid_max_window_bits("-8"));
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
for param in params {
buf.extend_from_slice(b"; ");
buf.extend_from_slice(param.as_bytes());
}
let header = HeaderValue::from_maybe_shared(buf.freeze())
.expect("semicolon separated HeaderValue is valid");
WebsocketExtension::try_from(header).expect("valid extension")
}

@ -0,0 +1,4 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -0,0 +1,18 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -5,6 +5,7 @@ use std::{
marker::PhantomData,
};
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -19,6 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -56,7 +58,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let (request, key) = generate_request(request, &config)?;
let machine = HandshakeMachine::start_write(stream, request);
@ -83,18 +85,24 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) {
let (result, extensions) =
match self.verify_data.verify_response(result, &self.config) {
Ok(r) => r,
Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail);
return Err(Error::Http(e))
},
return Err(Error::Http(e));
}
Err(e) => return Err(e),
};
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read_with_extensions(
stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result))
}
})
@ -102,7 +110,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
pub fn generate_request(
mut request: Request,
config: &Option<WebSocketConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new();
write!(
req,
@ -173,6 +184,9 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
}
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
}
writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok((req, key))
@ -186,7 +200,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> {
pub fn verify_response(
&self,
response: Response,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -231,7 +249,14 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
let extensions = if let Some(agreed) = headers
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
{
verify_extensions(&agreed, _config)?
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -240,8 +265,47 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(response)
Ok((response, extensions))
}
}
fn verify_extensions(
agreed_extensions: &headers::SecWebsocketExtensions,
_config: &Option<WebSocketConfig>,
) -> Result<Option<Extensions>> {
#[cfg(feature = "deflate")]
{
if let Some(compression) = _config.and_then(|c| c.compression) {
let mut extensions = None;
for extension in agreed_extensions.iter() {
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
// > the client MUST _Fail the WebSocket Connection_.
if extension.name() != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
extension.name().to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
extension.name().to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(extension.params())?),
});
}
return Ok(extensions);
}
}
if let Some(extension) = agreed_extensions.iter().next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
}
Ok(None)
}
impl TryParse for Response {
@ -286,6 +350,8 @@ pub fn generate_key() -> String {
mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test]
fn random_keys() {
@ -322,7 +388,7 @@ mod tests {
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -330,7 +396,7 @@ mod tests {
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -338,11 +404,40 @@ mod tests {
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n",
host = "localhost",
key = key
)
.into_bytes();
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -354,6 +449,6 @@ mod tests {
#[test]
fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err());
assert!(generate_request(request, &None).is_err());
}
}

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult,
};
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -20,6 +21,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
extensions: None,
_marker: PhantomData,
},
}
@ -240,7 +245,19 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let mut response = create_response(&result)?;
if let Some(config) = &self.config {
if let Some((agreed, extensions)) = result
.headers()
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
.and_then(|values| config.accept_offers(&values))
{
response.headers_mut().typed_insert(agreed);
self.extensions = Some(extensions);
}
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
@ -283,7 +300,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body)));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
self.extensions.take(),
);
ProcessingResult::Done(websocket)
}
}

@ -19,6 +19,7 @@ pub mod buffer;
#[cfg(feature = "handshake")]
pub mod client;
pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")]
pub mod handshake;
pub mod protocol;

@ -311,6 +311,18 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
}
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {

@ -84,6 +84,8 @@ use self::string_collect::StringCollector;
#[derive(Debug)]
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
}
#[derive(Debug)]
@ -94,6 +96,7 @@ enum IncompleteMessageCollector {
impl IncompleteMessage {
/// Create new.
#[cfg(not(feature = "deflate"))]
pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
@ -105,6 +108,25 @@ impl IncompleteMessage {
}
}
/// Create new.
#[cfg(feature = "deflate")]
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer.
pub fn len(&self) -> usize {
match self.collector {

@ -22,6 +22,7 @@ use self::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::Extensions,
util::NonBlockingResult,
};
@ -56,6 +57,9 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
}
impl Default for WebSocketConfig {
@ -65,6 +69,64 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
}
impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
#[cfg(feature = "deflate")]
{
let mut offers = Vec::new();
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
offers.push(compression);
}
if offers.is_empty() {
None
} else {
Some(headers::SecWebsocketExtensions::new(offers))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers(
&self,
_offers: &headers::SecWebsocketExtensions,
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
#[cfg(feature = "deflate")]
{
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
let mut agreed_extensions = Vec::new();
let mut extensions = Extensions::default();
if let Some(compression) = &self.compression {
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
agreed_extensions.push(agreed);
extensions.compression = Some(compression);
}
}
if agreed_extensions.is_empty() {
None
} else {
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
}
@ -91,6 +153,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
@ -108,6 +182,21 @@ impl<Stream> WebSocket<Stream> {
}
}
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.socket
@ -241,6 +330,8 @@ pub struct WebSocketContext {
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
}
impl WebSocketContext {
@ -254,6 +345,7 @@ impl WebSocketContext {
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_default(),
extensions: None,
}
}
@ -265,6 +357,19 @@ impl WebSocketContext {
}
}
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
extensions,
..WebSocketContext::new(role, config)
}
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config)
@ -348,8 +453,8 @@ impl WebSocketContext {
}
let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
@ -363,6 +468,17 @@ impl WebSocketContext {
self.write_pending(stream)
}
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
#[cfg(feature = "deflate")]
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
}
Ok(Frame::message(data, opcode, is_final))
}
/// Flush the pending send queue.
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
@ -439,12 +555,14 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
{
let is_compressed = {
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
}
hdr.rsv1
};
match self.role {
Role::Server => {
@ -479,6 +597,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
}
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -499,39 +621,34 @@ impl WebSocketContext {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
if self.incomplete.is_some() && is_compressed {
return Err(Error::Protocol(
ProtocolError::UnexpectedContinueFrame,
ProtocolError::CompressedContinueFrame,
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
let msg = self
.incomplete
.take()
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
self.extend_incomplete(msg, frame.into_data(), fin)
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
#[cfg(feature = "deflate")]
let msg = IncompleteMessage::new(message_type, is_compressed);
#[cfg(not(feature = "deflate"))]
let msg = IncompleteMessage::new(message_type);
self.extend_incomplete(msg, frame.into_data(), fin)
}
OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
@ -550,6 +667,32 @@ impl WebSocketContext {
}
}
fn extend_incomplete(
&mut self,
mut msg: IncompleteMessage,
data: Vec<u8>,
is_final: bool,
) -> Result<Option<Message>> {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(data, is_final)?
} else {
data
};
msg.extend(data, self.config.max_message_size)?;
if is_final {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
/// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
@ -605,6 +748,17 @@ impl WebSocketContext {
trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
fn has_compression(&self) -> bool {
#[cfg(feature = "deflate")]
{
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
#[cfg(not(feature = "deflate"))]
{
false
}
}
}
/// The current connection state.

Loading…
Cancel
Save