Updates request header parsing

pull/144/head
SirCipher 5 years ago
parent 672572e00a
commit 779514704e
  1. 490
      src/extensions/compression/deflate.rs
  2. 2
      src/extensions/compression/mod.rs
  3. 3
      src/extensions/mod.rs

@ -32,7 +32,7 @@ const LZ77_MIN_WINDOW_SIZE: u8 = 8;
const LZ77_MAX_WINDOW_SIZE: u8 = 15; const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration. /// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug, PartialEq)]
pub struct DeflateConfig { pub struct DeflateConfig {
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user. /// for all normal use-cases but small enough to prevent memory eating by a malicious user.
@ -263,6 +263,8 @@ fn parse_window_parameter<'a>(
max_window_bits: u8, max_window_bits: u8,
) -> Result<Option<u8>, String> { ) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() { if let Some(window_bits_str) = param_iter.next() {
let window_bits_str = window_bits_str.replace("\"", "");
match window_bits_str.trim().parse() { match window_bits_str.trim().parse() {
Ok(window_bits) => { Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
@ -282,12 +284,8 @@ fn parse_window_parameter<'a>(
} }
} }
fn decline<T>(res: &mut Response<T>) {
res.headers_mut().remove(EXT_IDENT);
}
/// A permessage-deflate extension error. /// A permessage-deflate extension error.
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq)]
pub enum DeflateExtensionError { pub enum DeflateExtensionError {
/// An error produced when deflating a message. /// An error produced when deflating a message.
DeflateError(String), DeflateError(String),
@ -299,6 +297,12 @@ pub enum DeflateExtensionError {
Capacity(Cow<'static, str>), Capacity(Cow<'static, str>),
} }
impl DeflateExtensionError {
fn malformatted() -> DeflateExtensionError {
DeflateExtensionError::NegotiationError("Malformatted header value".into())
}
}
impl Display for DeflateExtensionError { impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -342,10 +346,11 @@ pub fn on_response<T>(
Ok(header) => { Ok(header) => {
for param in header.split(';') { for param in header.split(';') {
match param.trim().to_lowercase().as_str() { match param.trim().to_lowercase().as_str() {
"permessage-deflate" => { EXT_IDENT => {
if seen_extension_name { if seen_extension_name {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate" "Duplicate extension parameter: {}",
EXT_IDENT
))); )));
} else { } else {
enabled = true; enabled = true;
@ -455,7 +460,7 @@ pub fn on_response<T>(
} }
/// ///
pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> { pub fn on_make_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> {
let mut header_value = String::from(EXT_IDENT); let mut header_value = String::from(EXT_IDENT);
let DeflateConfig { let DeflateConfig {
@ -488,135 +493,163 @@ pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request
request request
} }
/// fn validate_req_extensions(
pub fn on_receive_request<T>( header: &str,
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig, config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> { ) -> Result<Option<String>, DeflateExtensionError> {
let mut enabled = false; let mut response_str = String::with_capacity(header.len());
let mut param_iter = header.split(';');
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len());
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
for param in header.split(';') { match param_iter.next() {
match param.trim().to_lowercase().as_str() { Some(name) if name.trim() == EXT_IDENT => {
"permessage-deflate" => { response_str.push_str(EXT_IDENT);
enabled = true; }
response_str.push_str("permessage-deflate"); _ => {
} return Ok(None);
"server_no_context_takeover" => { }
if server_takeover { }
decline(response);
} else {
server_takeover = true;
if config.accept_no_context_takeover() {
config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if client_takeover {
decline(response);
} else {
client_takeover = true;
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
decline(response);
} else {
server_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.server_max_window_bits,
) {
Ok(Some(bits)) => {
config.server_max_window_bits = bits;
response_str.push_str("; "); let mut server_takeover = false;
response_str.push_str(param) let mut client_takeover = false;
} let mut server_max_bits = false;
Ok(None) => {} let mut client_max_bits = false;
Err(_) => {
decline(response);
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
decline(response);
} else {
client_max_bits = true;
match parse_window_parameter( while let Some(param) = param_iter.next() {
param.split('=').skip(1), match param.trim().to_lowercase().as_str() {
config.client_max_window_bits, "server_no_context_takeover" => {
) { if server_takeover {
Ok(Some(bits)) => { return Err(DeflateExtensionError::malformatted());
config.client_max_window_bits = bits; } else {
response_str.push_str("; "); server_takeover = true;
response_str.push_str(param); if config.accept_no_context_takeover() {
config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::malformatted());
} else {
client_takeover = true;
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
return Err(DeflateExtensionError::malformatted());
} else {
server_max_bits = true;
continue; match parse_window_parameter(
} param.split('=').skip(1),
Ok(None) => {} config.server_max_window_bits,
Err(_) => { ) {
decline(response); Ok(Some(bits)) => {
} config.server_max_window_bits = bits;
}
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(&format!( response_str.push_str(param)
"client_max_window_bits={}",
config.client_max_window_bits()
))
}
} }
_ => { Ok(None) => {}
decline(response); Err(_) => {
return Err(DeflateExtensionError::malformatted());
} }
} }
} }
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
return Err(DeflateExtensionError::malformatted());
} else {
client_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.client_max_window_bits,
) {
Ok(Some(bits)) => {
config.client_max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(param);
continue;
}
Ok(None) => {}
Err(_) => {
return Err(DeflateExtensionError::malformatted());
}
}
if !response_str.contains("client_no_context_takeover")
&& config.request_no_context_takeover()
{
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; "); response_str.push_str("; ");
response_str.push_str(&format!( response_str.push_str(&format!(
"server_max_window_bits={}", "client_max_window_bits={}",
config.server_max_window_bits() config.client_max_window_bits()
)) ))
} }
}
p => {
return Err(DeflateExtensionError::NegotiationError(
format!("Unknown permessage-deflate parameter: {}", p).into(),
))
}
}
}
if !response_str.contains("client_max_window_bits") if !response_str.contains("client_no_context_takeover") && config.request_no_context_takeover()
&& config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE {
{ config.decompress_reset = true;
continue; response_str.push_str("; client_no_context_takeover");
} }
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
config.server_max_window_bits()
))
}
response.headers_mut().insert( if !response_str.contains("client_max_window_bits")
SEC_WEBSOCKET_EXTENSIONS, && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE
HeaderValue::from_str(&response_str)?, {
); return Ok(None);
}
Ok(Some(response_str))
}
Ok(enabled) ///
pub fn on_receive_request<T>(
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
for header in header.split(',') {
let mut parser_config = config.clone();
match validate_req_extensions(header, &mut parser_config) {
Ok(Some(response_str)) => {
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
*config = parser_config;
return Ok(true);
}
Ok(None) => continue,
Err(e) => {
response.headers_mut().remove(EXT_IDENT);
return Err(e);
}
}
}
Ok(false)
} }
Err(e) => Err(DeflateExtensionError::NegotiationError(format!( Err(e) => Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}", "Failed to parse request header: {}",
@ -625,7 +658,6 @@ pub fn on_receive_request<T>(
}; };
} }
decline(response);
Ok(false) Ok(false)
} }
@ -964,3 +996,219 @@ impl FragmentBuffer {
) )
} }
} }
#[cfg(test)]
mod tests {
use crate::extensions::compression::deflate::{
on_receive_request, DeflateConfig, DeflateExtensionError, LZ77_MIN_WINDOW_SIZE,
};
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
#[test]
fn config_unchanged_on_err() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn config_unchanged_on_mismatch() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: LZ77_MIN_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn parses_named_parameters() {
let s ="permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 8,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: true,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=\"8\"");
}
#[test]
fn splits() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
#[test]
fn splits_on_new_line() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
}

@ -117,7 +117,7 @@ pub fn build_compression_headers<T>(
Some(ref mut config) => match &config.compression { Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request, WsCompression::None(_) => request,
#[cfg(feature = "deflate")] #[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config), WsCompression::Deflate(config) => deflate::on_make_request(request, config),
}, },
None => request, None => request,
} }

@ -14,8 +14,7 @@ pub trait WebSocketExtension {
Ok(frame) Ok(frame)
} }
/// Called when a frame has been received and unmasked. The frame provided frame will be of the /// Called when a WebSocket frame has been received.
/// type `OpCode::Data`.
fn on_receive_frame( fn on_receive_frame(
&mut self, &mut self,
data_opcode: Data, data_opcode: Data,

Loading…
Cancel
Save