Refactors deflate extension

pull/144/head
SirCipher 5 years ago
parent bd170a0de6
commit 15621a0b9f
  1. 2
      .gitignore
  2. 12
      examples/autobahn-client.rs
  3. 314
      src/extensions/deflate.rs
  4. 7
      src/extensions/mod.rs
  5. 7
      src/protocol/mod.rs

2
.gitignore vendored

@ -3,5 +3,3 @@ Cargo.lock
autobahn/client/*
autobahn/server/*
.idea

@ -2,11 +2,11 @@ use log::*;
use url::Url;
use tungstenite::client::connect_with_config;
use tungstenite::extensions::deflate::DeflateExt;
use tungstenite::extensions::deflate::{DeflateConfigBuilder, DeflateExt};
use tungstenite::protocol::WebSocketConfig;
use tungstenite::{connect, Error, Message, Result};
const AGENT: &str = "Tungstenite";
const AGENT: &str = "Tungstenite-final-comp-slice";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
@ -34,12 +34,16 @@ fn run_test(case: u32) -> Result<()> {
case, AGENT
))
.unwrap();
let deflate_config = DeflateConfigBuilder::default()
.max_message_size(None)
.build();
let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: DeflateExt::default(),
encoder: DeflateExt::new(deflate_config),
}),
)?;
@ -54,8 +58,6 @@ fn run_test(case: u32) -> Result<()> {
}
fn main() {
println!("Starting");
env_logger::init();
let total = get_case_count().unwrap();

@ -6,9 +6,9 @@ use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::{Error, Message};
use crate::Message;
use bytes::BufMut;
use flate2::{
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress,
FlushDecompress, Status,
@ -23,7 +23,7 @@ use std::slice;
const EXT_IDENT: &str = "permessage-deflate";
/// The minimum size of the LZ77 sliding window size.
const LZ77_MIN_WINDOW_SIZE: u8 = 9;
const LZ77_MIN_WINDOW_SIZE: u8 = 8;
/// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter
/// indicates that the client can receive messages compressed using an LZ77 sliding window of up to
@ -33,13 +33,12 @@ const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)]
pub struct DeflateConfig {
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user.
max_message_size: Option<usize>,
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user.
max_message_size: usize,
/// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this
/// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in
/// range 9..=15.
/// range 8..15 inclusive.
max_window_bits: u8,
/// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1.
request_no_context_takeover: bool,
@ -65,7 +64,7 @@ impl DeflateConfig {
}
/// Returns the maximum message size permitted.
pub fn max_message_size(&self) -> Option<usize> {
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
@ -101,7 +100,7 @@ impl DeflateConfig {
/// Sets the maximum message size permitted.
pub fn set_max_message_size(&mut self, max_message_size: Option<usize>) {
self.max_message_size = max_message_size;
self.max_message_size = max_message_size.unwrap_or_else(usize::max_value);
}
/// Sets the LZ77 sliding window size.
@ -124,7 +123,7 @@ impl DeflateConfig {
impl Default for DeflateConfig {
fn default() -> Self {
DeflateConfig {
max_message_size: Some(MAX_MESSAGE_SIZE),
max_message_size: MAX_MESSAGE_SIZE,
max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
@ -166,11 +165,11 @@ impl DeflateConfigBuilder {
self
}
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `9..=15`.
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`.
pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder {
assert!(
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits),
"max window bits must be in range 9..=15"
"max window bits must be in range 8..=15"
);
self.max_window_bits = max_window_bits;
self
@ -197,7 +196,7 @@ impl DeflateConfigBuilder {
/// Consumes the builder and produces a `DeflateConfig.`
pub fn build(self) -> DeflateConfig {
DeflateConfig {
max_message_size: self.max_message_size,
max_message_size: self.max_message_size.unwrap_or_else(usize::max_value),
max_window_bits: self.max_window_bits,
request_no_context_takeover: self.request_no_context_takeover,
accept_no_context_takeover: self.accept_no_context_takeover,
@ -215,8 +214,8 @@ pub struct DeflateExt {
enabled: bool,
/// The configuration for the extension.
config: DeflateConfig,
/// A stack of continuation frames awaiting `fin`.
fragments: Vec<Frame>,
/// A stack of continuation frames awaiting `fin` and the total size of all of the fragments.
fragment_buffer: FragmentBuffer,
/// The deflate decompressor.
inflator: Inflator,
/// The deflate compressor.
@ -231,38 +230,23 @@ impl DeflateExt {
DeflateExt {
enabled: false,
config,
fragments: vec![],
fragment_buffer: FragmentBuffer::new(config.max_message_size),
inflator: Inflator::new(),
deflator: Deflator::new(Compression::fast()),
uncompressed_extension: UncompressedExt::new(config.max_message_size()),
uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())),
}
}
fn complete_message(&self, data: Vec<u8>, opcode: OpCode) -> Result<Message, Error> {
let message_type = match opcode {
OpCode::Data(Data::Text) => IncompleteMessageType::Text,
OpCode::Data(Data::Binary) => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut incomplete_message = IncompleteMessage::new(message_type);
incomplete_message.extend(data, self.config.max_message_size())?;
incomplete_message.complete()
}
fn parse_window_parameter<'a>(
&self,
&mut self,
mut param_iter: impl Iterator<Item = &'a str>,
) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() {
match window_bits_str.trim().parse() {
Ok(mut window_bits) => {
if window_bits == 8 {
window_bits = LZ77_MIN_WINDOW_SIZE;
}
Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
if window_bits != self.config.max_window_bits() {
self.config.max_window_bits = window_bits;
Ok(Some(window_bits))
} else {
Ok(None)
@ -293,6 +277,8 @@ pub enum DeflateExtensionError {
InflateError(String),
/// An error produced during the WebSocket negotiation.
NegotiationError(String),
/// Produced when fragment buffer grew beyond the maximum configured size.
Capacity(Cow<'static, str>),
}
impl Display for DeflateExtensionError {
@ -307,6 +293,7 @@ impl Display for DeflateExtensionError {
DeflateExtensionError::NegotiationError(m) => {
write!(f, "An upgrade error was encountered: {}", m)
}
DeflateExtensionError::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
}
}
}
@ -336,7 +323,7 @@ impl WebSocketExtension for DeflateExt {
fn new(max_message_size: Option<usize>) -> Self {
DeflateExt::new(DeflateConfig {
max_message_size,
max_message_size: max_message_size.unwrap_or_else(usize::max_value),
..Default::default()
})
}
@ -389,7 +376,7 @@ impl WebSocketExtension for DeflateExt {
let mut client_max_bits = false;
for param in header.split(';') {
match param.trim() {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => response_str.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if server_takeover {
@ -419,13 +406,10 @@ impl WebSocketExtension for DeflateExt {
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator {
compress: Compress::new_with_window_bits(
self.config.compression_level(),
false,
bits,
),
};
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
response_str.push_str("; ");
response_str.push_str(param)
}
@ -444,11 +428,8 @@ impl WebSocketExtension for DeflateExt {
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator {
decompress: Decompress::new_with_window_bits(
false, bits,
),
};
self.inflator = Inflator::new_with_window_bits(bits);
response_str.push_str("; ");
response_str.push_str(param);
continue;
@ -527,11 +508,11 @@ impl WebSocketExtension for DeflateExt {
match header.to_str() {
Ok(header) => {
for param in header.split(';') {
match param.trim() {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
if extension_name {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter permessage-deflate"
"Duplicate extension parameter: permessage-deflate"
)));
} else {
self.enabled = true;
@ -541,7 +522,7 @@ impl WebSocketExtension for DeflateExt {
"server_no_context_takeover" => {
if server_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter server_no_context_takeover"
"Duplicate extension parameter: server_no_context_takeover"
)));
} else {
server_takeover = true;
@ -551,7 +532,7 @@ impl WebSocketExtension for DeflateExt {
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter client_no_context_takeover"
"Duplicate extension parameter: client_no_context_takeover"
)));
} else {
client_takeover = true;
@ -568,20 +549,14 @@ impl WebSocketExtension for DeflateExt {
param if param.starts_with("server_max_window_bits") => {
if server_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter server_max_window_bits"
"Duplicate extension parameter: server_max_window_bits"
)));
} else {
server_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator {
compress: Compress::new_with_window_bits(
self.config.compression_level(),
false,
bits,
),
};
self.inflator = Inflator::new_with_window_bits(bits);
}
Ok(None) => {}
Err(e) => {
@ -598,18 +573,17 @@ impl WebSocketExtension for DeflateExt {
param if param.starts_with("client_max_window_bits") => {
if client_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter client_max_window_bits"
"Duplicate extension parameter: client_max_window_bits"
)));
} else {
client_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator {
decompress: Decompress::new_with_window_bits(
false, bits,
),
};
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
}
Ok(None) => {}
Err(e) => {
@ -623,10 +597,10 @@ impl WebSocketExtension for DeflateExt {
}
}
}
param => {
p => {
return Err(DeflateExtensionError::NegotiationError(format!(
"Unknown permessage-deflate parameter: {}",
param
p
)));
}
}
@ -666,61 +640,63 @@ impl WebSocketExtension for DeflateExt {
Ok(frame)
}
fn on_receive_frame(&mut self, mut frame: Frame) -> Result<Option<Message>, Self::Error> {
match frame.header().opcode {
OpCode::Control(_) => unreachable!(),
_ => {
if self.enabled && (!self.fragments.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragments.push(frame);
Ok(None)
} else {
let message = if let OpCode::Data(Data::Continue) = frame.header().opcode {
self.fragments.push(frame);
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragment_buffer
.try_push_frame(frame)
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?;
Ok(None)
} else {
let mut compressed = if self.fragment_buffer.is_empty() {
Vec::with_capacity(frame.payload().len())
} else {
Vec::with_capacity(self.fragment_buffer.len() + frame.payload().len())
};
let opcode = self.fragments.first().unwrap().header().opcode;
let size = self
.fragments
.iter()
.fold(0, |len, frame| len + frame.payload().len());
let mut compressed = Vec::with_capacity(size);
let mut decompressed = Vec::with_capacity(size * 2);
let mut decompressed = Vec::with_capacity(frame.payload().len() * 2);
replace(&mut self.fragments, Vec::with_capacity(10))
.into_iter()
.for_each(|f| {
compressed.extend(f.into_data());
});
let opcode = match frame.header().opcode {
OpCode::Data(Data::Continue) => {
self.fragment_buffer
.try_push_frame(frame)
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?;
compressed.extend(&[0, 0, 255, 255]);
let opcode = self.fragment_buffer.first().unwrap().header().opcode;
self.inflator.decompress(&compressed, &mut decompressed)?;
self.fragment_buffer.reset().into_iter().for_each(|f| {
compressed.extend(f.into_data());
});
self.complete_message(decompressed, opcode)
} else {
frame.payload_mut().extend(&[0, 0, 255, 255]);
let mut decompressed = Vec::with_capacity(frame.payload().len() * 2);
self.inflator
.decompress(frame.payload(), &mut decompressed)?;
opcode
}
_ => {
compressed.put_slice(frame.payload());
frame.header().opcode
}
};
self.complete_message(decompressed, frame.header().opcode)
};
compressed.extend(&[0, 0, 255, 255]);
if self.config.decompress_reset() {
self.inflator.reset(false);
}
self.inflator.decompress(&compressed, &mut decompressed)?;
match message {
Ok(message) => Ok(Some(message)),
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())),
}
}
} else {
self.uncompressed_extension
.on_receive_frame(frame)
.map_err(|e| DeflateExtensionError::DeflateError(e.to_string()))
if self.config.decompress_reset() {
self.inflator.reset(false);
}
self.uncompressed_extension.on_receive_frame(Frame::message(
decompressed,
opcode,
true,
))
}
} else {
self.uncompressed_extension.on_receive_frame(frame)
};
match r {
Ok(msg) => Ok(msg),
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())),
}
}
}
@ -743,17 +719,28 @@ struct Deflator {
}
impl Deflator {
pub fn new(compresion: Compression) -> Deflator {
fn new(compresion: Compression) -> Deflator {
Deflator {
compress: Compress::new(compresion, false),
}
}
fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
}
Deflator {
compress: Compress::new_with_window_bits(compression, false, window_size),
}
}
fn reset(&mut self) {
self.compress.reset()
}
pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), CompressError> {
fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), CompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
@ -767,9 +754,16 @@ impl Deflator {
let before_out = self.compress.total_out();
let before_in = self.compress.total_in();
let out_slice = unsafe {
slice::from_raw_parts_mut(
output.as_mut_ptr().offset(output_size as isize),
output.capacity() - output_size,
)
};
let status = self
.compress
.compress_vec(&read_buff, output, FlushCompress::Sync)?;
.compress(&read_buff, out_slice, FlushCompress::Sync)?;
let consumed = (self.compress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
@ -796,21 +790,28 @@ struct Inflator {
}
impl Inflator {
pub fn new() -> Inflator {
fn new() -> Inflator {
Inflator {
decompress: Decompress::new(false),
}
}
fn new_with_window_bits(mut window_size: u8) -> Inflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
}
Inflator {
decompress: Decompress::new_with_window_bits(false, window_size),
}
}
fn reset(&mut self, zlib_header: bool) {
self.decompress.reset(zlib_header)
}
pub fn decompress(
&mut self,
input: &[u8],
output: &mut Vec<u8>,
) -> Result<(), DecompressError> {
fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), DecompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
@ -853,3 +854,68 @@ impl Inflator {
}
}
}
/// A buffer for holding continuation frames. Ensures that the total length of all of the frame's
/// payloads does not exceed `max_len`.
///
/// Defaults to an initial capacity of ten frames.
#[derive(Debug)]
struct FragmentBuffer {
fragments: Vec<Frame>,
fragments_len: usize,
max_len: usize,
}
impl FragmentBuffer {
/// Creates a new fragment buffer that will permit a maximum length of `max_len`.
fn new(max_len: usize) -> FragmentBuffer {
FragmentBuffer {
fragments: Vec::with_capacity(10),
fragments_len: 0,
max_len,
}
}
/// Attempts to push a frame into the buffer. This will fail if the new length of the buffer's
/// frames exceeds the maximum capacity of `max_len`.
fn try_push_frame(&mut self, frame: Frame) -> Result<(), String> {
let FragmentBuffer {
fragments,
fragments_len,
max_len,
} = self;
*fragments_len += frame.payload().len();
if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len {
return Err(format!(
"Message too big: {} + {} > {}",
fragments_len, fragments_len, max_len
)
.into());
} else {
fragments.push(frame);
Ok(())
}
}
/// Returns the total length of all of the frames that have been pushed into the buffer.
fn len(&self) -> usize {
self.fragments_len
}
/// Returns whether the buffer is empty.
fn is_empty(&self) -> bool {
self.fragments.is_empty()
}
/// Returns the first element of the fragments slice, or `None` if it is empty.
fn first(&self) -> Option<&Frame> {
self.fragments.first()
}
/// Drains the buffer and resets it to an initial capacity of 10 elements.
fn reset(&mut self) -> Vec<Frame> {
replace(&mut self.fragments, Vec::with_capacity(10))
}
}

@ -11,8 +11,8 @@ pub mod deflate;
/// An uncompressed message handler for a WebSocket.
pub mod uncompressed;
/// A trait for defining WebSocket extensions. Extensions may be stacked by nesting them inside
/// one another.
/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions
/// may be stacked by nesting them inside one another.
pub trait WebSocketExtension {
/// An error type that the extension produces.
type Error: Into<crate::Error>;
@ -50,6 +50,7 @@ pub trait WebSocketExtension {
Ok(frame)
}
/// Called when a frame has been received.
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
/// type `OpCode::Data`.
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error>;
}

@ -601,7 +601,12 @@ where
}
}
// let frame = self.config.encoder.on_send_frame(frame)?;
if frame.header().is_final {
frame = match self.config.encoder.on_send_frame(frame) {
Ok(frame) => frame,
Err(e) => return Err(e.into()),
};
}
trace!("Sending frame: {:?}", frame);
self.frame

Loading…
Cancel
Save