From 15621a0b9f9a6334a046d2ea7a918d12ad2aa86a Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 23 Sep 2020 11:33:46 +0100 Subject: [PATCH] Refactors deflate extension --- .gitignore | 2 - examples/autobahn-client.rs | 12 +- src/extensions/deflate.rs | 314 ++++++++++++++++++++++-------------- src/extensions/mod.rs | 7 +- src/protocol/mod.rs | 7 +- 5 files changed, 207 insertions(+), 135 deletions(-) diff --git a/.gitignore b/.gitignore index cfcabc5..6416a18 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,3 @@ Cargo.lock autobahn/client/* autobahn/server/* - -.idea \ No newline at end of file diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 523056d..81643fc 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -2,11 +2,11 @@ use log::*; use url::Url; use tungstenite::client::connect_with_config; -use tungstenite::extensions::deflate::DeflateExt; +use tungstenite::extensions::deflate::{DeflateConfigBuilder, DeflateExt}; use tungstenite::protocol::WebSocketConfig; use tungstenite::{connect, Error, Message, Result}; -const AGENT: &str = "Tungstenite"; +const AGENT: &str = "Tungstenite-final-comp-slice"; fn get_case_count() -> Result { let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; @@ -34,12 +34,16 @@ fn run_test(case: u32) -> Result<()> { case, AGENT )) .unwrap(); + let deflate_config = DeflateConfigBuilder::default() + .max_message_size(None) + .build(); + let (mut socket, _) = connect_with_config( case_url, Some(WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: DeflateExt::default(), + encoder: DeflateExt::new(deflate_config), }), )?; @@ -54,8 +58,6 @@ fn run_test(case: u32) -> Result<()> { } fn main() { - println!("Starting"); - env_logger::init(); let total = get_case_count().unwrap(); diff --git a/src/extensions/deflate.rs b/src/extensions/deflate.rs index fe1f7d9..fd1d086 100644 --- a/src/extensions/deflate.rs +++ b/src/extensions/deflate.rs @@ -6,9 +6,9 @@ use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::Frame; -use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; use crate::protocol::MAX_MESSAGE_SIZE; -use crate::{Error, Message}; +use crate::Message; +use bytes::BufMut; use flate2::{ Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress, FlushDecompress, Status, @@ -23,7 +23,7 @@ use std::slice; const EXT_IDENT: &str = "permessage-deflate"; /// The minimum size of the LZ77 sliding window size. -const LZ77_MIN_WINDOW_SIZE: u8 = 9; +const LZ77_MIN_WINDOW_SIZE: u8 = 8; /// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter /// indicates that the client can receive messages compressed using an LZ77 sliding window of up to @@ -33,13 +33,12 @@ const LZ77_MAX_WINDOW_SIZE: u8 = 15; /// A permessage-deflate configuration. #[derive(Clone, Copy, Debug)] pub struct DeflateConfig { - /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB - /// which should be reasonably big for all normal use-cases but small enough to prevent - /// memory eating by a malicious user. - max_message_size: Option, + /// The maximum size of a message. The default value is 64 MiB which should be reasonably big + /// for all normal use-cases but small enough to prevent memory eating by a malicious user. + max_message_size: usize, /// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this /// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in - /// range 9..=15. + /// range 8..15 inclusive. max_window_bits: u8, /// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1. request_no_context_takeover: bool, @@ -65,7 +64,7 @@ impl DeflateConfig { } /// Returns the maximum message size permitted. - pub fn max_message_size(&self) -> Option { + pub fn max_message_size(&self) -> usize { self.max_message_size } @@ -101,7 +100,7 @@ impl DeflateConfig { /// Sets the maximum message size permitted. pub fn set_max_message_size(&mut self, max_message_size: Option) { - self.max_message_size = max_message_size; + self.max_message_size = max_message_size.unwrap_or_else(usize::max_value); } /// Sets the LZ77 sliding window size. @@ -124,7 +123,7 @@ impl DeflateConfig { impl Default for DeflateConfig { fn default() -> Self { DeflateConfig { - max_message_size: Some(MAX_MESSAGE_SIZE), + max_message_size: MAX_MESSAGE_SIZE, max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, accept_no_context_takeover: true, @@ -166,11 +165,11 @@ impl DeflateConfigBuilder { self } - /// Sets the LZ77 sliding window size. Panics if the provided size is not in `9..=15`. + /// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`. pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { assert!( (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), - "max window bits must be in range 9..=15" + "max window bits must be in range 8..=15" ); self.max_window_bits = max_window_bits; self @@ -197,7 +196,7 @@ impl DeflateConfigBuilder { /// Consumes the builder and produces a `DeflateConfig.` pub fn build(self) -> DeflateConfig { DeflateConfig { - max_message_size: self.max_message_size, + max_message_size: self.max_message_size.unwrap_or_else(usize::max_value), max_window_bits: self.max_window_bits, request_no_context_takeover: self.request_no_context_takeover, accept_no_context_takeover: self.accept_no_context_takeover, @@ -215,8 +214,8 @@ pub struct DeflateExt { enabled: bool, /// The configuration for the extension. config: DeflateConfig, - /// A stack of continuation frames awaiting `fin`. - fragments: Vec, + /// A stack of continuation frames awaiting `fin` and the total size of all of the fragments. + fragment_buffer: FragmentBuffer, /// The deflate decompressor. inflator: Inflator, /// The deflate compressor. @@ -231,38 +230,23 @@ impl DeflateExt { DeflateExt { enabled: false, config, - fragments: vec![], + fragment_buffer: FragmentBuffer::new(config.max_message_size), inflator: Inflator::new(), deflator: Deflator::new(Compression::fast()), - uncompressed_extension: UncompressedExt::new(config.max_message_size()), + uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), } } - fn complete_message(&self, data: Vec, opcode: OpCode) -> Result { - let message_type = match opcode { - OpCode::Data(Data::Text) => IncompleteMessageType::Text, - OpCode::Data(Data::Binary) => IncompleteMessageType::Binary, - _ => panic!("Bug: message is not text nor binary"), - }; - - let mut incomplete_message = IncompleteMessage::new(message_type); - incomplete_message.extend(data, self.config.max_message_size())?; - incomplete_message.complete() - } - fn parse_window_parameter<'a>( - &self, + &mut self, mut param_iter: impl Iterator, ) -> Result, String> { if let Some(window_bits_str) = param_iter.next() { match window_bits_str.trim().parse() { - Ok(mut window_bits) => { - if window_bits == 8 { - window_bits = LZ77_MIN_WINDOW_SIZE; - } - + Ok(window_bits) => { if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { if window_bits != self.config.max_window_bits() { + self.config.max_window_bits = window_bits; Ok(Some(window_bits)) } else { Ok(None) @@ -293,6 +277,8 @@ pub enum DeflateExtensionError { InflateError(String), /// An error produced during the WebSocket negotiation. NegotiationError(String), + /// Produced when fragment buffer grew beyond the maximum configured size. + Capacity(Cow<'static, str>), } impl Display for DeflateExtensionError { @@ -307,6 +293,7 @@ impl Display for DeflateExtensionError { DeflateExtensionError::NegotiationError(m) => { write!(f, "An upgrade error was encountered: {}", m) } + DeflateExtensionError::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), } } } @@ -336,7 +323,7 @@ impl WebSocketExtension for DeflateExt { fn new(max_message_size: Option) -> Self { DeflateExt::new(DeflateConfig { - max_message_size, + max_message_size: max_message_size.unwrap_or_else(usize::max_value), ..Default::default() }) } @@ -389,7 +376,7 @@ impl WebSocketExtension for DeflateExt { let mut client_max_bits = false; for param in header.split(';') { - match param.trim() { + match param.trim().to_lowercase().as_str() { "permessage-deflate" => response_str.push_str("permessage-deflate"), "server_no_context_takeover" => { if server_takeover { @@ -419,13 +406,10 @@ impl WebSocketExtension for DeflateExt { match self.parse_window_parameter(param.split('=').skip(1)) { Ok(Some(bits)) => { - self.deflator = Deflator { - compress: Compress::new_with_window_bits( - self.config.compression_level(), - false, - bits, - ), - }; + self.deflator = Deflator::new_with_window_bits( + self.config.compression_level, + bits, + ); response_str.push_str("; "); response_str.push_str(param) } @@ -444,11 +428,8 @@ impl WebSocketExtension for DeflateExt { match self.parse_window_parameter(param.split('=').skip(1)) { Ok(Some(bits)) => { - self.inflator = Inflator { - decompress: Decompress::new_with_window_bits( - false, bits, - ), - }; + self.inflator = Inflator::new_with_window_bits(bits); + response_str.push_str("; "); response_str.push_str(param); continue; @@ -527,11 +508,11 @@ impl WebSocketExtension for DeflateExt { match header.to_str() { Ok(header) => { for param in header.split(';') { - match param.trim() { + match param.trim().to_lowercase().as_str() { "permessage-deflate" => { if extension_name { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter permessage-deflate" + "Duplicate extension parameter: permessage-deflate" ))); } else { self.enabled = true; @@ -541,7 +522,7 @@ impl WebSocketExtension for DeflateExt { "server_no_context_takeover" => { if server_takeover { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter server_no_context_takeover" + "Duplicate extension parameter: server_no_context_takeover" ))); } else { server_takeover = true; @@ -551,7 +532,7 @@ impl WebSocketExtension for DeflateExt { "client_no_context_takeover" => { if client_takeover { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter client_no_context_takeover" + "Duplicate extension parameter: client_no_context_takeover" ))); } else { client_takeover = true; @@ -568,20 +549,14 @@ impl WebSocketExtension for DeflateExt { param if param.starts_with("server_max_window_bits") => { if server_max_window_bits { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter server_max_window_bits" + "Duplicate extension parameter: server_max_window_bits" ))); } else { server_max_window_bits = true; match self.parse_window_parameter(param.split("=").skip(1)) { Ok(Some(bits)) => { - self.deflator = Deflator { - compress: Compress::new_with_window_bits( - self.config.compression_level(), - false, - bits, - ), - }; + self.inflator = Inflator::new_with_window_bits(bits); } Ok(None) => {} Err(e) => { @@ -598,18 +573,17 @@ impl WebSocketExtension for DeflateExt { param if param.starts_with("client_max_window_bits") => { if client_max_window_bits { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter client_max_window_bits" + "Duplicate extension parameter: client_max_window_bits" ))); } else { client_max_window_bits = true; match self.parse_window_parameter(param.split("=").skip(1)) { Ok(Some(bits)) => { - self.inflator = Inflator { - decompress: Decompress::new_with_window_bits( - false, bits, - ), - }; + self.deflator = Deflator::new_with_window_bits( + self.config.compression_level, + bits, + ); } Ok(None) => {} Err(e) => { @@ -623,10 +597,10 @@ impl WebSocketExtension for DeflateExt { } } } - param => { + p => { return Err(DeflateExtensionError::NegotiationError(format!( "Unknown permessage-deflate parameter: {}", - param + p ))); } } @@ -666,61 +640,63 @@ impl WebSocketExtension for DeflateExt { Ok(frame) } - fn on_receive_frame(&mut self, mut frame: Frame) -> Result, Self::Error> { - match frame.header().opcode { - OpCode::Control(_) => unreachable!(), - _ => { - if self.enabled && (!self.fragments.is_empty() || frame.header().rsv1) { - if !frame.header().is_final { - self.fragments.push(frame); - Ok(None) - } else { - let message = if let OpCode::Data(Data::Continue) = frame.header().opcode { - self.fragments.push(frame); + fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error> { + let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) { + if !frame.header().is_final { + self.fragment_buffer + .try_push_frame(frame) + .map_err(|s| DeflateExtensionError::Capacity(s.into()))?; + Ok(None) + } else { + let mut compressed = if self.fragment_buffer.is_empty() { + Vec::with_capacity(frame.payload().len()) + } else { + Vec::with_capacity(self.fragment_buffer.len() + frame.payload().len()) + }; - let opcode = self.fragments.first().unwrap().header().opcode; - let size = self - .fragments - .iter() - .fold(0, |len, frame| len + frame.payload().len()); - let mut compressed = Vec::with_capacity(size); - let mut decompressed = Vec::with_capacity(size * 2); + let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); - replace(&mut self.fragments, Vec::with_capacity(10)) - .into_iter() - .for_each(|f| { - compressed.extend(f.into_data()); - }); + let opcode = match frame.header().opcode { + OpCode::Data(Data::Continue) => { + self.fragment_buffer + .try_push_frame(frame) + .map_err(|s| DeflateExtensionError::Capacity(s.into()))?; - compressed.extend(&[0, 0, 255, 255]); + let opcode = self.fragment_buffer.first().unwrap().header().opcode; - self.inflator.decompress(&compressed, &mut decompressed)?; + self.fragment_buffer.reset().into_iter().for_each(|f| { + compressed.extend(f.into_data()); + }); - self.complete_message(decompressed, opcode) - } else { - frame.payload_mut().extend(&[0, 0, 255, 255]); - let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); - self.inflator - .decompress(frame.payload(), &mut decompressed)?; + opcode + } + _ => { + compressed.put_slice(frame.payload()); + frame.header().opcode + } + }; - self.complete_message(decompressed, frame.header().opcode) - }; + compressed.extend(&[0, 0, 255, 255]); - if self.config.decompress_reset() { - self.inflator.reset(false); - } + self.inflator.decompress(&compressed, &mut decompressed)?; - match message { - Ok(message) => Ok(Some(message)), - Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())), - } - } - } else { - self.uncompressed_extension - .on_receive_frame(frame) - .map_err(|e| DeflateExtensionError::DeflateError(e.to_string())) + if self.config.decompress_reset() { + self.inflator.reset(false); } + + self.uncompressed_extension.on_receive_frame(Frame::message( + decompressed, + opcode, + true, + )) } + } else { + self.uncompressed_extension.on_receive_frame(frame) + }; + + match r { + Ok(msg) => Ok(msg), + Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())), } } } @@ -743,17 +719,28 @@ struct Deflator { } impl Deflator { - pub fn new(compresion: Compression) -> Deflator { + fn new(compresion: Compression) -> Deflator { Deflator { compress: Compress::new(compresion, false), } } + fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator { + // https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303 + if window_size == 8 { + window_size = 9; + } + + Deflator { + compress: Compress::new_with_window_bits(compression, false, window_size), + } + } + fn reset(&mut self) { self.compress.reset() } - pub fn compress(&mut self, input: &[u8], output: &mut Vec) -> Result<(), CompressError> { + fn compress(&mut self, input: &[u8], output: &mut Vec) -> Result<(), CompressError> { let mut read_buff = Vec::from(input); let mut output_size; @@ -767,9 +754,16 @@ impl Deflator { let before_out = self.compress.total_out(); let before_in = self.compress.total_in(); + let out_slice = unsafe { + slice::from_raw_parts_mut( + output.as_mut_ptr().offset(output_size as isize), + output.capacity() - output_size, + ) + }; + let status = self .compress - .compress_vec(&read_buff, output, FlushCompress::Sync)?; + .compress(&read_buff, out_slice, FlushCompress::Sync)?; let consumed = (self.compress.total_in() - before_in) as usize; read_buff = read_buff.split_off(consumed); @@ -796,21 +790,28 @@ struct Inflator { } impl Inflator { - pub fn new() -> Inflator { + fn new() -> Inflator { Inflator { decompress: Decompress::new(false), } } + fn new_with_window_bits(mut window_size: u8) -> Inflator { + // https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303 + if window_size == 8 { + window_size = 9; + } + + Inflator { + decompress: Decompress::new_with_window_bits(false, window_size), + } + } + fn reset(&mut self, zlib_header: bool) { self.decompress.reset(zlib_header) } - pub fn decompress( - &mut self, - input: &[u8], - output: &mut Vec, - ) -> Result<(), DecompressError> { + fn decompress(&mut self, input: &[u8], output: &mut Vec) -> Result<(), DecompressError> { let mut read_buff = Vec::from(input); let mut output_size; @@ -853,3 +854,68 @@ impl Inflator { } } } + +/// A buffer for holding continuation frames. Ensures that the total length of all of the frame's +/// payloads does not exceed `max_len`. +/// +/// Defaults to an initial capacity of ten frames. +#[derive(Debug)] +struct FragmentBuffer { + fragments: Vec, + fragments_len: usize, + max_len: usize, +} + +impl FragmentBuffer { + /// Creates a new fragment buffer that will permit a maximum length of `max_len`. + fn new(max_len: usize) -> FragmentBuffer { + FragmentBuffer { + fragments: Vec::with_capacity(10), + fragments_len: 0, + max_len, + } + } + + /// Attempts to push a frame into the buffer. This will fail if the new length of the buffer's + /// frames exceeds the maximum capacity of `max_len`. + fn try_push_frame(&mut self, frame: Frame) -> Result<(), String> { + let FragmentBuffer { + fragments, + fragments_len, + max_len, + } = self; + + *fragments_len += frame.payload().len(); + + if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len { + return Err(format!( + "Message too big: {} + {} > {}", + fragments_len, fragments_len, max_len + ) + .into()); + } else { + fragments.push(frame); + Ok(()) + } + } + + /// Returns the total length of all of the frames that have been pushed into the buffer. + fn len(&self) -> usize { + self.fragments_len + } + + /// Returns whether the buffer is empty. + fn is_empty(&self) -> bool { + self.fragments.is_empty() + } + + /// Returns the first element of the fragments slice, or `None` if it is empty. + fn first(&self) -> Option<&Frame> { + self.fragments.first() + } + + /// Drains the buffer and resets it to an initial capacity of 10 elements. + fn reset(&mut self) -> Vec { + replace(&mut self.fragments, Vec::with_capacity(10)) + } +} diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 0e56bea..6a9cee1 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -11,8 +11,8 @@ pub mod deflate; /// An uncompressed message handler for a WebSocket. pub mod uncompressed; -/// A trait for defining WebSocket extensions. Extensions may be stacked by nesting them inside -/// one another. +/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions +/// may be stacked by nesting them inside one another. pub trait WebSocketExtension { /// An error type that the extension produces. type Error: Into; @@ -50,6 +50,7 @@ pub trait WebSocketExtension { Ok(frame) } - /// Called when a frame has been received. + /// Called when a frame has been received and unmasked. The frame provided frame will be of the + /// type `OpCode::Data`. fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error>; } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 28c3a43..bb695e5 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -601,7 +601,12 @@ where } } - // let frame = self.config.encoder.on_send_frame(frame)?; + if frame.header().is_final { + frame = match self.config.encoder.on_send_frame(frame) { + Ok(frame) => frame, + Err(e) => return Err(e.into()), + }; + } trace!("Sending frame: {:?}", frame); self.frame