Integrates permessage-deflate into servers

pull/144/head
SirCipher 5 years ago
parent d6f49547b5
commit 77ffd34cd2
  1. 2806
      autobahn/server-results.json
  2. 2
      scripts/autobahn-server.sh
  3. 168
      src/extensions/deflate.rs
  4. 10
      src/extensions/mod.rs
  5. 2
      src/handshake/client.rs
  6. 13
      src/handshake/server.rs

File diff suppressed because it is too large Load Diff

@ -14,7 +14,7 @@ trap cleanup TERM EXIT
function test_diff() {
if ! diff -q \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/client-results.json') \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server-results.json') \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server/index.json')
then
echo Difference in results, either this is a regression or \

@ -13,7 +13,7 @@ use flate2::{
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress,
FlushDecompress, Status,
};
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS};
use http::{HeaderValue, Request, Response};
use std::mem::replace;
use std::slice;
@ -100,6 +100,11 @@ impl DeflateExt {
Ok(None)
}
}
fn decline<T>(&mut self, res: &mut Response<T>) {
self.enabled = false;
res.headers_mut().remove(EXT_NAME);
}
}
#[derive(Clone, Copy, Debug)]
@ -165,6 +170,12 @@ impl From<DeflateExtensionError> for crate::Error {
}
}
impl From<InvalidHeaderValue> for DeflateExtensionError {
fn from(e: InvalidHeaderValue) -> Self {
DeflateExtensionError::NegotiationError(e.to_string())
}
}
const EXT_NAME: &str = "permessage-deflate";
impl WebSocketExtension for DeflateExt {
@ -182,7 +193,7 @@ impl WebSocketExtension for DeflateExt {
}
}
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_NAME);
let DeflateConfig {
max_window_bits,
@ -211,6 +222,159 @@ impl WebSocketExtension for DeflateExt {
request
}
fn on_receive_request<T>(
&mut self,
request: &Request<T>,
response: &mut Response<T>,
) -> Result<(), Self::Error> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
match header.to_str() {
Ok(header) => {
let mut res_ext = String::with_capacity(header.len());
let mut s_takeover = false;
let mut c_takeover = false;
let mut s_max = false;
let mut c_max = false;
for param in header.split(';') {
match param.trim() {
"permessage-deflate" => res_ext.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if s_takeover {
self.decline(response);
} else {
s_takeover = true;
if self.config.accept_no_context_takeover {
self.config.compress_reset = true;
res_ext.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if c_takeover {
self.decline(response);
} else {
c_takeover = true;
self.config.decompress_reset = true;
res_ext.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if s_max {
self.decline(response);
} else {
s_max = true;
let mut param_iter = param.split('=');
param_iter.next(); // we already know the name
if let Some(window_bits_str) = param_iter.next() {
if let Ok(window_bits) = window_bits_str.trim().parse() {
if window_bits >= 9 && window_bits <= 15 {
if window_bits < self.config.max_window_bits {
self.deflator = Deflator {
compress: Compress::new_with_window_bits(
self.config.compression_level,
false,
window_bits,
),
};
res_ext.push_str("; ");
res_ext.push_str(param)
}
} else {
self.decline(response);
}
} else {
self.decline(response);
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if c_max {
self.decline(response);
} else {
c_max = true;
let mut param_iter = param.split('=');
param_iter.next(); // we already know the name
if let Some(window_bits_str) = param_iter.next() {
if let Ok(window_bits) = window_bits_str.trim().parse() {
if window_bits >= 9 && window_bits <= 15 {
if window_bits < self.config.max_window_bits {
self.inflator = Inflator {
decompress:
Decompress::new_with_window_bits(
false,
window_bits,
),
};
res_ext.push_str("; ");
res_ext.push_str(param);
continue;
}
} else {
self.decline(response);
}
} else {
self.decline(response);
}
}
res_ext.push_str("; ");
res_ext.push_str(&format!(
"client_max_window_bits={}",
self.config.max_window_bits
))
}
}
_ => {
// decline all extension offers because we got a bad parameter
self.decline(response);
}
}
}
if !res_ext.contains("client_no_context_takeover")
&& self.config.request_no_context_takeover
{
self.config.decompress_reset = true;
res_ext.push_str("; client_no_context_takeover");
}
if !res_ext.contains("server_max_window_bits") {
res_ext.push_str("; ");
res_ext.push_str(&format!(
"server_max_window_bits={}",
self.config.max_window_bits
))
}
if !res_ext.contains("client_max_window_bits")
&& self.config.max_window_bits < 15
{
continue;
}
response
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_str(&res_ext)?);
self.enabled = true;
return Ok(());
}
Err(e) => {
self.enabled = false;
return Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse header: {}",
e,
)));
}
}
}
self.decline(response);
Ok(())
}
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> {
let mut extension_name = false;
let mut server_takeover = false;

@ -20,10 +20,18 @@ pub trait WebSocketExtension: Default + Clone {
false
}
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {
fn on_make_request<T>(&mut self, request: Request<T>) -> Request<T> {
request
}
fn on_receive_request<T>(
&mut self,
_request: &Request<T>,
_response: &mut Response<T>,
) -> Result<(), Self::Error> {
Ok(())
}
fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> {
Ok(())
}

@ -122,7 +122,7 @@ where
E: WebSocketExtension,
{
let request = match config {
Some(ref mut config) => config.encoder.on_request(request),
Some(ref mut config) => config.encoder.on_make_request(request),
None => request,
};
let mut req = Vec::new();

@ -245,16 +245,23 @@ where
Ok(match finish {
StageResult::DoneReading {
stream,
result,
result: request,
tail,
} => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()));
}
let response = create_response(&result)?;
let mut response = create_response(&request)?;
if let Some(ref mut config) = self.config {
if let Err(e) = config.encoder.on_receive_request(&request, &mut response) {
return Err(e.into());
}
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
callback.on_request(&request, response)
} else {
Ok(response)
};

Loading…
Cancel
Save