Add `deflate` feature

pull/235/head
kazk 4 years ago
parent c62eccc8df
commit 734a0b9830
  1. 1
      .travis.yml
  2. 15
      Cargo.toml
  3. 3
      examples/srv_accept_unmasked_frames.rs
  4. 2
      scripts/autobahn-client.sh
  5. 2
      scripts/autobahn-server.sh
  6. 8
      src/error.rs
  7. 7
      src/extensions/mod.rs
  8. 70
      src/handshake/client.rs
  9. 1
      src/protocol/frame/frame.rs
  10. 1
      src/protocol/message.rs
  11. 80
      src/protocol/mod.rs

@ -10,5 +10,6 @@ before_script:
script: script:
- cargo test --release - cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh - echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh - echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -23,12 +23,21 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"] __rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
# deflate-zlib = ["flate2/zlib"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies] [dependencies]
base64 = "0.13.0" base64 = "0.13.0"
byteorder = "1.3.2" byteorder = "1.3.2"
bytes = "1.0" bytes = "1.0"
flate2 = "1.0"
http = "0.2" http = "0.2"
httparse = "1.3.4" httparse = "1.3.4"
log = "0.4.8" log = "0.4.8"
@ -38,6 +47,10 @@ thiserror = "1.0.23"
url = "2.1.0" url = "2.1.0"
utf-8 = "0.7.5" utf-8 = "0.7.5"
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate] [dependencies.native-tls-crate]
optional = true optional = true
package = "native-tls" package = "native-tls"

@ -35,7 +35,8 @@ fn main() {
// rare cases where it is necessary to integrate with existing/legacy // rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames // clients which are sending unmasked frames
accept_unmasked_frames: true, accept_unmasked_frames: true,
..WebSocketConfig::default() #[cfg(feature = "deflate")]
compression: None,
}); });
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3 sleep 3
cargo run --release --example autobahn-client cargo run --release --example autobahn-client --features=deflate
test_diff test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi fi
} }
cargo run --release --example autobahn-server & WSSERVER_PID=$! cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3 sleep 3
docker run --rm \ docker run --rm \

@ -2,10 +2,7 @@
use std::{io, result, str, string}; use std::{io, result, str, string};
use crate::{ use crate::protocol::{frame::coding::Data, Message};
extensions,
protocol::{frame::coding::Data, Message},
};
use http::Response; use http::Response;
use thiserror::Error; use thiserror::Error;
@ -71,8 +68,9 @@ pub enum Error {
#[error("HTTP format error: {0}")] #[error("HTTP format error: {0}")]
HttpFormat(#[from] http::Error), HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension. /// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")] #[error("Deflate error: {0}")]
Deflate(#[from] extensions::DeflateError), Deflate(#[from] crate::extensions::DeflateError),
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {

@ -1,15 +1,20 @@
//! WebSocket extensions. //! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment. // Only `permessage-deflate` is supported at the moment.
use http::HeaderValue;
#[cfg(feature = "deflate")]
mod compression; mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext; use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError}; pub use compression::deflate::{DeflateConfig, DeflateError};
use http::HeaderValue;
/// Container for configured extensions. /// Container for configured extensions.
#[derive(Debug, Default)] #[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions { pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported. // Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>, pub(crate) compression: Option<DeflateContext>,
} }

@ -17,7 +17,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result, UrlError}, error::{Error, ProtocolError, Result, UrlError},
extensions::{self, Extensions}, extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -161,7 +161,7 @@ impl VerifyData {
pub fn verify_response( pub fn verify_response(
&self, &self,
response: Response, response: Response,
config: &Option<WebSocketConfig>, _config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> { ) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
@ -202,43 +202,58 @@ impl VerifyData {
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
} }
let mut extensions = None;
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter(); let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions_values.next() { let extensions = if let Some(value) = extensions_values.next() {
if extensions_values.next().is_some() { if extensions_values.next().is_some() {
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse)); return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
} }
let mut exts = extensions::iter_all(std::iter::once(value)); let mut exts = crate::extensions::iter_all(std::iter::once(value));
if let Some(compression) = &config.and_then(|c| c.compression) { #[cfg(feature = "deflate")]
for (name, params) in exts { {
if name != compression.name() { let mut extensions = None;
return Err(Error::Protocol(ProtocolError::InvalidExtension( if let Some(compression) = _config.and_then(|c| c.compression) {
name.to_string(), for (name, params) in exts {
))); if name != compression.name() {
} return Err(Error::Protocol(ProtocolError::InvalidExtension(
name.to_string(),
// Already had PMCE configured )));
if extensions.is_some() { }
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(), // Already had PMCE configured
))); if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(params)?),
});
} }
} else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
}
extensions
}
extensions = Some(Extensions { #[cfg(not(feature = "deflate"))]
compression: Some(compression.accept_response(params)?), {
}); if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
} }
} else if let Some((name, _)) = exts.next() { None
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
} }
} } else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field // 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was // and this header field indicates the use of a subprotocol that was
@ -292,7 +307,9 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response}; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::{client::IntoClientRequest, extensions::DeflateConfig, protocol::WebSocketConfig}; use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -361,6 +378,7 @@ mod tests {
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[cfg(feature = "deflate")]
#[test] #[test]
fn request_with_compression() { fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); let request = "ws://localhost/getCaseCount".into_client_request().unwrap();

@ -306,6 +306,7 @@ impl Frame {
/// Create a new compressed data frame. /// Create a new compressed data frame.
#[inline] #[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");

@ -107,6 +107,7 @@ impl IncompleteMessage {
} }
} }
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool { pub fn compressed(&self) -> bool {
self.compressed self.compressed
} }

@ -23,7 +23,7 @@ use self::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::{self, Extensions}, extensions::Extensions,
util::NonBlockingResult, util::NonBlockingResult,
}; };
@ -59,7 +59,8 @@ pub struct WebSocketConfig {
/// By default this option is set to `false`, i.e. according to RFC 6455. /// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool, pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension. /// Optional configuration for Per-Message Compression Extension.
pub compression: Option<extensions::DeflateConfig>, #[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
@ -69,6 +70,7 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
accept_unmasked_frames: false, accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None, compression: None,
} }
} }
@ -78,34 +80,48 @@ impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions. // Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment. // Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<HeaderValue> { pub(crate) fn generate_offers(&self) -> Option<HeaderValue> {
self.compression.map(|c| c.generate_offer()) #[cfg(feature = "deflate")]
{
self.compression.map(|c| c.generate_offer())
}
#[cfg(not(feature = "deflate"))]
{
None
}
} }
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration. // This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions. /// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers<'a>( pub fn accept_offers<'a>(
&'a self, &'a self,
extensions: impl Iterator<Item = &'a HeaderValue>, _extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, Extensions)> { ) -> Option<(HeaderValue, Extensions)> {
if let Some(compression) = &self.compression { #[cfg(feature = "deflate")]
let extensions = extensions::iter_all(extensions); {
let offers = if let Some(compression) = &self.compression {
extensions.filter_map( let extensions = crate::extensions::iter_all(_extensions);
|(k, v)| { let offers =
if k == compression.name() { extensions.filter_map(
Some(v) |(k, v)| {
} else { if k == compression.name() {
None Some(v)
} } else {
}, None
); }
},
// To support more extensions, store extension context in `Extensions` and );
// concatenate negotiation responses from each extension.
compression // To support more extensions, store extension context in `Extensions` and
.accept_offer(offers) // concatenate negotiation responses from each extension.
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) })) compression
} else { .accept_offer(offers)
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
} else {
None
}
}
#[cfg(not(feature = "deflate"))]
{
None None
} }
} }
@ -451,12 +467,15 @@ impl WebSocketContext {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind"); debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata); let opcode = OpCode::Data(opdata);
let is_final = true; let is_final = true;
#[cfg(feature = "deflate")]
let frame = let frame =
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) { if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final) Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
} else { } else {
Frame::message(data, opcode, is_final) Frame::message(data, opcode, is_final)
}; };
#[cfg(not(feature = "deflate"))]
let frame = Frame::message(data, opcode, is_final);
Ok(frame) Ok(frame)
} }
@ -609,6 +628,7 @@ impl WebSocketContext {
} }
if let Some(ref mut msg) = self.incomplete { if let Some(ref mut msg) = self.incomplete {
#[cfg(feature = "deflate")]
let data = if msg.compressed() { let data = if msg.compressed() {
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap // `msg.compressed` is only set when compression is enabled so it's safe to unwrap
self.extensions self.extensions
@ -619,6 +639,8 @@ impl WebSocketContext {
} else { } else {
frame.into_data() frame.into_data()
}; };
#[cfg(not(feature = "deflate"))]
let data = frame.into_data();
msg.extend(data, self.config.max_message_size)?; msg.extend(data, self.config.max_message_size)?;
if fin { if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?)) Ok(Some(self.incomplete.take().unwrap().complete()?))
@ -642,6 +664,7 @@ impl WebSocketContext {
_ => panic!("Bug: message is not text nor binary"), _ => panic!("Bug: message is not text nor binary"),
}; };
let mut m = IncompleteMessage::new(message_type, is_compressed); let mut m = IncompleteMessage::new(message_type, is_compressed);
#[cfg(feature = "deflate")]
let data = if is_compressed { let data = if is_compressed {
self.extensions self.extensions
.as_mut() .as_mut()
@ -651,6 +674,8 @@ impl WebSocketContext {
} else { } else {
frame.into_data() frame.into_data()
}; };
#[cfg(not(feature = "deflate"))]
let data = frame.into_data();
m.extend(data, self.config.max_message_size)?; m.extend(data, self.config.max_message_size)?;
m m
}; };
@ -738,7 +763,14 @@ impl WebSocketContext {
} }
fn has_compression(&self) -> bool { fn has_compression(&self) -> bool {
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some() #[cfg(feature = "deflate")]
{
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
#[cfg(not(feature = "deflate"))]
{
false
}
} }
} }

Loading…
Cancel
Save