parent
006cec72ea
commit
e93f9a01e6
@ -0,0 +1,155 @@ |
||||
//!
|
||||
|
||||
use std::fmt::{Debug, Display, Formatter}; |
||||
|
||||
use crate::extensions::deflate::{DeflateConfig, DeflateExtension}; |
||||
use crate::extensions::WebSocketExtension; |
||||
use crate::protocol::frame::Frame; |
||||
use http::header::SEC_WEBSOCKET_EXTENSIONS; |
||||
use http::{HeaderValue, Request, Response}; |
||||
|
||||
#[derive(Copy, Clone, Debug)] |
||||
pub enum CompressionConfig { |
||||
Uncompressed, |
||||
Deflate(DeflateConfig), |
||||
} |
||||
|
||||
impl CompressionConfig { |
||||
pub fn into_strategy(self) -> CompressionStrategy { |
||||
match self { |
||||
Self::Uncompressed => CompressionStrategy::Uncompressed, |
||||
Self::Deflate(_config) => CompressionStrategy::Deflate(DeflateExtension::new()), |
||||
} |
||||
} |
||||
|
||||
pub fn uncompressed() -> CompressionConfig { |
||||
CompressionConfig::Uncompressed |
||||
} |
||||
|
||||
pub fn deflate() -> CompressionConfig { |
||||
CompressionConfig::Deflate(DeflateConfig::default()) |
||||
} |
||||
} |
||||
|
||||
pub enum CompressionStrategy { |
||||
Uncompressed, |
||||
Deflate(DeflateExtension), |
||||
} |
||||
|
||||
#[derive(Debug, Clone)] |
||||
pub struct CompressionExtensionError(String); |
||||
|
||||
impl std::error::Error for CompressionExtensionError {} |
||||
|
||||
impl Display for CompressionExtensionError { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
||||
write!(f, "{}", self.0) |
||||
} |
||||
} |
||||
|
||||
impl From<CompressionExtensionError> for crate::Error { |
||||
fn from(e: CompressionExtensionError) -> Self { |
||||
crate::Error::ExtensionError(Box::new(e)) |
||||
} |
||||
} |
||||
|
||||
impl WebSocketExtension for CompressionStrategy { |
||||
type Error = CompressionExtensionError; |
||||
|
||||
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> { |
||||
match self { |
||||
Self::Uncompressed => request, |
||||
Self::Deflate(de) => de.on_request(request), |
||||
} |
||||
} |
||||
|
||||
fn on_response<T>(&mut self, response: &Response<T>) { |
||||
match self { |
||||
Self::Uncompressed => {} |
||||
Self::Deflate(de) => de.on_response(response), |
||||
} |
||||
} |
||||
|
||||
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> { |
||||
match self { |
||||
Self::Uncompressed => Ok(frame), |
||||
Self::Deflate(de) => de |
||||
.on_send_frame(frame) |
||||
.map_err(|e| CompressionExtensionError(e.to_string())), |
||||
} |
||||
} |
||||
|
||||
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Frame>, Self::Error> { |
||||
match self { |
||||
Self::Uncompressed => Ok(Some(frame)), |
||||
Self::Deflate(de) => de |
||||
.on_receive_frame(frame) |
||||
.map_err(|e| CompressionExtensionError(e.to_string())), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl Debug for CompressionStrategy { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
||||
match self { |
||||
Self::Uncompressed => f.debug_struct("Uncompressed").finish(), |
||||
Self::Deflate(_) => f.debug_struct("Deflate").finish(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl CompressionConfig { |
||||
fn as_header_value(&self) -> Option<HeaderValue> { |
||||
match self { |
||||
Self::Uncompressed => None, |
||||
Self::Deflate(_) => Some(HeaderValue::from_static("permessage-deflate")), |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Debug, Clone, Copy)] |
||||
pub struct CompressionSelectorError(&'static str); |
||||
|
||||
impl std::error::Error for CompressionSelectorError {} |
||||
|
||||
impl Display for CompressionSelectorError { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
||||
write!(f, "{}", self.0) |
||||
} |
||||
} |
||||
|
||||
impl From<CompressionSelectorError> for crate::Error { |
||||
fn from(e: CompressionSelectorError) -> Self { |
||||
crate::Error::ExtensionError(Box::new(e)) |
||||
} |
||||
} |
||||
|
||||
impl WebSocketExtension for CompressionConfig { |
||||
type Error = CompressionSelectorError; |
||||
|
||||
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> { |
||||
if let Some(header_value) = self.as_header_value() { |
||||
request |
||||
.headers_mut() |
||||
.append(SEC_WEBSOCKET_EXTENSIONS, header_value); |
||||
} |
||||
|
||||
request |
||||
} |
||||
|
||||
fn on_response<T>(&mut self, response: &Response<T>) { |
||||
let mut iter = response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter(); |
||||
|
||||
let self_header = match self.as_header_value() { |
||||
Some(hv) => hv, |
||||
None => return, |
||||
}; |
||||
|
||||
match iter.next() { |
||||
Some(hv) if hv == self_header => {} |
||||
_ => { |
||||
*self = CompressionConfig::Uncompressed; |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,278 @@ |
||||
//! Permessage-deflate extension
|
||||
|
||||
use std::fmt::{Display, Formatter}; |
||||
|
||||
use crate::extensions::WebSocketExtension; |
||||
use crate::protocol::frame::coding::{Data, OpCode}; |
||||
use crate::protocol::frame::Frame; |
||||
use flate2::{Compress, CompressError, Compression, Decompress, DecompressError}; |
||||
use std::mem::replace; |
||||
|
||||
pub struct DeflateExtension { |
||||
pub(crate) config: DeflateConfig, |
||||
pub(crate) fragments: Vec<Frame>, |
||||
inflator: Inflator, |
||||
deflator: Deflator, |
||||
} |
||||
|
||||
impl DeflateExtension { |
||||
pub fn new() -> DeflateExtension { |
||||
DeflateExtension { |
||||
config: Default::default(), |
||||
fragments: vec![], |
||||
inflator: Inflator::new(), |
||||
deflator: Deflator::new(Compression::best()), |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug)] |
||||
pub struct DeflateConfig { |
||||
/// The max size of the sliding window. If the other endpoint selects a smaller size, that size
|
||||
/// will be used instead. This must be an integer between 9 and 15 inclusive.
|
||||
/// Default: 15
|
||||
pub max_window_bits: u8, |
||||
/// Indicates whether to ask the other endpoint to reset the sliding window for each message.
|
||||
/// Default: false
|
||||
pub request_no_context_takeover: bool, |
||||
/// Indicates whether this endpoint will agree to reset the sliding window for each message it
|
||||
/// compresses. If this endpoint won't agree to reset the sliding window, then the handshake
|
||||
/// will fail if this endpoint is a client and the server requests no context takeover.
|
||||
/// Default: true
|
||||
pub accept_no_context_takeover: bool, |
||||
/// The number of WebSocket frames to store when defragmenting an incoming fragmented
|
||||
/// compressed message.
|
||||
/// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to
|
||||
/// allow for differences between compressed and uncompressed messages.
|
||||
/// Default: 10
|
||||
pub fragments_capacity: usize, |
||||
/// Indicates whether the extension handler will reallocate if the `fragments_capacity` is
|
||||
/// exceeded. If this is not true, a capacity error will be triggered instead.
|
||||
/// Default: true
|
||||
pub fragments_grow: bool, |
||||
compress_reset: bool, |
||||
decompress_reset: bool, |
||||
} |
||||
|
||||
impl Default for DeflateConfig { |
||||
fn default() -> Self { |
||||
DeflateConfig { |
||||
max_window_bits: 15, |
||||
request_no_context_takeover: false, |
||||
accept_no_context_takeover: true, |
||||
fragments_capacity: 10, |
||||
fragments_grow: true, |
||||
compress_reset: false, |
||||
decompress_reset: false, |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Debug, Clone)] |
||||
pub enum DeflateExtensionError { |
||||
DeflateError(String), |
||||
InflateError(String), |
||||
} |
||||
|
||||
impl Display for DeflateExtensionError { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
||||
match self { |
||||
DeflateExtensionError::DeflateError(m) => write!(f, "{}", m), |
||||
DeflateExtensionError::InflateError(m) => write!(f, "{}", m), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl std::error::Error for DeflateExtensionError {} |
||||
|
||||
impl From<DeflateExtensionError> for crate::Error { |
||||
fn from(e: DeflateExtensionError) -> Self { |
||||
crate::Error::ExtensionError(Box::new(e)) |
||||
} |
||||
} |
||||
|
||||
impl WebSocketExtension for DeflateExtension { |
||||
type Error = DeflateExtensionError; |
||||
|
||||
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> { |
||||
if let OpCode::Data(_) = frame.header().opcode { |
||||
frame.header_mut().rsv1 = true; |
||||
|
||||
let mut compressed = Vec::with_capacity(frame.payload().len()); |
||||
self.deflator.compress(frame.payload(), &mut compressed)?; |
||||
|
||||
let len = compressed.len(); |
||||
compressed.truncate(len - 4); |
||||
|
||||
*frame.payload_mut() = compressed; |
||||
|
||||
if self.config.compress_reset { |
||||
self.deflator.reset(); |
||||
} |
||||
} |
||||
|
||||
Ok(frame) |
||||
} |
||||
|
||||
fn on_receive_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>, Self::Error> { |
||||
if frame.header().rsv1 { |
||||
frame.header_mut().rsv1 = false; |
||||
|
||||
if !frame.header().is_final { |
||||
self.fragments.push(frame); |
||||
return Ok(None); |
||||
} else { |
||||
if let OpCode::Data(Data::Continue) = frame.header().opcode { |
||||
if !self.config.fragments_grow |
||||
&& self.config.fragments_capacity == self.fragments.len() |
||||
{ |
||||
return Err(DeflateExtensionError::DeflateError( |
||||
"Exceeded max fragments.".into(), |
||||
)); |
||||
} else { |
||||
self.fragments.push(frame); |
||||
} |
||||
|
||||
let opcode = self.fragments.first().unwrap().header().opcode; |
||||
let size = self |
||||
.fragments |
||||
.iter() |
||||
.fold(0, |len, frame| len + frame.payload().len()); |
||||
let mut compressed = Vec::with_capacity(size); |
||||
let decompressed = Vec::with_capacity(size * 2); |
||||
|
||||
replace( |
||||
&mut self.fragments, |
||||
Vec::with_capacity(self.config.fragments_capacity), |
||||
) |
||||
.into_iter() |
||||
.for_each(|f| { |
||||
compressed.extend(f.into_data()); |
||||
}); |
||||
|
||||
compressed.extend(&[0, 0, 255, 255]); |
||||
frame = Frame::message(decompressed, opcode, true); |
||||
} else { |
||||
frame.payload_mut().extend(&[0, 0, 255, 255]); |
||||
|
||||
let mut decompress_output = Vec::with_capacity(frame.payload().len() * 2); |
||||
self.inflator |
||||
.decompress(frame.payload(), &mut decompress_output)?; |
||||
|
||||
*frame.payload_mut() = decompress_output; |
||||
} |
||||
|
||||
if self.config.decompress_reset { |
||||
self.inflator.reset(false); |
||||
} |
||||
} |
||||
} |
||||
|
||||
Ok(Some(frame)) |
||||
} |
||||
} |
||||
|
||||
impl From<DecompressError> for DeflateExtensionError { |
||||
fn from(e: DecompressError) -> Self { |
||||
DeflateExtensionError::InflateError(e.to_string()) |
||||
} |
||||
} |
||||
|
||||
impl From<CompressError> for DeflateExtensionError { |
||||
fn from(e: CompressError) -> Self { |
||||
DeflateExtensionError::DeflateError(e.to_string()) |
||||
} |
||||
} |
||||
|
||||
struct Deflator { |
||||
compress: Compress, |
||||
} |
||||
|
||||
impl Deflator { |
||||
pub fn new(compresion: Compression) -> Deflator { |
||||
Deflator { |
||||
compress: Compress::new(compresion, false), |
||||
} |
||||
} |
||||
|
||||
fn reset(&mut self) { |
||||
self.compress.reset() |
||||
} |
||||
|
||||
pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<usize, CompressError> { |
||||
loop { |
||||
let before_in = self.compress.total_in(); |
||||
output.reserve(256); |
||||
let status = self |
||||
.compress |
||||
.compress_vec(input, output, flate2::FlushCompress::Sync)?; |
||||
let written = (self.compress.total_in() - before_in) as usize; |
||||
|
||||
if written != 0 || status == flate2::Status::StreamEnd { |
||||
return Ok(written); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
struct Inflator { |
||||
decompress: Decompress, |
||||
} |
||||
|
||||
impl Inflator { |
||||
pub fn new() -> Inflator { |
||||
Inflator { |
||||
decompress: Decompress::new(false), |
||||
} |
||||
} |
||||
|
||||
fn reset(&mut self, zlib_header: bool) { |
||||
self.decompress.reset(zlib_header) |
||||
} |
||||
|
||||
pub fn decompress( |
||||
&mut self, |
||||
input: &[u8], |
||||
output: &mut Vec<u8>, |
||||
) -> Result<usize, DecompressError> { |
||||
let mut read_buff = Vec::from(input); |
||||
let mut eof = false; |
||||
|
||||
loop { |
||||
if read_buff.is_empty() { |
||||
eof = true; |
||||
} |
||||
|
||||
if !eof && output.is_empty() { |
||||
output.reserve(256); |
||||
|
||||
unsafe { |
||||
output.set_len(output.capacity()); |
||||
} |
||||
} |
||||
|
||||
let before_out = self.decompress.total_out(); |
||||
let before_in = self.decompress.total_in(); |
||||
|
||||
let decompression_strategy = if eof { |
||||
flate2::FlushDecompress::Finish |
||||
} else { |
||||
flate2::FlushDecompress::None |
||||
}; |
||||
|
||||
let status = self |
||||
.decompress |
||||
.decompress(&read_buff, output, decompression_strategy)?; |
||||
|
||||
let consumed = (self.decompress.total_in() - before_in) as usize; |
||||
read_buff = read_buff.split_off(consumed); |
||||
|
||||
let read = (self.decompress.total_out() - before_out) as usize; |
||||
|
||||
if read != 0 || status == flate2::Status::StreamEnd { |
||||
output.truncate(read); |
||||
return Ok(read); |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,26 @@ |
||||
//! WebSocket extensions
|
||||
|
||||
use http::{Request, Response}; |
||||
|
||||
use crate::protocol::frame::Frame; |
||||
|
||||
pub mod compression; |
||||
pub mod deflate; |
||||
|
||||
pub trait WebSocketExtension { |
||||
type Error: Into<crate::Error>; |
||||
|
||||
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> { |
||||
request |
||||
} |
||||
|
||||
fn on_response<T>(&mut self, _response: &Response<T>) {} |
||||
|
||||
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> { |
||||
Ok(frame) |
||||
} |
||||
|
||||
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Frame>, Self::Error> { |
||||
Ok(Some(frame)) |
||||
} |
||||
} |
Loading…
Reference in new issue