From 0c8ae53633866b51f6efb8f01377d0bee8df28d8 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 16 Sep 2020 17:36:09 +0100 Subject: [PATCH] Tidy up --- examples/autobahn-client.rs | 1 - examples/autobahn-server.rs | 1 - src/extensions/deflate.rs | 469 +++++++++++++++++++++------------ src/extensions/mod.rs | 19 +- src/extensions/uncompressed.rs | 22 +- src/lib.rs | 20 +- src/protocol/mod.rs | 22 +- 7 files changed, 344 insertions(+), 210 deletions(-) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 61fdf72..1b4ee94 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -42,7 +42,6 @@ fn run_test(case: u32) -> Result<()> { case_url, Some(WebSocketConfig { max_send_queue: None, - max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), encoder: DeflateExt::default(), }), diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 362724b..6e88f83 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -20,7 +20,6 @@ fn handle_client(stream: TcpStream) -> Result<()> { stream, Some(WebSocketConfig { max_send_queue: None, - max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), encoder: DeflateExt::default(), }), diff --git a/src/extensions/deflate.rs b/src/extensions/deflate.rs index 40a4fca..b172e78 100644 --- a/src/extensions/deflate.rs +++ b/src/extensions/deflate.rs @@ -18,12 +18,195 @@ use http::{HeaderValue, Request, Response}; use std::mem::replace; use std::slice; +const EXT_NAME: &str = "permessage-deflate"; + +/// A permessage-deflate configuration. +#[derive(Clone, Copy, Debug)] +pub struct DeflateConfig { + /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB + /// which should be reasonably big for all normal use-cases but small enough to prevent + /// memory eating by a malicious user. + max_message_size: Option, + /// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this + /// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in + /// range 9..=15. + max_window_bits: u8, + /// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1. + request_no_context_takeover: bool, + accept_no_context_takeover: bool, + compress_reset: bool, + decompress_reset: bool, + compression_level: Compression, +} + +impl DeflateConfig { + /// Builds a new `DeflateConfig` using the `compression_level` and the defaults for all other + /// members. + pub fn with_compression_level(compression_level: Compression) -> DeflateConfig { + DeflateConfig { + compression_level, + ..Default::default() + } + } + + /// Returns the maximum message size permitted. + pub fn max_message_size(&self) -> Option { + self.max_message_size + } + + /// Returns the maximum LZ77 window size permitted. + pub fn max_window_bits(&self) -> u8 { + self.max_window_bits + } + + /// Returns whether `no_context_takeover` has been requested. + pub fn request_no_context_takeover(&self) -> bool { + self.request_no_context_takeover + } + + /// Returns whether this WebSocket will accept `no_context_takeover`. + pub fn accept_no_context_takeover(&self) -> bool { + self.accept_no_context_takeover + } + + /// Returns whether or not the inner compressor is set to reset after completing a message. + pub fn compress_reset(&self) -> bool { + self.compress_reset + } + + /// Returns whether or not the inner decompressor is set to reset after completing a message. + pub fn decompress_reset(&self) -> bool { + self.decompress_reset + } + + /// Returns the active compression level. + pub fn compression_level(&self) -> Compression { + self.compression_level + } + + /// Sets the maximum message size permitted. + pub fn set_max_message_size(&mut self, max_message_size: Option) { + self.max_message_size = max_message_size; + } + + /// Sets the LZ77 sliding window size. + pub fn set_max_window_bits(&mut self, max_window_bits: u8) { + assert!((9u8..=15u8).contains(&max_window_bits)); + self.max_window_bits = max_window_bits; + } + + /// Sets the WebSocket to request `no_context_takeover` if `true`. + pub fn set_request_no_context_takeover(&mut self, request_no_context_takeover: bool) { + self.request_no_context_takeover = request_no_context_takeover; + } + + /// Sets the WebSocket to accept `no_context_takeover` if `true`. + pub fn set_accept_no_context_takeover(&mut self, accept_no_context_takeover: bool) { + self.accept_no_context_takeover = accept_no_context_takeover; + } +} + +impl Default for DeflateConfig { + fn default() -> Self { + DeflateConfig { + max_message_size: Some(MAX_MESSAGE_SIZE), + max_window_bits: 15, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Compression::best(), + } + } +} + +/// A `DeflateConfig` builder. +#[derive(Debug, Copy, Clone)] +pub struct DeflateConfigBuilder { + max_message_size: Option, + max_window_bits: u8, + request_no_context_takeover: bool, + accept_no_context_takeover: bool, + fragments_grow: bool, + compression_level: Compression, +} + +impl Default for DeflateConfigBuilder { + fn default() -> Self { + DeflateConfigBuilder { + max_message_size: Some(MAX_MESSAGE_SIZE), + max_window_bits: 15, + request_no_context_takeover: false, + accept_no_context_takeover: true, + fragments_grow: true, + compression_level: Compression::fast(), + } + } +} + +impl DeflateConfigBuilder { + /// Sets the maximum message size permitted. + pub fn max_message_size(mut self, max_message_size: Option) -> DeflateConfigBuilder { + self.max_message_size = max_message_size; + self + } + + /// Sets the LZ77 sliding window size. Panics if the provided size is not in `9..=15`. + pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { + assert!( + (9u8..=15u8).contains(&max_window_bits), + "max window bits must be in range 9..=15" + ); + self.max_window_bits = max_window_bits; + self + } + + /// Sets the WebSocket to request `no_context_takeover`. + pub fn request_no_context_takeover( + mut self, + request_no_context_takeover: bool, + ) -> DeflateConfigBuilder { + self.request_no_context_takeover = request_no_context_takeover; + self + } + + /// Sets the WebSocket to accept `no_context_takeover`. + pub fn accept_no_context_takeover( + mut self, + accept_no_context_takeover: bool, + ) -> DeflateConfigBuilder { + self.accept_no_context_takeover = accept_no_context_takeover; + self + } + + /// Consumes the builder and produces a `DeflateConfig.` + pub fn build(self) -> DeflateConfig { + DeflateConfig { + max_message_size: self.max_message_size, + max_window_bits: self.max_window_bits, + request_no_context_takeover: self.request_no_context_takeover, + accept_no_context_takeover: self.accept_no_context_takeover, + compression_level: self.compression_level, + ..Default::default() + } + } +} + +/// A permessage-deflate encoding WebSocket extension. +#[derive(Debug)] pub struct DeflateExt { + /// Defines whether the extension is enabled. Following a successful handshake, this will be + /// `true`. enabled: bool, + /// The configuration for the extension. config: DeflateConfig, + /// A stack of continuation frames awaiting `fin`. fragments: Vec, + /// The deflate decompressor. inflator: Inflator, + /// The deflate compressor. deflator: Deflator, + /// If this deflate extension is not used, messages will be forwarded to this extension. uncompressed_extension: PlainTextExt, } @@ -34,8 +217,8 @@ impl Clone for DeflateExt { config: self.config, fragments: vec![], inflator: Inflator::new(), - deflator: Deflator::new(self.config.compression_level), - uncompressed_extension: PlainTextExt::new(self.config.max_message_size), + deflator: Deflator::new(self.config.compression_level()), + uncompressed_extension: PlainTextExt::new(self.config.max_message_size()), } } } @@ -47,6 +230,7 @@ impl Default for DeflateExt { } impl DeflateExt { + /// Creates a `DeflateExt` instance using the provided configuration. pub fn new(config: DeflateConfig) -> DeflateExt { DeflateExt { enabled: false, @@ -54,7 +238,7 @@ impl DeflateExt { fragments: vec![], inflator: Inflator::new(), deflator: Deflator::new(Compression::fast()), - uncompressed_extension: PlainTextExt::new(config.max_message_size), + uncompressed_extension: PlainTextExt::new(config.max_message_size()), } } @@ -66,7 +250,7 @@ impl DeflateExt { }; let mut incomplete_message = IncompleteMessage::new(message_type); - incomplete_message.extend(data, self.config.max_message_size)?; + incomplete_message.extend(data, self.config.max_message_size())?; incomplete_message.complete() } @@ -82,16 +266,13 @@ impl DeflateExt { } if window_bits >= 9 && window_bits <= 15 { - if window_bits as u8 != self.config.max_window_bits { + if window_bits != self.config.max_window_bits() { Ok(Some(window_bits)) } else { Ok(None) } } else { - Err(format!( - "Invalid server_max_window_bits parameter: {}", - window_bits - )) + Err(format!("Invalid window parameter: {}", window_bits)) } } Err(e) => Err(e.to_string()), @@ -107,57 +288,29 @@ impl DeflateExt { } } -#[derive(Clone, Copy, Debug)] -pub struct DeflateConfig { - pub max_message_size: Option, - pub max_window_bits: u8, - pub request_no_context_takeover: bool, - pub accept_no_context_takeover: bool, - pub fragments_capacity: usize, - pub fragments_grow: bool, - pub compress_reset: bool, - pub decompress_reset: bool, - pub compression_level: Compression, -} - -impl DeflateConfig { - pub fn with_compression_level(compression_level: Compression) -> DeflateConfig { - DeflateConfig { - compression_level, - ..Default::default() - } - } -} - -impl Default for DeflateConfig { - fn default() -> Self { - DeflateConfig { - max_message_size: Some(MAX_MESSAGE_SIZE), - max_window_bits: 15, - request_no_context_takeover: false, - accept_no_context_takeover: true, - fragments_capacity: 10, - fragments_grow: true, - compress_reset: false, - decompress_reset: false, - compression_level: Compression::best(), - } - } -} - +/// A permessage-deflate extension error. #[derive(Debug, Clone)] pub enum DeflateExtensionError { + /// An error produced when deflating a message. DeflateError(String), + /// An error produced when inflating a message. InflateError(String), + /// An error produced during the WebSocket negotiation. NegotiationError(String), } impl Display for DeflateExtensionError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - DeflateExtensionError::DeflateError(m) => write!(f, "{}", m), - DeflateExtensionError::InflateError(m) => write!(f, "{}", m), - DeflateExtensionError::NegotiationError(m) => write!(f, "{}", m), + DeflateExtensionError::DeflateError(m) => { + write!(f, "An error was produced during decompression: {}", m) + } + DeflateExtensionError::InflateError(m) => { + write!(f, "An error was produced during compression: {}", m) + } + DeflateExtensionError::NegotiationError(m) => { + write!(f, "An upgrade error was encountered: {}", m) + } } } } @@ -176,21 +329,18 @@ impl From for DeflateExtensionError { } } -const EXT_NAME: &str = "permessage-deflate"; - impl WebSocketExtension for DeflateExt { type Error = DeflateExtensionError; - fn enabled(&self) -> bool { - self.enabled + fn new(max_message_size: Option) -> Self { + DeflateExt::new(DeflateConfig { + max_message_size, + ..Default::default() + }) } - fn rsv1(&self) -> bool { - if self.enabled { - true - } else { - self.uncompressed_extension.rsv1() - } + fn enabled(&self) -> bool { + self.enabled } fn on_make_request(&mut self, mut request: Request) -> Request { @@ -228,147 +378,136 @@ impl WebSocketExtension for DeflateExt { response: &mut Response, ) -> Result<(), Self::Error> { for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { - match header.to_str() { + return match header.to_str() { Ok(header) => { - let mut res_ext = String::with_capacity(header.len()); - let mut s_takeover = false; - let mut c_takeover = false; - let mut s_max = false; - let mut c_max = false; + let mut response_str = String::with_capacity(header.len()); + let mut server_takeover = false; + let mut client_takeover = false; + let mut server_max_bits = false; + let mut client_max_bits = false; for param in header.split(';') { match param.trim() { - "permessage-deflate" => res_ext.push_str("permessage-deflate"), + "permessage-deflate" => response_str.push_str("permessage-deflate"), "server_no_context_takeover" => { - if s_takeover { + if server_takeover { self.decline(response); } else { - s_takeover = true; - if self.config.accept_no_context_takeover { + server_takeover = true; + if self.config.accept_no_context_takeover() { self.config.compress_reset = true; - res_ext.push_str("; server_no_context_takeover"); + response_str.push_str("; server_no_context_takeover"); } } } "client_no_context_takeover" => { - if c_takeover { + if client_takeover { self.decline(response); } else { - c_takeover = true; + client_takeover = true; self.config.decompress_reset = true; - res_ext.push_str("; client_no_context_takeover"); + response_str.push_str("; client_no_context_takeover"); } } param if param.starts_with("server_max_window_bits") => { - if s_max { + if server_max_bits { self.decline(response); } else { - s_max = true; - let mut param_iter = param.split('='); - param_iter.next(); // we already know the name - if let Some(window_bits_str) = param_iter.next() { - if let Ok(window_bits) = window_bits_str.trim().parse() { - if window_bits >= 9 && window_bits <= 15 { - if window_bits < self.config.max_window_bits { - self.deflator = Deflator { - compress: Compress::new_with_window_bits( - self.config.compression_level, - false, - window_bits, - ), - }; - res_ext.push_str("; "); - res_ext.push_str(param) - } - } else { - self.decline(response); - } - } else { + server_max_bits = true; + + match self.parse_window_parameter(param.split('=').skip(1)) { + Ok(Some(bits)) => { + self.deflator = Deflator { + compress: Compress::new_with_window_bits( + self.config.compression_level(), + false, + bits, + ), + }; + response_str.push_str("; "); + response_str.push_str(param) + } + Ok(None) => {} + Err(_) => { self.decline(response); } } } } param if param.starts_with("client_max_window_bits") => { - if c_max { + if client_max_bits { self.decline(response); } else { - c_max = true; - let mut param_iter = param.split('='); - param_iter.next(); // we already know the name - if let Some(window_bits_str) = param_iter.next() { - if let Ok(window_bits) = window_bits_str.trim().parse() { - if window_bits >= 9 && window_bits <= 15 { - if window_bits < self.config.max_window_bits { - self.inflator = Inflator { - decompress: - Decompress::new_with_window_bits( - false, - window_bits, - ), - }; - res_ext.push_str("; "); - res_ext.push_str(param); - continue; - } - } else { - self.decline(response); - } - } else { + client_max_bits = true; + + match self.parse_window_parameter(param.split('=').skip(1)) { + Ok(Some(bits)) => { + self.inflator = Inflator { + decompress: Decompress::new_with_window_bits( + false, bits, + ), + }; + response_str.push_str("; "); + response_str.push_str(param); + continue; + } + Ok(None) => {} + Err(_) => { self.decline(response); } } - res_ext.push_str("; "); - res_ext.push_str(&format!( + + response_str.push_str("; "); + response_str.push_str(&format!( "client_max_window_bits={}", - self.config.max_window_bits + self.config.max_window_bits() )) } } _ => { - // decline all extension offers because we got a bad parameter self.decline(response); } } } - if !res_ext.contains("client_no_context_takeover") - && self.config.request_no_context_takeover + if !response_str.contains("client_no_context_takeover") + && self.config.request_no_context_takeover() { self.config.decompress_reset = true; - res_ext.push_str("; client_no_context_takeover"); + response_str.push_str("; client_no_context_takeover"); } - if !res_ext.contains("server_max_window_bits") { - res_ext.push_str("; "); - res_ext.push_str(&format!( + if !response_str.contains("server_max_window_bits") { + response_str.push_str("; "); + response_str.push_str(&format!( "server_max_window_bits={}", - self.config.max_window_bits + self.config.max_window_bits() )) } - if !res_ext.contains("client_max_window_bits") - && self.config.max_window_bits < 15 + if !response_str.contains("client_max_window_bits") + && self.config.max_window_bits() < 15 { continue; } - response - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_str(&res_ext)?); + response.headers_mut().insert( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_str)?, + ); self.enabled = true; - return Ok(()); + Ok(()) } Err(e) => { self.enabled = false; - return Err(DeflateExtensionError::NegotiationError(format!( - "Failed to parse header: {}", + Err(DeflateExtensionError::NegotiationError(format!( + "Failed to parse request header: {}", e, - ))); + ))) } - } + }; } self.decline(response); @@ -382,7 +521,7 @@ impl WebSocketExtension for DeflateExt { let mut server_max_window_bits = false; let mut client_max_window_bits = false; - for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { + for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() { match header.to_str() { Ok(header) => { for param in header.split(';') { @@ -390,7 +529,7 @@ impl WebSocketExtension for DeflateExt { "permessage-deflate" => { if extension_name { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension name permessage-deflate" + "Duplicate extension parameter permessage-deflate" ))); } else { self.enabled = true; @@ -415,7 +554,7 @@ impl WebSocketExtension for DeflateExt { } else { client_takeover = true; - if self.config.accept_no_context_takeover { + if self.config.accept_no_context_takeover() { self.config.compress_reset = true; } else { return Err(DeflateExtensionError::NegotiationError( @@ -436,7 +575,7 @@ impl WebSocketExtension for DeflateExt { Ok(Some(bits)) => { self.deflator = Deflator { compress: Compress::new_with_window_bits( - self.config.compression_level, + self.config.compression_level(), false, bits, ), @@ -493,7 +632,6 @@ impl WebSocketExtension for DeflateExt { } Err(e) => { self.enabled = false; - return Err(DeflateExtensionError::NegotiationError(format!( "Failed to parse extension parameter: {}", e @@ -517,7 +655,7 @@ impl WebSocketExtension for DeflateExt { *frame.payload_mut() = compressed; frame.header_mut().rsv1 = true; - if self.config.compress_reset { + if self.config.compress_reset() { self.deflator.reset(); } } @@ -536,15 +674,7 @@ impl WebSocketExtension for DeflateExt { return Ok(None); } else { let message = if let OpCode::Data(Data::Continue) = frame.header().opcode { - if !self.config.fragments_grow - && self.config.fragments_capacity == self.fragments.len() - { - return Err(DeflateExtensionError::DeflateError( - "Exceeded max fragments.".into(), - )); - } else { - self.fragments.push(frame); - } + self.fragments.push(frame); let opcode = self.fragments.first().unwrap().header().opcode; let size = self @@ -554,14 +684,11 @@ impl WebSocketExtension for DeflateExt { let mut compressed = Vec::with_capacity(size); let mut decompressed = Vec::with_capacity(size * 2); - replace( - &mut self.fragments, - Vec::with_capacity(self.config.fragments_capacity), - ) - .into_iter() - .for_each(|f| { - compressed.extend(f.into_data()); - }); + replace(&mut self.fragments, Vec::with_capacity(10)) + .into_iter() + .for_each(|f| { + compressed.extend(f.into_data()); + }); compressed.extend(&[0, 0, 255, 255]); @@ -570,16 +697,14 @@ impl WebSocketExtension for DeflateExt { self.complete_message(decompressed, opcode) } else { frame.payload_mut().extend(&[0, 0, 255, 255]); - - let mut decompress_output = - Vec::with_capacity(frame.payload().len() * 2); + let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); self.inflator - .decompress(frame.payload(), &mut decompress_output)?; + .decompress(frame.payload(), &mut decompressed)?; - self.complete_message(decompress_output, frame.header().opcode) + self.complete_message(decompressed, frame.header().opcode) }; - if self.config.decompress_reset { + if self.config.decompress_reset() { self.inflator.reset(false); } @@ -610,6 +735,7 @@ impl From for DeflateExtensionError { } } +#[derive(Debug)] struct Deflator { compress: Compress, } @@ -656,12 +782,13 @@ impl Deflator { return Ok(()); } } - s => panic!(s), + s => panic!("Compression error: {:?}", s), } } } } +#[derive(Debug)] struct Inflator { decompress: Decompress, } @@ -719,7 +846,7 @@ impl Inflator { return Ok(()); } } - s => panic!(s), + s => panic!("Decompression error: {:?}", s), } } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 0017d29..d8207de 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -5,25 +5,32 @@ use http::{Request, Response}; use crate::protocol::frame::Frame; use crate::Message; +/// A permessage-deflate WebSocket extension (RFC 7692). #[cfg(feature = "deflate")] pub mod deflate; +/// An uncompressed message handler for a WebSocket. pub mod uncompressed; +/// A trait for defining WebSocket extensions. Extensions may be stacked by nesting them inside +/// one another. pub trait WebSocketExtension: Default + Clone { + /// An error type that the extension produces. type Error: Into; - fn enabled(&self) -> bool { - false - } + /// Constructs a new WebSocket extension that will permit messages of the provided size. + fn new(max_message_size: Option) -> Self; - fn rsv1(&self) -> bool { + /// Returns whether or not the extension is enabled. + fn enabled(&self) -> bool { false } + /// For WebSocket clients, this will be called when a `Request` is being constructed. fn on_make_request(&mut self, request: Request) -> Request { request } + /// For WebSocket server, this will be called when a `Request` has been received. fn on_receive_request( &mut self, _request: &Request, @@ -32,13 +39,17 @@ pub trait WebSocketExtension: Default + Clone { Ok(()) } + /// For WebSocket clients, this will be called when a response from the server has been + /// received. If an error is produced, then subsequent calls to `rsv1()` should return `false`. fn on_response(&mut self, _response: &Response) -> Result<(), Self::Error> { Ok(()) } + /// Called when a frame is about to be sent. fn on_send_frame(&mut self, frame: Frame) -> Result { Ok(frame) } + /// Called when a frame has been received. fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error>; } diff --git a/src/extensions/uncompressed.rs b/src/extensions/uncompressed.rs index d998a14..c4f4643 100644 --- a/src/extensions/uncompressed.rs +++ b/src/extensions/uncompressed.rs @@ -5,6 +5,7 @@ use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; use crate::protocol::MAX_MESSAGE_SIZE; use crate::{Error, Message}; +/// An uncompressed message handler for a WebSocket. #[derive(Debug)] pub struct PlainTextExt { incomplete: Option, @@ -12,6 +13,8 @@ pub struct PlainTextExt { } impl PlainTextExt { + /// Builds a new `PlainTextExt` that will permit a maximum message size of `max_message_size` + /// or will be unbounded if `None`. pub fn new(max_message_size: Option) -> PlainTextExt { PlainTextExt { incomplete: None, @@ -38,17 +41,28 @@ impl Default for PlainTextExt { impl WebSocketExtension for PlainTextExt { type Error = Error; - fn enabled(&self) -> bool { - true + fn new(max_message_size: Option) -> Self { + PlainTextExt { + incomplete: None, + max_message_size, + } } - fn rsv1(&self) -> bool { - false + fn enabled(&self) -> bool { + true } fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error> { let fin = frame.header().is_final; + let hdr = frame.header(); + + if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { + return Err(Error::Protocol( + "Reserved bits are non-zero and no WebSocket extensions are enabled".into(), + )); + } + match frame.header().opcode { OpCode::Data(data) => match data { Data::Continue => { diff --git a/src/lib.rs b/src/lib.rs index e2ee1e6..36b947d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,15 @@ //! Lightweight, flexible WebSockets for Rust. #![deny( - // missing_docs, - // missing_copy_implementations, - // missing_debug_implementations, - // trivial_casts, - // trivial_numeric_casts, - // unstable_features, - // unused_must_use, - // unused_mut, - // unused_imports, - // unused_import_braces + missing_docs, + missing_copy_implementations, + missing_debug_implementations, + trivial_casts, + trivial_numeric_casts, + unstable_features, + unused_must_use, + unused_mut, + unused_imports, + unused_import_braces )] pub use http; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index db997db..e95207f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -41,10 +41,6 @@ where /// means here that the size of the queue is unlimited. The default value is the unlimited /// queue. pub max_send_queue: Option, - /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB - /// which should be reasonably big for all normal use-cases but small enough to prevent - /// memory eating by a malicious user. - pub max_message_size: Option, /// The maximum size of a single message frame. `None` means no size limit. The limit is for /// frame payload NOT including the frame header. The default value is 16 MiB which should /// be reasonably big for all normal use-cases but small enough to prevent memory eating @@ -61,9 +57,8 @@ where fn default() -> Self { WebSocketConfig { max_send_queue: None, - max_message_size: Some(MAX_MESSAGE_SIZE), max_frame_size: Some(16 << 20), - encoder: Default::default(), + encoder: E::new(Some(MAX_MESSAGE_SIZE)), } } } @@ -72,10 +67,11 @@ impl WebSocketConfig where E: WebSocketExtension, { + /// Creates a `WebSocketConfig` instance using the default configuration and the provided + /// encoder for new connections. pub fn default_with_encoder(encoder: E) -> WebSocketConfig { WebSocketConfig { max_send_queue: None, - max_message_size: Some(MAX_MESSAGE_SIZE), max_frame_size: Some(16 << 20), encoder, } @@ -476,16 +472,6 @@ where )); } - { - let hdr = frame.header(); - - if !self.get_config().encoder.rsv1() && hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol( - "Reserved bits are non-zero and no WebSocket extensions are enabled".into(), - )); - } - } - match self.role { Role::Server => { if frame.is_masked() { @@ -735,7 +721,6 @@ mod tests { 0x6c, 0x64, 0x21, ]); let limit = WebSocketConfig { - max_message_size: Some(10), max_send_queue: None, max_frame_size: Some(16 << 20), encoder: PlainTextExt::new(Some(10)), @@ -751,7 +736,6 @@ mod tests { fn size_limiting_binary() { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let limit = WebSocketConfig { - max_message_size: Some(2), max_send_queue: None, max_frame_size: Some(16 << 20), encoder: PlainTextExt::new(Some(2)),