Add `permessage-deflate` support

pull/235/head
kazk 4 years ago
parent 89697449ff
commit 54acf30635
  1. 2
      .gitignore
  2. 1
      Cargo.toml
  3. 2
      README.md
  4. 1086
      autobahn/expected-results.json
  5. 14
      examples/autobahn-client.rs
  6. 14
      examples/autobahn-server.rs
  7. 1
      examples/srv_accept_unmasked_frames.rs
  8. 22
      src/error.rs
  9. 289
      src/extensions/compression/deflate.rs
  10. 4
      src/extensions/compression/mod.rs
  11. 81
      src/extensions/mod.rs
  12. 95
      src/handshake/client.rs
  13. 26
      src/handshake/server.rs
  14. 1
      src/lib.rs
  15. 11
      src/protocol/frame/frame.rs
  16. 8
      src/protocol/message.rs
  17. 123
      src/protocol/mod.rs

2
.gitignore vendored

@ -1,2 +1,4 @@
target
Cargo.lock
autobahn/client/
autobahn/server/

@ -28,6 +28,7 @@ __rustls-tls = ["rustls", "webpki"]
base64 = "0.13.0"
byteorder = "1.3.2"
bytes = "1.0"
flate2 = "1.0"
http = "0.2"
httparse = "1.3.4"
log = "0.4.8"

@ -63,8 +63,6 @@ TLS is supported on all platforms using native-tls or rustls available through t
and `rustls-tls` feature flags. By default **no TLS feature is activated**, so make sure you
use `native-tls` or `rustls-tls` feature if you need support of the TLS.
There is no support for permessage-deflate at the moment. It's planned.
Testing
-------

File diff suppressed because it is too large Load Diff

@ -1,7 +1,10 @@
use log::*;
use url::Url;
use tungstenite::{connect, Error, Message, Result};
use tungstenite::{
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite";
@ -24,7 +27,14 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?;
let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {

@ -4,7 +4,10 @@ use std::{
};
use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::{
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -14,7 +17,14 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test");
loop {
match socket.read_message()? {

@ -35,6 +35,7 @@ fn main() {
// rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames
accept_unmasked_frames: true,
..WebSocketConfig::default()
});
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

@ -2,7 +2,10 @@
use std::{io, result, str, string};
use crate::protocol::{frame::coding::Data, Message};
use crate::{
extensions,
protocol::{frame::coding::Data, Message},
};
use http::Response;
use thiserror::Error;
@ -67,6 +70,9 @@ pub enum Error {
/// HTTP format error.
#[error("HTTP format error: {0}")]
HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[error("deflate error: {0}")]
Deflate(#[from] extensions::DeflateError),
}
impl From<str::Utf8Error> for Error {
@ -138,7 +144,7 @@ pub enum CapacityError {
}
/// Indicates the specific type/cause of a protocol error.
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum ProtocolError {
/// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
#[error("Unsupported HTTP method used - only GET is allowed")]
@ -191,6 +197,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
@ -203,6 +212,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
@ -215,6 +227,12 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
}
/// Indicates the specific type/cause of URL error.

@ -0,0 +1,289 @@
use std::io::Write;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use http::HeaderValue;
use thiserror::Error;
use crate::{
extensions::{self, Param},
protocol::Role,
};
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Error from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("failed to compress: {0}")]
Compress(std::io::Error),
/// Decompress failed
#[error("failed to decompress: {0}")]
Decompress(std::io::Error),
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
// TODO Configs for how the server accepts these offers.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl Default for DeflateConfig {
fn default() -> Self {
Self {
compression: Compression::default(),
server_no_context_takeover: false,
client_no_context_takeover: false,
}
}
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn negotiation_offers(&self) -> HeaderValue {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
}
if self.client_no_context_takeover {
offers.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
// This can be used for `WebSocket::from_raw_socket_with_compression`.
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub fn negotiation_response(&self, extensions: &str) -> Option<(HeaderValue, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// * The negotiation offer contains an extension parameter not defined
// for use in an offer.
// * The negotiation offer contains an extension parameter with an
// invalid value.
// * The negotiation offer contains multiple extension parameters with
// the same name.
// * The server doesn't support the offered configuration.
'outer: for (_, offer) in
extensions::parse_header(extensions).iter().filter(|(k, _)| k == self.name())
{
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for param in offer {
match param.name() {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
continue 'outer;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
continue 'outer;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
continue 'outer;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
continue 'outer;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
continue 'outer;
}
}
}
return Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)));
}
None
}
pub(crate) fn accept_response(&self, agreed: &[Param]) -> Result<DeflateContext, DeflateError> {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
for param in agreed {
match param.name() {
SERVER_NO_CONTEXT_TAKEOVER => {
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {}
CLIENT_MAX_WINDOW_BITS => {}
_ => {
//
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[Param]) -> HeaderValue {
let mut value = Vec::new();
write!(value, "{}", PER_MESSAGE_DEFLATE).unwrap();
for param in params {
if let Some(v) = param.value() {
write!(value, "; {}={}", param.name(), v).unwrap();
} else {
write!(value, "; {}", param.name()).unwrap();
}
}
HeaderValue::from_bytes(&value).unwrap()
}

@ -0,0 +1,4 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -0,0 +1,81 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
use std::borrow::Cow;
mod compression;
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError};
/// Extension parameter.
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct Param<'a> {
name: Cow<'a, str>,
value: Option<Cow<'a, str>>,
}
impl<'a> Param<'a> {
/// Create a new parameter with name.
pub fn new(name: impl Into<Cow<'a, str>>) -> Self {
Param { name: name.into(), value: None }
}
/// Consume itself to create a parameter with value.
pub fn with_value(mut self, value: impl Into<Cow<'a, str>>) -> Self {
self.value = Some(value.into());
self
}
/// Get the name of the parameter.
pub fn name(&self) -> &str {
&self.name
}
/// Get the optional value of the parameter.
pub fn value(&self) -> Option<&str> {
self.value.as_ref().map(|v| v.as_ref())
}
}
// NOTE This doesn't support quoted values
/// Parse `Sec-WebSocket-Extensions` offer/response.
pub(crate) fn parse_header(exts: &str) -> Vec<(Cow<'_, str>, Vec<Param<'_>>)> {
let mut collected = Vec::new();
// ext-name; a; b=c, ext-name; x, y=z
for ext in exts.split(',') {
let mut parts = ext.split(';');
if let Some(name) = parts.next().map(str::trim) {
let mut params = Vec::new();
for p in parts {
let mut kv = p.splitn(2, '=');
if let Some(key) = kv.next().map(str::trim) {
let param = if let Some(value) = kv.next().map(str::trim) {
Param::new(key).with_value(value)
} else {
Param::new(key)
};
params.push(param);
}
}
collected.push((Cow::from(name), params));
}
}
collected
}
#[test]
fn test_parse_extensions() {
let extensions = "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits";
assert_eq!(
parse_header(extensions),
vec![
(
Cow::from("permessage-deflate"),
vec![
Param::new("client_max_window_bits"),
Param::new("server_max_window_bits").with_value("10")
]
),
(Cow::from("permessage-deflate"), vec![Param::new("client_max_window_bits")])
]
);
}

@ -17,6 +17,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
extensions::{self, DeflateContext},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -55,7 +56,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let key = generate_key();
let machine = {
let req = generate_request(request, &key)?;
let req = generate_request(request, &key, &config)?;
HandshakeMachine::start_write(stream, req)
};
@ -82,10 +83,15 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let result = self.verify_data.verify_response(result)?;
let (result, pmce) = self.verify_data.verify_response(result, &self.config)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read_with_compression(
stream,
tail,
Role::Client,
self.config,
pmce,
);
ProcessingResult::Done((websocket, result))
}
})
@ -93,7 +99,11 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Generate client request.
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
fn generate_request(
request: Request,
key: &str,
config: &Option<WebSocketConfig>,
) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();
@ -131,6 +141,10 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
}
if let Some(compression) = &config.and_then(|c| c.compression) {
let offer = compression.negotiation_offers();
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offer.to_str()?).unwrap();
}
writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok(req)
@ -144,7 +158,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> {
pub fn verify_response(
&self,
response: Response,
config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<DeflateContext>)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -184,12 +202,39 @@ impl VerifyData {
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
}
let mut pmce = None;
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
if let Some(exts) = headers
.get("Sec-WebSocket-Extensions")
.and_then(|h| h.to_str().ok())
.map(extensions::parse_header)
{
if let Some(compression) = &config.and_then(|c| c.compression) {
for (name, params) in exts {
if name != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
name.to_string(),
)));
}
// Already had PMCE configured
if pmce.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(),
)));
}
pmce = Some(compression.accept_response(&params)?);
}
} else if let Some((name, _)) = exts.get(0) {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
}
}
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -198,7 +243,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(response)
Ok((response, pmce))
}
}
@ -243,7 +288,7 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use crate::{client::IntoClientRequest, extensions::DeflateConfig, protocol::WebSocketConfig};
#[test]
fn random_keys() {
@ -273,7 +318,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &None).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -290,7 +335,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &None).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -307,7 +352,33 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &None).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n";
let request = generate_request(
request,
key,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -20,6 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -202,6 +203,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated Per-Message Compression Extension context for server.
pmce: Option<extensions::DeflateContext>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -219,6 +222,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
pmce: None,
_marker: PhantomData,
},
}
@ -240,7 +244,20 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let mut response = create_response(&result)?;
if let Some(compression) = &self.config.and_then(|c| c.compression) {
if let Some(extensions) = result
.headers()
.get("Sec-WebSocket-Extensions")
.and_then(|v| v.to_str().ok())
{
if let Some((agreed, pmce)) = compression.negotiation_response(extensions) {
self.pmce = Some(pmce);
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
}
}
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
@ -280,7 +297,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket_with_compression(
stream,
Role::Server,
self.config,
self.pmce.take(),
);
ProcessingResult::Done(websocket)
}
}

@ -17,6 +17,7 @@ pub use http;
pub mod buffer;
pub mod client;
pub mod error;
pub mod extensions;
pub mod handshake;
pub mod protocol;
mod server;

@ -304,6 +304,17 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
}
/// Create a new compressed data frame.
#[inline]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {

@ -84,6 +84,7 @@ use self::string_collect::StringCollector;
#[derive(Debug)]
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
compressed: bool,
}
#[derive(Debug)]
@ -94,7 +95,7 @@ enum IncompleteMessageCollector {
impl IncompleteMessage {
/// Create new.
pub fn new(message_type: IncompleteMessageType) -> Self {
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
@ -102,9 +103,14 @@ impl IncompleteMessage {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer.
pub fn len(&self) -> usize {
match self.collector {

@ -22,6 +22,7 @@ use self::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::{self, DeflateContext},
util::NonBlockingResult,
};
@ -56,6 +57,8 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
pub compression: Option<extensions::DeflateConfig>,
}
impl Default for WebSocketConfig {
@ -65,6 +68,7 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
accept_unmasked_frames: false,
compression: None,
}
}
}
@ -91,6 +95,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_compression(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.pmce = pmce;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
@ -108,6 +124,21 @@ impl<Stream> WebSocket<Stream> {
}
}
pub(crate) fn from_partially_read_with_compression(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_compression(
part, role, config, pmce,
),
}
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.socket
@ -241,6 +272,8 @@ pub struct WebSocketContext {
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
/// Per-Message Compression Extension. Only deflate is supported at the moment.
pub(crate) pmce: Option<extensions::DeflateContext>,
}
impl WebSocketContext {
@ -254,6 +287,7 @@ impl WebSocketContext {
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_else(WebSocketConfig::default),
pmce: None,
}
}
@ -265,6 +299,19 @@ impl WebSocketContext {
}
}
pub(crate) fn from_partially_read_with_compression(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
pmce,
..WebSocketContext::new(role, config)
}
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config)
@ -348,8 +395,28 @@ impl WebSocketContext {
}
let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Text(data) => {
if let Some(pmce) = self.pmce.as_mut() {
Frame::compressed_message(
pmce.compress(data.as_bytes())?,
OpCode::Data(OpData::Text),
true,
)
} else {
Frame::message(data.into(), OpCode::Data(OpData::Text), true)
}
}
Message::Binary(data) => {
if let Some(pmce) = self.pmce.as_mut() {
Frame::compressed_message(
pmce.compress(&data)?,
OpCode::Data(OpData::Binary),
true,
)
} else {
Frame::message(data, OpCode::Data(OpData::Binary), true)
}
}
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
@ -438,11 +505,16 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
let mut is_compressed = false;
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
if (hdr.rsv1 && self.pmce.is_none()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
if hdr.rsv1 && self.pmce.is_some() {
is_compressed = true;
}
}
match self.role {
@ -478,6 +550,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
}
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -498,22 +574,37 @@ impl WebSocketContext {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
if self.incomplete.is_some() && is_compressed {
return Err(Error::Protocol(
ProtocolError::UnexpectedContinueFrame,
ProtocolError::CompressedContinueFrame,
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
if let Some(ref mut msg) = self.incomplete {
let data = if msg.compressed() {
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap
self.pmce
.as_mut()
.unwrap()
.decompress(frame.into_data(), fin)?
} else {
frame.into_data()
};
msg.extend(data, self.config.max_message_size)?;
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
} else {
Ok(None)
Err(Error::Protocol(ProtocolError::UnexpectedContinueFrame))
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
@ -521,8 +612,16 @@ impl WebSocketContext {
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
let mut m = IncompleteMessage::new(message_type, is_compressed);
let data = if is_compressed {
self.pmce
.as_mut()
.unwrap()
.decompress(frame.into_data(), fin)?
} else {
frame.into_data()
};
m.extend(data, self.config.max_message_size)?;
m
};
if fin {

Loading…
Cancel
Save