diff --git a/src/client.rs b/src/client.rs index 20af922..31213c1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -66,8 +66,6 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::error::{Error, Result}; -use crate::extensions::uncompressed::UncompressedExt; -use crate::extensions::WebSocketExtension; use crate::handshake::client::ClientHandshake; use crate::handshake::HandshakeError; use crate::protocol::WebSocket; @@ -88,13 +86,12 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config( +pub fn connect_with_config( request: Req, - config: Option>, -) -> Result<(WebSocket, Response)> + config: Option, +) -> Result<(WebSocket, Response)> where Req: IntoClientRequest, - Ext: WebSocketExtension, { let request: Request = request.into_client_request()?; let uri = request.uri(); @@ -128,9 +125,7 @@ where /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect( - request: Req, -) -> Result<(WebSocket, Response)> { +pub fn connect(request: Req) -> Result<(WebSocket, Response)> { connect_with_config(request, None) } @@ -167,15 +162,14 @@ pub fn uri_mode(uri: &Uri) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client_with_config( +pub fn client_with_config( request: Req, stream: Stream, - config: Option>, -) -> StdResult<(WebSocket, Response), HandshakeError>> + config: Option, +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: IntoClientRequest, - Ext: WebSocketExtension, { ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake() } @@ -188,10 +182,7 @@ where pub fn client( request: Req, stream: Stream, -) -> StdResult< - (WebSocket, Response), - HandshakeError>, -> +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: IntoClientRequest, diff --git a/src/extensions/deflate.rs b/src/extensions/deflate.rs index 82caf36..bfd54ff 100644 --- a/src/extensions/deflate.rs +++ b/src/extensions/deflate.rs @@ -231,41 +231,39 @@ impl DeflateExt { enabled: false, config, fragment_buffer: FragmentBuffer::new(config.max_message_size), - inflator: Inflator::new(), - deflator: Deflator::new(Compression::fast()), + inflator: Inflator::new(config.max_window_bits), + deflator: Deflator::new(config.compression_level, config.max_window_bits), uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), } } +} - fn parse_window_parameter<'a>( - &mut self, - mut param_iter: impl Iterator, - ) -> Result, String> { - if let Some(window_bits_str) = param_iter.next() { - match window_bits_str.trim().parse() { - Ok(window_bits) => { - if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { - if window_bits != self.config.max_window_bits() { - self.config.max_window_bits = window_bits; - Ok(Some(window_bits)) - } else { - Ok(None) - } +fn parse_window_parameter<'a>( + mut param_iter: impl Iterator, + max_window_bits: u8, +) -> Result, String> { + if let Some(window_bits_str) = param_iter.next() { + match window_bits_str.trim().parse() { + Ok(window_bits) => { + if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { + if window_bits != max_window_bits { + Ok(Some(window_bits)) } else { - Err(format!("Invalid window parameter: {}", window_bits)) + Ok(None) } + } else { + Err(format!("Invalid window parameter: {}", window_bits)) } - Err(e) => Err(e.to_string()), } - } else { - Ok(None) + Err(e) => Err(e.to_string()), } + } else { + Ok(None) } +} - fn decline(&mut self, res: &mut Response) { - self.enabled = false; - res.headers_mut().remove(EXT_IDENT); - } +fn decline(res: &mut Response) { + res.headers_mut().remove(EXT_IDENT); } /// A permessage-deflate extension error. @@ -298,328 +296,332 @@ impl Display for DeflateExtensionError { } } -impl std::error::Error for DeflateExtensionError {} - -impl From for crate::Error { - fn from(e: DeflateExtensionError) -> Self { - crate::Error::ExtensionError(Cow::from(e.to_string())) - } -} - -impl From for DeflateExtensionError { - fn from(e: InvalidHeaderValue) -> Self { - DeflateExtensionError::NegotiationError(e.to_string()) - } -} - -impl Default for DeflateExt { - fn default() -> Self { - DeflateExt::new(Default::default()) - } -} - -impl WebSocketExtension for DeflateExt { - type Error = DeflateExtensionError; - - fn new(max_message_size: Option) -> Self { - DeflateExt::new(DeflateConfig { - max_message_size: max_message_size.unwrap_or_else(usize::max_value), - ..Default::default() - }) - } - - fn enabled(&self) -> bool { - self.enabled - } - - fn on_make_request(&mut self, mut request: Request) -> Request { - let mut header_value = String::from(EXT_IDENT); - let DeflateConfig { - max_window_bits, - request_no_context_takeover, - .. - } = self.config; - - if max_window_bits < LZ77_MAX_WINDOW_SIZE { - header_value.push_str(&format!( - "; client_max_window_bits={}; server_max_window_bits={}", - max_window_bits, max_window_bits - )) - } else { - header_value.push_str("; client_max_window_bits") - } - - if request_no_context_takeover { - header_value.push_str("; server_no_context_takeover") - } - - request.headers_mut().append( - SEC_WEBSOCKET_EXTENSIONS, - HeaderValue::from_str(&header_value).unwrap(), - ); - - request - } - - fn on_receive_request( - &mut self, - request: &Request, - response: &mut Response, - ) -> Result<(), Self::Error> { - for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { - return match header.to_str() { - Ok(header) => { - let mut response_str = String::with_capacity(header.len()); - let mut server_takeover = false; - let mut client_takeover = false; - let mut server_max_bits = false; - let mut client_max_bits = false; - - for param in header.split(';') { - match param.trim().to_lowercase().as_str() { - "permessage-deflate" => response_str.push_str("permessage-deflate"), - "server_no_context_takeover" => { - if server_takeover { - self.decline(response); - } else { - server_takeover = true; - if self.config.accept_no_context_takeover() { - self.config.compress_reset = true; - response_str.push_str("; server_no_context_takeover"); - } - } +/// +pub fn on_response( + response: &Response, + config: &mut DeflateConfig, +) -> Result { + let mut extension_name = false; + let mut server_takeover = false; + let mut client_takeover = false; + let mut server_max_window_bits = false; + let mut client_max_window_bits = false; + let mut enabled = false; + + let DeflateConfig { + max_window_bits, + accept_no_context_takeover, + compress_reset, + decompress_reset, + .. + } = config; + + for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() { + match header.to_str() { + Ok(header) => { + for param in header.split(';') { + match param.trim().to_lowercase().as_str() { + "permessage-deflate" => { + if extension_name { + return Err(DeflateExtensionError::NegotiationError(format!( + "Duplicate extension parameter: permessage-deflate" + ))); + } else { + enabled = true; + extension_name = true; } - "client_no_context_takeover" => { - if client_takeover { - self.decline(response); + } + "server_no_context_takeover" => { + if server_takeover { + return Err(DeflateExtensionError::NegotiationError(format!( + "Duplicate extension parameter: server_no_context_takeover" + ))); + } else { + server_takeover = true; + *decompress_reset = true; + } + } + "client_no_context_takeover" => { + if client_takeover { + return Err(DeflateExtensionError::NegotiationError(format!( + "Duplicate extension parameter: client_no_context_takeover" + ))); + } else { + client_takeover = true; + + if *accept_no_context_takeover { + *compress_reset = true; } else { - client_takeover = true; - self.config.decompress_reset = true; - response_str.push_str("; client_no_context_takeover"); + return Err(DeflateExtensionError::NegotiationError(format!( + "The client requires context takeover." + ))); } } - param if param.starts_with("server_max_window_bits") => { - if server_max_bits { - self.decline(response); - } else { - server_max_bits = true; - - match self.parse_window_parameter(param.split('=').skip(1)) { - Ok(Some(bits)) => { - self.deflator = Deflator::new_with_window_bits( - self.config.compression_level, - bits, - ); - response_str.push_str("; "); - response_str.push_str(param) - } - Ok(None) => {} - Err(_) => { - self.decline(response); - } + } + param if param.starts_with("server_max_window_bits") => { + if server_max_window_bits { + return Err(DeflateExtensionError::NegotiationError(format!( + "Duplicate extension parameter: server_max_window_bits" + ))); + } else { + server_max_window_bits = true; + + match parse_window_parameter( + param.split("=").skip(1), + *max_window_bits, + ) { + Ok(Some(bits)) => { + *max_window_bits = bits; + } + Ok(None) => {} + Err(e) => { + return Err(DeflateExtensionError::NegotiationError( + format!( + "server_max_window_bits parameter error: {}", + e + ), + )) } } } - param if param.starts_with("client_max_window_bits") => { - if client_max_bits { - self.decline(response); - } else { - client_max_bits = true; - - match self.parse_window_parameter(param.split('=').skip(1)) { - Ok(Some(bits)) => { - self.inflator = Inflator::new_with_window_bits(bits); - - response_str.push_str("; "); - response_str.push_str(param); - continue; - } - Ok(None) => {} - Err(_) => { - self.decline(response); - } + } + param if param.starts_with("client_max_window_bits") => { + if client_max_window_bits { + return Err(DeflateExtensionError::NegotiationError(format!( + "Duplicate extension parameter: client_max_window_bits" + ))); + } else { + client_max_window_bits = true; + + match parse_window_parameter( + param.split("=").skip(1), + *max_window_bits, + ) { + Ok(Some(bits)) => { + *max_window_bits = bits; + } + Ok(None) => {} + Err(e) => { + return Err(DeflateExtensionError::NegotiationError( + format!( + "client_max_window_bits parameter error: {}", + e + ), + )) } - - response_str.push_str("; "); - response_str.push_str(&format!( - "client_max_window_bits={}", - self.config.max_window_bits() - )) } } - _ => { - self.decline(response); - } + } + p => { + return Err(DeflateExtensionError::NegotiationError(format!( + "Unknown permessage-deflate parameter: {}", + p + ))); } } + } + } + Err(e) => { + return Err(DeflateExtensionError::NegotiationError(format!( + "Failed to parse extension parameter: {}", + e + ))); + } + } + } - if !response_str.contains("client_no_context_takeover") - && self.config.request_no_context_takeover() - { - self.config.decompress_reset = true; - response_str.push_str("; client_no_context_takeover"); - } - - if !response_str.contains("server_max_window_bits") { - response_str.push_str("; "); - response_str.push_str(&format!( - "server_max_window_bits={}", - self.config.max_window_bits() - )) - } - - if !response_str.contains("client_max_window_bits") - && self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE - { - continue; - } + Ok(enabled) +} - response.headers_mut().insert( - SEC_WEBSOCKET_EXTENSIONS, - HeaderValue::from_str(&response_str)?, - ); +/// +pub fn on_request(mut request: Request, config: &DeflateConfig) -> Request { + let mut header_value = String::from(EXT_IDENT); - self.enabled = true; + let DeflateConfig { + max_window_bits, + request_no_context_takeover, + .. + } = config; - Ok(()) - } - Err(e) => { - self.enabled = false; - Err(DeflateExtensionError::NegotiationError(format!( - "Failed to parse request header: {}", - e, - ))) - } - }; - } + if *max_window_bits < LZ77_MAX_WINDOW_SIZE { + header_value.push_str(&format!( + "; client_max_window_bits={}; server_max_window_bits={}", + max_window_bits, max_window_bits + )) + } else { + header_value.push_str("; client_max_window_bits") + } - self.decline(response); - Ok(()) + if *request_no_context_takeover { + header_value.push_str("; server_no_context_takeover") } - fn on_response(&mut self, response: &Response) -> Result<(), Self::Error> { - let mut extension_name = false; - let mut server_takeover = false; - let mut client_takeover = false; - let mut server_max_window_bits = false; - let mut client_max_window_bits = false; + request.headers_mut().append( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&header_value).unwrap(), + ); - for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() { - match header.to_str() { - Ok(header) => { - for param in header.split(';') { - match param.trim().to_lowercase().as_str() { - "permessage-deflate" => { - if extension_name { - return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: permessage-deflate" - ))); - } else { - self.enabled = true; - extension_name = true; + request +} + +/// +pub fn on_receive_request( + request: &Request, + response: &mut Response, + config: &mut DeflateConfig, +) -> Result<(), DeflateExtensionError> { + for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { + return match header.to_str() { + Ok(header) => { + let mut response_str = String::with_capacity(header.len()); + let mut server_takeover = false; + let mut client_takeover = false; + let mut server_max_bits = false; + let mut client_max_bits = false; + + for param in header.split(';') { + match param.trim().to_lowercase().as_str() { + "permessage-deflate" => response_str.push_str("permessage-deflate"), + "server_no_context_takeover" => { + if server_takeover { + decline(response); + } else { + server_takeover = true; + if config.accept_no_context_takeover() { + config.compress_reset = true; + response_str.push_str("; server_no_context_takeover"); } } - "server_no_context_takeover" => { - if server_takeover { - return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: server_no_context_takeover" - ))); - } else { - server_takeover = true; - self.config.decompress_reset = true; - } + } + "client_no_context_takeover" => { + if client_takeover { + decline(response); + } else { + client_takeover = true; + config.decompress_reset = true; + response_str.push_str("; client_no_context_takeover"); } - "client_no_context_takeover" => { - if client_takeover { - return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: client_no_context_takeover" - ))); - } else { - client_takeover = true; - - if self.config.accept_no_context_takeover() { - self.config.compress_reset = true; - } else { - return Err(DeflateExtensionError::NegotiationError( - format!("The client requires context takeover."), - )); + } + param if param.starts_with("server_max_window_bits") => { + if server_max_bits { + decline(response); + } else { + server_max_bits = true; + + match parse_window_parameter( + param.split('=').skip(1), + config.max_window_bits, + ) { + Ok(Some(bits)) => { + config.max_window_bits = bits; + + response_str.push_str("; "); + response_str.push_str(param) } - } - } - param if param.starts_with("server_max_window_bits") => { - if server_max_window_bits { - return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: server_max_window_bits" - ))); - } else { - server_max_window_bits = true; - - match self.parse_window_parameter(param.split("=").skip(1)) { - Ok(Some(bits)) => { - self.inflator = Inflator::new_with_window_bits(bits); - } - Ok(None) => {} - Err(e) => { - return Err(DeflateExtensionError::NegotiationError( - format!( - "server_max_window_bits parameter error: {}", - e - ), - )) - } + Ok(None) => {} + Err(_) => { + decline(response); } } } - param if param.starts_with("client_max_window_bits") => { - if client_max_window_bits { - return Err(DeflateExtensionError::NegotiationError(format!( - "Duplicate extension parameter: client_max_window_bits" - ))); - } else { - client_max_window_bits = true; - - match self.parse_window_parameter(param.split("=").skip(1)) { - Ok(Some(bits)) => { - self.deflator = Deflator::new_with_window_bits( - self.config.compression_level, - bits, - ); - } - Ok(None) => {} - Err(e) => { - return Err(DeflateExtensionError::NegotiationError( - format!( - "client_max_window_bits parameter error: {}", - e - ), - )) - } + } + param if param.starts_with("client_max_window_bits") => { + if client_max_bits { + decline(response); + } else { + client_max_bits = true; + + match parse_window_parameter( + param.split('=').skip(1), + config.max_window_bits, + ) { + Ok(Some(bits)) => { + config.max_window_bits = bits; + response_str.push_str("; "); + response_str.push_str(param); + + continue; + } + Ok(None) => {} + Err(_) => { + decline(response); } } + + response_str.push_str("; "); + response_str.push_str(&format!( + "client_max_window_bits={}", + config.max_window_bits() + )) } - p => { - return Err(DeflateExtensionError::NegotiationError(format!( - "Unknown permessage-deflate parameter: {}", - p - ))); - } + } + _ => { + decline(response); } } } - Err(e) => { - self.enabled = false; - return Err(DeflateExtensionError::NegotiationError(format!( - "Failed to parse extension parameter: {}", - e - ))); + + if !response_str.contains("client_no_context_takeover") + && config.request_no_context_takeover() + { + config.decompress_reset = true; + response_str.push_str("; client_no_context_takeover"); + } + + if !response_str.contains("server_max_window_bits") { + response_str.push_str("; "); + response_str.push_str(&format!( + "server_max_window_bits={}", + config.max_window_bits() + )) + } + + if !response_str.contains("client_max_window_bits") + && config.max_window_bits() < LZ77_MAX_WINDOW_SIZE + { + continue; } + + response.headers_mut().insert( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_str)?, + ); + + Ok(()) } - } + Err(e) => Err(DeflateExtensionError::NegotiationError(format!( + "Failed to parse request header: {}", + e, + ))), + }; + } + + decline(response); + Ok(()) +} + +impl std::error::Error for DeflateExtensionError {} + +impl From for crate::Error { + fn from(e: DeflateExtensionError) -> Self { + crate::Error::ExtensionError(Cow::from(e.to_string())) + } +} - Ok(()) +impl From for DeflateExtensionError { + fn from(e: InvalidHeaderValue) -> Self { + DeflateExtensionError::NegotiationError(e.to_string()) + } +} + +impl Default for DeflateExt { + fn default() -> Self { + DeflateExt::new(Default::default()) } +} - fn on_send_frame(&mut self, mut frame: Frame) -> Result { +impl WebSocketExtension for DeflateExt { + fn on_send_frame(&mut self, mut frame: Frame) -> Result { if self.enabled { if let OpCode::Data(_) = frame.header().opcode { let mut compressed = Vec::with_capacity(frame.payload().len()); @@ -640,7 +642,7 @@ impl WebSocketExtension for DeflateExt { Ok(frame) } - fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error> { + fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) { if !frame.header().is_final { self.fragment_buffer @@ -696,20 +698,20 @@ impl WebSocketExtension for DeflateExt { match r { Ok(msg) => Ok(msg), - Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())), + Err(e) => Err(crate::Error::ExtensionError(e.to_string().into())), } } } -impl From for DeflateExtensionError { +impl From for crate::Error { fn from(e: DecompressError) -> Self { - DeflateExtensionError::InflateError(e.to_string()) + crate::Error::ExtensionError(e.to_string().into()) } } -impl From for DeflateExtensionError { +impl From for crate::Error { fn from(e: CompressError) -> Self { - DeflateExtensionError::DeflateError(e.to_string()) + crate::Error::ExtensionError(e.to_string().into()) } } @@ -719,13 +721,7 @@ struct Deflator { } impl Deflator { - fn new(compresion: Compression) -> Deflator { - Deflator { - compress: Compress::new(compresion, false), - } - } - - fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator { + fn new(compression: Compression, mut window_size: u8) -> Deflator { // https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303 if window_size == 8 { window_size = 9; @@ -790,13 +786,7 @@ struct Inflator { } impl Inflator { - fn new() -> Inflator { - Inflator { - decompress: Decompress::new(false), - } - } - - fn new_with_window_bits(mut window_size: u8) -> Inflator { + fn new(mut window_size: u8) -> Inflator { // https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303 if window_size == 8 { window_size = 9; @@ -888,11 +878,10 @@ impl FragmentBuffer { *fragments_len += frame.payload().len(); if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len { - return Err(format!( + Err(format!( "Message too big: {} + {} > {}", fragments_len, fragments_len, max_len - ) - .into()); + )) } else { fragments.push(frame); Ok(()) diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 6a9cee1..0c2a169 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -2,8 +2,15 @@ use http::{Request, Response}; +#[cfg(feature = "deflate")] +use crate::extensions::deflate::{DeflateConfig, DeflateExt}; +use crate::extensions::uncompressed::UncompressedExt; use crate::protocol::frame::Frame; +use crate::protocol::WebSocketConfig; use crate::Message; +use std::borrow::Cow; +use std::error::Error; +use std::fmt::{Display, Formatter}; /// A permessage-deflate WebSocket extension (RFC 7692). #[cfg(feature = "deflate")] @@ -11,46 +18,155 @@ pub mod deflate; /// An uncompressed message handler for a WebSocket. pub mod uncompressed; +/// +#[derive(Copy, Clone, Debug)] +pub enum WsCompression { + /// + None(Option), + /// + #[cfg(feature = "deflate")] + Deflate(DeflateConfig), +} + /// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions /// may be stacked by nesting them inside one another. pub trait WebSocketExtension { - /// An error type that the extension produces. - type Error: Into; + /// Called when a frame is about to be sent. + fn on_send_frame(&mut self, frame: Frame) -> Result { + Ok(frame) + } + + /// Called when a frame has been received and unmasked. The frame provided frame will be of the + /// type `OpCode::Data`. + fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error>; +} - /// Constructs a new WebSocket extension that will permit messages of the provided size. - fn new(max_message_size: Option) -> Self; +/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`. +#[derive(Debug)] +pub enum CompressionSwitcher { + /// + #[cfg(feature = "deflate")] + Compressed(DeflateExt), + /// + Uncompressed(UncompressedExt), +} - /// Returns whether or not the extension is enabled. - fn enabled(&self) -> bool { - false +impl CompressionSwitcher { + /// + pub fn from_config(config: WsCompression) -> CompressionSwitcher { + match config { + WsCompression::None(size) => { + CompressionSwitcher::Uncompressed(UncompressedExt::new(size)) + } + #[cfg(feature = "deflate")] + WsCompression::Deflate(config) => { + CompressionSwitcher::Compressed(DeflateExt::new(config)) + } + } } +} - /// For WebSocket clients, this will be called when a `Request` is being constructed. - fn on_make_request(&mut self, request: Request) -> Request { - request +impl Default for CompressionSwitcher { + fn default() -> Self { + CompressionSwitcher::Uncompressed(UncompressedExt::default()) } +} + +#[derive(Debug)] +/// +pub struct CompressionError(String); - /// For WebSocket server, this will be called when a `Request` has been received. - fn on_receive_request( - &mut self, - _request: &Request, - _response: &mut Response, - ) -> Result<(), Self::Error> { - Ok(()) +impl Error for CompressionError {} + +impl From for crate::Error { + fn from(e: CompressionError) -> Self { + crate::Error::ExtensionError(Cow::from(e.to_string())) } +} - /// For WebSocket clients, this will be called when a response from the server has been - /// received. If an error is produced, then subsequent calls to `rsv1()` should return `false`. - fn on_response(&mut self, _response: &Response) -> Result<(), Self::Error> { - Ok(()) +impl Display for CompressionError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CompressionError") + .field("error", &self.0) + .finish() } +} - /// Called when a frame is about to be sent. - fn on_send_frame(&mut self, frame: Frame) -> Result { - Ok(frame) +impl WebSocketExtension for CompressionSwitcher { + fn on_send_frame(&mut self, frame: Frame) -> Result { + match self { + CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame), + #[cfg(feature = "deflate")] + CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame), + } } - /// Called when a frame has been received and unmasked. The frame provided frame will be of the - /// type `OpCode::Data`. - fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error>; + fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { + match self { + CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame), + #[cfg(feature = "deflate")] + CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame), + } + } +} + +/// +pub fn build_compression_headers( + request: Request, + config: &mut Option, +) -> Request { + match config { + Some(ref mut config) => match &config.compression { + WsCompression::None(_) => request, + #[cfg(feature = "deflate")] + WsCompression::Deflate(config) => deflate::on_request(request, config), + }, + None => request, + } +} + +/// +pub fn verify_compression_resp_headers( + _response: &Response, + config: &mut Option, +) -> Result<(), CompressionError> { + match config { + Some(ref mut config) => match &mut config.compression { + WsCompression::None(_) => Ok(()), + #[cfg(feature = "deflate")] + WsCompression::Deflate(ref mut deflate_config) => { + let result = deflate::on_response(_response, deflate_config) + .map_err(|e| CompressionError(e.to_string())); + match result { + Ok(true) => Ok(()), + Ok(false) => { + config.compression = + WsCompression::None(Some(deflate_config.max_message_size())); + Ok(()) + } + Err(e) => Err(e), + } + } + }, + None => Ok(()), + } +} + +/// +pub fn verify_compression_req_headers( + _request: &Request, + _response: &mut Response, + config: &mut Option, +) -> Result<(), CompressionError> { + match config { + Some(ref mut config) => match &mut config.compression { + WsCompression::None(_) => Ok(()), + #[cfg(feature = "deflate")] + WsCompression::Deflate(ref mut deflate_config) => { + deflate::on_receive_request(_request, _response, deflate_config) + .map_err(|e| CompressionError(e.to_string())) + } + }, + None => Ok(()), + } } diff --git a/src/extensions/uncompressed.rs b/src/extensions/uncompressed.rs index b16fb49..939e934 100644 --- a/src/extensions/uncompressed.rs +++ b/src/extensions/uncompressed.rs @@ -2,8 +2,8 @@ use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::Frame; use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; -use crate::{Error, Message}; use crate::protocol::MAX_MESSAGE_SIZE; +use crate::{Error, Message}; /// An uncompressed message handler for a WebSocket. #[derive(Debug)] @@ -16,7 +16,7 @@ impl Default for UncompressedExt { fn default() -> Self { UncompressedExt { incomplete: None, - max_message_size: Some(MAX_MESSAGE_SIZE) + max_message_size: Some(MAX_MESSAGE_SIZE), } } } @@ -33,20 +33,7 @@ impl UncompressedExt { } impl WebSocketExtension for UncompressedExt { - type Error = Error; - - fn new(max_message_size: Option) -> Self { - UncompressedExt { - incomplete: None, - max_message_size, - } - } - - fn enabled(&self) -> bool { - true - } - - fn on_receive_frame(&mut self, frame: Frame) -> Result, Self::Error> { + fn on_receive_frame(&mut self, frame: Frame) -> Result, crate::Error> { let fin = frame.header().is_final; let hdr = frame.header(); diff --git a/src/handshake/client.rs b/src/handshake/client.rs index d8c84d7..efb0df4 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; -use crate::extensions::WebSocketExtension; +use crate::extensions::{build_compression_headers, verify_compression_resp_headers}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Client request type. @@ -22,25 +22,21 @@ pub type Response = HttpResponse<()>; /// Client handshake role. #[derive(Debug)] -pub struct ClientHandshake -where - Extension: WebSocketExtension, -{ +pub struct ClientHandshake { verify_data: VerifyData, - config: Option>>, + config: Option>, _marker: PhantomData, } -impl ClientHandshake +impl ClientHandshake where Stream: Read + Write, - Ext: WebSocketExtension, { /// Initiate a client handshake. pub fn start( stream: Stream, request: Request, - mut config: Option>, + mut config: Option, ) -> Result> { if request.method() != http::Method::GET { return Err(Error::Protocol( @@ -81,14 +77,13 @@ where } } -impl HandshakeRole for ClientHandshake +impl HandshakeRole for ClientHandshake where Stream: Read + Write, - Ext: WebSocketExtension, { type IncomingData = Response; type InternalStream = Stream; - type FinalResult = (WebSocket, Response); + type FinalResult = (WebSocket, Response); fn stage_finished( &mut self, @@ -115,18 +110,12 @@ where } /// Generate client request. -fn generate_request( +fn generate_request( request: Request, key: &str, - config: &mut Option>, -) -> Result> -where - Ext: WebSocketExtension, -{ - let request = match config { - Some(ref mut config) => config.encoder.on_make_request(request), - None => request, - }; + config: &mut Option, +) -> Result> { + let request = build_compression_headers(request, config); let mut req = Vec::new(); let uri = request.uri(); @@ -183,14 +172,11 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response( + pub fn verify_response( &self, response: &Response, - config: &mut Option>, - ) -> Result<()> - where - Ext: WebSocketExtension, - { + config: &mut Option, + ) -> Result<()> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { @@ -246,11 +232,7 @@ impl VerifyData { // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if let Some(config) = config { - if let Err(e) = config.encoder.on_response(response) { - return Err(e.into()); - } - } + verify_compression_resp_headers(response, config)?; // 6. If the response includes a |Sec-WebSocket-Protocol| header field // and this header field indicates the use of a subprotocol that was @@ -308,7 +290,6 @@ mod tests { use super::super::machine::TryParse; use super::{generate_key, generate_request, Response}; use crate::client::IntoClientRequest; - use crate::extensions::uncompressed::UncompressedExt; #[test] fn random_keys() { @@ -338,9 +319,7 @@ mod tests { Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; - let request = - generate_request::(request, key, &mut Some(Default::default())) - .unwrap(); + let request = generate_request(request, key, &mut Some(Default::default())).unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } @@ -359,9 +338,7 @@ mod tests { Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; - let request = - generate_request::(request, key, &mut Some(Default::default())) - .unwrap(); + let request = generate_request(request, key, &mut Some(Default::default())).unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } @@ -380,9 +357,7 @@ mod tests { Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; - let request = - generate_request::(request, key, &mut Some(Default::default())) - .unwrap(); + let request = generate_request(request, key, &mut Some(Default::default())).unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index c04755a..181d44b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -12,7 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; -use crate::extensions::WebSocketExtension; +use crate::extensions::verify_compression_req_headers; use crate::protocol::{Role, WebSocket, WebSocketConfig}; /// Server request type. @@ -191,43 +191,35 @@ impl Callback for NoCallback { /// Server handshake role. #[allow(missing_copy_implementations)] #[derive(Debug)] -pub struct ServerHandshake -where - Ext: WebSocketExtension, -{ +pub struct ServerHandshake { /// Callback which is called whenever the server read the request from the client and is ready /// to reply to it. The callback returns an optional headers which will be added to the reply /// which the server sends to the user. callback: Option, /// WebSocket configuration. - config: Option>>, + config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. error_code: Option, /// Internal stream type. _marker: PhantomData, } -impl ServerHandshake +impl ServerHandshake where S: Read + Write, C: Callback, - Ext: WebSocketExtension, { /// Start server handshake. `callback` specifies a custom callback which the user can pass to /// the handshake, this callback will be called when the a websocket client connnects to the /// server, you can specify the callback if you want to add additional header to the client /// upon join based on the incoming headers. - pub fn start( - stream: S, - callback: C, - config: Option>, - ) -> MidHandshake { + pub fn start(stream: S, callback: C, config: Option) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), role: ServerHandshake { callback: Some(callback), - config: Some(config), + config, error_code: None, _marker: PhantomData, }, @@ -235,15 +227,14 @@ where } } -impl HandshakeRole for ServerHandshake +impl HandshakeRole for ServerHandshake where S: Read + Write, C: Callback, - Ext: WebSocketExtension, { type IncomingData = Request; type InternalStream = S; - type FinalResult = WebSocket; + type FinalResult = WebSocket; fn stage_finished( &mut self, @@ -260,12 +251,7 @@ where } let mut response = create_response(&request)?; - - if let Some(ref mut config) = self.config.as_mut().unwrap() { - if let Err(e) = config.encoder.on_receive_request(&request, &mut response) { - return Err(e.into()); - } - } + verify_compression_req_headers(&request, &mut response, &mut self.config)?; let callback_result = if let Some(callback) = self.callback.take() { callback.on_request(&request, response) @@ -305,11 +291,7 @@ where return Err(Error::Http(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); - let websocket = WebSocket::from_raw_socket( - stream, - Role::Server, - self.config.take().unwrap(), - ); + let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); ProcessingResult::Done(websocket) } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index bb695e5..2241d75 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -16,8 +16,7 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use self::frame::{Frame, FrameCodec}; use self::message::IncompleteMessage; use crate::error::{Error, Result}; -use crate::extensions::uncompressed::UncompressedExt; -use crate::extensions::WebSocketExtension; +use crate::extensions::{CompressionSwitcher, WebSocketExtension, WsCompression}; use crate::util::NonBlockingResult; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; @@ -33,10 +32,7 @@ pub enum Role { /// The configuration for WebSocket connection. #[derive(Debug, Copy, Clone)] -pub struct WebSocketConfig -where - E: WebSocketExtension, -{ +pub struct WebSocketConfig { /// The size of the send queue. You can use it to turn on/off the backpressure features. `None` /// means here that the size of the queue is unlimited. The default value is the unlimited /// queue. @@ -46,34 +42,16 @@ where /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, - /// Per-message compression strategy. - pub encoder: E, + /// A per-message compression configuration. + pub compression: WsCompression, } -impl Default for WebSocketConfig -where - E: WebSocketExtension, -{ +impl Default for WebSocketConfig { fn default() -> Self { WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: E::new(Some(MAX_MESSAGE_SIZE)), - } - } -} - -impl WebSocketConfig -where - E: WebSocketExtension, -{ - /// Creates a `WebSocketConfig` instance using the default configuration and the provided - /// encoder for new connections. - pub fn default_with_encoder(encoder: E) -> WebSocketConfig { - WebSocketConfig { - max_send_queue: None, - max_frame_size: Some(16 << 20), - encoder, + compression: WsCompression::None(Some(MAX_MESSAGE_SIZE)), } } } @@ -83,30 +61,20 @@ where /// This is THE structure you want to create to be able to speak the WebSocket protocol. /// It may be created by calling `connect`, `accept` or `client` functions. #[derive(Debug)] -pub struct WebSocket -where - Ext: WebSocketExtension, -{ +pub struct WebSocket { /// The underlying socket. socket: Stream, /// The context for managing a WebSocket. - context: WebSocketContext, + context: WebSocketContext, } -impl WebSocket -where - Ext: WebSocketExtension, -{ +impl WebSocket { /// Convert a raw socket into a WebSocket without performing a handshake. /// /// Call this function if you're using Tungstenite as a part of a web framework /// or together with an existing one. If you need an initial handshake, use /// `connect()` or `accept()` functions of the crate to construct a websocket. - pub fn from_raw_socket( - stream: Stream, - role: Role, - config: Option>, - ) -> Self { + pub fn from_raw_socket(stream: Stream, role: Role, config: Option) -> Self { WebSocket { socket: stream, context: WebSocketContext::new(role, config), @@ -122,7 +90,7 @@ where stream: Stream, part: Vec, role: Role, - config: Option>, + config: Option, ) -> Self { WebSocket { socket: stream, @@ -141,12 +109,12 @@ where } /// Change the configuration. - pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { + pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { self.context.set_config(set_func) } /// Read the configuration. - pub fn get_config(&self) -> &WebSocketConfig { + pub fn get_config(&self) -> &WebSocketConfig { self.context.get_config() } @@ -166,10 +134,9 @@ where } } -impl WebSocket +impl WebSocket where Stream: Read + Write, - Ext: WebSocketExtension, { /// Read a message from stream, if possible. /// @@ -253,10 +220,7 @@ where /// A context for managing WebSocket stream. #[derive(Debug)] -pub struct WebSocketContext -where - Ext: WebSocketExtension, -{ +pub struct WebSocketContext { /// Server or client? role: Role, /// encoder/decoder of frame. @@ -270,16 +234,16 @@ where /// Send: an OOB pong message. pong: Option, /// The configuration for the websocket session. - config: WebSocketConfig, + config: WebSocketConfig, + /// A per-message compression strategy. + decoder: CompressionSwitcher, } -impl WebSocketContext -where - Ext: WebSocketExtension, -{ +impl WebSocketContext { /// Create a WebSocket context that manages a post-handshake stream. - pub fn new(role: Role, config: Option>) -> Self { + pub fn new(role: Role, config: Option) -> Self { let config = config.unwrap_or_else(Default::default); + let decoder = CompressionSwitcher::from_config(config.compression); WebSocketContext { role, @@ -289,15 +253,12 @@ where send_queue: VecDeque::new(), pong: None, config, + decoder, } } /// Create a WebSocket context that manages an post-handshake stream. - pub fn from_partially_read( - part: Vec, - role: Role, - config: Option>, - ) -> Self { + pub fn from_partially_read(part: Vec, role: Role, config: Option) -> Self { WebSocketContext { frame: FrameCodec::from_partially_read(part), ..WebSocketContext::new(role, config) @@ -305,12 +266,12 @@ where } /// Change the configuration. - pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { + pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { set_func(&mut self.config) } /// Read the configuration. - pub fn get_config(&self) -> &WebSocketConfig { + pub fn get_config(&self) -> &WebSocketConfig { &self.config } @@ -527,12 +488,8 @@ where OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), } } - - _ => match self.config.encoder.on_receive_frame(frame) { - Ok(r) => Ok(r), - Err(e) => Err(e.into()), - }, - } // match opcode + _ => self.decoder.on_receive_frame(frame), + } } else { // Connection closed by peer match replace(&mut self.state, WebSocketState::Terminated) { @@ -602,10 +559,7 @@ where } if frame.header().is_final { - frame = match self.config.encoder.on_send_frame(frame) { - Ok(frame) => frame, - Err(e) => return Err(e.into()), - }; + frame = self.decoder.on_send_frame(frame)?; } trace!("Sending frame: {:?}", frame); @@ -682,7 +636,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use crate::extensions::uncompressed::UncompressedExt; + use crate::extensions::WsCompression; use std::io; use std::io::Cursor; @@ -710,8 +664,7 @@ mod tests { 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02, 0x03, ]); - let mut socket: WebSocket<_, UncompressedExt> = - WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); + let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!( @@ -733,7 +686,7 @@ mod tests { let limit = WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: UncompressedExt::new(Some(10)), + compression: WsCompression::None(Some(10)), }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( @@ -748,7 +701,7 @@ mod tests { let limit = WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: UncompressedExt::new(Some(2)), + compression: WsCompression::None(Some(2)), }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( diff --git a/src/server.rs b/src/server.rs index 99e3757..2415254 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,8 +7,6 @@ use crate::handshake::HandshakeError; use crate::protocol::{WebSocket, WebSocketConfig}; -use crate::extensions::uncompressed::UncompressedExt; -use crate::extensions::WebSocketExtension; use std::io::{Read, Write}; /// Accept the given Stream as a WebSocket. @@ -20,13 +18,12 @@ use std::io::{Read, Write}; /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept_with_config( +pub fn accept_with_config( stream: Stream, - config: Option>, -) -> Result, HandshakeError>> + config: Option, +) -> Result, HandshakeError>> where Stream: Read + Write, - Ext: WebSocketExtension, { accept_hdr_with_config(stream, NoCallback, config) } @@ -39,10 +36,7 @@ where /// those from `Mio` and others. pub fn accept( stream: S, -) -> Result< - WebSocket, - HandshakeError>, -> { +) -> Result, HandshakeError>> { accept_with_config(stream, None) } @@ -54,15 +48,14 @@ pub fn accept( /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr_with_config( +pub fn accept_hdr_with_config( stream: S, callback: C, - config: Option>, -) -> Result, HandshakeError>> + config: Option, +) -> Result, HandshakeError>> where S: Read + Write, C: Callback, - Ext: WebSocketExtension, { ServerHandshake::start(stream, callback, config).handshake() } @@ -75,6 +68,6 @@ where pub fn accept_hdr( stream: S, callback: C, -) -> Result, HandshakeError>> { +) -> Result, HandshakeError>> { accept_hdr_with_config(stream, callback, None) } diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 87396ff..b86bfe0 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -8,16 +8,15 @@ use std::time::Duration; use native_tls::TlsStream; use net2::TcpStreamExt; -use tungstenite::extensions::uncompressed::UncompressedExt; use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; use url::Url; -type Sock = WebSocket>, Ext>; +type Sock = WebSocket>>; fn do_test(port: u16, client_task: CT, server_task: ST) where - CT: FnOnce(Sock) + Send + 'static, - ST: FnOnce(WebSocket), + CT: FnOnce(Sock) + Send + 'static, + ST: FnOnce(WebSocket), { env_logger::try_init().ok();