diff --git a/autobahn/fuzzingclient.json b/autobahn/fuzzingclient.json index a265c37..5cc1933 100644 --- a/autobahn/fuzzingclient.json +++ b/autobahn/fuzzingclient.json @@ -6,7 +6,7 @@ "url": "ws://127.0.0.1:9002" } ], - "cases": ["*"], + "cases": ["13.7.11"], "exclude-cases": [], "exclude-agent-cases": {} } diff --git a/src/extensions/compression/deflate.rs b/src/extensions/compression/deflate.rs index b17591f..b4a7510 100644 --- a/src/extensions/compression/deflate.rs +++ b/src/extensions/compression/deflate.rs @@ -5,7 +5,8 @@ use std::fmt::{Display, Formatter}; use crate::extensions::compression::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; -use crate::protocol::frame::Frame; +use crate::protocol::frame::{ExtensionHeaders, Frame}; +use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; use crate::protocol::MAX_MESSAGE_SIZE; use crate::Message; use bytes::BufMut; @@ -35,7 +36,7 @@ const LZ77_MAX_WINDOW_SIZE: u8 = 15; pub struct DeflateConfig { /// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// for all normal use-cases but small enough to prevent memory eating by a malicious user. - max_message_size: usize, + max_message_size: Option, /// The client's LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, /// this conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must /// be in range 8..15 inclusive. @@ -68,7 +69,7 @@ impl DeflateConfig { } /// Returns the maximum message size permitted. - pub fn max_message_size(&self) -> usize { + pub fn max_message_size(&self) -> Option { self.max_message_size } @@ -109,7 +110,7 @@ impl DeflateConfig { /// Sets the maximum message size permitted. pub fn set_max_message_size(&mut self, max_message_size: Option) { - self.max_message_size = max_message_size.unwrap_or_else(usize::max_value); + self.max_message_size = max_message_size; } /// Sets the LZ77 sliding window size. @@ -132,7 +133,7 @@ impl DeflateConfig { impl Default for DeflateConfig { fn default() -> Self { DeflateConfig { - max_message_size: MAX_MESSAGE_SIZE, + max_message_size: Some(MAX_MESSAGE_SIZE), server_max_window_bits: LZ77_MAX_WINDOW_SIZE, client_max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, @@ -218,7 +219,7 @@ impl DeflateConfigBuilder { /// Consumes the builder and produces a `DeflateConfig.` pub fn build(self) -> DeflateConfig { DeflateConfig { - max_message_size: self.max_message_size.unwrap_or_else(usize::max_value), + max_message_size: self.max_message_size, server_max_window_bits: self.server_max_window_bits, client_max_window_bits: self.client_max_window_bits, request_no_context_takeover: self.request_no_context_takeover, @@ -252,7 +253,7 @@ impl DeflateExt { fragment_buffer: FragmentBuffer::new(config.max_message_size), inflator: Inflator::new(config.server_max_window_bits), deflator: Deflator::new(config.compression_level, config.client_max_window_bits), - uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), + uncompressed_extension: UncompressedExt::new(config.max_message_size()), } } } @@ -492,7 +493,9 @@ pub fn on_receive_request( request: &Request, response: &mut Response, config: &mut DeflateConfig, -) -> Result<(), DeflateExtensionError> { +) -> Result { + let mut enabled = false; + for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { return match header.to_str() { Ok(header) => { @@ -504,7 +507,10 @@ pub fn on_receive_request( for param in header.split(';') { match param.trim().to_lowercase().as_str() { - "permessage-deflate" => response_str.push_str("permessage-deflate"), + "permessage-deflate" => { + enabled = true; + response_str.push_str("permessage-deflate"); + } "server_no_context_takeover" => { if server_takeover { decline(response); @@ -610,7 +616,7 @@ pub fn on_receive_request( HeaderValue::from_str(&response_str)?, ); - Ok(()) + Ok(enabled) } Err(e) => Err(DeflateExtensionError::NegotiationError(format!( "Failed to parse request header: {}", @@ -620,7 +626,7 @@ pub fn on_receive_request( } decline(response); - Ok(()) + Ok(false) } impl std::error::Error for DeflateExtensionError {} @@ -653,7 +659,7 @@ impl WebSocketExtension for DeflateExt { compressed.truncate(len - 4); *frame.payload_mut() = compressed; - frame.header_mut().rsv1 = true; + frame.header_mut().ext_headers.rsv1 = true; if self.config.compress_reset() { self.deflator.reset(); @@ -663,39 +669,46 @@ impl WebSocketExtension for DeflateExt { Ok(frame) } - fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { - if !self.fragment_buffer.is_empty() || frame.header().rsv1 { - if !frame.header().is_final { - self.fragment_buffer - .try_push_frame(frame) - .map_err(|s| DeflateExtensionError::Capacity(s.into()))?; + fn on_receive_frame( + &mut self, + data_opcode: Data, + is_final: bool, + header: ExtensionHeaders, + payload: Vec, + ) -> Result, crate::Error> { + if !self.fragment_buffer.is_empty() || header.rsv1 { + if !is_final { + self.fragment_buffer.try_push(data_opcode, payload)?; Ok(None) } else { let mut compressed = if self.fragment_buffer.is_empty() { - Vec::with_capacity(frame.payload().len()) + Vec::with_capacity(payload.len()) } else { - Vec::with_capacity(self.fragment_buffer.len() + frame.payload().len()) + Vec::with_capacity(self.fragment_buffer.len() + payload.len()) }; - let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); - - let opcode = match frame.header().opcode { - OpCode::Data(Data::Continue) => { - self.fragment_buffer - .try_push_frame(frame) - .map_err(|s| DeflateExtensionError::Capacity(s.into()))?; - - let opcode = self.fragment_buffer.first().unwrap().header().opcode; + let mut decompressed = Vec::with_capacity(payload.len() * 2); - self.fragment_buffer.reset().into_iter().for_each(|f| { - compressed.extend(f.into_data()); - }); + let message_type = match data_opcode { + Data::Continue => { + self.fragment_buffer.try_push(data_opcode, payload)?; + let (opcode, payload) = self.fragment_buffer.reset(); + decompressed = payload; opcode } - _ => { - compressed.put_slice(frame.payload()); - frame.header().opcode + Data::Binary => { + compressed.put_slice(payload.as_slice()); + IncompleteMessageType::Binary + } + Data::Text => { + compressed.put_slice(payload.as_slice()); + IncompleteMessageType::Text + } + Data::Reserved(_) => { + return Err(crate::Error::ExtensionError( + "Unexpected reserved frame received".into(), + )) } }; @@ -707,14 +720,14 @@ impl WebSocketExtension for DeflateExt { self.inflator.reset(false); } - self.uncompressed_extension.on_receive_frame(Frame::message( - decompressed, - opcode, - true, - )) + let mut msg = IncompleteMessage::new(message_type); + msg.extend(decompressed.as_slice(), self.config.max_message_size)?; + + Ok(Some(msg.complete()?)) } } else { - self.uncompressed_extension.on_receive_frame(frame) + self.uncompressed_extension + .on_receive_frame(data_opcode, is_final, header, payload) } } } @@ -867,46 +880,72 @@ impl Inflator { /// Defaults to an initial capacity of ten frames. #[derive(Debug)] struct FragmentBuffer { - fragments: Vec, - fragments_len: usize, - max_len: usize, + frame_opcode: Option, + fragments: Vec, + max_len: Option, } impl FragmentBuffer { /// Creates a new fragment buffer that will permit a maximum length of `max_len`. - fn new(max_len: usize) -> FragmentBuffer { + fn new(max_len: Option) -> FragmentBuffer { FragmentBuffer { - fragments: Vec::with_capacity(10), - fragments_len: 0, + frame_opcode: None, + fragments: Vec::new(), max_len, } } /// Attempts to push a frame into the buffer. This will fail if the new length of the buffer's /// frames exceeds the maximum capacity of `max_len`. - fn try_push_frame(&mut self, frame: Frame) -> Result<(), String> { + fn try_push(&mut self, opcode: Data, payload: Vec) -> Result<(), DeflateExtensionError> { let FragmentBuffer { fragments, - fragments_len, max_len, + frame_opcode, } = self; - *fragments_len += frame.payload().len(); + if fragments.is_empty() { + let ty = match opcode { + Data::Text => IncompleteMessageType::Text, + Data::Binary => IncompleteMessageType::Binary, + opc => { + return Err(DeflateExtensionError::Capacity( + format!("Expected a text or binary frame but received: {}", opc).into(), + )) + } + }; - if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len { - Err(format!( - "Message too big: {} + {} > {}", - fragments_len, fragments_len, max_len - )) - } else { - fragments.push(frame); - Ok(()) + *frame_opcode = Some(ty); + } + + match max_len { + Some(max_len) => { + let mut fragments_len = fragments.len(); + fragments_len += payload.len(); + + if fragments_len > *max_len || payload.len() > *max_len - fragments_len { + return Err(DeflateExtensionError::Capacity( + format!( + "Message too big: {} + {} > {}", + fragments_len, fragments_len, max_len + ) + .into(), + )); + } else { + fragments.extend(payload); + Ok(()) + } + } + None => { + fragments.extend(payload); + Ok(()) + } } } - /// Returns the total length of all of the frames that have been pushed into the buffer. + /// Returns the total length of all of the payloads that have been pushed into the buffer. fn len(&self) -> usize { - self.fragments_len + self.fragments.len() } /// Returns whether the buffer is empty. @@ -914,14 +953,14 @@ impl FragmentBuffer { self.fragments.is_empty() } - /// Returns the first element of the fragments slice, or `None` if it is empty. - fn first(&self) -> Option<&Frame> { - self.fragments.first() - } - - /// Drains the buffer and resets it to an initial capacity of 10 elements. - fn reset(&mut self) -> Vec { - self.fragments_len = 0; - replace(&mut self.fragments, Vec::with_capacity(10)) + /// Drains the buffer. Returning the message's opcode and its payload. + fn reset(&mut self) -> (IncompleteMessageType, Vec) { + let payloads = replace(&mut self.fragments, Vec::new()); + ( + self.frame_opcode + .take() + .expect("Inconsistent state: missing opcode"), + payloads, + ) } } diff --git a/src/extensions/compression/mod.rs b/src/extensions/compression/mod.rs index cfde4db..19f0ff1 100644 --- a/src/extensions/compression/mod.rs +++ b/src/extensions/compression/mod.rs @@ -4,7 +4,8 @@ use crate::extensions::compression::deflate::{DeflateConfig, DeflateExt}; use crate::extensions::compression::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; -use crate::protocol::frame::Frame; +use crate::protocol::frame::coding::Data; +use crate::protocol::frame::{ExtensionHeaders, Frame}; use crate::protocol::WebSocketConfig; use crate::Message; use http::{Request, Response}; @@ -88,11 +89,21 @@ impl WebSocketExtension for CompressionSwitcher { } } - fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { + fn on_receive_frame( + &mut self, + data_opcode: Data, + is_final: bool, + header: ExtensionHeaders, + payload: Vec, + ) -> Result, crate::Error> { match self { - CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame), + CompressionSwitcher::Uncompressed(ext) => { + ext.on_receive_frame(data_opcode, is_final, header, payload) + } #[cfg(feature = "deflate")] - CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame), + CompressionSwitcher::Compressed(ext) => { + ext.on_receive_frame(data_opcode, is_final, header, payload) + } } } } @@ -128,8 +139,7 @@ pub fn verify_compression_resp_headers( match result { Ok(true) => Ok(()), Ok(false) => { - config.compression = - WsCompression::None(Some(deflate_config.max_message_size())); + config.compression = WsCompression::None(deflate_config.max_message_size()); Ok(()) } Err(e) => Err(e), @@ -151,8 +161,17 @@ pub fn verify_compression_req_headers( WsCompression::None(_) => Ok(()), #[cfg(feature = "deflate")] WsCompression::Deflate(ref mut deflate_config) => { - deflate::on_receive_request(_request, _response, deflate_config) - .map_err(|e| CompressionError(e.to_string())) + let result = deflate::on_receive_request(_request, _response, deflate_config) + .map_err(|e| CompressionError(e.to_string())); + + match result { + Ok(true) => Ok(()), + Ok(false) => { + config.compression = WsCompression::None(deflate_config.max_message_size()); + Ok(()) + } + Err(e) => Err(e), + } } }, None => Ok(()), diff --git a/src/extensions/compression/uncompressed.rs b/src/extensions/compression/uncompressed.rs index d29092f..9cb0050 100644 --- a/src/extensions/compression/uncompressed.rs +++ b/src/extensions/compression/uncompressed.rs @@ -1,6 +1,6 @@ use crate::extensions::WebSocketExtension; -use crate::protocol::frame::coding::{Data, OpCode}; -use crate::protocol::frame::Frame; +use crate::protocol::frame::coding::Data; +use crate::protocol::frame::ExtensionHeaders; use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; use crate::protocol::MAX_MESSAGE_SIZE; use crate::{Error, Message}; @@ -33,58 +33,60 @@ impl UncompressedExt { } impl WebSocketExtension for UncompressedExt { - fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { - let fin = frame.header().is_final; - let hdr = frame.header(); + fn on_receive_frame( + &mut self, + data_opcode: Data, + is_final: bool, + header: ExtensionHeaders, + payload: Vec, + ) -> Result, crate::Error> { + let fin = is_final; - if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { + if header.rsv1 || header.rsv2 || header.rsv3 { return Err(Error::Protocol( "Reserved bits are non-zero and no WebSocket extensions are enabled".into(), )); } - match frame.header().opcode { - OpCode::Data(data) => match data { - Data::Continue => { - if let Some(ref mut msg) = self.incomplete { - msg.extend(frame.into_data(), self.max_message_size)?; - } else { - return Err(Error::Protocol( - "Continue frame but nothing to continue".into(), - )); - } - if fin { - Ok(Some(self.incomplete.take().unwrap().complete()?)) - } else { - Ok(None) - } + match data_opcode { + Data::Continue => { + if let Some(ref mut msg) = self.incomplete { + msg.extend(payload, self.max_message_size)?; + } else { + return Err(Error::Protocol( + "Continue frame but nothing to continue".into(), + )); } - c if self.incomplete.is_some() => Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into(), - )), - Data::Text | Data::Binary => { - let msg = { - let message_type = match data { - Data::Text => IncompleteMessageType::Text, - Data::Binary => IncompleteMessageType::Binary, - _ => panic!("Bug: message is not text nor binary"), - }; - let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data(), self.max_message_size)?; - m + if fin { + Ok(Some(self.incomplete.take().unwrap().complete()?)) + } else { + Ok(None) + } + } + c if self.incomplete.is_some() => Err(Error::Protocol( + format!("Received {} while waiting for more fragments", c).into(), + )), + Data::Text | Data::Binary => { + let msg = { + let message_type = match data_opcode { + Data::Text => IncompleteMessageType::Text, + Data::Binary => IncompleteMessageType::Binary, + _ => panic!("Bug: message is not text nor binary"), }; - if fin { - Ok(Some(msg.complete()?)) - } else { - self.incomplete = Some(msg); - Ok(None) - } + let mut m = IncompleteMessage::new(message_type); + m.extend(payload, self.max_message_size)?; + m + }; + if fin { + Ok(Some(msg.complete()?)) + } else { + self.incomplete = Some(msg); + Ok(None) } - Data::Reserved(i) => Err(Error::Protocol( - format!("Unknown data frame type {}", i).into(), - )), - }, - _ => unreachable!(), + } + Data::Reserved(i) => Err(Error::Protocol( + format!("Unknown data frame type {}", i).into(), + )), } } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 1a24612..c0c84ce 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -2,7 +2,8 @@ pub mod compression; -use crate::protocol::frame::Frame; +use crate::protocol::frame::coding::Data; +use crate::protocol::frame::{ExtensionHeaders, Frame}; use crate::Message; /// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions @@ -15,5 +16,11 @@ pub trait WebSocketExtension { /// Called when a frame has been received and unmasked. The frame provided frame will be of the /// type `OpCode::Data`. - fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error>; + fn on_receive_frame( + &mut self, + data_opcode: Data, + is_final: bool, + header: ExtensionHeaders, + payload: Vec, + ) -> Result, crate::Error>; } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index e6a0009..e067d51 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -42,25 +42,41 @@ impl<'t> fmt::Display for CloseFrame<'t> { pub struct FrameHeader { /// Indicates that the frame is the last one of a possibly fragmented message. pub is_final: bool, + /// Reserved extension headers/bits. + pub ext_headers: ExtensionHeaders, + /// WebSocket protocol opcode. + pub opcode: OpCode, + /// A frame mask, if any. + pub mask: Option<[u8; 4]>, +} + +/// A struct representing reserved extension headers from a WebSocket frame. +#[allow(missing_copy_implementations)] +#[derive(Debug, Clone)] +pub struct ExtensionHeaders { /// Reserved for protocol extensions. pub rsv1: bool, /// Reserved for protocol extensions. pub rsv2: bool, /// Reserved for protocol extensions. pub rsv3: bool, - /// WebSocket protocol opcode. - pub opcode: OpCode, - /// A frame mask, if any. - pub mask: Option<[u8; 4]>, } -impl Default for FrameHeader { +impl Default for ExtensionHeaders { fn default() -> Self { - FrameHeader { - is_final: true, + ExtensionHeaders { rsv1: false, rsv2: false, rsv3: false, + } + } +} + +impl Default for FrameHeader { + fn default() -> Self { + FrameHeader { + is_final: true, + ext_headers: Default::default(), opcode: OpCode::Control(Control::Close), mask: None, } @@ -93,9 +109,9 @@ impl FrameHeader { let one = { code | if self.is_final { 0x80 } else { 0 } - | if self.rsv1 { 0x40 } else { 0 } - | if self.rsv2 { 0x20 } else { 0 } - | if self.rsv3 { 0x10 } else { 0 } + | if self.ext_headers.rsv1 { 0x40 } else { 0 } + | if self.ext_headers.rsv2 { 0x20 } else { 0 } + | if self.ext_headers.rsv3 { 0x10 } else { 0 } }; let lenfmt = LengthFormat::for_length(length); @@ -192,11 +208,10 @@ impl FrameHeader { _ => (), } + let ext_headers = ExtensionHeaders { rsv1, rsv2, rsv3 }; let hdr = FrameHeader { is_final, - rsv1, - rsv2, - rsv3, + ext_headers, opcode, mask, }; @@ -381,6 +396,12 @@ impl Frame { output.write_all(self.payload())?; Ok(()) } + + /// Splits the frame into a tuple of its header and payload. + pub fn split(self) -> (FrameHeader, Vec) { + let Frame { header, payload } = self; + (header, payload) + } } impl fmt::Display for Frame { @@ -397,9 +418,9 @@ payload length: {} payload: 0x{} ", self.header.is_final, - self.header.rsv1, - self.header.rsv2, - self.header.rsv3, + self.header.ext_headers.rsv1, + self.header.ext_headers.rsv2, + self.header.ext_headers.rsv3, self.header.opcode, // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 6756f0a..7145d4e 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -7,7 +7,7 @@ mod frame; mod mask; pub use self::frame::CloseFrame; -pub use self::frame::{Frame, FrameHeader}; +pub use self::frame::{ExtensionHeaders, Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 7019494..c948d19 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -159,6 +159,7 @@ impl IncompleteMessage { } /// The type of incomplete message. +#[derive(Debug, Copy, Clone)] pub enum IncompleteMessageType { Text, Binary, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index ba52ec3..91c9860 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -489,7 +489,15 @@ impl WebSocketContext { OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), } } - _ => self.decoder.on_receive_frame(frame), + OpCode::Data(data) => { + let (header, payload) = frame.split(); + self.decoder.on_receive_frame( + data, + header.is_final, + header.ext_headers, + payload, + ) + } } } else { // Connection closed by peer