commit
cd43267e17
@ -1,2 +1,5 @@ |
||||
target |
||||
Cargo.lock |
||||
|
||||
autobahn/client/* |
||||
autobahn/server/* |
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,922 @@ |
||||
//! Permessage-deflate extension
|
||||
|
||||
use std::fmt::{Display, Formatter}; |
||||
|
||||
use crate::extensions::uncompressed::UncompressedExt; |
||||
use crate::extensions::WebSocketExtension; |
||||
use crate::protocol::frame::coding::{Data, OpCode}; |
||||
use crate::protocol::frame::Frame; |
||||
use crate::protocol::MAX_MESSAGE_SIZE; |
||||
use crate::Message; |
||||
use bytes::BufMut; |
||||
use flate2::{ |
||||
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress, |
||||
FlushDecompress, Status, |
||||
}; |
||||
use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS}; |
||||
use http::{HeaderValue, Request, Response}; |
||||
use std::borrow::Cow; |
||||
use std::mem::replace; |
||||
use std::slice; |
||||
|
||||
/// The WebSocket Extension Identifier as per the IANA registry.
|
||||
const EXT_IDENT: &str = "permessage-deflate"; |
||||
|
||||
/// The minimum size of the LZ77 sliding window size.
|
||||
const LZ77_MIN_WINDOW_SIZE: u8 = 8; |
||||
|
||||
/// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter
|
||||
/// indicates that the client can receive messages compressed using an LZ77 sliding window of up to
|
||||
/// 32,768 bytes. RFC 7692 7.1.2.1.
|
||||
const LZ77_MAX_WINDOW_SIZE: u8 = 15; |
||||
|
||||
/// A permessage-deflate configuration.
|
||||
#[derive(Clone, Copy, Debug)] |
||||
pub struct DeflateConfig { |
||||
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big
|
||||
/// for all normal use-cases but small enough to prevent memory eating by a malicious user.
|
||||
max_message_size: usize, |
||||
/// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this
|
||||
/// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in
|
||||
/// range 8..15 inclusive.
|
||||
max_window_bits: u8, |
||||
/// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1.
|
||||
request_no_context_takeover: bool, |
||||
/// Whether to accept `no_context_takeover`.
|
||||
accept_no_context_takeover: bool, |
||||
// Whether the compressor should be reset after usage.
|
||||
compress_reset: bool, |
||||
// Whether the decompressor should be reset after usage.
|
||||
decompress_reset: bool, |
||||
/// The active compression level. The integer here is typically on a scale of 0-9 where 0 means
|
||||
/// "no compression" and 9 means "take as long as you'd like".
|
||||
compression_level: Compression, |
||||
} |
||||
|
||||
impl DeflateConfig { |
||||
/// Builds a new `DeflateConfig` using the `compression_level` and the defaults for all other
|
||||
/// members.
|
||||
pub fn with_compression_level(compression_level: Compression) -> DeflateConfig { |
||||
DeflateConfig { |
||||
compression_level, |
||||
..Default::default() |
||||
} |
||||
} |
||||
|
||||
/// Returns the maximum message size permitted.
|
||||
pub fn max_message_size(&self) -> usize { |
||||
self.max_message_size |
||||
} |
||||
|
||||
/// Returns the maximum LZ77 window size permitted.
|
||||
pub fn max_window_bits(&self) -> u8 { |
||||
self.max_window_bits |
||||
} |
||||
|
||||
/// Returns whether `no_context_takeover` has been requested.
|
||||
pub fn request_no_context_takeover(&self) -> bool { |
||||
self.request_no_context_takeover |
||||
} |
||||
|
||||
/// Returns whether this WebSocket will accept `no_context_takeover`.
|
||||
pub fn accept_no_context_takeover(&self) -> bool { |
||||
self.accept_no_context_takeover |
||||
} |
||||
|
||||
/// Returns whether or not the inner compressor is set to reset after completing a message.
|
||||
pub fn compress_reset(&self) -> bool { |
||||
self.compress_reset |
||||
} |
||||
|
||||
/// Returns whether or not the inner decompressor is set to reset after completing a message.
|
||||
pub fn decompress_reset(&self) -> bool { |
||||
self.decompress_reset |
||||
} |
||||
|
||||
/// Returns the active compression level.
|
||||
pub fn compression_level(&self) -> Compression { |
||||
self.compression_level |
||||
} |
||||
|
||||
/// Sets the maximum message size permitted.
|
||||
pub fn set_max_message_size(&mut self, max_message_size: Option<usize>) { |
||||
self.max_message_size = max_message_size.unwrap_or_else(usize::max_value); |
||||
} |
||||
|
||||
/// Sets the LZ77 sliding window size.
|
||||
pub fn set_max_window_bits(&mut self, max_window_bits: u8) { |
||||
assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits)); |
||||
self.max_window_bits = max_window_bits; |
||||
} |
||||
|
||||
/// Sets the WebSocket to request `no_context_takeover` if `true`.
|
||||
pub fn set_request_no_context_takeover(&mut self, request_no_context_takeover: bool) { |
||||
self.request_no_context_takeover = request_no_context_takeover; |
||||
} |
||||
|
||||
/// Sets the WebSocket to accept `no_context_takeover` if `true`.
|
||||
pub fn set_accept_no_context_takeover(&mut self, accept_no_context_takeover: bool) { |
||||
self.accept_no_context_takeover = accept_no_context_takeover; |
||||
} |
||||
} |
||||
|
||||
impl Default for DeflateConfig { |
||||
fn default() -> Self { |
||||
DeflateConfig { |
||||
max_message_size: MAX_MESSAGE_SIZE, |
||||
max_window_bits: LZ77_MAX_WINDOW_SIZE, |
||||
request_no_context_takeover: false, |
||||
accept_no_context_takeover: true, |
||||
compress_reset: false, |
||||
decompress_reset: false, |
||||
compression_level: Compression::best(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// A `DeflateConfig` builder.
|
||||
#[derive(Debug, Copy, Clone)] |
||||
pub struct DeflateConfigBuilder { |
||||
max_message_size: Option<usize>, |
||||
max_window_bits: u8, |
||||
request_no_context_takeover: bool, |
||||
accept_no_context_takeover: bool, |
||||
fragments_grow: bool, |
||||
compression_level: Compression, |
||||
} |
||||
|
||||
impl Default for DeflateConfigBuilder { |
||||
fn default() -> Self { |
||||
DeflateConfigBuilder { |
||||
max_message_size: Some(MAX_MESSAGE_SIZE), |
||||
max_window_bits: LZ77_MAX_WINDOW_SIZE, |
||||
request_no_context_takeover: false, |
||||
accept_no_context_takeover: true, |
||||
fragments_grow: true, |
||||
compression_level: Compression::fast(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl DeflateConfigBuilder { |
||||
/// Sets the maximum message size permitted.
|
||||
pub fn max_message_size(mut self, max_message_size: Option<usize>) -> DeflateConfigBuilder { |
||||
self.max_message_size = max_message_size; |
||||
self |
||||
} |
||||
|
||||
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`.
|
||||
pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { |
||||
assert!( |
||||
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), |
||||
"max window bits must be in range 8..=15" |
||||
); |
||||
self.max_window_bits = max_window_bits; |
||||
self |
||||
} |
||||
|
||||
/// Sets the WebSocket to request `no_context_takeover`.
|
||||
pub fn request_no_context_takeover( |
||||
mut self, |
||||
request_no_context_takeover: bool, |
||||
) -> DeflateConfigBuilder { |
||||
self.request_no_context_takeover = request_no_context_takeover; |
||||
self |
||||
} |
||||
|
||||
/// Sets the WebSocket to accept `no_context_takeover`.
|
||||
pub fn accept_no_context_takeover( |
||||
mut self, |
||||
accept_no_context_takeover: bool, |
||||
) -> DeflateConfigBuilder { |
||||
self.accept_no_context_takeover = accept_no_context_takeover; |
||||
self |
||||
} |
||||
|
||||
/// Consumes the builder and produces a `DeflateConfig.`
|
||||
pub fn build(self) -> DeflateConfig { |
||||
DeflateConfig { |
||||
max_message_size: self.max_message_size.unwrap_or_else(usize::max_value), |
||||
max_window_bits: self.max_window_bits, |
||||
request_no_context_takeover: self.request_no_context_takeover, |
||||
accept_no_context_takeover: self.accept_no_context_takeover, |
||||
compression_level: self.compression_level, |
||||
..Default::default() |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// A permessage-deflate encoding WebSocket extension.
|
||||
#[derive(Debug)] |
||||
pub struct DeflateExt { |
||||
/// Defines whether the extension is enabled. Following a successful handshake, this will be
|
||||
/// `true`.
|
||||
enabled: bool, |
||||
/// The configuration for the extension.
|
||||
config: DeflateConfig, |
||||
/// A stack of continuation frames awaiting `fin` and the total size of all of the fragments.
|
||||
fragment_buffer: FragmentBuffer, |
||||
/// The deflate decompressor.
|
||||
inflator: Inflator, |
||||
/// The deflate compressor.
|
||||
deflator: Deflator, |
||||
/// If this deflate extension is not used, messages will be forwarded to this extension.
|
||||
uncompressed_extension: UncompressedExt, |
||||
} |
||||
|
||||
impl DeflateExt { |
||||
/// Creates a `DeflateExt` instance using the provided configuration.
|
||||
pub fn new(config: DeflateConfig) -> DeflateExt { |
||||
DeflateExt { |
||||
enabled: false, |
||||
config, |
||||
fragment_buffer: FragmentBuffer::new(config.max_message_size), |
||||
inflator: Inflator::new(), |
||||
deflator: Deflator::new(Compression::fast()), |
||||
uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())), |
||||
} |
||||
} |
||||
|
||||
fn parse_window_parameter<'a>( |
||||
&mut self, |
||||
mut param_iter: impl Iterator<Item = &'a str>, |
||||
) -> Result<Option<u8>, String> { |
||||
if let Some(window_bits_str) = param_iter.next() { |
||||
match window_bits_str.trim().parse() { |
||||
Ok(window_bits) => { |
||||
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { |
||||
if window_bits != self.config.max_window_bits() { |
||||
self.config.max_window_bits = window_bits; |
||||
Ok(Some(window_bits)) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} else { |
||||
Err(format!("Invalid window parameter: {}", window_bits)) |
||||
} |
||||
} |
||||
Err(e) => Err(e.to_string()), |
||||
} |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
|
||||
fn decline<T>(&mut self, res: &mut Response<T>) { |
||||
self.enabled = false; |
||||
res.headers_mut().remove(EXT_IDENT); |
||||
} |
||||
} |
||||
|
||||
/// A permessage-deflate extension error.
|
||||
#[derive(Debug, Clone)] |
||||
pub enum DeflateExtensionError { |
||||
/// An error produced when deflating a message.
|
||||
DeflateError(String), |
||||
/// An error produced when inflating a message.
|
||||
InflateError(String), |
||||
/// An error produced during the WebSocket negotiation.
|
||||
NegotiationError(String), |
||||
/// Produced when fragment buffer grew beyond the maximum configured size.
|
||||
Capacity(Cow<'static, str>), |
||||
} |
||||
|
||||
impl Display for DeflateExtensionError { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
||||
match self { |
||||
DeflateExtensionError::DeflateError(m) => { |
||||
write!(f, "An error was produced during decompression: {}", m) |
||||
} |
||||
DeflateExtensionError::InflateError(m) => { |
||||
write!(f, "An error was produced during compression: {}", m) |
||||
} |
||||
DeflateExtensionError::NegotiationError(m) => { |
||||
write!(f, "An upgrade error was encountered: {}", m) |
||||
} |
||||
DeflateExtensionError::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl std::error::Error for DeflateExtensionError {} |
||||
|
||||
impl From<DeflateExtensionError> for crate::Error { |
||||
fn from(e: DeflateExtensionError) -> Self { |
||||
crate::Error::ExtensionError(Cow::from(e.to_string())) |
||||
} |
||||
} |
||||
|
||||
impl From<InvalidHeaderValue> for DeflateExtensionError { |
||||
fn from(e: InvalidHeaderValue) -> Self { |
||||
DeflateExtensionError::NegotiationError(e.to_string()) |
||||
} |
||||
} |
||||
|
||||
impl Default for DeflateExt { |
||||
fn default() -> Self { |
||||
DeflateExt::new(Default::default()) |
||||
} |
||||
} |
||||
|
||||
impl WebSocketExtension for DeflateExt { |
||||
type Error = DeflateExtensionError; |
||||
|
||||
fn new(max_message_size: Option<usize>) -> Self { |
||||
DeflateExt::new(DeflateConfig { |
||||
max_message_size: max_message_size.unwrap_or_else(usize::max_value), |
||||
..Default::default() |
||||
}) |
||||
} |
||||
|
||||
fn enabled(&self) -> bool { |
||||
self.enabled |
||||
} |
||||
|
||||
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> { |
||||
let mut header_value = String::from(EXT_IDENT); |
||||
let DeflateConfig { |
||||
max_window_bits, |
||||
request_no_context_takeover, |
||||
.. |
||||
} = self.config; |
||||
|
||||
if max_window_bits < LZ77_MAX_WINDOW_SIZE { |
||||
header_value.push_str(&format!( |
||||
"; client_max_window_bits={}; server_max_window_bits={}", |
||||
max_window_bits, max_window_bits |
||||
)) |
||||
} else { |
||||
header_value.push_str("; client_max_window_bits") |
||||
} |
||||
|
||||
if request_no_context_takeover { |
||||
header_value.push_str("; server_no_context_takeover") |
||||
} |
||||
|
||||
request.headers_mut().append( |
||||
SEC_WEBSOCKET_EXTENSIONS, |
||||
HeaderValue::from_str(&header_value).unwrap(), |
||||
); |
||||
|
||||
request |
||||
} |
||||
|
||||
fn on_receive_request<T>( |
||||
&mut self, |
||||
request: &Request<T>, |
||||
response: &mut Response<T>, |
||||
) -> Result<(), Self::Error> { |
||||
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { |
||||
return match header.to_str() { |
||||
Ok(header) => { |
||||
let mut response_str = String::with_capacity(header.len()); |
||||
let mut server_takeover = false; |
||||
let mut client_takeover = false; |
||||
let mut server_max_bits = false; |
||||
let mut client_max_bits = false; |
||||
|
||||
for param in header.split(';') { |
||||
match param.trim().to_lowercase().as_str() { |
||||
"permessage-deflate" => response_str.push_str("permessage-deflate"), |
||||
"server_no_context_takeover" => { |
||||
if server_takeover { |
||||
self.decline(response); |
||||
} else { |
||||
server_takeover = true; |
||||
if self.config.accept_no_context_takeover() { |
||||
self.config.compress_reset = true; |
||||
response_str.push_str("; server_no_context_takeover"); |
||||
} |
||||
} |
||||
} |
||||
"client_no_context_takeover" => { |
||||
if client_takeover { |
||||
self.decline(response); |
||||
} else { |
||||
client_takeover = true; |
||||
self.config.decompress_reset = true; |
||||
response_str.push_str("; client_no_context_takeover"); |
||||
} |
||||
} |
||||
param if param.starts_with("server_max_window_bits") => { |
||||
if server_max_bits { |
||||
self.decline(response); |
||||
} else { |
||||
server_max_bits = true; |
||||
|
||||
match self.parse_window_parameter(param.split('=').skip(1)) { |
||||
Ok(Some(bits)) => { |
||||
self.deflator = Deflator::new_with_window_bits( |
||||
self.config.compression_level, |
||||
bits, |
||||
); |
||||
response_str.push_str("; "); |
||||
response_str.push_str(param) |
||||
} |
||||
Ok(None) => {} |
||||
Err(_) => { |
||||
self.decline(response); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
param if param.starts_with("client_max_window_bits") => { |
||||
if client_max_bits { |
||||
self.decline(response); |
||||
} else { |
||||
client_max_bits = true; |
||||
|
||||
match self.parse_window_parameter(param.split('=').skip(1)) { |
||||
Ok(Some(bits)) => { |
||||
self.inflator = Inflator::new_with_window_bits(bits); |
||||
|
||||
response_str.push_str("; "); |
||||
response_str.push_str(param); |
||||
continue; |
||||
} |
||||
Ok(None) => {} |
||||
Err(_) => { |
||||
self.decline(response); |
||||
} |
||||
} |
||||
|
||||
response_str.push_str("; "); |
||||
response_str.push_str(&format!( |
||||
"client_max_window_bits={}", |
||||
self.config.max_window_bits() |
||||
)) |
||||
} |
||||
} |
||||
_ => { |
||||
self.decline(response); |
||||
} |
||||
} |
||||
} |
||||
|
||||
if !response_str.contains("client_no_context_takeover") |
||||
&& self.config.request_no_context_takeover() |
||||
{ |
||||
self.config.decompress_reset = true; |
||||
response_str.push_str("; client_no_context_takeover"); |
||||
} |
||||
|
||||
if !response_str.contains("server_max_window_bits") { |
||||
response_str.push_str("; "); |
||||
response_str.push_str(&format!( |
||||
"server_max_window_bits={}", |
||||
self.config.max_window_bits() |
||||
)) |
||||
} |
||||
|
||||
if !response_str.contains("client_max_window_bits") |
||||
&& self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE |
||||
{ |
||||
continue; |
||||
} |
||||
|
||||
response.headers_mut().insert( |
||||
SEC_WEBSOCKET_EXTENSIONS, |
||||
HeaderValue::from_str(&response_str)?, |
||||
); |
||||
|
||||
self.enabled = true; |
||||
|
||||
Ok(()) |
||||
} |
||||
Err(e) => { |
||||
self.enabled = false; |
||||
Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Failed to parse request header: {}", |
||||
e, |
||||
))) |
||||
} |
||||
}; |
||||
} |
||||
|
||||
self.decline(response); |
||||
Ok(()) |
||||
} |
||||
|
||||
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> { |
||||
let mut extension_name = false; |
||||
let mut server_takeover = false; |
||||
let mut client_takeover = false; |
||||
let mut server_max_window_bits = false; |
||||
let mut client_max_window_bits = false; |
||||
|
||||
for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() { |
||||
match header.to_str() { |
||||
Ok(header) => { |
||||
for param in header.split(';') { |
||||
match param.trim().to_lowercase().as_str() { |
||||
"permessage-deflate" => { |
||||
if extension_name { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Duplicate extension parameter: permessage-deflate" |
||||
))); |
||||
} else { |
||||
self.enabled = true; |
||||
extension_name = true; |
||||
} |
||||
} |
||||
"server_no_context_takeover" => { |
||||
if server_takeover { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Duplicate extension parameter: server_no_context_takeover" |
||||
))); |
||||
} else { |
||||
server_takeover = true; |
||||
self.config.decompress_reset = true; |
||||
} |
||||
} |
||||
"client_no_context_takeover" => { |
||||
if client_takeover { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Duplicate extension parameter: client_no_context_takeover" |
||||
))); |
||||
} else { |
||||
client_takeover = true; |
||||
|
||||
if self.config.accept_no_context_takeover() { |
||||
self.config.compress_reset = true; |
||||
} else { |
||||
return Err(DeflateExtensionError::NegotiationError( |
||||
format!("The client requires context takeover."), |
||||
)); |
||||
} |
||||
} |
||||
} |
||||
param if param.starts_with("server_max_window_bits") => { |
||||
if server_max_window_bits { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Duplicate extension parameter: server_max_window_bits" |
||||
))); |
||||
} else { |
||||
server_max_window_bits = true; |
||||
|
||||
match self.parse_window_parameter(param.split("=").skip(1)) { |
||||
Ok(Some(bits)) => { |
||||
self.inflator = Inflator::new_with_window_bits(bits); |
||||
} |
||||
Ok(None) => {} |
||||
Err(e) => { |
||||
return Err(DeflateExtensionError::NegotiationError( |
||||
format!( |
||||
"server_max_window_bits parameter error: {}", |
||||
e |
||||
), |
||||
)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
param if param.starts_with("client_max_window_bits") => { |
||||
if client_max_window_bits { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Duplicate extension parameter: client_max_window_bits" |
||||
))); |
||||
} else { |
||||
client_max_window_bits = true; |
||||
|
||||
match self.parse_window_parameter(param.split("=").skip(1)) { |
||||
Ok(Some(bits)) => { |
||||
self.deflator = Deflator::new_with_window_bits( |
||||
self.config.compression_level, |
||||
bits, |
||||
); |
||||
} |
||||
Ok(None) => {} |
||||
Err(e) => { |
||||
return Err(DeflateExtensionError::NegotiationError( |
||||
format!( |
||||
"client_max_window_bits parameter error: {}", |
||||
e |
||||
), |
||||
)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
p => { |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Unknown permessage-deflate parameter: {}", |
||||
p |
||||
))); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
Err(e) => { |
||||
self.enabled = false; |
||||
return Err(DeflateExtensionError::NegotiationError(format!( |
||||
"Failed to parse extension parameter: {}", |
||||
e |
||||
))); |
||||
} |
||||
} |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> { |
||||
if self.enabled { |
||||
if let OpCode::Data(_) = frame.header().opcode { |
||||
let mut compressed = Vec::with_capacity(frame.payload().len()); |
||||
self.deflator.compress(frame.payload(), &mut compressed)?; |
||||
|
||||
let len = compressed.len(); |
||||
compressed.truncate(len - 4); |
||||
|
||||
*frame.payload_mut() = compressed; |
||||
frame.header_mut().rsv1 = true; |
||||
|
||||
if self.config.compress_reset() { |
||||
self.deflator.reset(); |
||||
} |
||||
} |
||||
} |
||||
|
||||
Ok(frame) |
||||
} |
||||
|
||||
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> { |
||||
let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) { |
||||
if !frame.header().is_final { |
||||
self.fragment_buffer |
||||
.try_push_frame(frame) |
||||
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?; |
||||
Ok(None) |
||||
} else { |
||||
let mut compressed = if self.fragment_buffer.is_empty() { |
||||
Vec::with_capacity(frame.payload().len()) |
||||
} else { |
||||
Vec::with_capacity(self.fragment_buffer.len() + frame.payload().len()) |
||||
}; |
||||
|
||||
let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); |
||||
|
||||
let opcode = match frame.header().opcode { |
||||
OpCode::Data(Data::Continue) => { |
||||
self.fragment_buffer |
||||
.try_push_frame(frame) |
||||
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?; |
||||
|
||||
let opcode = self.fragment_buffer.first().unwrap().header().opcode; |
||||
|
||||
self.fragment_buffer.reset().into_iter().for_each(|f| { |
||||
compressed.extend(f.into_data()); |
||||
}); |
||||
|
||||
opcode |
||||
} |
||||
_ => { |
||||
compressed.put_slice(frame.payload()); |
||||
frame.header().opcode |
||||
} |
||||
}; |
||||
|
||||
compressed.extend(&[0, 0, 255, 255]); |
||||
|
||||
self.inflator.decompress(&compressed, &mut decompressed)?; |
||||
|
||||
if self.config.decompress_reset() { |
||||
self.inflator.reset(false); |
||||
} |
||||
|
||||
self.uncompressed_extension.on_receive_frame(Frame::message( |
||||
decompressed, |
||||
opcode, |
||||
true, |
||||
)) |
||||
} |
||||
} else { |
||||
self.uncompressed_extension.on_receive_frame(frame) |
||||
}; |
||||
|
||||
match r { |
||||
Ok(msg) => Ok(msg), |
||||
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl From<DecompressError> for DeflateExtensionError { |
||||
fn from(e: DecompressError) -> Self { |
||||
DeflateExtensionError::InflateError(e.to_string()) |
||||
} |
||||
} |
||||
|
||||
impl From<CompressError> for DeflateExtensionError { |
||||
fn from(e: CompressError) -> Self { |
||||
DeflateExtensionError::DeflateError(e.to_string()) |
||||
} |
||||
} |
||||
|
||||
#[derive(Debug)] |
||||
struct Deflator { |
||||
compress: Compress, |
||||
} |
||||
|
||||
impl Deflator { |
||||
fn new(compresion: Compression) -> Deflator { |
||||
Deflator { |
||||
compress: Compress::new(compresion, false), |
||||
} |
||||
} |
||||
|
||||
fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator { |
||||
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
|
||||
if window_size == 8 { |
||||
window_size = 9; |
||||
} |
||||
|
||||
Deflator { |
||||
compress: Compress::new_with_window_bits(compression, false, window_size), |
||||
} |
||||
} |
||||
|
||||
fn reset(&mut self) { |
||||
self.compress.reset() |
||||
} |
||||
|
||||
fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), CompressError> { |
||||
let mut read_buff = Vec::from(input); |
||||
let mut output_size; |
||||
|
||||
loop { |
||||
output_size = output.len(); |
||||
|
||||
if output_size == output.capacity() { |
||||
output.reserve(input.len()); |
||||
} |
||||
|
||||
let before_out = self.compress.total_out(); |
||||
let before_in = self.compress.total_in(); |
||||
|
||||
let out_slice = unsafe { |
||||
slice::from_raw_parts_mut( |
||||
output.as_mut_ptr().offset(output_size as isize), |
||||
output.capacity() - output_size, |
||||
) |
||||
}; |
||||
|
||||
let status = self |
||||
.compress |
||||
.compress(&read_buff, out_slice, FlushCompress::Sync)?; |
||||
|
||||
let consumed = (self.compress.total_in() - before_in) as usize; |
||||
read_buff = read_buff.split_off(consumed); |
||||
|
||||
unsafe { |
||||
output.set_len((self.compress.total_out() - before_out) as usize + output_size); |
||||
} |
||||
|
||||
match status { |
||||
Status::Ok | Status::BufError => { |
||||
if before_out == self.compress.total_out() && read_buff.is_empty() { |
||||
return Ok(()); |
||||
} |
||||
} |
||||
s => panic!("Compression error: {:?}", s), |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Debug)] |
||||
struct Inflator { |
||||
decompress: Decompress, |
||||
} |
||||
|
||||
impl Inflator { |
||||
fn new() -> Inflator { |
||||
Inflator { |
||||
decompress: Decompress::new(false), |
||||
} |
||||
} |
||||
|
||||
fn new_with_window_bits(mut window_size: u8) -> Inflator { |
||||
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
|
||||
if window_size == 8 { |
||||
window_size = 9; |
||||
} |
||||
|
||||
Inflator { |
||||
decompress: Decompress::new_with_window_bits(false, window_size), |
||||
} |
||||
} |
||||
|
||||
fn reset(&mut self, zlib_header: bool) { |
||||
self.decompress.reset(zlib_header) |
||||
} |
||||
|
||||
fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), DecompressError> { |
||||
let mut read_buff = Vec::from(input); |
||||
let mut output_size; |
||||
|
||||
loop { |
||||
output_size = output.len(); |
||||
|
||||
if output_size == output.capacity() { |
||||
output.reserve(input.len()); |
||||
} |
||||
|
||||
let before_out = self.decompress.total_out(); |
||||
let before_in = self.decompress.total_in(); |
||||
|
||||
let out_slice = unsafe { |
||||
slice::from_raw_parts_mut( |
||||
output.as_mut_ptr().offset(output_size as isize), |
||||
output.capacity() - output_size, |
||||
) |
||||
}; |
||||
|
||||
let status = |
||||
self.decompress |
||||
.decompress(&read_buff, out_slice, FlushDecompress::Sync)?; |
||||
|
||||
let consumed = (self.decompress.total_in() - before_in) as usize; |
||||
read_buff = read_buff.split_off(consumed); |
||||
|
||||
unsafe { |
||||
output.set_len((self.decompress.total_out() - before_out) as usize + output_size); |
||||
} |
||||
|
||||
match status { |
||||
Status::Ok | Status::BufError => { |
||||
if before_out == self.decompress.total_out() && read_buff.is_empty() { |
||||
return Ok(()); |
||||
} |
||||
} |
||||
s => panic!("Decompression error: {:?}", s), |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// A buffer for holding continuation frames. Ensures that the total length of all of the frame's
|
||||
/// payloads does not exceed `max_len`.
|
||||
///
|
||||
/// Defaults to an initial capacity of ten frames.
|
||||
#[derive(Debug)] |
||||
struct FragmentBuffer { |
||||
fragments: Vec<Frame>, |
||||
fragments_len: usize, |
||||
max_len: usize, |
||||
} |
||||
|
||||
impl FragmentBuffer { |
||||
/// Creates a new fragment buffer that will permit a maximum length of `max_len`.
|
||||
fn new(max_len: usize) -> FragmentBuffer { |
||||
FragmentBuffer { |
||||
fragments: Vec::with_capacity(10), |
||||
fragments_len: 0, |
||||
max_len, |
||||
} |
||||
} |
||||
|
||||
/// Attempts to push a frame into the buffer. This will fail if the new length of the buffer's
|
||||
/// frames exceeds the maximum capacity of `max_len`.
|
||||
fn try_push_frame(&mut self, frame: Frame) -> Result<(), String> { |
||||
let FragmentBuffer { |
||||
fragments, |
||||
fragments_len, |
||||
max_len, |
||||
} = self; |
||||
|
||||
*fragments_len += frame.payload().len(); |
||||
|
||||
if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len { |
||||
return Err(format!( |
||||
"Message too big: {} + {} > {}", |
||||
fragments_len, fragments_len, max_len |
||||
) |
||||
.into()); |
||||
} else { |
||||
fragments.push(frame); |
||||
Ok(()) |
||||
} |
||||
} |
||||
|
||||
/// Returns the total length of all of the frames that have been pushed into the buffer.
|
||||
fn len(&self) -> usize { |
||||
self.fragments_len |
||||
} |
||||
|
||||
/// Returns whether the buffer is empty.
|
||||
fn is_empty(&self) -> bool { |
||||
self.fragments.is_empty() |
||||
} |
||||
|
||||
/// Returns the first element of the fragments slice, or `None` if it is empty.
|
||||
fn first(&self) -> Option<&Frame> { |
||||
self.fragments.first() |
||||
} |
||||
|
||||
/// Drains the buffer and resets it to an initial capacity of 10 elements.
|
||||
fn reset(&mut self) -> Vec<Frame> { |
||||
self.fragments_len = 0; |
||||
replace(&mut self.fragments, Vec::with_capacity(10)) |
||||
} |
||||
} |
@ -0,0 +1,56 @@ |
||||
//! WebSocket extensions
|
||||
|
||||
use http::{Request, Response}; |
||||
|
||||
use crate::protocol::frame::Frame; |
||||
use crate::Message; |
||||
|
||||
/// A permessage-deflate WebSocket extension (RFC 7692).
|
||||
#[cfg(feature = "deflate")] |
||||
pub mod deflate; |
||||
/// An uncompressed message handler for a WebSocket.
|
||||
pub mod uncompressed; |
||||
|
||||
/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions
|
||||
/// may be stacked by nesting them inside one another.
|
||||
pub trait WebSocketExtension { |
||||
/// An error type that the extension produces.
|
||||
type Error: Into<crate::Error>; |
||||
|
||||
/// Constructs a new WebSocket extension that will permit messages of the provided size.
|
||||
fn new(max_message_size: Option<usize>) -> Self; |
||||
|
||||
/// Returns whether or not the extension is enabled.
|
||||
fn enabled(&self) -> bool { |
||||
false |
||||
} |
||||
|
||||
/// For WebSocket clients, this will be called when a `Request` is being constructed.
|
||||
fn on_make_request<T>(&mut self, request: Request<T>) -> Request<T> { |
||||
request |
||||
} |
||||
|
||||
/// For WebSocket server, this will be called when a `Request` has been received.
|
||||
fn on_receive_request<T>( |
||||
&mut self, |
||||
_request: &Request<T>, |
||||
_response: &mut Response<T>, |
||||
) -> Result<(), Self::Error> { |
||||
Ok(()) |
||||
} |
||||
|
||||
/// For WebSocket clients, this will be called when a response from the server has been
|
||||
/// received. If an error is produced, then subsequent calls to `rsv1()` should return `false`.
|
||||
fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> { |
||||
Ok(()) |
||||
} |
||||
|
||||
/// Called when a frame is about to be sent.
|
||||
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> { |
||||
Ok(frame) |
||||
} |
||||
|
||||
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
|
||||
/// type `OpCode::Data`.
|
||||
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error>; |
||||
} |
@ -0,0 +1,104 @@ |
||||
use crate::extensions::WebSocketExtension; |
||||
use crate::protocol::frame::coding::{Data, OpCode}; |
||||
use crate::protocol::frame::Frame; |
||||
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; |
||||
use crate::{Error, Message}; |
||||
use crate::protocol::MAX_MESSAGE_SIZE; |
||||
|
||||
/// An uncompressed message handler for a WebSocket.
|
||||
#[derive(Debug)] |
||||
pub struct UncompressedExt { |
||||
incomplete: Option<IncompleteMessage>, |
||||
max_message_size: Option<usize>, |
||||
} |
||||
|
||||
impl Default for UncompressedExt { |
||||
fn default() -> Self { |
||||
UncompressedExt { |
||||
incomplete: None, |
||||
max_message_size: Some(MAX_MESSAGE_SIZE) |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl UncompressedExt { |
||||
/// Builds a new `UncompressedExt` that will permit a maximum message size of `max_message_size`
|
||||
/// or will be unbounded if `None`.
|
||||
pub fn new(max_message_size: Option<usize>) -> UncompressedExt { |
||||
UncompressedExt { |
||||
incomplete: None, |
||||
max_message_size, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl WebSocketExtension for UncompressedExt { |
||||
type Error = Error; |
||||
|
||||
fn new(max_message_size: Option<usize>) -> Self { |
||||
UncompressedExt { |
||||
incomplete: None, |
||||
max_message_size, |
||||
} |
||||
} |
||||
|
||||
fn enabled(&self) -> bool { |
||||
true |
||||
} |
||||
|
||||
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> { |
||||
let fin = frame.header().is_final; |
||||
|
||||
let hdr = frame.header(); |
||||
|
||||
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { |
||||
return Err(Error::Protocol( |
||||
"Reserved bits are non-zero and no WebSocket extensions are enabled".into(), |
||||
)); |
||||
} |
||||
|
||||
match frame.header().opcode { |
||||
OpCode::Data(data) => match data { |
||||
Data::Continue => { |
||||
if let Some(ref mut msg) = self.incomplete { |
||||
msg.extend(frame.into_data(), self.max_message_size)?; |
||||
} else { |
||||
return Err(Error::Protocol( |
||||
"Continue frame but nothing to continue".into(), |
||||
)); |
||||
} |
||||
if fin { |
||||
Ok(Some(self.incomplete.take().unwrap().complete()?)) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
c if self.incomplete.is_some() => Err(Error::Protocol( |
||||
format!("Received {} while waiting for more fragments", c).into(), |
||||
)), |
||||
Data::Text | Data::Binary => { |
||||
let msg = { |
||||
let message_type = match data { |
||||
Data::Text => IncompleteMessageType::Text, |
||||
Data::Binary => IncompleteMessageType::Binary, |
||||
_ => panic!("Bug: message is not text nor binary"), |
||||
}; |
||||
let mut m = IncompleteMessage::new(message_type); |
||||
m.extend(frame.into_data(), self.max_message_size)?; |
||||
m |
||||
}; |
||||
if fin { |
||||
Ok(Some(msg.complete()?)) |
||||
} else { |
||||
self.incomplete = Some(msg); |
||||
Ok(None) |
||||
} |
||||
} |
||||
Data::Reserved(i) => Err(Error::Protocol( |
||||
format!("Unknown data frame type {}", i).into(), |
||||
)), |
||||
}, |
||||
_ => unreachable!(), |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue