From 734a0b983070d7f965c984d5489c5be038e01530 Mon Sep 17 00:00:00 2001 From: kazk Date: Wed, 15 Sep 2021 00:48:01 -0700 Subject: [PATCH] Add `deflate` feature --- .travis.yml | 1 + Cargo.toml | 15 ++++- examples/srv_accept_unmasked_frames.rs | 3 +- scripts/autobahn-client.sh | 2 +- scripts/autobahn-server.sh | 2 +- src/error.rs | 8 +-- src/extensions/mod.rs | 7 ++- src/handshake/client.rs | 70 +++++++++++++--------- src/protocol/frame/frame.rs | 1 + src/protocol/message.rs | 1 + src/protocol/mod.rs | 80 ++++++++++++++++++-------- 11 files changed, 130 insertions(+), 60 deletions(-) diff --git a/.travis.yml b/.travis.yml index 32457c3..072d20e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,5 +10,6 @@ before_script: script: - cargo test --release + - cargo test --release --features=deflate - echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh - echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh diff --git a/Cargo.toml b/Cargo.toml index 647b104..56a0b90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,12 +23,21 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] __rustls-tls = ["rustls", "webpki"] +deflate = ["flate2"] +# deflate-zlib = ["flate2/zlib"] + +[[example]] +name = "autobahn-client" +required-features = ["deflate"] + +[[example]] +name = "autobahn-server" +required-features = ["deflate"] [dependencies] base64 = "0.13.0" byteorder = "1.3.2" bytes = "1.0" -flate2 = "1.0" http = "0.2" httparse = "1.3.4" log = "0.4.8" @@ -38,6 +47,10 @@ thiserror = "1.0.23" url = "2.1.0" utf-8 = "0.7.5" +[dependencies.flate2] +optional = true +version = "1.0" + [dependencies.native-tls-crate] optional = true package = "native-tls" diff --git a/examples/srv_accept_unmasked_frames.rs b/examples/srv_accept_unmasked_frames.rs index a3134f0..ebbd302 100644 --- a/examples/srv_accept_unmasked_frames.rs +++ b/examples/srv_accept_unmasked_frames.rs @@ -35,7 +35,8 @@ fn main() { // rare cases where it is necessary to integrate with existing/legacy // clients which are sending unmasked frames accept_unmasked_frames: true, - ..WebSocketConfig::default() + #[cfg(feature = "deflate")] + compression: None, }); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); diff --git a/scripts/autobahn-client.sh b/scripts/autobahn-client.sh index 3312a60..4ba13c5 100755 --- a/scripts/autobahn-client.sh +++ b/scripts/autobahn-client.sh @@ -32,5 +32,5 @@ docker run -d --rm \ wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' sleep 3 -cargo run --release --example autobahn-client +cargo run --release --example autobahn-client --features=deflate test_diff diff --git a/scripts/autobahn-server.sh b/scripts/autobahn-server.sh index 3b7349f..4c1cf08 100755 --- a/scripts/autobahn-server.sh +++ b/scripts/autobahn-server.sh @@ -22,7 +22,7 @@ function test_diff() { fi } -cargo run --release --example autobahn-server & WSSERVER_PID=$! +cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$! sleep 3 docker run --rm \ diff --git a/src/error.rs b/src/error.rs index a502135..929f6c5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,10 +2,7 @@ use std::{io, result, str, string}; -use crate::{ - extensions, - protocol::{frame::coding::Data, Message}, -}; +use crate::protocol::{frame::coding::Data, Message}; use http::Response; use thiserror::Error; @@ -71,8 +68,9 @@ pub enum Error { #[error("HTTP format error: {0}")] HttpFormat(#[from] http::Error), /// Error from `permessage-deflate` extension. + #[cfg(feature = "deflate")] #[error("Deflate error: {0}")] - Deflate(#[from] extensions::DeflateError), + Deflate(#[from] crate::extensions::DeflateError), } impl From for Error { diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 0c19120..ada0de5 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,15 +1,20 @@ //! WebSocket extensions. // Only `permessage-deflate` is supported at the moment. +use http::HeaderValue; +#[cfg(feature = "deflate")] mod compression; +#[cfg(feature = "deflate")] use compression::deflate::DeflateContext; +#[cfg(feature = "deflate")] pub use compression::deflate::{DeflateConfig, DeflateError}; -use http::HeaderValue; /// Container for configured extensions. #[derive(Debug, Default)] +#[allow(missing_copy_implementations)] pub struct Extensions { // Per-Message Compression. Only `permessage-deflate` is supported. + #[cfg(feature = "deflate")] pub(crate) compression: Option, } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e662b8e..6246e51 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -17,7 +17,7 @@ use super::{ }; use crate::{ error::{Error, ProtocolError, Result, UrlError}, - extensions::{self, Extensions}, + extensions::Extensions, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -161,7 +161,7 @@ impl VerifyData { pub fn verify_response( &self, response: Response, - config: &Option, + _config: &Option, ) -> Result<(Response, Option)> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) @@ -202,43 +202,58 @@ impl VerifyData { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); } - let mut extensions = None; // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter(); - if let Some(value) = extensions_values.next() { + let extensions = if let Some(value) = extensions_values.next() { if extensions_values.next().is_some() { return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse)); } - let mut exts = extensions::iter_all(std::iter::once(value)); - if let Some(compression) = &config.and_then(|c| c.compression) { - for (name, params) in exts { - if name != compression.name() { - return Err(Error::Protocol(ProtocolError::InvalidExtension( - name.to_string(), - ))); - } - - // Already had PMCE configured - if extensions.is_some() { - return Err(Error::Protocol(ProtocolError::ExtensionConflict( - name.to_string(), - ))); + let mut exts = crate::extensions::iter_all(std::iter::once(value)); + #[cfg(feature = "deflate")] + { + let mut extensions = None; + if let Some(compression) = _config.and_then(|c| c.compression) { + for (name, params) in exts { + if name != compression.name() { + return Err(Error::Protocol(ProtocolError::InvalidExtension( + name.to_string(), + ))); + } + + // Already had PMCE configured + if extensions.is_some() { + return Err(Error::Protocol(ProtocolError::ExtensionConflict( + name.to_string(), + ))); + } + + extensions = Some(Extensions { + compression: Some(compression.accept_response(params)?), + }); } + } else if let Some((name, _)) = exts.next() { + // The client didn't request anything, but got something + return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string()))); + } + extensions + } - extensions = Some(Extensions { - compression: Some(compression.accept_response(params)?), - }); + #[cfg(not(feature = "deflate"))] + { + if let Some((name, _)) = exts.next() { + // The client didn't request anything, but got something + return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string()))); } - } else if let Some((name, _)) = exts.next() { - // The client didn't request anything, but got something - return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string()))); + None } - } + } else { + None + }; // 6. If the response includes a |Sec-WebSocket-Protocol| header field // and this header field indicates the use of a subprotocol that was @@ -292,7 +307,9 @@ fn generate_key() -> String { #[cfg(test)] mod tests { use super::{super::machine::TryParse, generate_key, generate_request, Response}; - use crate::{client::IntoClientRequest, extensions::DeflateConfig, protocol::WebSocketConfig}; + use crate::client::IntoClientRequest; + #[cfg(feature = "deflate")] + use crate::{extensions::DeflateConfig, protocol::WebSocketConfig}; #[test] fn random_keys() { @@ -361,6 +378,7 @@ mod tests { assert_eq!(&request[..], &correct[..]); } + #[cfg(feature = "deflate")] #[test] fn request_with_compression() { let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index e6051d5..057d3da 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -306,6 +306,7 @@ impl Frame { /// Create a new compressed data frame. #[inline] + #[cfg(feature = "deflate")] pub(crate) fn compressed_message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 6fa8301..5f35115 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -107,6 +107,7 @@ impl IncompleteMessage { } } + #[cfg(feature = "deflate")] pub fn compressed(&self) -> bool { self.compressed } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8f83319..c0a5817 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -23,7 +23,7 @@ use self::{ }; use crate::{ error::{Error, ProtocolError, Result}, - extensions::{self, Extensions}, + extensions::Extensions, util::NonBlockingResult, }; @@ -59,7 +59,8 @@ pub struct WebSocketConfig { /// By default this option is set to `false`, i.e. according to RFC 6455. pub accept_unmasked_frames: bool, /// Optional configuration for Per-Message Compression Extension. - pub compression: Option, + #[cfg(feature = "deflate")] + pub compression: Option, } impl Default for WebSocketConfig { @@ -69,6 +70,7 @@ impl Default for WebSocketConfig { max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false, + #[cfg(feature = "deflate")] compression: None, } } @@ -78,34 +80,48 @@ impl WebSocketConfig { // Generate extension negotiation offers for configured extensions. // Only `permessage-deflate` is supported at the moment. pub(crate) fn generate_offers(&self) -> Option { - self.compression.map(|c| c.generate_offer()) + #[cfg(feature = "deflate")] + { + self.compression.map(|c| c.generate_offer()) + } + #[cfg(not(feature = "deflate"))] + { + None + } } // This can be used with `WebSocket::from_raw_socket_with_extensions` for integration. /// Returns negotiation response based on offers and `Extensions` to manage extensions. pub fn accept_offers<'a>( &'a self, - extensions: impl Iterator, + _extensions: impl Iterator, ) -> Option<(HeaderValue, Extensions)> { - if let Some(compression) = &self.compression { - let extensions = extensions::iter_all(extensions); - let offers = - extensions.filter_map( - |(k, v)| { - if k == compression.name() { - Some(v) - } else { - None - } - }, - ); - - // To support more extensions, store extension context in `Extensions` and - // concatenate negotiation responses from each extension. - compression - .accept_offer(offers) - .map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) })) - } else { + #[cfg(feature = "deflate")] + { + if let Some(compression) = &self.compression { + let extensions = crate::extensions::iter_all(_extensions); + let offers = + extensions.filter_map( + |(k, v)| { + if k == compression.name() { + Some(v) + } else { + None + } + }, + ); + + // To support more extensions, store extension context in `Extensions` and + // concatenate negotiation responses from each extension. + compression + .accept_offer(offers) + .map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) })) + } else { + None + } + } + #[cfg(not(feature = "deflate"))] + { None } } @@ -451,12 +467,15 @@ impl WebSocketContext { debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind"); let opcode = OpCode::Data(opdata); let is_final = true; + #[cfg(feature = "deflate")] let frame = if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) { Frame::compressed_message(pmce.compress(&data)?, opcode, is_final) } else { Frame::message(data, opcode, is_final) }; + #[cfg(not(feature = "deflate"))] + let frame = Frame::message(data, opcode, is_final); Ok(frame) } @@ -609,6 +628,7 @@ impl WebSocketContext { } if let Some(ref mut msg) = self.incomplete { + #[cfg(feature = "deflate")] let data = if msg.compressed() { // `msg.compressed` is only set when compression is enabled so it's safe to unwrap self.extensions @@ -619,6 +639,8 @@ impl WebSocketContext { } else { frame.into_data() }; + #[cfg(not(feature = "deflate"))] + let data = frame.into_data(); msg.extend(data, self.config.max_message_size)?; if fin { Ok(Some(self.incomplete.take().unwrap().complete()?)) @@ -642,6 +664,7 @@ impl WebSocketContext { _ => panic!("Bug: message is not text nor binary"), }; let mut m = IncompleteMessage::new(message_type, is_compressed); + #[cfg(feature = "deflate")] let data = if is_compressed { self.extensions .as_mut() @@ -651,6 +674,8 @@ impl WebSocketContext { } else { frame.into_data() }; + #[cfg(not(feature = "deflate"))] + let data = frame.into_data(); m.extend(data, self.config.max_message_size)?; m }; @@ -738,7 +763,14 @@ impl WebSocketContext { } fn has_compression(&self) -> bool { - self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some() + #[cfg(feature = "deflate")] + { + self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some() + } + #[cfg(not(feature = "deflate"))] + { + false + } } }