Remove `parse_header` and `Param`

pull/235/head
kazk 4 years ago
parent 73ef209ac6
commit 40eb9235d9
  1. 4
      src/error.rs
  2. 43
      src/extensions/compression/deflate.rs
  3. 178
      src/extensions/mod.rs
  4. 17
      src/handshake/client.rs
  5. 15
      src/handshake/server.rs

@ -233,6 +233,10 @@ pub enum ProtocolError {
/// The negotiation response included an extension more than once. /// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")] #[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String), ExtensionConflict(String),
// https://datatracker.ietf.org/doc/html/rfc6455#section-11.3.2
/// `Sec-WebSocket-Extensions` header appeared multiple times in HTTP response
#[error("Sec-WebSocket-Extensions header must not appear more than once in response")]
MultipleExtensionsHeaderInResponse,
} }
/// Indicates the specific type/cause of URL error. /// Indicates the specific type/cause of URL error.

@ -4,10 +4,7 @@ use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress,
use http::HeaderValue; use http::HeaderValue;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{extensions, protocol::Role};
extensions::{self, Param},
protocol::Role,
};
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate"; const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover"; const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
@ -62,17 +59,20 @@ impl DeflateConfig {
pub(crate) fn generate_offer(&self) -> HeaderValue { pub(crate) fn generate_offer(&self) -> HeaderValue {
let mut offers = Vec::new(); let mut offers = Vec::new();
if self.server_no_context_takeover { if self.server_no_context_takeover {
offers.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
} }
if self.client_no_context_takeover { if self.client_no_context_takeover {
offers.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
} }
to_header_value(&offers) to_header_value(&offers)
} }
// This can be used for `WebSocket::from_raw_socket_with_compression`. // This can be used for `WebSocket::from_raw_socket_with_compression`.
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression. /// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub fn negotiation_response(&self, extensions: &str) -> Option<(HeaderValue, DeflateContext)> { pub fn negotiation_response<'a>(
&'a self,
extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`. // Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this // A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met: // extension if any of the following conditions are met:
@ -84,7 +84,7 @@ impl DeflateConfig {
// the same name. // the same name.
// * The server doesn't support the offered configuration. // * The server doesn't support the offered configuration.
'outer: for (_, offer) in 'outer: for (_, offer) in
extensions::parse_header(extensions).iter().filter(|(k, _)| k == self.name()) extensions::iter_all(extensions).filter(|&(k, _)| k == self.name())
{ {
let mut config = let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() }; DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
@ -92,8 +92,8 @@ impl DeflateConfig {
let mut seen_server_no_context_takeover = false; let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false; let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false; let mut seen_client_max_window_bits = false;
for param in offer { for (key, _val) in offer {
match param.name() { match key {
SERVER_NO_CONTEXT_TAKEOVER => { SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined. // Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover { if seen_server_no_context_takeover {
@ -101,7 +101,7 @@ impl DeflateConfig {
} }
seen_server_no_context_takeover = true; seen_server_no_context_takeover = true;
config.server_no_context_takeover = true; config.server_no_context_takeover = true;
agreed.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
} }
CLIENT_NO_CONTEXT_TAKEOVER => { CLIENT_NO_CONTEXT_TAKEOVER => {
@ -111,7 +111,7 @@ impl DeflateConfig {
} }
seen_client_no_context_takeover = true; seen_client_no_context_takeover = true;
config.client_no_context_takeover = true; config.client_no_context_takeover = true;
agreed.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
} }
// Max window bits are not supported at the moment. // Max window bits are not supported at the moment.
@ -142,11 +142,14 @@ impl DeflateConfig {
None None
} }
pub(crate) fn accept_response(&self, agreed: &[Param]) -> Result<DeflateContext, DeflateError> { pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() }; DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
for param in agreed { for (key, _val) in agreed {
match param.name() { match key {
SERVER_NO_CONTEXT_TAKEOVER => { SERVER_NO_CONTEXT_TAKEOVER => {
config.server_no_context_takeover = true; config.server_no_context_takeover = true;
} }
@ -276,15 +279,11 @@ impl DeflateContext {
} }
} }
fn to_header_value(params: &[Param]) -> HeaderValue { fn to_header_value(params: &[HeaderValue]) -> HeaderValue {
let mut value = Vec::new(); let mut value = Vec::new();
write!(value, "{}", PER_MESSAGE_DEFLATE).unwrap(); write!(value, "{}", PER_MESSAGE_DEFLATE).unwrap();
for param in params { for param in params {
if let Some(v) = param.value() { write!(value, "; {}", param.to_str().unwrap()).unwrap();
write!(value, "; {}={}", param.name(), v).unwrap();
} else {
write!(value, "; {}", param.name()).unwrap();
}
} }
HeaderValue::from_bytes(&value).unwrap() HeaderValue::from_bytes(&value).expect("joining HeaderValue should be valid")
} }

@ -1,81 +1,131 @@
//! WebSocket extensions. //! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment. // Only `permessage-deflate` is supported at the moment.
use std::borrow::Cow;
mod compression; mod compression;
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError}; pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError};
use http::HeaderValue;
/// Extension parameter. /// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) fn iter_all<'a>(
pub(crate) struct Param<'a> { values: impl Iterator<Item = &'a HeaderValue>,
name: Cow<'a, str>, ) -> impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)> {
value: Option<Cow<'a, str>>, values
.filter_map(|h| h.to_str().ok())
.map(|value_str| {
split_iter(value_str, ',').filter_map(|offer| {
// Parameters are separted by semicolons.
// The first element is the name of the extension.
let mut iter = split_iter(offer.trim(), ';').map(str::trim);
let name = iter.next()?;
let params = iter.filter_map(|kv| {
let mut it = kv.splitn(2, '=');
let key = it.next()?.trim();
let val = it.next().map(|v| v.trim().trim_matches('"'));
Some((key, val))
});
Some((name, params))
})
})
.flatten()
} }
impl<'a> Param<'a> { fn split_iter(input: &str, sep: char) -> impl Iterator<Item = &str> {
/// Create a new parameter with name. let mut in_quotes = false;
pub fn new(name: impl Into<Cow<'a, str>>) -> Self { let mut prev = None;
Param { name: name.into(), value: None } input.split(move |c| {
} if in_quotes {
if c == '"' && prev != Some('\\') {
in_quotes = false;
}
prev = Some(c);
false
} else if c == sep {
prev = Some(c);
true
} else {
if c == '"' {
in_quotes = true;
}
prev = Some(c);
false
}
})
}
/// Consume itself to create a parameter with value. #[cfg(test)]
pub fn with_value(mut self, value: impl Into<Cow<'a, str>>) -> Self { mod tests {
self.value = Some(value.into()); use http::{header::SEC_WEBSOCKET_EXTENSIONS, HeaderMap};
self
}
/// Get the name of the parameter. use super::*;
pub fn name(&self) -> &str {
&self.name // Make sure comma separated offers and multiple headers are equivalent
} fn test_iteration<'a>(
mut iter: impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)>,
) {
let (name, mut params) = iter.next().unwrap();
assert_eq!(name, "permessage-deflate");
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
assert_eq!(params.next(), Some(("server_max_window_bits", Some("10"))));
assert!(params.next().is_none());
/// Get the optional value of the parameter. let (name, mut params) = iter.next().unwrap();
pub fn value(&self) -> Option<&str> { assert_eq!(name, "permessage-deflate");
self.value.as_ref().map(|v| v.as_ref()) assert_eq!(params.next(), Some(("client_max_window_bits", None)));
assert!(params.next().is_none());
assert!(iter.next().is_none());
} }
}
// NOTE This doesn't support quoted values #[test]
/// Parse `Sec-WebSocket-Extensions` offer/response. fn iter_single() {
pub(crate) fn parse_header(exts: &str) -> Vec<(Cow<'_, str>, Vec<Param<'_>>)> { let mut hm = HeaderMap::new();
let mut collected = Vec::new(); hm.append(
// ext-name; a; b=c, ext-name; x, y=z SEC_WEBSOCKET_EXTENSIONS,
for ext in exts.split(',') { HeaderValue::from_static(
let mut parts = ext.split(';'); "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits",
if let Some(name) = parts.next().map(str::trim) { ),
let mut params = Vec::new(); );
for p in parts { test_iteration(iter_all(std::iter::once(hm.get(SEC_WEBSOCKET_EXTENSIONS).unwrap())));
let mut kv = p.splitn(2, '=');
if let Some(key) = kv.next().map(str::trim) {
let param = if let Some(value) = kv.next().map(str::trim) {
Param::new(key).with_value(value)
} else {
Param::new(key)
};
params.push(param);
}
}
collected.push((Cow::from(name), params));
}
} }
collected
}
#[test] #[test]
fn test_parse_extensions() { fn iter_multiple() {
let extensions = "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits"; let mut hm = HeaderMap::new();
assert_eq!( hm.append(
parse_header(extensions), SEC_WEBSOCKET_EXTENSIONS,
vec![ HeaderValue::from_static(
( "permessage-deflate; client_max_window_bits; server_max_window_bits=10",
Cow::from("permessage-deflate"),
vec![
Param::new("client_max_window_bits"),
Param::new("server_max_window_bits").with_value("10")
]
), ),
(Cow::from("permessage-deflate"), vec![Param::new("client_max_window_bits")]) );
] hm.append(
); SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_static("permessage-deflate; client_max_window_bits"),
);
test_iteration(iter_all(hm.get_all(SEC_WEBSOCKET_EXTENSIONS).iter()));
}
} }
// TODO More strict parsing
// https://datatracker.ietf.org/doc/html/rfc6455#section-4.3
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
//
// token = 1*<any CHAR except CTLs or separators>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// separators = "(" | ")" | "<" | ">" | "@"
// | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "="
// | "{" | "}" | SP | HT
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
// qdtext = <any TEXT except <">>
// quoted-pair = "\" CHAR

@ -208,11 +208,13 @@ impl VerifyData {
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
if let Some(exts) = headers let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter();
.get("Sec-WebSocket-Extensions") if let Some(value) = extensions.next() {
.and_then(|h| h.to_str().ok()) if extensions.next().is_some() {
.map(extensions::parse_header) return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
{ }
let mut exts = extensions::iter_all(std::iter::once(value));
if let Some(compression) = &config.and_then(|c| c.compression) { if let Some(compression) = &config.and_then(|c| c.compression) {
for (name, params) in exts { for (name, params) in exts {
if name != compression.name() { if name != compression.name() {
@ -227,10 +229,9 @@ impl VerifyData {
name.to_string(), name.to_string(),
))); )));
} }
pmce = Some(compression.accept_response(params)?);
pmce = Some(compression.accept_response(&params)?);
} }
} else if let Some((name, _)) = exts.get(0) { } else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something // The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string()))); return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
} }

@ -246,17 +246,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let mut response = create_response(&result)?; let mut response = create_response(&result)?;
if let Some(compression) = &self.config.and_then(|c| c.compression) { if let Some(compression) = &self.config.and_then(|c| c.compression) {
for extensions in result let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter();
.headers() if let Some((agreed, pmce)) = compression.negotiation_response(extensions) {
.get_all("Sec-WebSocket-Extensions") self.pmce = Some(pmce);
.iter() response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
.filter_map(|h| h.to_str().ok())
{
if let Some((agreed, pmce)) = compression.negotiation_response(extensions) {
self.pmce = Some(pmce);
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
break;
}
} }
} }

Loading…
Cancel
Save