From b658064b5e37f0f082e33188fe3b4f8475de42a0 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Mon, 2 Nov 2020 17:37:32 +0000 Subject: [PATCH] Splits client/server max_window_bits --- examples/autobahn-client.rs | 5 +- examples/autobahn-server.rs | 5 +- src/extensions/compression/deflate.rs | 154 +++++++++++--------- src/extensions/compression/mod.rs | 160 +++++++++++++++++++++ src/extensions/compression/uncompressed.rs | 1 - src/extensions/mod.rs | 155 +------------------- src/handshake/client.rs | 2 +- src/handshake/server.rs | 2 +- src/protocol/mod.rs | 5 +- 9 files changed, 257 insertions(+), 232 deletions(-) create mode 100644 src/extensions/compression/mod.rs diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 8e0726f..fb0a839 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -2,7 +2,8 @@ use log::*; use url::Url; use tungstenite::client::connect_with_config; -use tungstenite::extensions::deflate::{DeflateConfigBuilder, DeflateExt}; +use tungstenite::extensions::compression::deflate::DeflateConfigBuilder; +use tungstenite::extensions::compression::WsCompression; use tungstenite::protocol::WebSocketConfig; use tungstenite::{connect, Error, Message, Result}; @@ -43,7 +44,7 @@ fn run_test(case: u32) -> Result<()> { Some(WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: DeflateExt::new(deflate_config), + compression: WsCompression::Deflate(deflate_config), }), )?; diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index f6e7622..5e9fd2b 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -2,7 +2,8 @@ use std::net::{TcpListener, TcpStream}; use std::thread::spawn; use log::*; -use tungstenite::extensions::deflate::{DeflateExt, DeflateConfigBuilder}; +use tungstenite::extensions::compression::deflate::DeflateConfigBuilder; +use tungstenite::extensions::compression::WsCompression; use tungstenite::handshake::HandshakeRole; use tungstenite::protocol::WebSocketConfig; use tungstenite::server::accept_with_config; @@ -25,7 +26,7 @@ fn handle_client(stream: TcpStream) -> Result<()> { Some(WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: DeflateExt::new(deflate_config), + compression: WsCompression::Deflate(deflate_config), }), ) .map_err(must_not_block)?; diff --git a/src/extensions/compression/deflate.rs b/src/extensions/compression/deflate.rs index bfd54ff..b17591f 100644 --- a/src/extensions/compression/deflate.rs +++ b/src/extensions/compression/deflate.rs @@ -2,7 +2,7 @@ use std::fmt::{Display, Formatter}; -use crate::extensions::uncompressed::UncompressedExt; +use crate::extensions::compression::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::Frame; @@ -36,10 +36,14 @@ pub struct DeflateConfig { /// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// for all normal use-cases but small enough to prevent memory eating by a malicious user. max_message_size: usize, - /// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this - /// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in - /// range 8..15 inclusive. - max_window_bits: u8, + /// The client's LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, + /// this conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must + /// be in range 8..15 inclusive. + server_max_window_bits: u8, + /// The client's LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, + /// this conforms to RFC 7692 7.1.2.2. In server mode, this conforms to RFC 7692 7.1.2.2. Must + /// be in range 8..15 inclusive. + client_max_window_bits: u8, /// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1. request_no_context_takeover: bool, /// Whether to accept `no_context_takeover`. @@ -68,9 +72,14 @@ impl DeflateConfig { self.max_message_size } - /// Returns the maximum LZ77 window size permitted. - pub fn max_window_bits(&self) -> u8 { - self.max_window_bits + /// Returns the maximum LZ77 window size permitted for the server. + pub fn server_max_window_bits(&self) -> u8 { + self.server_max_window_bits + } + + /// Returns the maximum LZ77 window size permitted for the client. + pub fn client_max_window_bits(&self) -> u8 { + self.client_max_window_bits } /// Returns whether `no_context_takeover` has been requested. @@ -106,7 +115,7 @@ impl DeflateConfig { /// Sets the LZ77 sliding window size. pub fn set_max_window_bits(&mut self, max_window_bits: u8) { assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits)); - self.max_window_bits = max_window_bits; + self.client_max_window_bits = max_window_bits; } /// Sets the WebSocket to request `no_context_takeover` if `true`. @@ -124,7 +133,8 @@ impl Default for DeflateConfig { fn default() -> Self { DeflateConfig { max_message_size: MAX_MESSAGE_SIZE, - max_window_bits: LZ77_MAX_WINDOW_SIZE, + server_max_window_bits: LZ77_MAX_WINDOW_SIZE, + client_max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, accept_no_context_takeover: true, compress_reset: false, @@ -138,7 +148,8 @@ impl Default for DeflateConfig { #[derive(Debug, Copy, Clone)] pub struct DeflateConfigBuilder { max_message_size: Option, - max_window_bits: u8, + server_max_window_bits: u8, + client_max_window_bits: u8, request_no_context_takeover: bool, accept_no_context_takeover: bool, fragments_grow: bool, @@ -149,7 +160,8 @@ impl Default for DeflateConfigBuilder { fn default() -> Self { DeflateConfigBuilder { max_message_size: Some(MAX_MESSAGE_SIZE), - max_window_bits: LZ77_MAX_WINDOW_SIZE, + server_max_window_bits: LZ77_MAX_WINDOW_SIZE, + client_max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, accept_no_context_takeover: true, fragments_grow: true, @@ -165,13 +177,23 @@ impl DeflateConfigBuilder { self } - /// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`. - pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { + /// Sets the server's LZ77 sliding window size. Panics if the provided size is not in `8..=15`. + pub fn servers_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { + assert!( + (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), + "max window bits must be in range 8..=15" + ); + self.server_max_window_bits = max_window_bits; + self + } + + /// Sets the client's LZ77 sliding window size. Panics if the provided size is not in `8..=15`. + pub fn client_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { assert!( (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), "max window bits must be in range 8..=15" ); - self.max_window_bits = max_window_bits; + self.client_max_window_bits = max_window_bits; self } @@ -197,7 +219,8 @@ impl DeflateConfigBuilder { pub fn build(self) -> DeflateConfig { DeflateConfig { max_message_size: self.max_message_size.unwrap_or_else(usize::max_value), - max_window_bits: self.max_window_bits, + server_max_window_bits: self.server_max_window_bits, + client_max_window_bits: self.client_max_window_bits, request_no_context_takeover: self.request_no_context_takeover, accept_no_context_takeover: self.accept_no_context_takeover, compression_level: self.compression_level, @@ -209,9 +232,6 @@ impl DeflateConfigBuilder { /// A permessage-deflate encoding WebSocket extension. #[derive(Debug)] pub struct DeflateExt { - /// Defines whether the extension is enabled. Following a successful handshake, this will be - /// `true`. - enabled: bool, /// The configuration for the extension. config: DeflateConfig, /// A stack of continuation frames awaiting `fin` and the total size of all of the fragments. @@ -228,11 +248,10 @@ impl DeflateExt { /// Creates a `DeflateExt` instance using the provided configuration. pub fn new(config: DeflateConfig) -> DeflateExt { DeflateExt { - enabled: false, config, fragment_buffer: FragmentBuffer::new(config.max_message_size), - inflator: Inflator::new(config.max_window_bits), - deflator: Deflator::new(config.compression_level, config.max_window_bits), + inflator: Inflator::new(config.server_max_window_bits), + deflator: Deflator::new(config.compression_level, config.client_max_window_bits), uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), } } @@ -301,15 +320,16 @@ pub fn on_response( response: &Response, config: &mut DeflateConfig, ) -> Result { - let mut extension_name = false; - let mut server_takeover = false; - let mut client_takeover = false; - let mut server_max_window_bits = false; - let mut client_max_window_bits = false; + let mut seen_extension_name = false; + let mut seen_server_takeover = false; + let mut seen_client_takeover = false; + let mut seen_server_max_window_bits = false; + let mut seen_client_max_window_bits = false; let mut enabled = false; let DeflateConfig { - max_window_bits, + server_max_window_bits, + client_max_window_bits, accept_no_context_takeover, compress_reset, decompress_reset, @@ -322,32 +342,32 @@ pub fn on_response( for param in header.split(';') { match param.trim().to_lowercase().as_str() { "permessage-deflate" => { - if extension_name { + if seen_extension_name { return Err(DeflateExtensionError::NegotiationError(format!( "Duplicate extension parameter: permessage-deflate" ))); } else { enabled = true; - extension_name = true; + seen_extension_name = true; } } "server_no_context_takeover" => { - if server_takeover { + if seen_server_takeover { return Err(DeflateExtensionError::NegotiationError(format!( "Duplicate extension parameter: server_no_context_takeover" ))); } else { - server_takeover = true; + seen_server_takeover = true; *decompress_reset = true; } } "client_no_context_takeover" => { - if client_takeover { + if seen_client_takeover { return Err(DeflateExtensionError::NegotiationError(format!( "Duplicate extension parameter: client_no_context_takeover" ))); } else { - client_takeover = true; + seen_client_takeover = true; if *accept_no_context_takeover { *compress_reset = true; @@ -359,19 +379,19 @@ pub fn on_response( } } param if param.starts_with("server_max_window_bits") => { - if server_max_window_bits { + if seen_server_max_window_bits { return Err(DeflateExtensionError::NegotiationError(format!( "Duplicate extension parameter: server_max_window_bits" ))); } else { - server_max_window_bits = true; + seen_server_max_window_bits = true; match parse_window_parameter( param.split("=").skip(1), - *max_window_bits, + *server_max_window_bits, ) { Ok(Some(bits)) => { - *max_window_bits = bits; + *server_max_window_bits = bits; } Ok(None) => {} Err(e) => { @@ -386,19 +406,19 @@ pub fn on_response( } } param if param.starts_with("client_max_window_bits") => { - if client_max_window_bits { + if seen_client_max_window_bits { return Err(DeflateExtensionError::NegotiationError(format!( "Duplicate extension parameter: client_max_window_bits" ))); } else { - client_max_window_bits = true; + seen_client_max_window_bits = true; match parse_window_parameter( param.split("=").skip(1), - *max_window_bits, + *client_max_window_bits, ) { Ok(Some(bits)) => { - *max_window_bits = bits; + *client_max_window_bits = bits; } Ok(None) => {} Err(e) => { @@ -438,15 +458,18 @@ pub fn on_request(mut request: Request, config: &DeflateConfig) -> Request let mut header_value = String::from(EXT_IDENT); let DeflateConfig { - max_window_bits, + server_max_window_bits, + client_max_window_bits, request_no_context_takeover, .. } = config; - if *max_window_bits < LZ77_MAX_WINDOW_SIZE { + if *client_max_window_bits < LZ77_MAX_WINDOW_SIZE + || *server_max_window_bits < LZ77_MAX_WINDOW_SIZE + { header_value.push_str(&format!( "; client_max_window_bits={}; server_max_window_bits={}", - max_window_bits, max_window_bits + client_max_window_bits, server_max_window_bits )) } else { header_value.push_str("; client_max_window_bits") @@ -510,10 +533,10 @@ pub fn on_receive_request( match parse_window_parameter( param.split('=').skip(1), - config.max_window_bits, + config.server_max_window_bits, ) { Ok(Some(bits)) => { - config.max_window_bits = bits; + config.server_max_window_bits = bits; response_str.push_str("; "); response_str.push_str(param) @@ -533,10 +556,10 @@ pub fn on_receive_request( match parse_window_parameter( param.split('=').skip(1), - config.max_window_bits, + config.client_max_window_bits, ) { Ok(Some(bits)) => { - config.max_window_bits = bits; + config.client_max_window_bits = bits; response_str.push_str("; "); response_str.push_str(param); @@ -551,7 +574,7 @@ pub fn on_receive_request( response_str.push_str("; "); response_str.push_str(&format!( "client_max_window_bits={}", - config.max_window_bits() + config.client_max_window_bits() )) } } @@ -572,12 +595,12 @@ pub fn on_receive_request( response_str.push_str("; "); response_str.push_str(&format!( "server_max_window_bits={}", - config.max_window_bits() + config.server_max_window_bits() )) } if !response_str.contains("client_max_window_bits") - && config.max_window_bits() < LZ77_MAX_WINDOW_SIZE + && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE { continue; } @@ -622,20 +645,18 @@ impl Default for DeflateExt { impl WebSocketExtension for DeflateExt { fn on_send_frame(&mut self, mut frame: Frame) -> Result { - if self.enabled { - if let OpCode::Data(_) = frame.header().opcode { - let mut compressed = Vec::with_capacity(frame.payload().len()); - self.deflator.compress(frame.payload(), &mut compressed)?; + if let OpCode::Data(_) = frame.header().opcode { + let mut compressed = Vec::with_capacity(frame.payload().len()); + self.deflator.compress(frame.payload(), &mut compressed)?; - let len = compressed.len(); - compressed.truncate(len - 4); + let len = compressed.len(); + compressed.truncate(len - 4); - *frame.payload_mut() = compressed; - frame.header_mut().rsv1 = true; + *frame.payload_mut() = compressed; + frame.header_mut().rsv1 = true; - if self.config.compress_reset() { - self.deflator.reset(); - } + if self.config.compress_reset() { + self.deflator.reset(); } } @@ -643,7 +664,7 @@ impl WebSocketExtension for DeflateExt { } fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { - let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) { + if !self.fragment_buffer.is_empty() || frame.header().rsv1 { if !frame.header().is_final { self.fragment_buffer .try_push_frame(frame) @@ -694,11 +715,6 @@ impl WebSocketExtension for DeflateExt { } } else { self.uncompressed_extension.on_receive_frame(frame) - }; - - match r { - Ok(msg) => Ok(msg), - Err(e) => Err(crate::Error::ExtensionError(e.to_string().into())), } } } diff --git a/src/extensions/compression/mod.rs b/src/extensions/compression/mod.rs new file mode 100644 index 0000000..cfde4db --- /dev/null +++ b/src/extensions/compression/mod.rs @@ -0,0 +1,160 @@ +//! WebSocket compression + +#[cfg(feature = "deflate")] +use crate::extensions::compression::deflate::{DeflateConfig, DeflateExt}; +use crate::extensions::compression::uncompressed::UncompressedExt; +use crate::extensions::WebSocketExtension; +use crate::protocol::frame::Frame; +use crate::protocol::WebSocketConfig; +use crate::Message; +use http::{Request, Response}; +use std::borrow::Cow; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// A permessage-deflate WebSocket extension (RFC 7692). +#[cfg(feature = "deflate")] +pub mod deflate; +/// An uncompressed message handler for a WebSocket. +pub mod uncompressed; + +/// +#[derive(Copy, Clone, Debug)] +pub enum WsCompression { + /// + None(Option), + /// + #[cfg(feature = "deflate")] + Deflate(DeflateConfig), +} + +/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`. +#[derive(Debug)] +pub enum CompressionSwitcher { + /// + #[cfg(feature = "deflate")] + Compressed(DeflateExt), + /// + Uncompressed(UncompressedExt), +} + +impl CompressionSwitcher { + /// + pub fn from_config(config: WsCompression) -> CompressionSwitcher { + match config { + WsCompression::None(size) => { + CompressionSwitcher::Uncompressed(UncompressedExt::new(size)) + } + #[cfg(feature = "deflate")] + WsCompression::Deflate(config) => { + CompressionSwitcher::Compressed(DeflateExt::new(config)) + } + } + } +} + +impl Default for CompressionSwitcher { + fn default() -> Self { + CompressionSwitcher::Uncompressed(UncompressedExt::default()) + } +} + +#[derive(Debug)] +/// +pub struct CompressionError(String); + +impl Error for CompressionError {} + +impl From for crate::Error { + fn from(e: CompressionError) -> Self { + crate::Error::ExtensionError(Cow::from(e.to_string())) + } +} + +impl Display for CompressionError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CompressionError") + .field("error", &self.0) + .finish() + } +} + +impl WebSocketExtension for CompressionSwitcher { + fn on_send_frame(&mut self, frame: Frame) -> Result { + match self { + CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame), + #[cfg(feature = "deflate")] + CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame), + } + } + + fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { + match self { + CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame), + #[cfg(feature = "deflate")] + CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame), + } + } +} + +/// +pub fn build_compression_headers( + request: Request, + config: &mut Option, +) -> Request { + match config { + Some(ref mut config) => match &config.compression { + WsCompression::None(_) => request, + #[cfg(feature = "deflate")] + WsCompression::Deflate(config) => deflate::on_request(request, config), + }, + None => request, + } +} + +/// +pub fn verify_compression_resp_headers( + _response: &Response, + config: &mut Option, +) -> Result<(), CompressionError> { + match config { + Some(ref mut config) => match &mut config.compression { + WsCompression::None(_) => Ok(()), + #[cfg(feature = "deflate")] + WsCompression::Deflate(ref mut deflate_config) => { + let result = deflate::on_response(_response, deflate_config) + .map_err(|e| CompressionError(e.to_string())); + + match result { + Ok(true) => Ok(()), + Ok(false) => { + config.compression = + WsCompression::None(Some(deflate_config.max_message_size())); + Ok(()) + } + Err(e) => Err(e), + } + } + }, + None => Ok(()), + } +} + +/// +pub fn verify_compression_req_headers( + _request: &Request, + _response: &mut Response, + config: &mut Option, +) -> Result<(), CompressionError> { + match config { + Some(ref mut config) => match &mut config.compression { + WsCompression::None(_) => Ok(()), + #[cfg(feature = "deflate")] + WsCompression::Deflate(ref mut deflate_config) => { + deflate::on_receive_request(_request, _response, deflate_config) + .map_err(|e| CompressionError(e.to_string())) + } + }, + None => Ok(()), + } +} diff --git a/src/extensions/compression/uncompressed.rs b/src/extensions/compression/uncompressed.rs index 939e934..d29092f 100644 --- a/src/extensions/compression/uncompressed.rs +++ b/src/extensions/compression/uncompressed.rs @@ -35,7 +35,6 @@ impl UncompressedExt { impl WebSocketExtension for UncompressedExt { fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { let fin = frame.header().is_final; - let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 0c2a169..1a24612 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,32 +1,9 @@ //! WebSocket extensions -use http::{Request, Response}; +pub mod compression; -#[cfg(feature = "deflate")] -use crate::extensions::deflate::{DeflateConfig, DeflateExt}; -use crate::extensions::uncompressed::UncompressedExt; use crate::protocol::frame::Frame; -use crate::protocol::WebSocketConfig; use crate::Message; -use std::borrow::Cow; -use std::error::Error; -use std::fmt::{Display, Formatter}; - -/// A permessage-deflate WebSocket extension (RFC 7692). -#[cfg(feature = "deflate")] -pub mod deflate; -/// An uncompressed message handler for a WebSocket. -pub mod uncompressed; - -/// -#[derive(Copy, Clone, Debug)] -pub enum WsCompression { - /// - None(Option), - /// - #[cfg(feature = "deflate")] - Deflate(DeflateConfig), -} /// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions /// may be stacked by nesting them inside one another. @@ -40,133 +17,3 @@ pub trait WebSocketExtension { /// type `OpCode::Data`. fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error>; } - -/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`. -#[derive(Debug)] -pub enum CompressionSwitcher { - /// - #[cfg(feature = "deflate")] - Compressed(DeflateExt), - /// - Uncompressed(UncompressedExt), -} - -impl CompressionSwitcher { - /// - pub fn from_config(config: WsCompression) -> CompressionSwitcher { - match config { - WsCompression::None(size) => { - CompressionSwitcher::Uncompressed(UncompressedExt::new(size)) - } - #[cfg(feature = "deflate")] - WsCompression::Deflate(config) => { - CompressionSwitcher::Compressed(DeflateExt::new(config)) - } - } - } -} - -impl Default for CompressionSwitcher { - fn default() -> Self { - CompressionSwitcher::Uncompressed(UncompressedExt::default()) - } -} - -#[derive(Debug)] -/// -pub struct CompressionError(String); - -impl Error for CompressionError {} - -impl From for crate::Error { - fn from(e: CompressionError) -> Self { - crate::Error::ExtensionError(Cow::from(e.to_string())) - } -} - -impl Display for CompressionError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CompressionError") - .field("error", &self.0) - .finish() - } -} - -impl WebSocketExtension for CompressionSwitcher { - fn on_send_frame(&mut self, frame: Frame) -> Result { - match self { - CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame), - #[cfg(feature = "deflate")] - CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame), - } - } - - fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { - match self { - CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame), - #[cfg(feature = "deflate")] - CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame), - } - } -} - -/// -pub fn build_compression_headers( - request: Request, - config: &mut Option, -) -> Request { - match config { - Some(ref mut config) => match &config.compression { - WsCompression::None(_) => request, - #[cfg(feature = "deflate")] - WsCompression::Deflate(config) => deflate::on_request(request, config), - }, - None => request, - } -} - -/// -pub fn verify_compression_resp_headers( - _response: &Response, - config: &mut Option, -) -> Result<(), CompressionError> { - match config { - Some(ref mut config) => match &mut config.compression { - WsCompression::None(_) => Ok(()), - #[cfg(feature = "deflate")] - WsCompression::Deflate(ref mut deflate_config) => { - let result = deflate::on_response(_response, deflate_config) - .map_err(|e| CompressionError(e.to_string())); - match result { - Ok(true) => Ok(()), - Ok(false) => { - config.compression = - WsCompression::None(Some(deflate_config.max_message_size())); - Ok(()) - } - Err(e) => Err(e), - } - } - }, - None => Ok(()), - } -} - -/// -pub fn verify_compression_req_headers( - _request: &Request, - _response: &mut Response, - config: &mut Option, -) -> Result<(), CompressionError> { - match config { - Some(ref mut config) => match &mut config.compression { - WsCompression::None(_) => Ok(()), - #[cfg(feature = "deflate")] - WsCompression::Deflate(ref mut deflate_config) => { - deflate::on_receive_request(_request, _response, deflate_config) - .map_err(|e| CompressionError(e.to_string())) - } - }, - None => Ok(()), - } -} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index efb0df4..4eadc3f 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; -use crate::extensions::{build_compression_headers, verify_compression_resp_headers}; +use crate::extensions::compression::{build_compression_headers, verify_compression_resp_headers}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Client request type. diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 181d44b..110a574 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -12,7 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; -use crate::extensions::verify_compression_req_headers; +use crate::extensions::compression::verify_compression_req_headers; use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Server request type. diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 2241d75..ba52ec3 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -16,7 +16,8 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use self::frame::{Frame, FrameCodec}; use self::message::IncompleteMessage; use crate::error::{Error, Result}; -use crate::extensions::{CompressionSwitcher, WebSocketExtension, WsCompression}; +use crate::extensions::compression::{CompressionSwitcher, WsCompression}; +use crate::extensions::WebSocketExtension; use crate::util::NonBlockingResult; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; @@ -636,7 +637,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use crate::extensions::WsCompression; + use crate::extensions::compression::WsCompression; use std::io; use std::io::Cursor;