Updates request header parsing

pull/144/head
SirCipher 5 years ago
parent 672572e00a
commit 779514704e
  1. 490
      src/extensions/compression/deflate.rs
  2. 2
      src/extensions/compression/mod.rs
  3. 3
      src/extensions/mod.rs

@ -32,7 +32,7 @@ const LZ77_MIN_WINDOW_SIZE: u8 = 8;
const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DeflateConfig {
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user.
@ -263,6 +263,8 @@ fn parse_window_parameter<'a>(
max_window_bits: u8,
) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() {
let window_bits_str = window_bits_str.replace("\"", "");
match window_bits_str.trim().parse() {
Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
@ -282,12 +284,8 @@ fn parse_window_parameter<'a>(
}
}
fn decline<T>(res: &mut Response<T>) {
res.headers_mut().remove(EXT_IDENT);
}
/// A permessage-deflate extension error.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum DeflateExtensionError {
/// An error produced when deflating a message.
DeflateError(String),
@ -299,6 +297,12 @@ pub enum DeflateExtensionError {
Capacity(Cow<'static, str>),
}
impl DeflateExtensionError {
fn malformatted() -> DeflateExtensionError {
DeflateExtensionError::NegotiationError("Malformatted header value".into())
}
}
impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
@ -342,10 +346,11 @@ pub fn on_response<T>(
Ok(header) => {
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
EXT_IDENT => {
if seen_extension_name {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate"
"Duplicate extension parameter: {}",
EXT_IDENT
)));
} else {
enabled = true;
@ -455,7 +460,7 @@ pub fn on_response<T>(
}
///
pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> {
pub fn on_make_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> {
let mut header_value = String::from(EXT_IDENT);
let DeflateConfig {
@ -488,135 +493,163 @@ pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request
request
}
///
pub fn on_receive_request<T>(
request: &Request<T>,
response: &mut Response<T>,
fn validate_req_extensions(
header: &str,
config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> {
let mut enabled = false;
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len());
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
) -> Result<Option<String>, DeflateExtensionError> {
let mut response_str = String::with_capacity(header.len());
let mut param_iter = header.split(';');
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
enabled = true;
response_str.push_str("permessage-deflate");
}
"server_no_context_takeover" => {
if server_takeover {
decline(response);
} else {
server_takeover = true;
if config.accept_no_context_takeover() {
config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if client_takeover {
decline(response);
} else {
client_takeover = true;
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
decline(response);
} else {
server_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.server_max_window_bits,
) {
Ok(Some(bits)) => {
config.server_max_window_bits = bits;
match param_iter.next() {
Some(name) if name.trim() == EXT_IDENT => {
response_str.push_str(EXT_IDENT);
}
_ => {
return Ok(None);
}
}
response_str.push_str("; ");
response_str.push_str(param)
}
Ok(None) => {}
Err(_) => {
decline(response);
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
decline(response);
} else {
client_max_bits = true;
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
match parse_window_parameter(
param.split('=').skip(1),
config.client_max_window_bits,
) {
Ok(Some(bits)) => {
config.client_max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(param);
while let Some(param) = param_iter.next() {
match param.trim().to_lowercase().as_str() {
"server_no_context_takeover" => {
if server_takeover {
return Err(DeflateExtensionError::malformatted());
} else {
server_takeover = true;
if config.accept_no_context_takeover() {
config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::malformatted());
} else {
client_takeover = true;
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
return Err(DeflateExtensionError::malformatted());
} else {
server_max_bits = true;
continue;
}
Ok(None) => {}
Err(_) => {
decline(response);
}
}
match parse_window_parameter(
param.split('=').skip(1),
config.server_max_window_bits,
) {
Ok(Some(bits)) => {
config.server_max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(&format!(
"client_max_window_bits={}",
config.client_max_window_bits()
))
}
response_str.push_str("; ");
response_str.push_str(param)
}
_ => {
decline(response);
Ok(None) => {}
Err(_) => {
return Err(DeflateExtensionError::malformatted());
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
return Err(DeflateExtensionError::malformatted());
} else {
client_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.client_max_window_bits,
) {
Ok(Some(bits)) => {
config.client_max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(param);
continue;
}
Ok(None) => {}
Err(_) => {
return Err(DeflateExtensionError::malformatted());
}
}
if !response_str.contains("client_no_context_takeover")
&& config.request_no_context_takeover()
{
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
config.server_max_window_bits()
"client_max_window_bits={}",
config.client_max_window_bits()
))
}
}
p => {
return Err(DeflateExtensionError::NegotiationError(
format!("Unknown permessage-deflate parameter: {}", p).into(),
))
}
}
}
if !response_str.contains("client_max_window_bits")
&& config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
continue;
}
if !response_str.contains("client_no_context_takeover") && config.request_no_context_takeover()
{
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
config.server_max_window_bits()
))
}
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
if !response_str.contains("client_max_window_bits")
&& config.client_max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
return Ok(None);
}
Ok(Some(response_str))
}
Ok(enabled)
///
pub fn on_receive_request<T>(
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
for header in header.split(',') {
let mut parser_config = config.clone();
match validate_req_extensions(header, &mut parser_config) {
Ok(Some(response_str)) => {
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
*config = parser_config;
return Ok(true);
}
Ok(None) => continue,
Err(e) => {
response.headers_mut().remove(EXT_IDENT);
return Err(e);
}
}
}
Ok(false)
}
Err(e) => Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}",
@ -625,7 +658,6 @@ pub fn on_receive_request<T>(
};
}
decline(response);
Ok(false)
}
@ -964,3 +996,219 @@ impl FragmentBuffer {
)
}
}
#[cfg(test)]
mod tests {
use crate::extensions::compression::deflate::{
on_receive_request, DeflateConfig, DeflateExtensionError, LZ77_MIN_WINDOW_SIZE,
};
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
#[test]
fn config_unchanged_on_err() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"80000\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn config_unchanged_on_mismatch() {
let s ="permessage-deflate; unknown_parameter=\"invalid\"; client_no_context_takeover; server_no_context_takeover";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let initial_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: LZ77_MIN_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let mut parsed_config = initial_config.clone();
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(
r,
Err(DeflateExtensionError::NegotiationError(
"Unknown permessage-deflate parameter: unknown_parameter=\"invalid\"".into()
))
);
assert_eq!(initial_config, parsed_config);
}
#[test]
fn parses_named_parameters() {
let s ="permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=\"8\"";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 8,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: true,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_no_context_takeover; server_max_window_bits=\"8\"");
}
#[test]
fn splits() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8, no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
#[test]
fn splits_on_new_line() {
let s ="not-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover; server_max_window_bits=8,\\n\\r\\t \\ no-permessage-deflate; client_no_context_takeover; client_max_window_bits; server_no_context_takeover, permessage-deflate; client_no_context_takeover; client_max_window_bits";
let mut request = Request::new(());
request
.headers_mut()
.insert(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static(s));
let mut response = Response::new(());
let mut parsed_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Default::default(),
};
let r = on_receive_request(&request, &mut response, &mut parsed_config);
assert_eq!(r, Ok(true));
let expected_config = DeflateConfig {
max_message_size: None,
server_max_window_bits: 10,
client_max_window_bits: 11,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: true,
compression_level: Default::default(),
};
assert_eq!(parsed_config, expected_config);
let parsed_header = response
.headers()
.get(SEC_WEBSOCKET_EXTENSIONS)
.expect("Missing header")
.to_str()
.expect("Failed to parse header");
assert_eq!(parsed_header,"permessage-deflate; client_no_context_takeover; client_max_window_bits=11; server_max_window_bits=10");
}
}

@ -117,7 +117,7 @@ pub fn build_compression_headers<T>(
Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request,
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config),
WsCompression::Deflate(config) => deflate::on_make_request(request, config),
},
None => request,
}

@ -14,8 +14,7 @@ pub trait WebSocketExtension {
Ok(frame)
}
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
/// type `OpCode::Data`.
/// Called when a WebSocket frame has been received.
fn on_receive_frame(
&mut self,
data_opcode: Data,

Loading…
Cancel
Save