Updates request header parsing

pull/144/head
SirCipher 5 years ago
parent 672572e00a
commit 779514704e
  1. 322
      src/extensions/compression/deflate.rs
  2. 2
      src/extensions/compression/mod.rs
  3. 3
      src/extensions/mod.rs

@ -32,7 +32,7 @@ const LZ77_MIN_WINDOW_SIZE: u8 = 8;
const LZ77_MAX_WINDOW_SIZE: u8 = 15; const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration. /// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug, PartialEq)]
pub struct DeflateConfig { pub struct DeflateConfig {
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big /// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user. /// for all normal use-cases but small enough to prevent memory eating by a malicious user.
@ -263,6 +263,8 @@ fn parse_window_parameter<'a>(
max_window_bits: u8, max_window_bits: u8,
) -> Result<Option<u8>, String> { ) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() { if let Some(window_bits_str) = param_iter.next() {
let window_bits_str = window_bits_str.replace("\"", "");
match window_bits_str.trim().parse() { match window_bits_str.trim().parse() {
Ok(window_bits) => { Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
@ -282,12 +284,8 @@ fn parse_window_parameter<'a>(
} }
} }
fn decline<T>(res: &mut Response<T>) {
res.headers_mut().remove(EXT_IDENT);
}
/// A permessage-deflate extension error. /// A permessage-deflate extension error.
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq)]
pub enum DeflateExtensionError { pub enum DeflateExtensionError {
/// An error produced when deflating a message. /// An error produced when deflating a message.
DeflateError(String), DeflateError(String),
@ -299,6 +297,12 @@ pub enum DeflateExtensionError {
Capacity(Cow<'static, str>), Capacity(Cow<'static, str>),
} }
impl DeflateExtensionError {
fn malformatted() -> DeflateExtensionError {
DeflateExtensionError::NegotiationError("Malformatted header value".into())
}
}
impl Display for DeflateExtensionError { impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -342,10 +346,11 @@ pub fn on_response<T>(
Ok(header) => { Ok(header) => {
for param in header.split(';') { for param in header.split(';') {
match param.trim().to_lowercase().as_str() { match param.trim().to_lowercase().as_str() {
"permessage-deflate" => { EXT_IDENT => {
if seen_extension_name { if seen_extension_name {
return Err(DeflateExtensionError::NegotiationError(format!( return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate" "Duplicate extension parameter: {}",
EXT_IDENT
))); )));
} else { } else {
enabled = true; enabled = true;
@ -455,7 +460,7 @@ pub fn on_response<T>(
} }
/// ///
pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> { pub fn on_make_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> {
let mut header_value = String::from(EXT_IDENT); let mut header_value = String::from(EXT_IDENT);
let DeflateConfig { let DeflateConfig {
@ -488,32 +493,32 @@ pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request
request request
} }
/// fn validate_req_extensions(
pub fn on_receive_request<T>( header: &str,
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig, config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> { ) -> Result<Option<String>, DeflateExtensionError> {
let mut enabled = false;
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len()); let mut response_str = String::with_capacity(header.len());
let mut param_iter = header.split(';');
match param_iter.next() {
Some(name) if name.trim() == EXT_IDENT => {
response_str.push_str(EXT_IDENT);
}
_ => {
return Ok(None);
}
}
let mut server_takeover = false; let mut server_takeover = false;
let mut client_takeover = false; let mut client_takeover = false;
let mut server_max_bits = false; let mut server_max_bits = false;
let mut client_max_bits = false; let mut client_max_bits = false;
for param in header.split(';') { while let Some(param) = param_iter.next() {
match param.trim().to_lowercase().as_str() { match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
enabled = true;
response_str.push_str("permessage-deflate");
}
"server_no_context_takeover" => { "server_no_context_takeover" => {
if server_takeover { if server_takeover {
decline(response); return Err(DeflateExtensionError::malformatted());
} else { } else {
server_takeover = true; server_takeover = true;
if config.accept_no_context_takeover() { if config.accept_no_context_takeover() {
@ -524,7 +529,7 @@ pub fn on_receive_request<T>(
} }
"client_no_context_takeover" => { "client_no_context_takeover" => {
if client_takeover { if client_takeover {
decline(response); return Err(DeflateExtensionError::malformatted());
} else { } else {
client_takeover = true; client_takeover = true;
config.decompress_reset = true; config.decompress_reset = true;
@ -533,7 +538,7 @@ pub fn on_receive_request<T>(
} }
param if param.starts_with("server_max_window_bits") => { param if param.starts_with("server_max_window_bits") => {
if server_max_bits { if server_max_bits {
decline(response); return Err(DeflateExtensionError::malformatted());
} else { } else {
server_max_bits = true; server_max_bits = true;
@ -549,14 +554,14 @@ pub fn on_receive_request<T>(
} }
Ok(None) => {} Ok(None) => {}
Err(_) => { Err(_) => {
decline(response); return Err(DeflateExtensionError::malformatted());
} }
} }
} }
} }
param if param.starts_with("client_max_window_bits") => { param if param.starts_with("client_max_window_bits") => {
if client_max_bits { if client_max_bits {
decline(response); return Err(DeflateExtensionError::malformatted());
} else { } else {
client_max_bits = true; client_max_bits = true;
@ -573,7 +578,7 @@ pub fn on_receive_request<T>(
} }
Ok(None) => {} Ok(None) => {}
Err(_) => { Err(_) => {
decline(response); return Err(DeflateExtensionError::malformatted());
} }
} }
@ -584,14 +589,15 @@ pub fn on_receive_request<T>(
)) ))
} }
} }
_ => { p => {
decline(response); return Err(DeflateExtensionError::NegotiationError(
format!("Unknown permessage-deflate parameter: {}", p).into(),
))
} }
} }
} }
if !response_str.contains("client_no_context_takeover") if !response_str.contains("client_no_context_takeover") && config.request_no_context_takeover()
&& config.request_no_context_takeover()
{ {
config.decompress_reset = true; config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover"); response_str.push_str("; client_no_context_takeover");
@ -608,15 +614,42 @@ pub fn on_receive_request<T>(
if !response_str.contains("client_max_window_bits") if !response_str.contains("client_max_window_bits")
&& config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE && config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE
{ {
continue; return Ok(None);
}
Ok(Some(response_str))
} }
///
pub fn on_receive_request<T>(
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
for header in header.split(',') {
let mut parser_config = config.clone();
match validate_req_extensions(header, &mut parser_config) {
Ok(Some(response_str)) => {
response.headers_mut().insert( response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?, HeaderValue::from_str(&response_str)?,
); );
Ok(enabled) *config = parser_config;
return Ok(true);
}
Ok(None) => continue,
Err(e) => {
response.headers_mut().remove(EXT_IDENT);
return Err(e);
}
}
}
Ok(false)
} }
Err(e) => Err(DeflateExtensionError::NegotiationError(format!( Err(e) => Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}", "Failed to parse request header: {}",
@ -625,7 +658,6 @@ pub fn on_receive_request<T>(
}; };
} }
decline(response);
Ok(false) Ok(false)
} }
@ -964,3 +996,219 @@ impl FragmentBuffer {
) )
} }
} }
#[cfg(test)]
mod tests {
use crate::extensions::compression::deflate::{
on_receive_request, DeflateConfig, DeflateExtensionError, LZ77_MIN_WINDOW_SIZE,
};
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
#[test]
fn config_unchanged_on_err() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn config_unchanged_on_mismatch() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: LZ77_MIN_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn parses_named_parameters() {
let s ="permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 8,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: true,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=\"8\"");
}
#[test]
fn splits() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
#[test]
fn splits_on_new_line() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
}

@ -117,7 +117,7 @@ pub fn build_compression_headers<T>(
Some(ref mut config) => match &config.compression { Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request, WsCompression::None(_) => request,
#[cfg(feature = "deflate")] #[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config), WsCompression::Deflate(config) => deflate::on_make_request(request, config),
}, },
None => request, None => request,
} }

@ -14,8 +14,7 @@ pub trait WebSocketExtension {
Ok(frame) Ok(frame)
} }
/// Called when a frame has been received and unmasked. The frame provided frame will be of the /// Called when a WebSocket frame has been received.
/// type `OpCode::Data`.
fn on_receive_frame( fn on_receive_frame(
&mut self, &mut self,
data_opcode: Data, data_opcode: Data,

Loading…
Cancel
Save