Integrates permessage-deflate into servers

pull/144/head
SirCipher 5 years ago
parent d6f49547b5
commit 77ffd34cd2
  1. 2806
      autobahn/server-results.json
  2. 2
      scripts/autobahn-server.sh
  3. 168
      src/extensions/deflate.rs
  4. 10
      src/extensions/mod.rs
  5. 2
      src/handshake/client.rs
  6. 13
      src/handshake/server.rs

File diff suppressed because it is too large Load Diff

@ -14,7 +14,7 @@ trap cleanup TERM EXIT
function test_diff() { function test_diff() {
if ! diff -q \ if ! diff -q \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/client-results.json') \ <(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server-results.json') \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server/index.json') <(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server/index.json')
then then
echo Difference in results, either this is a regression or \ echo Difference in results, either this is a regression or \

@ -13,7 +13,7 @@ use flate2::{
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress, Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress,
FlushDecompress, Status, FlushDecompress, Status,
}; };
use http::header::SEC_WEBSOCKET_EXTENSIONS; use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS};
use http::{HeaderValue, Request, Response}; use http::{HeaderValue, Request, Response};
use std::mem::replace; use std::mem::replace;
use std::slice; use std::slice;
@ -100,6 +100,11 @@ impl DeflateExt {
Ok(None) Ok(None)
} }
} }
fn decline<T>(&mut self, res: &mut Response<T>) {
self.enabled = false;
res.headers_mut().remove(EXT_NAME);
}
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
@ -165,6 +170,12 @@ impl From<DeflateExtensionError> for crate::Error {
} }
} }
impl From<InvalidHeaderValue> for DeflateExtensionError {
fn from(e: InvalidHeaderValue) -> Self {
DeflateExtensionError::NegotiationError(e.to_string())
}
}
const EXT_NAME: &str = "permessage-deflate"; const EXT_NAME: &str = "permessage-deflate";
impl WebSocketExtension for DeflateExt { impl WebSocketExtension for DeflateExt {
@ -182,7 +193,7 @@ impl WebSocketExtension for DeflateExt {
} }
} }
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> { fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_NAME); let mut header_value = String::from(EXT_NAME);
let DeflateConfig { let DeflateConfig {
max_window_bits, max_window_bits,
@ -211,6 +222,159 @@ impl WebSocketExtension for DeflateExt {
request request
} }
fn on_receive_request<T>(
&mut self,
request: &Request<T>,
response: &mut Response<T>,
) -> Result<(), Self::Error> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
match header.to_str() {
Ok(header) => {
let mut res_ext = String::with_capacity(header.len());
let mut s_takeover = false;
let mut c_takeover = false;
let mut s_max = false;
let mut c_max = false;
for param in header.split(';') {
match param.trim() {
"permessage-deflate" => res_ext.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if s_takeover {
self.decline(response);
} else {
s_takeover = true;
if self.config.accept_no_context_takeover {
self.config.compress_reset = true;
res_ext.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if c_takeover {
self.decline(response);
} else {
c_takeover = true;
self.config.decompress_reset = true;
res_ext.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if s_max {
self.decline(response);
} else {
s_max = true;
let mut param_iter = param.split('=');
param_iter.next(); // we already know the name
if let Some(window_bits_str) = param_iter.next() {
if let Ok(window_bits) = window_bits_str.trim().parse() {
if window_bits >= 9 && window_bits <= 15 {
if window_bits < self.config.max_window_bits {
self.deflator = Deflator {
compress: Compress::new_with_window_bits(
self.config.compression_level,
false,
window_bits,
),
};
res_ext.push_str("; ");
res_ext.push_str(param)
}
} else {
self.decline(response);
}
} else {
self.decline(response);
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if c_max {
self.decline(response);
} else {
c_max = true;
let mut param_iter = param.split('=');
param_iter.next(); // we already know the name
if let Some(window_bits_str) = param_iter.next() {
if let Ok(window_bits) = window_bits_str.trim().parse() {
if window_bits >= 9 && window_bits <= 15 {
if window_bits < self.config.max_window_bits {
self.inflator = Inflator {
decompress:
Decompress::new_with_window_bits(
false,
window_bits,
),
};
res_ext.push_str("; ");
res_ext.push_str(param);
continue;
}
} else {
self.decline(response);
}
} else {
self.decline(response);
}
}
res_ext.push_str("; ");
res_ext.push_str(&format!(
"client_max_window_bits={}",
self.config.max_window_bits
))
}
}
_ => {
// decline all extension offers because we got a bad parameter
self.decline(response);
}
}
}
if !res_ext.contains("client_no_context_takeover")
&& self.config.request_no_context_takeover
{
self.config.decompress_reset = true;
res_ext.push_str("; client_no_context_takeover");
}
if !res_ext.contains("server_max_window_bits") {
res_ext.push_str("; ");
res_ext.push_str(&format!(
"server_max_window_bits={}",
self.config.max_window_bits
))
}
if !res_ext.contains("client_max_window_bits")
&& self.config.max_window_bits < 15
{
continue;
}
response
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_str(&res_ext)?);
self.enabled = true;
return Ok(());
}
Err(e) => {
self.enabled = false;
return Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse header: {}",
e,
)));
}
}
}
self.decline(response);
Ok(())
}
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> { fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> {
let mut extension_name = false; let mut extension_name = false;
let mut server_takeover = false; let mut server_takeover = false;

@ -20,10 +20,18 @@ pub trait WebSocketExtension: Default + Clone {
false false
} }
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> { fn on_make_request<T>(&mut self, request: Request<T>) -> Request<T> {
request request
} }
fn on_receive_request<T>(
&mut self,
_request: &Request<T>,
_response: &mut Response<T>,
) -> Result<(), Self::Error> {
Ok(())
}
fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> { fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> {
Ok(()) Ok(())
} }

@ -122,7 +122,7 @@ where
E: WebSocketExtension, E: WebSocketExtension,
{ {
let request = match config { let request = match config {
Some(ref mut config) => config.encoder.on_request(request), Some(ref mut config) => config.encoder.on_make_request(request),
None => request, None => request,
}; };
let mut req = Vec::new(); let mut req = Vec::new();

@ -245,16 +245,23 @@ where
Ok(match finish { Ok(match finish {
StageResult::DoneReading { StageResult::DoneReading {
stream, stream,
result, result: request,
tail, tail,
} => { } => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol("Junk after client request".into()));
} }
let response = create_response(&result)?; let mut response = create_response(&request)?;
if let Some(ref mut config) = self.config {
if let Err(e) = config.encoder.on_receive_request(&request, &mut response) {
return Err(e.into());
}
}
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&request, response)
} else { } else {
Ok(response) Ok(response)
}; };

Loading…
Cancel
Save