diff --git a/src/extensions/compression/deflate.rs b/src/extensions/compression/deflate.rs index b4a7510..5007171 100644 --- a/src/extensions/compression/deflate.rs +++ b/src/extensions/compression/deflate.rs @@ -32,7 +32,7 @@ const LZ77_MIN_WINDOW_SIZE: u8 = 8; const LZ77_MAX_WINDOW_SIZE: u8 = 15; /// A permessage-deflate configuration. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub struct DeflateConfig { /// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// for all normal use-cases but small enough to prevent memory eating by a malicious user. @@ -263,6 +263,8 @@ fn parse_window_parameter<'a>( max_window_bits: u8, ) -> Result, String> { if let Some(window_bits_str) = param_iter.next() { + let window_bits_str = window_bits_str.replace("\"", ""); + match window_bits_str.trim().parse() { Ok(window_bits) => { if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { @@ -282,12 +284,8 @@ fn parse_window_parameter<'a>( } } -fn decline(res: &mut Response) { - res.headers_mut().remove(EXT_IDENT); -} - /// A permessage-deflate extension error. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum DeflateExtensionError { /// An error produced when deflating a message. DeflateError(String), @@ -299,6 +297,12 @@ pub enum DeflateExtensionError { Capacity(Cow<'static, str>), } +impl DeflateExtensionError { + fn malformatted() -> DeflateExtensionError { + DeflateExtensionError::NegotiationError("Malformatted header value".into()) + } +} + impl Display for DeflateExtensionError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -342,10 +346,11 @@ pub fn on_response( Ok(header) => { for param in header.split(';') { match param.trim().to_lowercase().as_str() { - "permessage-deflate" => { + EXT_IDENT => { if seen_extension_name { return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: permessage-deflate" + "Duplicate extension parameter: {}", + EXT_IDENT ))); } else { enabled = true; @@ -455,7 +460,7 @@ pub fn on_response( } /// -pub fn on_request(mut request: Request, config: &DeflateConfig) -> Request { +pub fn on_make_request(mut request: Request, config: &DeflateConfig) -> Request { let mut header_value = String::from(EXT_IDENT); let DeflateConfig { @@ -488,135 +493,163 @@ pub fn on_request(mut request: Request, config: &DeflateConfig) -> Request request } -/// -pub fn on_receive_request( - request: &Request, - response: &mut Response, +fn validate_req_extensions( + header: &str, config: &mut DeflateConfig, -) -> Result { - let mut enabled = false; - - for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { - return match header.to_str() { - Ok(header) => { - let mut response_str = String::with_capacity(header.len()); - let mut server_takeover = false; - let mut client_takeover = false; - let mut server_max_bits = false; - let mut client_max_bits = false; +) -> Result, DeflateExtensionError> { + let mut response_str = String::with_capacity(header.len()); + let mut param_iter = header.split(';'); - for param in header.split(';') { - match param.trim().to_lowercase().as_str() { - "permessage-deflate" => { - enabled = true; - response_str.push_str("permessage-deflate"); - } - "server_no_context_takeover" => { - if server_takeover { - decline(response); - } else { - server_takeover = true; - if config.accept_no_context_takeover() { - config.compress_reset = true; - response_str.push_str("; server_no_context_takeover"); - } - } - } - "client_no_context_takeover" => { - if client_takeover { - decline(response); - } else { - client_takeover = true; - config.decompress_reset = true; - response_str.push_str("; client_no_context_takeover"); - } - } - param if param.starts_with("server_max_window_bits") => { - if server_max_bits { - decline(response); - } else { - server_max_bits = true; - - match parse_window_parameter( - param.split('=').skip(1), - config.server_max_window_bits, - ) { - Ok(Some(bits)) => { - config.server_max_window_bits = bits; + match param_iter.next() { + Some(name) if name.trim() == EXT_IDENT => { + response_str.push_str(EXT_IDENT); + } + _ => { + return Ok(None); + } + } - response_str.push_str("; "); - response_str.push_str(param) - } - Ok(None) => {} - Err(_) => { - decline(response); - } - } - } - } - param if param.starts_with("client_max_window_bits") => { - if client_max_bits { - decline(response); - } else { - client_max_bits = true; + let mut server_takeover = false; + let mut client_takeover = false; + let mut server_max_bits = false; + let mut client_max_bits = false; - match parse_window_parameter( - param.split('=').skip(1), - config.client_max_window_bits, - ) { - Ok(Some(bits)) => { - config.client_max_window_bits = bits; - response_str.push_str("; "); - response_str.push_str(param); + while let Some(param) = param_iter.next() { + match param.trim().to_lowercase().as_str() { + "server_no_context_takeover" => { + if server_takeover { + return Err(DeflateExtensionError::malformatted()); + } else { + server_takeover = true; + if config.accept_no_context_takeover() { + config.compress_reset = true; + response_str.push_str("; server_no_context_takeover"); + } + } + } + "client_no_context_takeover" => { + if client_takeover { + return Err(DeflateExtensionError::malformatted()); + } else { + client_takeover = true; + config.decompress_reset = true; + response_str.push_str("; client_no_context_takeover"); + } + } + param if param.starts_with("server_max_window_bits") => { + if server_max_bits { + return Err(DeflateExtensionError::malformatted()); + } else { + server_max_bits = true; - continue; - } - Ok(None) => {} - Err(_) => { - decline(response); - } - } + match parse_window_parameter( + param.split('=').skip(1), + config.server_max_window_bits, + ) { + Ok(Some(bits)) => { + config.server_max_window_bits = bits; - response_str.push_str("; "); - response_str.push_str(&format!( - "client_max_window_bits={}", - config.client_max_window_bits() - )) - } + response_str.push_str("; "); + response_str.push_str(param) } - _ => { - decline(response); + Ok(None) => {} + Err(_) => { + return Err(DeflateExtensionError::malformatted()); } } } + } + param if param.starts_with("client_max_window_bits") => { + if client_max_bits { + return Err(DeflateExtensionError::malformatted()); + } else { + client_max_bits = true; + + match parse_window_parameter( + param.split('=').skip(1), + config.client_max_window_bits, + ) { + Ok(Some(bits)) => { + config.client_max_window_bits = bits; + response_str.push_str("; "); + response_str.push_str(param); + + continue; + } + Ok(None) => {} + Err(_) => { + return Err(DeflateExtensionError::malformatted()); + } + } - if !response_str.contains("client_no_context_takeover") - && config.request_no_context_takeover() - { - config.decompress_reset = true; - response_str.push_str("; client_no_context_takeover"); - } - - if !response_str.contains("server_max_window_bits") { response_str.push_str("; "); response_str.push_str(&format!( - "server_max_window_bits={}", - config.server_max_window_bits() + "client_max_window_bits={}", + config.client_max_window_bits() )) } + } + p => { + return Err(DeflateExtensionError::NegotiationError( + format!("Unknown permessage-deflate parameter: {}", p).into(), + )) + } + } + } - if !response_str.contains("client_max_window_bits") - && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE - { - continue; - } + if !response_str.contains("client_no_context_takeover") && config.request_no_context_takeover() + { + config.decompress_reset = true; + response_str.push_str("; client_no_context_takeover"); + } + + if !response_str.contains("server_max_window_bits") { + response_str.push_str("; "); + response_str.push_str(&format!( + "server_max_window_bits={}", + config.server_max_window_bits() + )) + } - response.headers_mut().insert( - SEC_WEBSOCKET_EXTENSIONS, - HeaderValue::from_str(&response_str)?, - ); + if !response_str.contains("client_max_window_bits") + && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE + { + return Ok(None); + } + + Ok(Some(response_str)) +} - Ok(enabled) +/// +pub fn on_receive_request( + request: &Request, + response: &mut Response, + config: &mut DeflateConfig, +) -> Result { + for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { + return match header.to_str() { + Ok(header) => { + for header in header.split(',') { + let mut parser_config = config.clone(); + + match validate_req_extensions(header, &mut parser_config) { + Ok(Some(response_str)) => { + response.headers_mut().insert( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_str)?, + ); + + *config = parser_config; + return Ok(true); + } + Ok(None) => continue, + Err(e) => { + response.headers_mut().remove(EXT_IDENT); + return Err(e); + } + } + } + Ok(false) } Err(e) => Err(DeflateExtensionError::NegotiationError(format!( "Failed to parse request header: {}", @@ -625,7 +658,6 @@ pub fn on_receive_request( }; } - decline(response); Ok(false) } @@ -964,3 +996,219 @@ impl FragmentBuffer { ) } } + +#[cfg(test)] +mod tests { + use crate::extensions::compression::deflate::{ + on_receive_request, DeflateConfig, DeflateExtensionError, LZ77_MIN_WINDOW_SIZE, + }; + use http::header::SEC_WEBSOCKET_EXTENSIONS; + use http::{HeaderValue, Request, Response}; + + #[test] + fn config_unchanged_on_err() { + let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\""; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let initial_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Default::default(), + }; + + let mut parsed_config = initial_config.clone(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!( + r, + Err(DeflateExtensionError::NegotiationError( + "Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into() + )) + ); + assert_eq!(initial_config, parsed_config); + } + + #[test] + fn config_unchanged_on_mismatch() { + let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let initial_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: LZ77_MIN_WINDOW_SIZE, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Default::default(), + }; + + let mut parsed_config = initial_config.clone(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!( + r, + Err(DeflateExtensionError::NegotiationError( + "Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into() + )) + ); + assert_eq!(initial_config, parsed_config); + } + + #[test] + fn parses_named_parameters() { + let s ="permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\""; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Default::default(), + }; + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let expected_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 8, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: true, + decompress_reset: true, + compression_level: Default::default(), + }; + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=\"8\""); + } + + #[test] + fn splits() { + let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Default::default(), + }; + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let expected_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: true, + compression_level: Default::default(), + }; + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); + } + + #[test] + fn splits_on_new_line() { + let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: false, + compression_level: Default::default(), + }; + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let expected_config = DeflateConfig { + max_message_size: None, + server_max_window_bits: 10, + client_max_window_bits: 11, + request_no_context_takeover: false, + accept_no_context_takeover: true, + compress_reset: false, + decompress_reset: true, + compression_level: Default::default(), + }; + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); + } +} diff --git a/src/extensions/compression/mod.rs b/src/extensions/compression/mod.rs index 19f0ff1..d7b0a71 100644 --- a/src/extensions/compression/mod.rs +++ b/src/extensions/compression/mod.rs @@ -117,7 +117,7 @@ pub fn build_compression_headers( Some(ref mut config) => match &config.compression { WsCompression::None(_) => request, #[cfg(feature = "deflate")] - WsCompression::Deflate(config) => deflate::on_request(request, config), + WsCompression::Deflate(config) => deflate::on_make_request(request, config), }, None => request, } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index c0c84ce..4fad147 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -14,8 +14,7 @@ pub trait WebSocketExtension { Ok(frame) } - /// Called when a frame has been received and unmasked. The frame provided frame will be of the - /// type `OpCode::Data`. + /// Called when a WebSocket frame has been received. fn on_receive_frame( &mut self, data_opcode: Data,