Remove `parse_header` and `Param`

pull/235/head
kazk 4 years ago
parent 73ef209ac6
commit 40eb9235d9
  1. 4
      src/error.rs
  2. 43
      src/extensions/compression/deflate.rs
  3. 178
      src/extensions/mod.rs
  4. 17
      src/handshake/client.rs
  5. 15
      src/handshake/server.rs

@ -233,6 +233,10 @@ pub enum ProtocolError {
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
// https://datatracker.ietf.org/doc/html/rfc6455#section-11.3.2
/// `Sec-WebSocket-Extensions` header appeared multiple times in HTTP response
#[error("Sec-WebSocket-Extensions header must not appear more than once in response")]
MultipleExtensionsHeaderInResponse,
}
/// Indicates the specific type/cause of URL error.

@ -4,10 +4,7 @@ use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress,
use http::HeaderValue;
use thiserror::Error;
use crate::{
extensions::{self, Param},
protocol::Role,
};
use crate::{extensions, protocol::Role};
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
@ -62,17 +59,20 @@ impl DeflateConfig {
pub(crate) fn generate_offer(&self) -> HeaderValue {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
if self.client_no_context_takeover {
offers.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
// This can be used for `WebSocket::from_raw_socket_with_compression`.
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub fn negotiation_response(&self, extensions: &str) -> Option<(HeaderValue, DeflateContext)> {
pub fn negotiation_response<'a>(
&'a self,
extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
@ -84,7 +84,7 @@ impl DeflateConfig {
// the same name.
// * The server doesn't support the offered configuration.
'outer: for (_, offer) in
extensions::parse_header(extensions).iter().filter(|(k, _)| k == self.name())
extensions::iter_all(extensions).filter(|&(k, _)| k == self.name())
{
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
@ -92,8 +92,8 @@ impl DeflateConfig {
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for param in offer {
match param.name() {
for (key, _val) in offer {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
@ -101,7 +101,7 @@ impl DeflateConfig {
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
@ -111,7 +111,7 @@ impl DeflateConfig {
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
@ -142,11 +142,14 @@ impl DeflateConfig {
None
}
pub(crate) fn accept_response(&self, agreed: &[Param]) -> Result<DeflateContext, DeflateError> {
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
for param in agreed {
match param.name() {
for (key, _val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
config.server_no_context_takeover = true;
}
@ -276,15 +279,11 @@ impl DeflateContext {
}
}
fn to_header_value(params: &[Param]) -> HeaderValue {
fn to_header_value(params: &[HeaderValue]) -> HeaderValue {
let mut value = Vec::new();
write!(value, "{}", PER_MESSAGE_DEFLATE).unwrap();
for param in params {
if let Some(v) = param.value() {
write!(value, "; {}={}", param.name(), v).unwrap();
} else {
write!(value, "; {}", param.name()).unwrap();
}
write!(value, "; {}", param.to_str().unwrap()).unwrap();
}
HeaderValue::from_bytes(&value).unwrap()
HeaderValue::from_bytes(&value).expect("joining HeaderValue should be valid")
}

@ -1,81 +1,131 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
use std::borrow::Cow;
mod compression;
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError};
use http::HeaderValue;
/// Extension parameter.
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct Param<'a> {
name: Cow<'a, str>,
value: Option<Cow<'a, str>>,
/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
pub(crate) fn iter_all<'a>(
values: impl Iterator<Item = &'a HeaderValue>,
) -> impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)> {
values
.filter_map(|h| h.to_str().ok())
.map(|value_str| {
split_iter(value_str, ',').filter_map(|offer| {
// Parameters are separted by semicolons.
// The first element is the name of the extension.
let mut iter = split_iter(offer.trim(), ';').map(str::trim);
let name = iter.next()?;
let params = iter.filter_map(|kv| {
let mut it = kv.splitn(2, '=');
let key = it.next()?.trim();
let val = it.next().map(|v| v.trim().trim_matches('"'));
Some((key, val))
});
Some((name, params))
})
})
.flatten()
}
impl<'a> Param<'a> {
/// Create a new parameter with name.
pub fn new(name: impl Into<Cow<'a, str>>) -> Self {
Param { name: name.into(), value: None }
}
fn split_iter(input: &str, sep: char) -> impl Iterator<Item = &str> {
let mut in_quotes = false;
let mut prev = None;
input.split(move |c| {
if in_quotes {
if c == '"' && prev != Some('\\') {
in_quotes = false;
}
prev = Some(c);
false
} else if c == sep {
prev = Some(c);
true
} else {
if c == '"' {
in_quotes = true;
}
prev = Some(c);
false
}
})
}
/// Consume itself to create a parameter with value.
pub fn with_value(mut self, value: impl Into<Cow<'a, str>>) -> Self {
self.value = Some(value.into());
self
}
#[cfg(test)]
mod tests {
use http::{header::SEC_WEBSOCKET_EXTENSIONS, HeaderMap};
/// Get the name of the parameter.
pub fn name(&self) -> &str {
&self.name
}
use super::*;
// Make sure comma separated offers and multiple headers are equivalent
fn test_iteration<'a>(
mut iter: impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)>,
) {
let (name, mut params) = iter.next().unwrap();
assert_eq!(name, "permessage-deflate");
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
assert_eq!(params.next(), Some(("server_max_window_bits", Some("10"))));
assert!(params.next().is_none());
/// Get the optional value of the parameter.
pub fn value(&self) -> Option<&str> {
self.value.as_ref().map(|v| v.as_ref())
let (name, mut params) = iter.next().unwrap();
assert_eq!(name, "permessage-deflate");
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
assert!(params.next().is_none());
assert!(iter.next().is_none());
}
}
// NOTE This doesn't support quoted values
/// Parse `Sec-WebSocket-Extensions` offer/response.
pub(crate) fn parse_header(exts: &str) -> Vec<(Cow<'_, str>, Vec<Param<'_>>)> {
let mut collected = Vec::new();
// ext-name; a; b=c, ext-name; x, y=z
for ext in exts.split(',') {
let mut parts = ext.split(';');
if let Some(name) = parts.next().map(str::trim) {
let mut params = Vec::new();
for p in parts {
let mut kv = p.splitn(2, '=');
if let Some(key) = kv.next().map(str::trim) {
let param = if let Some(value) = kv.next().map(str::trim) {
Param::new(key).with_value(value)
} else {
Param::new(key)
};
params.push(param);
}
}
collected.push((Cow::from(name), params));
}
#[test]
fn iter_single() {
let mut hm = HeaderMap::new();
hm.append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_static(
"permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits",
),
);
test_iteration(iter_all(std::iter::once(hm.get(SEC_WEBSOCKET_EXTENSIONS).unwrap())));
}
collected
}
#[test]
fn test_parse_extensions() {
let extensions = "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits";
assert_eq!(
parse_header(extensions),
vec![
(
Cow::from("permessage-deflate"),
vec![
Param::new("client_max_window_bits"),
Param::new("server_max_window_bits").with_value("10")
]
#[test]
fn iter_multiple() {
let mut hm = HeaderMap::new();
hm.append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_static(
"permessage-deflate; client_max_window_bits; server_max_window_bits=10",
),
(Cow::from("permessage-deflate"), vec![Param::new("client_max_window_bits")])
]
);
);
hm.append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_static("permessage-deflate; client_max_window_bits"),
);
test_iteration(iter_all(hm.get_all(SEC_WEBSOCKET_EXTENSIONS).iter()));
}
}
// TODO More strict parsing
// https://datatracker.ietf.org/doc/html/rfc6455#section-4.3
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
//
// token = 1*<any CHAR except CTLs or separators>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// separators = "(" | ")" | "<" | ">" | "@"
// | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "="
// | "{" | "}" | SP | HT
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
// qdtext = <any TEXT except <">>
// quoted-pair = "\" CHAR

@ -208,11 +208,13 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if let Some(exts) = headers
.get("Sec-WebSocket-Extensions")
.and_then(|h| h.to_str().ok())
.map(extensions::parse_header)
{
let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions.next() {
if extensions.next().is_some() {
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
}
let mut exts = extensions::iter_all(std::iter::once(value));
if let Some(compression) = &config.and_then(|c| c.compression) {
for (name, params) in exts {
if name != compression.name() {
@ -227,10 +229,9 @@ impl VerifyData {
name.to_string(),
)));
}
pmce = Some(compression.accept_response(&params)?);
pmce = Some(compression.accept_response(params)?);
}
} else if let Some((name, _)) = exts.get(0) {
} else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
}

@ -246,17 +246,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let mut response = create_response(&result)?;
if let Some(compression) = &self.config.and_then(|c| c.compression) {
for extensions in result
.headers()
.get_all("Sec-WebSocket-Extensions")
.iter()
.filter_map(|h| h.to_str().ok())
{
if let Some((agreed, pmce)) = compression.negotiation_response(extensions) {
self.pmce = Some(pmce);
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
break;
}
let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter();
if let Some((agreed, pmce)) = compression.negotiation_response(extensions) {
self.pmce = Some(pmce);
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
}
}

Loading…
Cancel
Save