diff --git a/src/extensions/compression/deflate.rs b/src/extensions/compression/deflate.rs index 5007171..86d01a6 100644 --- a/src/extensions/compression/deflate.rs +++ b/src/extensions/compression/deflate.rs @@ -128,6 +128,16 @@ impl DeflateConfig { pub fn set_accept_no_context_takeover(&mut self, accept_no_context_takeover: bool) { self.accept_no_context_takeover = accept_no_context_takeover; } + + #[cfg(test)] + pub fn set_compress_reset(&mut self, compress_reset: bool) { + self.compress_reset = compress_reset + } + + #[cfg(test)] + pub fn set_decompress_reset(&mut self, decompress_reset: bool) { + self.decompress_reset = decompress_reset; + } } impl Default for DeflateConfig { @@ -179,7 +189,7 @@ impl DeflateConfigBuilder { } /// Sets the server's LZ77 sliding window size. Panics if the provided size is not in `8..=15`. - pub fn servers_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { + pub fn server_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { assert!( (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), "max window bits must be in range 8..=15" @@ -259,28 +269,23 @@ impl DeflateExt { } fn parse_window_parameter<'a>( - mut param_iter: impl Iterator, + window_param: &str, max_window_bits: u8, -) -> Result, String> { - if let Some(window_bits_str) = param_iter.next() { - let window_bits_str = window_bits_str.replace("\"", ""); - - match window_bits_str.trim().parse() { - Ok(window_bits) => { - if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { - if window_bits != max_window_bits { - Ok(Some(window_bits)) - } else { - Ok(None) - } +) -> Result, DeflateExtensionError> { + let window_param = window_param.replace("\"", ""); + match window_param.trim().parse() { + Ok(window_bits) => { + if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { + if window_bits != max_window_bits { + Ok(Some(window_bits)) } else { - Err(format!("Invalid window parameter: {}", window_bits)) + Ok(None) } + } else { + Err(DeflateExtensionError::InvalidMaxWindowBits) } - Err(e) => Err(e.to_string()), } - } else { - Ok(None) + Err(_) => Err(DeflateExtensionError::InvalidMaxWindowBits), } } @@ -295,6 +300,8 @@ pub enum DeflateExtensionError { NegotiationError(String), /// Produced when fragment buffer grew beyond the maximum configured size. Capacity(Cow<'static, str>), + /// An invalid LZ77 window size was provided. + InvalidMaxWindowBits, } impl DeflateExtensionError { @@ -316,11 +323,17 @@ impl Display for DeflateExtensionError { write!(f, "An upgrade error was encountered: {}", m) } DeflateExtensionError::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), + DeflateExtensionError::InvalidMaxWindowBits => { + write!(f, "An invalid window bit size was provided") + } } } } -/// +/// Verifies any required Sec-WebSocket-Extension headers required for the configured compression +/// level from the HTTP response. Returns `Ok(true)` if a configuration could be agreed, `Ok(false)` +/// if the HTTP header was well formatted but no configuration could be agreed, or an error if it +/// was malformatted. pub fn on_response( response: &Response, config: &mut DeflateConfig, @@ -392,21 +405,18 @@ pub fn on_response( } else { seen_server_max_window_bits = true; - match parse_window_parameter( - param.split("=").skip(1), - *server_max_window_bits, - ) { - Ok(Some(bits)) => { - *server_max_window_bits = bits; + let mut window_param = param.split("=").skip(1); + match window_param.next() { + Some(window_param) => { + if let Some(bits) = parse_window_parameter( + window_param, + *server_max_window_bits, + )? { + *server_max_window_bits = bits; + } } - Ok(None) => {} - Err(e) => { - return Err(DeflateExtensionError::NegotiationError( - format!( - "server_max_window_bits parameter error: {}", - e - ), - )) + None => { + return Err(DeflateExtensionError::InvalidMaxWindowBits) } } } @@ -419,22 +429,14 @@ pub fn on_response( } else { seen_client_max_window_bits = true; - match parse_window_parameter( - param.split("=").skip(1), - *client_max_window_bits, - ) { - Ok(Some(bits)) => { + let mut window_param = param.split("=").skip(1); + if let Some(window_param) = window_param.next() { + if let Some(bits) = parse_window_parameter( + window_param, + *client_max_window_bits, + )? { *client_max_window_bits = bits; } - Ok(None) => {} - Err(e) => { - return Err(DeflateExtensionError::NegotiationError( - format!( - "client_max_window_bits parameter error: {}", - e - ), - )) - } } } } @@ -459,7 +461,7 @@ pub fn on_response( Ok(enabled) } -/// +/// Applies the required headers to negotiate this PCME. pub fn on_make_request(mut request: Request, config: &DeflateConfig) -> Request { let mut header_value = String::from(EXT_IDENT); @@ -542,19 +544,19 @@ fn validate_req_extensions( } else { server_max_bits = true; - match parse_window_parameter( - param.split('=').skip(1), - config.server_max_window_bits, - ) { - Ok(Some(bits)) => { - config.server_max_window_bits = bits; - - response_str.push_str("; "); - response_str.push_str(param) + let mut window_param = param.split("=").skip(1); + match window_param.next() { + Some(window_param) => { + if let Some(bits) = + parse_window_parameter(window_param, config.server_max_window_bits)? + { + config.server_max_window_bits = bits; + } } - Ok(None) => {} - Err(_) => { - return Err(DeflateExtensionError::malformatted()); + None => { + // If the client specifies 'server_max_window_bits' then a value must + // be provided. + return Err(DeflateExtensionError::InvalidMaxWindowBits); } } } @@ -565,20 +567,15 @@ fn validate_req_extensions( } else { client_max_bits = true; - match parse_window_parameter( - param.split('=').skip(1), - config.client_max_window_bits, - ) { - Ok(Some(bits)) => { + let mut window_param = param.split("=").skip(1); + if let Some(window_param) = window_param.next() { + // Absence of this parameter in an extension negotiation offer indicates + // that the client can receive messages compressed using an LZ77 sliding + // window of up to 32,768 bytes. + if let Some(bits) = + parse_window_parameter(window_param, config.client_max_window_bits)? + { config.client_max_window_bits = bits; - response_str.push_str("; "); - response_str.push_str(param); - - continue; - } - Ok(None) => {} - Err(_) => { - return Err(DeflateExtensionError::malformatted()); } } @@ -620,7 +617,9 @@ fn validate_req_extensions( Ok(Some(response_str)) } -/// +/// Verifies any required Sec-WebSocket-Extension headers in the HTTP request and updates the +/// response. Returns `Ok(true)` if a configuration could be agreed, `Ok(false)` if the HTTP header +/// was well formatted but no configuration could be agreed, or an error if it was malformatted. pub fn on_receive_request( request: &Request, response: &mut Response, @@ -996,219 +995,3 @@ impl FragmentBuffer { ) } } - -#[cfg(test)] -mod tests { - use crate::extensions::compression::deflate::{ - on_receive_request, DeflateConfig, DeflateExtensionError, LZ77_MIN_WINDOW_SIZE, - }; - use http::header::SEC_WEBSOCKET_EXTENSIONS; - use http::{HeaderValue, Request, Response}; - - #[test] - fn config_unchanged_on_err() { - let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\""; - let mut request = Request::new(()); - request - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); - - let mut response = Response::new(()); - let initial_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: false, - compression_level: Default::default(), - }; - - let mut parsed_config = initial_config.clone(); - - let r = on_receive_request(&request, &mut response, &mut parsed_config); - - assert_eq!( - r, - Err(DeflateExtensionError::NegotiationError( - "Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into() - )) - ); - assert_eq!(initial_config, parsed_config); - } - - #[test] - fn config_unchanged_on_mismatch() { - let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover"; - let mut request = Request::new(()); - request - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); - - let mut response = Response::new(()); - let initial_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: LZ77_MIN_WINDOW_SIZE, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: false, - compression_level: Default::default(), - }; - - let mut parsed_config = initial_config.clone(); - - let r = on_receive_request(&request, &mut response, &mut parsed_config); - - assert_eq!( - r, - Err(DeflateExtensionError::NegotiationError( - "Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into() - )) - ); - assert_eq!(initial_config, parsed_config); - } - - #[test] - fn parses_named_parameters() { - let s ="permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\""; - let mut request = Request::new(()); - request - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); - - let mut response = Response::new(()); - let mut parsed_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: false, - compression_level: Default::default(), - }; - - let r = on_receive_request(&request, &mut response, &mut parsed_config); - - assert_eq!(r, Ok(true)); - - let expected_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 8, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: true, - decompress_reset: true, - compression_level: Default::default(), - }; - - assert_eq!(parsed_config, expected_config); - - let parsed_header = response - .headers() - .get(SEC_WEBSOCKET_EXTENSIONS) - .expect("Missing header") - .to_str() - .expect("Failed to parse header"); - - assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=\"8\""); - } - - #[test] - fn splits() { - let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; - let mut request = Request::new(()); - request - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); - - let mut response = Response::new(()); - let mut parsed_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: false, - compression_level: Default::default(), - }; - - let r = on_receive_request(&request, &mut response, &mut parsed_config); - - assert_eq!(r, Ok(true)); - - let expected_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: true, - compression_level: Default::default(), - }; - - assert_eq!(parsed_config, expected_config); - - let parsed_header = response - .headers() - .get(SEC_WEBSOCKET_EXTENSIONS) - .expect("Missing header") - .to_str() - .expect("Failed to parse header"); - - assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); - } - - #[test] - fn splits_on_new_line() { - let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; - let mut request = Request::new(()); - request - .headers_mut() - .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); - - let mut response = Response::new(()); - let mut parsed_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: false, - compression_level: Default::default(), - }; - - let r = on_receive_request(&request, &mut response, &mut parsed_config); - - assert_eq!(r, Ok(true)); - - let expected_config = DeflateConfig { - max_message_size: None, - server_max_window_bits: 10, - client_max_window_bits: 11, - request_no_context_takeover: false, - accept_no_context_takeover: true, - compress_reset: false, - decompress_reset: true, - compression_level: Default::default(), - }; - - assert_eq!(parsed_config, expected_config); - - let parsed_header = response - .headers() - .get(SEC_WEBSOCKET_EXTENSIONS) - .expect("Missing header") - .to_str() - .expect("Failed to parse header"); - - assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); - } -} diff --git a/src/extensions/compression/mod.rs b/src/extensions/compression/mod.rs index d7b0a71..7027d5a 100644 --- a/src/extensions/compression/mod.rs +++ b/src/extensions/compression/mod.rs @@ -1,5 +1,8 @@ //! WebSocket compression +#[cfg(test)] +mod tests; + #[cfg(feature = "deflate")] use crate::extensions::compression::deflate::{DeflateConfig, DeflateExt}; use crate::extensions::compression::uncompressed::UncompressedExt; @@ -19,12 +22,12 @@ pub mod deflate; /// An uncompressed message handler for a WebSocket. pub mod uncompressed; -/// +/// The level of compression to use with the WebSocket. #[derive(Copy, Clone, Debug)] pub enum WsCompression { - /// + /// No compression is applied. None(Option), - /// + /// Per-message DEFLATE. #[cfg(feature = "deflate")] Deflate(DeflateConfig), } @@ -32,15 +35,15 @@ pub enum WsCompression { /// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`. #[derive(Debug)] pub enum CompressionSwitcher { - /// + /// No compression is applied. + Uncompressed(UncompressedExt), + /// Per-message DEFLATE. #[cfg(feature = "deflate")] Compressed(DeflateExt), - /// - Uncompressed(UncompressedExt), } impl CompressionSwitcher { - /// + /// Builds a new `CompressionSwitcher` from the provided compression level. pub fn from_config(config: WsCompression) -> CompressionSwitcher { match config { WsCompression::None(size) => { @@ -60,8 +63,8 @@ impl Default for CompressionSwitcher { } } +/// A generic compression error with the underlying cause. #[derive(Debug)] -/// pub struct CompressionError(String); impl Error for CompressionError {} @@ -108,8 +111,9 @@ impl WebSocketExtension for CompressionSwitcher { } } -/// -pub fn build_compression_headers( +/// Applies any required Sec-WebSocket-Extension headers required for the configured compression +/// level to the HTTP request. +pub fn apply_compression_headers( request: Request, config: &mut Option, ) -> Request { @@ -123,7 +127,9 @@ pub fn build_compression_headers( } } -/// +/// Verifies any required Sec-WebSocket-Extension headers required for the configured compression +/// level from the HTTP response. If DEFLATE is not supported, then this reverts to applying no +/// compression. pub fn verify_compression_resp_headers( _response: &Response, config: &mut Option, @@ -150,7 +156,8 @@ pub fn verify_compression_resp_headers( } } -/// +/// Verifies any required Sec-WebSocket-Extension headers in the HTTP request and updates the +/// response. If DEFLATE is not supported, then this reverts to applying no compression. pub fn verify_compression_req_headers( _request: &Request, _response: &mut Response, diff --git a/src/extensions/compression/tests/deflate.rs b/src/extensions/compression/tests/deflate.rs new file mode 100644 index 0000000..a831839 --- /dev/null +++ b/src/extensions/compression/tests/deflate.rs @@ -0,0 +1,290 @@ +use crate::extensions::compression::deflate::{ + on_receive_request, DeflateConfigBuilder, DeflateExtensionError, +}; +use http::header::SEC_WEBSOCKET_EXTENSIONS; +use http::{HeaderValue, Request, Response}; + +mod server { + use super::*; + + #[test] + fn config_unchanged_on_err() { + let s = "permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\""; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let initial_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let mut parsed_config = initial_config.clone(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Err(DeflateExtensionError::InvalidMaxWindowBits)); + assert_eq!(initial_config, parsed_config); + } + + #[test] + fn missing_client_window_size() { + let s = "permessage-deflate; client_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut initial_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_receive_request(&request, &mut response, &mut initial_config); + + assert!(r.is_ok()); + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!( + parsed_header, + "permessage-deflate; client_max_window_bits=11; server_max_window_bits=10" + ); + } + + #[test] + fn missing_server_window_size() { + let s = "permessage-deflate; server_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let initial_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let mut parsed_config = initial_config.clone(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Err(DeflateExtensionError::InvalidMaxWindowBits)); + assert_eq!(initial_config, parsed_config); + } + + #[test] + fn config_unchanged_on_mismatch() { + let s = "permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let initial_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(8) + .build(); + + let mut parsed_config = initial_config.clone(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!( + r, + Err(DeflateExtensionError::NegotiationError( + "Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into() + )) + ); + assert_eq!(initial_config, parsed_config); + } + + #[test] + fn parses_named_parameters() { + let s = "permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\""; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let mut expected_config = DeflateConfigBuilder::default() + .server_max_window_bits(8) + .client_max_window_bits(11) + .build(); + + expected_config.set_compress_reset(true); + expected_config.set_decompress_reset(true); + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header, "permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=8"); + } + + #[test] + fn splits() { + let s = "not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let mut expected_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + expected_config.set_decompress_reset(true); + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header, "permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); + } + + #[test] + fn splits_on_new_line() { + let s = "not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits"; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let mut expected_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + expected_config.set_decompress_reset(true); + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header, "permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10"); + } +} + +mod client { + use super::*; + use crate::extensions::compression::deflate::on_response; + + #[test] + fn splits_on_new_line() { + let s = "permessage-deflate; client_no_context_takeover; client_max_window_bits=8; server_max_window_bits=10"; + + let mut response = Response::new(()); + response + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut parsed_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_response(&mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let mut expected_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(8) + .build(); + expected_config.set_compress_reset(true); + + assert_eq!(parsed_config, expected_config); + } + + #[test] + fn parses_named_parameters() { + let s = "permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\""; + let mut request = Request::new(()); + request + .headers_mut() + .insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s)); + + let mut response = Response::new(()); + let mut parsed_config = DeflateConfigBuilder::default() + .server_max_window_bits(10) + .client_max_window_bits(11) + .build(); + + let r = on_receive_request(&request, &mut response, &mut parsed_config); + + assert_eq!(r, Ok(true)); + + let mut expected_config = DeflateConfigBuilder::default() + .server_max_window_bits(8) + .client_max_window_bits(11) + .build(); + + expected_config.set_compress_reset(true); + expected_config.set_decompress_reset(true); + + assert_eq!(parsed_config, expected_config); + + let parsed_header = response + .headers() + .get(SEC_WEBSOCKET_EXTENSIONS) + .expect("Missing header") + .to_str() + .expect("Failed to parse header"); + + assert_eq!(parsed_header, "permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=8"); + } +} diff --git a/src/extensions/compression/tests/mod.rs b/src/extensions/compression/tests/mod.rs new file mode 100644 index 0000000..1d3559a --- /dev/null +++ b/src/extensions/compression/tests/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "deflate")] +mod deflate; diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 4eadc3f..350347c 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; -use crate::extensions::compression::{build_compression_headers, verify_compression_resp_headers}; +use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Client request type. @@ -115,7 +115,7 @@ fn generate_request( key: &str, config: &mut Option, ) -> Result> { - let request = build_compression_headers(request, config); + let request = apply_compression_headers(request, config); let mut req = Vec::new(); let uri = request.uri();