Splits client/server max_window_bits

pull/144/head
SirCipher 5 years ago
parent 7795ca1d08
commit b658064b5e
  1. 5
      examples/autobahn-client.rs
  2. 5
      examples/autobahn-server.rs
  3. 154
      src/extensions/compression/deflate.rs
  4. 160
      src/extensions/compression/mod.rs
  5. 1
      src/extensions/compression/uncompressed.rs
  6. 155
      src/extensions/mod.rs
  7. 2
      src/handshake/client.rs
  8. 2
      src/handshake/server.rs
  9. 5
      src/protocol/mod.rs

@ -2,7 +2,8 @@ use log::*;
use url::Url; use url::Url;
use tungstenite::client::connect_with_config; use tungstenite::client::connect_with_config;
use tungstenite::extensions::deflate::{DeflateConfigBuilder, DeflateExt}; use tungstenite::extensions::compression::deflate::DeflateConfigBuilder;
use tungstenite::extensions::compression::WsCompression;
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
use tungstenite::{connect, Error, Message, Result}; use tungstenite::{connect, Error, Message, Result};
@ -43,7 +44,7 @@ fn run_test(case: u32) -> Result<()> {
Some(WebSocketConfig { Some(WebSocketConfig {
max_send_queue: None, max_send_queue: None,
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
encoder: DeflateExt::new(deflate_config), compression: WsCompression::Deflate(deflate_config),
}), }),
)?; )?;

@ -2,7 +2,8 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn; use std::thread::spawn;
use log::*; use log::*;
use tungstenite::extensions::deflate::{DeflateExt, DeflateConfigBuilder}; use tungstenite::extensions::compression::deflate::DeflateConfigBuilder;
use tungstenite::extensions::compression::WsCompression;
use tungstenite::handshake::HandshakeRole; use tungstenite::handshake::HandshakeRole;
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
use tungstenite::server::accept_with_config; use tungstenite::server::accept_with_config;
@ -25,7 +26,7 @@ fn handle_client(stream: TcpStream) -> Result<()> {
Some(WebSocketConfig { Some(WebSocketConfig {
max_send_queue: None, max_send_queue: None,
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
encoder: DeflateExt::new(deflate_config), compression: WsCompression::Deflate(deflate_config),
}), }),
) )
.map_err(must_not_block)?; .map_err(must_not_block)?;

@ -2,7 +2,7 @@
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::compression::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension; use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame; use crate::protocol::frame::Frame;
@ -36,10 +36,14 @@ pub struct DeflateConfig {
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user. /// for all normal use-cases but small enough to prevent memory eating by a malicious user.
max_message_size: usize, max_message_size: usize,
/// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this /// The client's LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode,
/// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in /// this conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must
/// range 8..15 inclusive. /// be in range 8..15 inclusive.
max_window_bits: u8, server_max_window_bits: u8,
/// The client's LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode,
/// this conforms to RFC 7692 7.1.2.2. In server mode, this conforms to RFC 7692 7.1.2.2. Must
/// be in range 8..15 inclusive.
client_max_window_bits: u8,
/// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1. /// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1.
request_no_context_takeover: bool, request_no_context_takeover: bool,
/// Whether to accept `no_context_takeover`. /// Whether to accept `no_context_takeover`.
@ -68,9 +72,14 @@ impl DeflateConfig {
self.max_message_size self.max_message_size
} }
/// Returns the maximum LZ77 window size permitted. /// Returns the maximum LZ77 window size permitted for the server.
pub fn max_window_bits(&self) -> u8 { pub fn server_max_window_bits(&self) -> u8 {
self.max_window_bits self.server_max_window_bits
}
/// Returns the maximum LZ77 window size permitted for the client.
pub fn client_max_window_bits(&self) -> u8 {
self.client_max_window_bits
} }
/// Returns whether `no_context_takeover` has been requested. /// Returns whether `no_context_takeover` has been requested.
@ -106,7 +115,7 @@ impl DeflateConfig {
/// Sets the LZ77 sliding window size. /// Sets the LZ77 sliding window size.
pub fn set_max_window_bits(&mut self, max_window_bits: u8) { pub fn set_max_window_bits(&mut self, max_window_bits: u8) {
assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits)); assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits));
self.max_window_bits = max_window_bits; self.client_max_window_bits = max_window_bits;
} }
/// Sets the WebSocket to request `no_context_takeover` if `true`. /// Sets the WebSocket to request `no_context_takeover` if `true`.
@ -124,7 +133,8 @@ impl Default for DeflateConfig {
fn default() -> Self { fn default() -> Self {
DeflateConfig { DeflateConfig {
max_message_size: MAX_MESSAGE_SIZE, max_message_size: MAX_MESSAGE_SIZE,
max_window_bits: LZ77_MAX_WINDOW_SIZE, server_max_window_bits: LZ77_MAX_WINDOW_SIZE,
client_max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false, request_no_context_takeover: false,
accept_no_context_takeover: true, accept_no_context_takeover: true,
compress_reset: false, compress_reset: false,
@ -138,7 +148,8 @@ impl Default for DeflateConfig {
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct DeflateConfigBuilder { pub struct DeflateConfigBuilder {
max_message_size: Option<usize>, max_message_size: Option<usize>,
max_window_bits: u8, server_max_window_bits: u8,
client_max_window_bits: u8,
request_no_context_takeover: bool, request_no_context_takeover: bool,
accept_no_context_takeover: bool, accept_no_context_takeover: bool,
fragments_grow: bool, fragments_grow: bool,
@ -149,7 +160,8 @@ impl Default for DeflateConfigBuilder {
fn default() -> Self { fn default() -> Self {
DeflateConfigBuilder { DeflateConfigBuilder {
max_message_size: Some(MAX_MESSAGE_SIZE), max_message_size: Some(MAX_MESSAGE_SIZE),
max_window_bits: LZ77_MAX_WINDOW_SIZE, server_max_window_bits: LZ77_MAX_WINDOW_SIZE,
client_max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false, request_no_context_takeover: false,
accept_no_context_takeover: true, accept_no_context_takeover: true,
fragments_grow: true, fragments_grow: true,
@ -165,13 +177,23 @@ impl DeflateConfigBuilder {
self self
} }
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`. /// Sets the server's LZ77 sliding window size. Panics if the provided size is not in `8..=15`.
pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { pub fn servers_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder {
assert!(
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits),
"max window bits must be in range 8..=15"
);
self.server_max_window_bits = max_window_bits;
self
}
/// Sets the client's LZ77 sliding window size. Panics if the provided size is not in `8..=15`.
pub fn client_max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder {
assert!( assert!(
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits),
"max window bits must be in range 8..=15" "max window bits must be in range 8..=15"
); );
self.max_window_bits = max_window_bits; self.client_max_window_bits = max_window_bits;
self self
} }
@ -197,7 +219,8 @@ impl DeflateConfigBuilder {
pub fn build(self) -> DeflateConfig { pub fn build(self) -> DeflateConfig {
DeflateConfig { DeflateConfig {
max_message_size: self.max_message_size.unwrap_or_else(usize::max_value), max_message_size: self.max_message_size.unwrap_or_else(usize::max_value),
max_window_bits: self.max_window_bits, server_max_window_bits: self.server_max_window_bits,
client_max_window_bits: self.client_max_window_bits,
request_no_context_takeover: self.request_no_context_takeover, request_no_context_takeover: self.request_no_context_takeover,
accept_no_context_takeover: self.accept_no_context_takeover, accept_no_context_takeover: self.accept_no_context_takeover,
compression_level: self.compression_level, compression_level: self.compression_level,
@ -209,9 +232,6 @@ impl DeflateConfigBuilder {
/// A permessage-deflate encoding WebSocket extension. /// A permessage-deflate encoding WebSocket extension.
#[derive(Debug)] #[derive(Debug)]
pub struct DeflateExt { pub struct DeflateExt {
/// Defines whether the extension is enabled. Following a successful handshake, this will be
/// `true`.
enabled: bool,
/// The configuration for the extension. /// The configuration for the extension.
config: DeflateConfig, config: DeflateConfig,
/// A stack of continuation frames awaiting `fin` and the total size of all of the fragments. /// A stack of continuation frames awaiting `fin` and the total size of all of the fragments.
@ -228,11 +248,10 @@ impl DeflateExt {
/// Creates a `DeflateExt` instance using the provided configuration. /// Creates a `DeflateExt` instance using the provided configuration.
pub fn new(config: DeflateConfig) -> DeflateExt { pub fn new(config: DeflateConfig) -> DeflateExt {
DeflateExt { DeflateExt {
enabled: false,
config, config,
fragment_buffer: FragmentBuffer::new(config.max_message_size), fragment_buffer: FragmentBuffer::new(config.max_message_size),
inflator: Inflator::new(config.max_window_bits), inflator: Inflator::new(config.server_max_window_bits),
deflator: Deflator::new(config.compression_level, config.max_window_bits), deflator: Deflator::new(config.compression_level, config.client_max_window_bits),
uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())),
} }
} }
@ -301,15 +320,16 @@ pub fn on_response<T>(
response: &Response<T>, response: &Response<T>,
config: &mut DeflateConfig, config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> { ) -> Result<bool, DeflateExtensionError> {
let mut extension_name = false; let mut seen_extension_name = false;
let mut server_takeover = false; let mut seen_server_takeover = false;
let mut client_takeover = false; let mut seen_client_takeover = false;
let mut server_max_window_bits = false; let mut seen_server_max_window_bits = false;
let mut client_max_window_bits = false; let mut seen_client_max_window_bits = false;
let mut enabled = false; let mut enabled = false;
let DeflateConfig { let DeflateConfig {
max_window_bits, server_max_window_bits,
client_max_window_bits,
accept_no_context_takeover, accept_no_context_takeover,
compress_reset, compress_reset,
decompress_reset, decompress_reset,
@ -322,32 +342,32 @@ pub fn on_response<T>(
for param in header.split(';') { for param in header.split(';') {
match param.trim().to_lowercase().as_str() { match param.trim().to_lowercase().as_str() {
"permessage-deflate" => { "permessage-deflate" => {
if extension_name { if seen_extension_name {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate" "Duplicate extension parameter: permessage-deflate"
))); )));
} else { } else {
enabled = true; enabled = true;
extension_name = true; seen_extension_name = true;
} }
} }
"server_no_context_takeover" => { "server_no_context_takeover" => {
if server_takeover { if seen_server_takeover {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_no_context_takeover" "Duplicate extension parameter: server_no_context_takeover"
))); )));
} else { } else {
server_takeover = true; seen_server_takeover = true;
*decompress_reset = true; *decompress_reset = true;
} }
} }
"client_no_context_takeover" => { "client_no_context_takeover" => {
if client_takeover { if seen_client_takeover {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_no_context_takeover" "Duplicate extension parameter: client_no_context_takeover"
))); )));
} else { } else {
client_takeover = true; seen_client_takeover = true;
if *accept_no_context_takeover { if *accept_no_context_takeover {
*compress_reset = true; *compress_reset = true;
@ -359,19 +379,19 @@ pub fn on_response<T>(
} }
} }
param if param.starts_with("server_max_window_bits") => { param if param.starts_with("server_max_window_bits") => {
if server_max_window_bits { if seen_server_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_max_window_bits" "Duplicate extension parameter: server_max_window_bits"
))); )));
} else { } else {
server_max_window_bits = true; seen_server_max_window_bits = true;
match parse_window_parameter( match parse_window_parameter(
param.split("=").skip(1), param.split("=").skip(1),
*max_window_bits, *server_max_window_bits,
) { ) {
Ok(Some(bits)) => { Ok(Some(bits)) => {
*max_window_bits = bits; *server_max_window_bits = bits;
} }
Ok(None) => {} Ok(None) => {}
Err(e) => { Err(e) => {
@ -386,19 +406,19 @@ pub fn on_response<T>(
} }
} }
param if param.starts_with("client_max_window_bits") => { param if param.starts_with("client_max_window_bits") => {
if client_max_window_bits { if seen_client_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_max_window_bits" "Duplicate extension parameter: client_max_window_bits"
))); )));
} else { } else {
client_max_window_bits = true; seen_client_max_window_bits = true;
match parse_window_parameter( match parse_window_parameter(
param.split("=").skip(1), param.split("=").skip(1),
*max_window_bits, *client_max_window_bits,
) { ) {
Ok(Some(bits)) => { Ok(Some(bits)) => {
*max_window_bits = bits; *client_max_window_bits = bits;
} }
Ok(None) => {} Ok(None) => {}
Err(e) => { Err(e) => {
@ -438,15 +458,18 @@ pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request
let mut header_value = String::from(EXT_IDENT); let mut header_value = String::from(EXT_IDENT);
let DeflateConfig { let DeflateConfig {
max_window_bits, server_max_window_bits,
client_max_window_bits,
request_no_context_takeover, request_no_context_takeover,
.. ..
} = config; } = config;
if *max_window_bits < LZ77_MAX_WINDOW_SIZE { if *client_max_window_bits < LZ77_MAX_WINDOW_SIZE
|| *server_max_window_bits < LZ77_MAX_WINDOW_SIZE
{
header_value.push_str(&format!( header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}", "; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits client_max_window_bits, server_max_window_bits
)) ))
} else { } else {
header_value.push_str("; client_max_window_bits") header_value.push_str("; client_max_window_bits")
@ -510,10 +533,10 @@ pub fn on_receive_request<T>(
match parse_window_parameter( match parse_window_parameter(
param.split('=').skip(1), param.split('=').skip(1),
config.max_window_bits, config.server_max_window_bits,
) { ) {
Ok(Some(bits)) => { Ok(Some(bits)) => {
config.max_window_bits = bits; config.server_max_window_bits = bits;
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(param) response_str.push_str(param)
@ -533,10 +556,10 @@ pub fn on_receive_request<T>(
match parse_window_parameter( match parse_window_parameter(
param.split('=').skip(1), param.split('=').skip(1),
config.max_window_bits, config.client_max_window_bits,
) { ) {
Ok(Some(bits)) => { Ok(Some(bits)) => {
config.max_window_bits = bits; config.client_max_window_bits = bits;
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(param); response_str.push_str(param);
@ -551,7 +574,7 @@ pub fn on_receive_request<T>(
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(&format!( response_str.push_str(&format!(
"client_max_window_bits={}", "client_max_window_bits={}",
config.max_window_bits() config.client_max_window_bits()
)) ))
} }
} }
@ -572,12 +595,12 @@ pub fn on_receive_request<T>(
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(&format!( response_str.push_str(&format!(
"server_max_window_bits={}", "server_max_window_bits={}",
config.max_window_bits() config.server_max_window_bits()
)) ))
} }
if !response_str.contains("client_max_window_bits") if !response_str.contains("client_max_window_bits")
&& config.max_window_bits() < LZ77_MAX_WINDOW_SIZE && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE
{ {
continue; continue;
} }
@ -622,20 +645,18 @@ impl Default for DeflateExt {
impl WebSocketExtension for DeflateExt { impl WebSocketExtension for DeflateExt {
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, crate::Error> { fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, crate::Error> {
if self.enabled { if let OpCode::Data(_) = frame.header().opcode {
if let OpCode::Data(_) = frame.header().opcode { let mut compressed = Vec::with_capacity(frame.payload().len());
let mut compressed = Vec::with_capacity(frame.payload().len()); self.deflator.compress(frame.payload(), &mut compressed)?;
self.deflator.compress(frame.payload(), &mut compressed)?;
let len = compressed.len(); let len = compressed.len();
compressed.truncate(len - 4); compressed.truncate(len - 4);
*frame.payload_mut() = compressed; *frame.payload_mut() = compressed;
frame.header_mut().rsv1 = true; frame.header_mut().rsv1 = true;
if self.config.compress_reset() { if self.config.compress_reset() {
self.deflator.reset(); self.deflator.reset();
}
} }
} }
@ -643,7 +664,7 @@ impl WebSocketExtension for DeflateExt {
} }
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> { fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) { if !self.fragment_buffer.is_empty() || frame.header().rsv1 {
if !frame.header().is_final { if !frame.header().is_final {
self.fragment_buffer self.fragment_buffer
.try_push_frame(frame) .try_push_frame(frame)
@ -694,11 +715,6 @@ impl WebSocketExtension for DeflateExt {
} }
} else { } else {
self.uncompressed_extension.on_receive_frame(frame) self.uncompressed_extension.on_receive_frame(frame)
};
match r {
Ok(msg) => Ok(msg),
Err(e) => Err(crate::Error::ExtensionError(e.to_string().into())),
} }
} }
} }

@ -0,0 +1,160 @@
//! WebSocket compression
#[cfg(feature = "deflate")]
use crate::extensions::compression::deflate::{DeflateConfig, DeflateExt};
use crate::extensions::compression::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::Frame;
use crate::protocol::WebSocketConfig;
use crate::Message;
use http::{Request, Response};
use std::borrow::Cow;
use std::error::Error;
use std::fmt::{Display, Formatter};
/// A permessage-deflate WebSocket extension (RFC 7692).
#[cfg(feature = "deflate")]
pub mod deflate;
/// An uncompressed message handler for a WebSocket.
pub mod uncompressed;
///
#[derive(Copy, Clone, Debug)]
pub enum WsCompression {
///
None(Option<usize>),
///
#[cfg(feature = "deflate")]
Deflate(DeflateConfig),
}
/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`.
#[derive(Debug)]
pub enum CompressionSwitcher {
///
#[cfg(feature = "deflate")]
Compressed(DeflateExt),
///
Uncompressed(UncompressedExt),
}
impl CompressionSwitcher {
///
pub fn from_config(config: WsCompression) -> CompressionSwitcher {
match config {
WsCompression::None(size) => {
CompressionSwitcher::Uncompressed(UncompressedExt::new(size))
}
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => {
CompressionSwitcher::Compressed(DeflateExt::new(config))
}
}
}
}
impl Default for CompressionSwitcher {
fn default() -> Self {
CompressionSwitcher::Uncompressed(UncompressedExt::default())
}
}
#[derive(Debug)]
///
pub struct CompressionError(String);
impl Error for CompressionError {}
impl From<CompressionError> for crate::Error {
fn from(e: CompressionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
impl Display for CompressionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionError")
.field("error", &self.0)
.finish()
}
}
impl WebSocketExtension for CompressionSwitcher {
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame),
}
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame),
}
}
}
///
pub fn build_compression_headers<T>(
request: Request<T>,
config: &mut Option<WebSocketConfig>,
) -> Request<T> {
match config {
Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request,
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config),
},
None => request,
}
}
///
pub fn verify_compression_resp_headers<T>(
_response: &Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
let result = deflate::on_response(_response, deflate_config)
.map_err(|e| CompressionError(e.to_string()));
match result {
Ok(true) => Ok(()),
Ok(false) => {
config.compression =
WsCompression::None(Some(deflate_config.max_message_size()));
Ok(())
}
Err(e) => Err(e),
}
}
},
None => Ok(()),
}
}
///
pub fn verify_compression_req_headers<T>(
_request: &Request<T>,
_response: &mut Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
deflate::on_receive_request(_request, _response, deflate_config)
.map_err(|e| CompressionError(e.to_string()))
}
},
None => Ok(()),
}
}

@ -35,7 +35,6 @@ impl UncompressedExt {
impl WebSocketExtension for UncompressedExt { impl WebSocketExtension for UncompressedExt {
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> { fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
let fin = frame.header().is_final; let fin = frame.header().is_final;
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {

@ -1,32 +1,9 @@
//! WebSocket extensions //! WebSocket extensions
use http::{Request, Response}; pub mod compression;
#[cfg(feature = "deflate")]
use crate::extensions::deflate::{DeflateConfig, DeflateExt};
use crate::extensions::uncompressed::UncompressedExt;
use crate::protocol::frame::Frame; use crate::protocol::frame::Frame;
use crate::protocol::WebSocketConfig;
use crate::Message; use crate::Message;
use std::borrow::Cow;
use std::error::Error;
use std::fmt::{Display, Formatter};
/// A permessage-deflate WebSocket extension (RFC 7692).
#[cfg(feature = "deflate")]
pub mod deflate;
/// An uncompressed message handler for a WebSocket.
pub mod uncompressed;
///
#[derive(Copy, Clone, Debug)]
pub enum WsCompression {
///
None(Option<usize>),
///
#[cfg(feature = "deflate")]
Deflate(DeflateConfig),
}
/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions /// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions
/// may be stacked by nesting them inside one another. /// may be stacked by nesting them inside one another.
@ -40,133 +17,3 @@ pub trait WebSocketExtension {
/// type `OpCode::Data`. /// type `OpCode::Data`.
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error>; fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error>;
} }
/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`.
#[derive(Debug)]
pub enum CompressionSwitcher {
///
#[cfg(feature = "deflate")]
Compressed(DeflateExt),
///
Uncompressed(UncompressedExt),
}
impl CompressionSwitcher {
///
pub fn from_config(config: WsCompression) -> CompressionSwitcher {
match config {
WsCompression::None(size) => {
CompressionSwitcher::Uncompressed(UncompressedExt::new(size))
}
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => {
CompressionSwitcher::Compressed(DeflateExt::new(config))
}
}
}
}
impl Default for CompressionSwitcher {
fn default() -> Self {
CompressionSwitcher::Uncompressed(UncompressedExt::default())
}
}
#[derive(Debug)]
///
pub struct CompressionError(String);
impl Error for CompressionError {}
impl From<CompressionError> for crate::Error {
fn from(e: CompressionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
impl Display for CompressionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionError")
.field("error", &self.0)
.finish()
}
}
impl WebSocketExtension for CompressionSwitcher {
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame),
}
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame),
}
}
}
///
pub fn build_compression_headers<T>(
request: Request<T>,
config: &mut Option<WebSocketConfig>,
) -> Request<T> {
match config {
Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request,
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config),
},
None => request,
}
}
///
pub fn verify_compression_resp_headers<T>(
_response: &Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
let result = deflate::on_response(_response, deflate_config)
.map_err(|e| CompressionError(e.to_string()));
match result {
Ok(true) => Ok(()),
Ok(false) => {
config.compression =
WsCompression::None(Some(deflate_config.max_message_size()));
Ok(())
}
Err(e) => Err(e),
}
}
},
None => Ok(()),
}
}
///
pub fn verify_compression_req_headers<T>(
_request: &Request<T>,
_response: &mut Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
deflate::on_receive_request(_request, _response, deflate_config)
.map_err(|e| CompressionError(e.to_string()))
}
},
None => Ok(()),
}
}

@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::extensions::{build_compression_headers, verify_compression_resp_headers}; use crate::extensions::compression::{build_compression_headers, verify_compression_resp_headers};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type. /// Client request type.

@ -12,7 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::extensions::verify_compression_req_headers; use crate::extensions::compression::verify_compression_req_headers;
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Server request type. /// Server request type.

@ -16,7 +16,8 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec}; use self::frame::{Frame, FrameCodec};
use self::message::IncompleteMessage; use self::message::IncompleteMessage;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::extensions::{CompressionSwitcher, WebSocketExtension, WsCompression}; use crate::extensions::compression::{CompressionSwitcher, WsCompression};
use crate::extensions::WebSocketExtension;
use crate::util::NonBlockingResult; use crate::util::NonBlockingResult;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
@ -636,7 +637,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests { mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::WsCompression; use crate::extensions::compression::WsCompression;
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;

Loading…
Cancel
Save