|
|
|
@ -13,7 +13,7 @@ use flate2::{ |
|
|
|
|
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress, |
|
|
|
|
FlushDecompress, Status, |
|
|
|
|
}; |
|
|
|
|
use http::header::SEC_WEBSOCKET_EXTENSIONS; |
|
|
|
|
use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS}; |
|
|
|
|
use http::{HeaderValue, Request, Response}; |
|
|
|
|
use std::mem::replace; |
|
|
|
|
use std::slice; |
|
|
|
@ -100,6 +100,11 @@ impl DeflateExt { |
|
|
|
|
Ok(None) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn decline<T>(&mut self, res: &mut Response<T>) { |
|
|
|
|
self.enabled = false; |
|
|
|
|
res.headers_mut().remove(EXT_NAME); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug)] |
|
|
|
@ -165,6 +170,12 @@ impl From<DeflateExtensionError> for crate::Error { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl From<InvalidHeaderValue> for DeflateExtensionError { |
|
|
|
|
fn from(e: InvalidHeaderValue) -> Self { |
|
|
|
|
DeflateExtensionError::NegotiationError(e.to_string()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
const EXT_NAME: &str = "permessage-deflate"; |
|
|
|
|
|
|
|
|
|
impl WebSocketExtension for DeflateExt { |
|
|
|
@ -182,7 +193,7 @@ impl WebSocketExtension for DeflateExt { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> { |
|
|
|
|
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> { |
|
|
|
|
let mut header_value = String::from(EXT_NAME); |
|
|
|
|
let DeflateConfig { |
|
|
|
|
max_window_bits, |
|
|
|
@ -211,6 +222,159 @@ impl WebSocketExtension for DeflateExt { |
|
|
|
|
request |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn on_receive_request<T>( |
|
|
|
|
&mut self, |
|
|
|
|
request: &Request<T>, |
|
|
|
|
response: &mut Response<T>, |
|
|
|
|
) -> Result<(), Self::Error> { |
|
|
|
|
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) { |
|
|
|
|
match header.to_str() { |
|
|
|
|
Ok(header) => { |
|
|
|
|
let mut res_ext = String::with_capacity(header.len()); |
|
|
|
|
let mut s_takeover = false; |
|
|
|
|
let mut c_takeover = false; |
|
|
|
|
let mut s_max = false; |
|
|
|
|
let mut c_max = false; |
|
|
|
|
|
|
|
|
|
for param in header.split(';') { |
|
|
|
|
match param.trim() { |
|
|
|
|
"permessage-deflate" => res_ext.push_str("permessage-deflate"), |
|
|
|
|
"server_no_context_takeover" => { |
|
|
|
|
if s_takeover { |
|
|
|
|
self.decline(response); |
|
|
|
|
} else { |
|
|
|
|
s_takeover = true; |
|
|
|
|
if self.config.accept_no_context_takeover { |
|
|
|
|
self.config.compress_reset = true; |
|
|
|
|
res_ext.push_str("; server_no_context_takeover"); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
"client_no_context_takeover" => { |
|
|
|
|
if c_takeover { |
|
|
|
|
self.decline(response); |
|
|
|
|
} else { |
|
|
|
|
c_takeover = true; |
|
|
|
|
self.config.decompress_reset = true; |
|
|
|
|
res_ext.push_str("; client_no_context_takeover"); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
param if param.starts_with("server_max_window_bits") => { |
|
|
|
|
if s_max { |
|
|
|
|
self.decline(response); |
|
|
|
|
} else { |
|
|
|
|
s_max = true; |
|
|
|
|
let mut param_iter = param.split('='); |
|
|
|
|
param_iter.next(); // we already know the name
|
|
|
|
|
if let Some(window_bits_str) = param_iter.next() { |
|
|
|
|
if let Ok(window_bits) = window_bits_str.trim().parse() { |
|
|
|
|
if window_bits >= 9 && window_bits <= 15 { |
|
|
|
|
if window_bits < self.config.max_window_bits { |
|
|
|
|
self.deflator = Deflator { |
|
|
|
|
compress: Compress::new_with_window_bits( |
|
|
|
|
self.config.compression_level, |
|
|
|
|
false, |
|
|
|
|
window_bits, |
|
|
|
|
), |
|
|
|
|
}; |
|
|
|
|
res_ext.push_str("; "); |
|
|
|
|
res_ext.push_str(param) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
self.decline(response); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
self.decline(response); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
param if param.starts_with("client_max_window_bits") => { |
|
|
|
|
if c_max { |
|
|
|
|
self.decline(response); |
|
|
|
|
} else { |
|
|
|
|
c_max = true; |
|
|
|
|
let mut param_iter = param.split('='); |
|
|
|
|
param_iter.next(); // we already know the name
|
|
|
|
|
if let Some(window_bits_str) = param_iter.next() { |
|
|
|
|
if let Ok(window_bits) = window_bits_str.trim().parse() { |
|
|
|
|
if window_bits >= 9 && window_bits <= 15 { |
|
|
|
|
if window_bits < self.config.max_window_bits { |
|
|
|
|
self.inflator = Inflator { |
|
|
|
|
decompress: |
|
|
|
|
Decompress::new_with_window_bits( |
|
|
|
|
false, |
|
|
|
|
window_bits, |
|
|
|
|
), |
|
|
|
|
}; |
|
|
|
|
res_ext.push_str("; "); |
|
|
|
|
res_ext.push_str(param); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
self.decline(response); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
self.decline(response); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
res_ext.push_str("; "); |
|
|
|
|
res_ext.push_str(&format!( |
|
|
|
|
"client_max_window_bits={}", |
|
|
|
|
self.config.max_window_bits |
|
|
|
|
)) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
_ => { |
|
|
|
|
// decline all extension offers because we got a bad parameter
|
|
|
|
|
self.decline(response); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !res_ext.contains("client_no_context_takeover") |
|
|
|
|
&& self.config.request_no_context_takeover |
|
|
|
|
{ |
|
|
|
|
self.config.decompress_reset = true; |
|
|
|
|
res_ext.push_str("; client_no_context_takeover"); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !res_ext.contains("server_max_window_bits") { |
|
|
|
|
res_ext.push_str("; "); |
|
|
|
|
res_ext.push_str(&format!( |
|
|
|
|
"server_max_window_bits={}", |
|
|
|
|
self.config.max_window_bits |
|
|
|
|
)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !res_ext.contains("client_max_window_bits") |
|
|
|
|
&& self.config.max_window_bits < 15 |
|
|
|
|
{ |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
response |
|
|
|
|
.headers_mut() |
|
|
|
|
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_str(&res_ext)?); |
|
|
|
|
|
|
|
|
|
self.enabled = true; |
|
|
|
|
|
|
|
|
|
return Ok(()); |
|
|
|
|
} |
|
|
|
|
Err(e) => { |
|
|
|
|
self.enabled = false; |
|
|
|
|
return Err(DeflateExtensionError::NegotiationError(format!( |
|
|
|
|
"Failed to parse header: {}", |
|
|
|
|
e, |
|
|
|
|
))); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
self.decline(response); |
|
|
|
|
Ok(()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> { |
|
|
|
|
let mut extension_name = false; |
|
|
|
|
let mut server_takeover = false; |
|
|
|
|