Add `deflate` feature

pull/235/head
kazk 4 years ago
parent c62eccc8df
commit 734a0b9830
  1. 1
      .travis.yml
  2. 15
      Cargo.toml
  3. 3
      examples/srv_accept_unmasked_frames.rs
  4. 2
      scripts/autobahn-client.sh
  5. 2
      scripts/autobahn-server.sh
  6. 8
      src/error.rs
  7. 7
      src/extensions/mod.rs
  8. 70
      src/handshake/client.rs
  9. 1
      src/protocol/frame/frame.rs
  10. 1
      src/protocol/message.rs
  11. 80
      src/protocol/mod.rs

@ -10,5 +10,6 @@ before_script:
script:
- cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -23,12 +23,21 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
# deflate-zlib = ["flate2/zlib"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies]
base64 = "0.13.0"
byteorder = "1.3.2"
bytes = "1.0"
flate2 = "1.0"
http = "0.2"
httparse = "1.3.4"
log = "0.4.8"
@ -38,6 +47,10 @@ thiserror = "1.0.23"
url = "2.1.0"
utf-8 = "0.7.5"
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate]
optional = true
package = "native-tls"

@ -35,7 +35,8 @@ fn main() {
// rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames
accept_unmasked_frames: true,
..WebSocketConfig::default()
#[cfg(feature = "deflate")]
compression: None,
});
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3
cargo run --release --example autobahn-client
cargo run --release --example autobahn-client --features=deflate
test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi
}
cargo run --release --example autobahn-server & WSSERVER_PID=$!
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3
docker run --rm \

@ -2,10 +2,7 @@
use std::{io, result, str, string};
use crate::{
extensions,
protocol::{frame::coding::Data, Message},
};
use crate::protocol::{frame::coding::Data, Message};
use http::Response;
use thiserror::Error;
@ -71,8 +68,9 @@ pub enum Error {
#[error("HTTP format error: {0}")]
HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] extensions::DeflateError),
Deflate(#[from] crate::extensions::DeflateError),
}
impl From<str::Utf8Error> for Error {

@ -1,15 +1,20 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
use http::HeaderValue;
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
use http::HeaderValue;
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -17,7 +17,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
extensions::{self, Extensions},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -161,7 +161,7 @@ impl VerifyData {
pub fn verify_response(
&self,
response: Response,
config: &Option<WebSocketConfig>,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
@ -202,43 +202,58 @@ impl VerifyData {
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
}
let mut extensions = None;
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions_values.next() {
let extensions = if let Some(value) = extensions_values.next() {
if extensions_values.next().is_some() {
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
}
let mut exts = extensions::iter_all(std::iter::once(value));
if let Some(compression) = &config.and_then(|c| c.compression) {
for (name, params) in exts {
if name != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
name.to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(),
)));
let mut exts = crate::extensions::iter_all(std::iter::once(value));
#[cfg(feature = "deflate")]
{
let mut extensions = None;
if let Some(compression) = _config.and_then(|c| c.compression) {
for (name, params) in exts {
if name != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
name.to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(params)?),
});
}
} else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
}
extensions
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(params)?),
});
#[cfg(not(feature = "deflate"))]
{
if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
}
} else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
None
}
}
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -292,7 +307,9 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::{client::IntoClientRequest, extensions::DeflateConfig, protocol::WebSocketConfig};
use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test]
fn random_keys() {
@ -361,6 +378,7 @@ mod tests {
assert_eq!(&request[..], &correct[..]);
}
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();

@ -306,6 +306,7 @@ impl Frame {
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");

@ -107,6 +107,7 @@ impl IncompleteMessage {
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}

@ -23,7 +23,7 @@ use self::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::{self, Extensions},
extensions::Extensions,
util::NonBlockingResult,
};
@ -59,7 +59,8 @@ pub struct WebSocketConfig {
/// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
pub compression: Option<extensions::DeflateConfig>,
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
}
impl Default for WebSocketConfig {
@ -69,6 +70,7 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
@ -78,34 +80,48 @@ impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<HeaderValue> {
self.compression.map(|c| c.generate_offer())
#[cfg(feature = "deflate")]
{
self.compression.map(|c| c.generate_offer())
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers<'a>(
&'a self,
extensions: impl Iterator<Item = &'a HeaderValue>,
_extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, Extensions)> {
if let Some(compression) = &self.compression {
let extensions = extensions::iter_all(extensions);
let offers =
extensions.filter_map(
|(k, v)| {
if k == compression.name() {
Some(v)
} else {
None
}
},
);
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
compression
.accept_offer(offers)
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
} else {
#[cfg(feature = "deflate")]
{
if let Some(compression) = &self.compression {
let extensions = crate::extensions::iter_all(_extensions);
let offers =
extensions.filter_map(
|(k, v)| {
if k == compression.name() {
Some(v)
} else {
None
}
},
);
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
compression
.accept_offer(offers)
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
} else {
None
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
@ -451,12 +467,15 @@ impl WebSocketContext {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
#[cfg(feature = "deflate")]
let frame =
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
} else {
Frame::message(data, opcode, is_final)
};
#[cfg(not(feature = "deflate"))]
let frame = Frame::message(data, opcode, is_final);
Ok(frame)
}
@ -609,6 +628,7 @@ impl WebSocketContext {
}
if let Some(ref mut msg) = self.incomplete {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap
self.extensions
@ -619,6 +639,8 @@ impl WebSocketContext {
} else {
frame.into_data()
};
#[cfg(not(feature = "deflate"))]
let data = frame.into_data();
msg.extend(data, self.config.max_message_size)?;
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
@ -642,6 +664,7 @@ impl WebSocketContext {
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type, is_compressed);
#[cfg(feature = "deflate")]
let data = if is_compressed {
self.extensions
.as_mut()
@ -651,6 +674,8 @@ impl WebSocketContext {
} else {
frame.into_data()
};
#[cfg(not(feature = "deflate"))]
let data = frame.into_data();
m.extend(data, self.config.max_message_size)?;
m
};
@ -738,7 +763,14 @@ impl WebSocketContext {
}
fn has_compression(&self) -> bool {
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
#[cfg(feature = "deflate")]
{
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
#[cfg(not(feature = "deflate"))]
{
false
}
}
}

Loading…
Cancel
Save