From c62eccc8dfa1d8a2eed388e994aac4207c57d90e Mon Sep 17 00:00:00 2001 From: kazk Date: Tue, 14 Sep 2021 21:51:14 -0700 Subject: [PATCH] Make `DeflateContext` private and add `Extensions` container --- src/extensions/mod.rs | 10 +++++- src/handshake/client.rs | 28 +++++++++-------- src/handshake/server.rs | 18 +++++------ src/protocol/mod.rs | 67 ++++++++++++++++++++++++----------------- 4 files changed, 73 insertions(+), 50 deletions(-) diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 95ee2bc..0c19120 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -2,9 +2,17 @@ // Only `permessage-deflate` is supported at the moment. mod compression; -pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError}; +use compression::deflate::DeflateContext; +pub use compression::deflate::{DeflateConfig, DeflateError}; use http::HeaderValue; +/// Container for configured extensions. +#[derive(Debug, Default)] +pub struct Extensions { + // Per-Message Compression. Only `permessage-deflate` is supported. + pub(crate) compression: Option, +} + /// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values. pub(crate) fn iter_all<'a>( values: impl Iterator, diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 04c6a05..e662b8e 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -17,7 +17,7 @@ use super::{ }; use crate::{ error::{Error, ProtocolError, Result, UrlError}, - extensions::{self, DeflateContext}, + extensions::{self, Extensions}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -83,14 +83,15 @@ impl HandshakeRole for ClientHandshake { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } StageResult::DoneReading { stream, result, tail } => { - let (result, pmce) = self.verify_data.verify_response(result, &self.config)?; + let (result, extensions) = + self.verify_data.verify_response(result, &self.config)?; debug!("Client handshake done."); - let websocket = WebSocket::from_partially_read_with_compression( + let websocket = WebSocket::from_partially_read_with_extensions( stream, tail, Role::Client, self.config, - pmce, + extensions, ); ProcessingResult::Done((websocket, result)) } @@ -161,7 +162,7 @@ impl VerifyData { &self, response: Response, config: &Option, - ) -> Result<(Response, Option)> { + ) -> Result<(Response, Option)> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { @@ -201,15 +202,15 @@ impl VerifyData { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); } - let mut pmce = None; + let mut extensions = None; // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter(); - if let Some(value) = extensions.next() { - if extensions.next().is_some() { + let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter(); + if let Some(value) = extensions_values.next() { + if extensions_values.next().is_some() { return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse)); } @@ -223,12 +224,15 @@ impl VerifyData { } // Already had PMCE configured - if pmce.is_some() { + if extensions.is_some() { return Err(Error::Protocol(ProtocolError::ExtensionConflict( name.to_string(), ))); } - pmce = Some(compression.accept_response(params)?); + + extensions = Some(Extensions { + compression: Some(compression.accept_response(params)?), + }); } } else if let Some((name, _)) = exts.next() { // The client didn't request anything, but got something @@ -243,7 +247,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok((response, pmce)) + Ok((response, extensions)) } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 7db9878..69bea67 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -20,7 +20,7 @@ use super::{ }; use crate::{ error::{Error, ProtocolError, Result}, - extensions, + extensions::Extensions, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -203,8 +203,8 @@ pub struct ServerHandshake { config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. error_response: Option, - // Negotiated Per-Message Compression Extension context for server. - pmce: Option, + // Negotiated extension context for server. + extensions: Option, /// Internal stream type. _marker: PhantomData, } @@ -222,7 +222,7 @@ impl ServerHandshake { callback: Some(callback), config, error_response: None, - pmce: None, + extensions: None, _marker: PhantomData, }, } @@ -246,10 +246,10 @@ impl HandshakeRole for ServerHandshake { let mut response = create_response(&result)?; if let Some(config) = &self.config { - let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter(); - if let Some((agreed, pmce)) = config.accept_offers(extensions) { - self.pmce = Some(pmce); + let values = result.headers().get_all("Sec-WebSocket-Extensions").iter(); + if let Some((agreed, extensions)) = config.accept_offers(values) { response.headers_mut().insert("Sec-WebSocket-Extensions", agreed); + self.extensions = Some(extensions); } } @@ -292,11 +292,11 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Http(err)); } else { debug!("Server handshake done."); - let websocket = WebSocket::from_raw_socket_with_compression( + let websocket = WebSocket::from_raw_socket_with_extensions( stream, Role::Server, self.config, - self.pmce.take(), + self.extensions.take(), ); ProcessingResult::Done(websocket) } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index d998582..8f83319 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -23,7 +23,7 @@ use self::{ }; use crate::{ error::{Error, ProtocolError, Result}, - extensions::{self, DeflateContext}, + extensions::{self, Extensions}, util::NonBlockingResult, }; @@ -81,15 +81,14 @@ impl WebSocketConfig { self.compression.map(|c| c.generate_offer()) } - // TODO Replace `DeflateContext` with something more general - // This can be used with `WebSocket::from_raw_socket_with_compression` for integration. - /// Returns negotiation response based on offers and `DeflateContext` to manage per message compression. + // This can be used with `WebSocket::from_raw_socket_with_extensions` for integration. + /// Returns negotiation response based on offers and `Extensions` to manage extensions. pub fn accept_offers<'a>( &'a self, extensions: impl Iterator, - ) -> Option<(HeaderValue, DeflateContext)> { + ) -> Option<(HeaderValue, Extensions)> { if let Some(compression) = &self.compression { - let extensions = crate::extensions::iter_all(extensions); + let extensions = extensions::iter_all(extensions); let offers = extensions.filter_map( |(k, v)| { @@ -100,7 +99,12 @@ impl WebSocketConfig { } }, ); - compression.accept_offer(offers) + + // To support more extensions, store extension context in `Extensions` and + // concatenate negotiation responses from each extension. + compression + .accept_offer(offers) + .map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) })) } else { None } @@ -130,14 +134,14 @@ impl WebSocket { } /// Convert a raw socket into a WebSocket without performing a handshake. - pub fn from_raw_socket_with_compression( + pub fn from_raw_socket_with_extensions( stream: Stream, role: Role, config: Option, - pmce: Option, + extensions: Option, ) -> Self { let mut context = WebSocketContext::new(role, config); - context.pmce = pmce; + context.extensions = extensions; WebSocket { socket: stream, context } } @@ -158,17 +162,17 @@ impl WebSocket { } } - pub(crate) fn from_partially_read_with_compression( + pub(crate) fn from_partially_read_with_extensions( stream: Stream, part: Vec, role: Role, config: Option, - pmce: Option, + extensions: Option, ) -> Self { WebSocket { socket: stream, - context: WebSocketContext::from_partially_read_with_compression( - part, role, config, pmce, + context: WebSocketContext::from_partially_read_with_extensions( + part, role, config, extensions, ), } } @@ -306,8 +310,8 @@ pub struct WebSocketContext { pong: Option, /// The configuration for the websocket session. config: WebSocketConfig, - /// Per-Message Compression Extension. Only deflate is supported at the moment. - pub(crate) pmce: Option, + // Container for extensions. + pub(crate) extensions: Option, } impl WebSocketContext { @@ -321,7 +325,7 @@ impl WebSocketContext { send_queue: VecDeque::new(), pong: None, config: config.unwrap_or_else(WebSocketConfig::default), - pmce: None, + extensions: None, } } @@ -333,15 +337,15 @@ impl WebSocketContext { } } - pub(crate) fn from_partially_read_with_compression( + pub(crate) fn from_partially_read_with_extensions( part: Vec, role: Role, config: Option, - pmce: Option, + extensions: Option, ) -> Self { WebSocketContext { frame: FrameCodec::from_partially_read(part), - pmce, + extensions, ..WebSocketContext::new(role, config) } } @@ -447,11 +451,12 @@ impl WebSocketContext { debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind"); let opcode = OpCode::Data(opdata); let is_final = true; - let frame = if let Some(pmce) = self.pmce.as_mut() { - Frame::compressed_message(pmce.compress(&data)?, opcode, is_final) - } else { - Frame::message(data, opcode, is_final) - }; + let frame = + if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) { + Frame::compressed_message(pmce.compress(&data)?, opcode, is_final) + } else { + Frame::message(data, opcode, is_final) + }; Ok(frame) } @@ -533,7 +538,7 @@ impl WebSocketContext { // Connection_. let is_compressed = { let hdr = frame.header(); - if (hdr.rsv1 && self.pmce.is_none()) || hdr.rsv2 || hdr.rsv3 { + if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 { return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); } @@ -606,8 +611,9 @@ impl WebSocketContext { if let Some(ref mut msg) = self.incomplete { let data = if msg.compressed() { // `msg.compressed` is only set when compression is enabled so it's safe to unwrap - self.pmce + self.extensions .as_mut() + .and_then(|x| x.compression.as_mut()) .unwrap() .decompress(frame.into_data(), fin)? } else { @@ -637,8 +643,9 @@ impl WebSocketContext { }; let mut m = IncompleteMessage::new(message_type, is_compressed); let data = if is_compressed { - self.pmce + self.extensions .as_mut() + .and_then(|x| x.compression.as_mut()) .unwrap() .decompress(frame.into_data(), fin)? } else { @@ -729,6 +736,10 @@ impl WebSocketContext { trace!("Sending frame: {:?}", frame); self.frame.write_frame(stream, frame).check_connection_reset(self.state) } + + fn has_compression(&self) -> bool { + self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some() + } } /// The current connection state.