diff --git a/src/error.rs b/src/error.rs index e1d86a0..a502135 100644 --- a/src/error.rs +++ b/src/error.rs @@ -233,6 +233,10 @@ pub enum ProtocolError { /// The negotiation response included an extension more than once. #[error("Extension negotiation response had conflicting extension: {0}")] ExtensionConflict(String), + // https://datatracker.ietf.org/doc/html/rfc6455#section-11.3.2 + /// `Sec-WebSocket-Extensions` header appeared multiple times in HTTP response + #[error("Sec-WebSocket-Extensions header must not appear more than once in response")] + MultipleExtensionsHeaderInResponse, } /// Indicates the specific type/cause of URL error. diff --git a/src/extensions/compression/deflate.rs b/src/extensions/compression/deflate.rs index 0a65ff3..9f4930d 100644 --- a/src/extensions/compression/deflate.rs +++ b/src/extensions/compression/deflate.rs @@ -4,10 +4,7 @@ use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, use http::HeaderValue; use thiserror::Error; -use crate::{ - extensions::{self, Param}, - protocol::Role, -}; +use crate::{extensions, protocol::Role}; const PER_MESSAGE_DEFLATE: &str = "permessage-deflate"; const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover"; @@ -62,17 +59,20 @@ impl DeflateConfig { pub(crate) fn generate_offer(&self) -> HeaderValue { let mut offers = Vec::new(); if self.server_no_context_takeover { - offers.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); + offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER)); } if self.client_no_context_takeover { - offers.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); + offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER)); } to_header_value(&offers) } // This can be used for `WebSocket::from_raw_socket_with_compression`. /// Returns negotiation response based on offers and `DeflateContext` to manage per message compression. - pub fn negotiation_response(&self, extensions: &str) -> Option<(HeaderValue, DeflateContext)> { + pub fn negotiation_response<'a>( + &'a self, + extensions: impl Iterator, + ) -> Option<(HeaderValue, DeflateContext)> { // Accept the first valid offer for `permessage-deflate`. // A server MUST decline an extension negotiation offer for this // extension if any of the following conditions are met: @@ -84,7 +84,7 @@ impl DeflateConfig { // the same name. // * The server doesn't support the offered configuration. 'outer: for (_, offer) in - extensions::parse_header(extensions).iter().filter(|(k, _)| k == self.name()) + extensions::iter_all(extensions).filter(|&(k, _)| k == self.name()) { let mut config = DeflateConfig { compression: self.compression, ..DeflateConfig::default() }; @@ -92,8 +92,8 @@ impl DeflateConfig { let mut seen_server_no_context_takeover = false; let mut seen_client_no_context_takeover = false; let mut seen_client_max_window_bits = false; - for param in offer { - match param.name() { + for (key, _val) in offer { + match key { SERVER_NO_CONTEXT_TAKEOVER => { // Invalid offer with multiple params with same name is declined. if seen_server_no_context_takeover { @@ -101,7 +101,7 @@ impl DeflateConfig { } seen_server_no_context_takeover = true; config.server_no_context_takeover = true; - agreed.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); + agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER)); } CLIENT_NO_CONTEXT_TAKEOVER => { @@ -111,7 +111,7 @@ impl DeflateConfig { } seen_client_no_context_takeover = true; config.client_no_context_takeover = true; - agreed.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); + agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER)); } // Max window bits are not supported at the moment. @@ -142,11 +142,14 @@ impl DeflateConfig { None } - pub(crate) fn accept_response(&self, agreed: &[Param]) -> Result { + pub(crate) fn accept_response<'a>( + &'a self, + agreed: impl Iterator)>, + ) -> Result { let mut config = DeflateConfig { compression: self.compression, ..DeflateConfig::default() }; - for param in agreed { - match param.name() { + for (key, _val) in agreed { + match key { SERVER_NO_CONTEXT_TAKEOVER => { config.server_no_context_takeover = true; } @@ -276,15 +279,11 @@ impl DeflateContext { } } -fn to_header_value(params: &[Param]) -> HeaderValue { +fn to_header_value(params: &[HeaderValue]) -> HeaderValue { let mut value = Vec::new(); write!(value, "{}", PER_MESSAGE_DEFLATE).unwrap(); for param in params { - if let Some(v) = param.value() { - write!(value, "; {}={}", param.name(), v).unwrap(); - } else { - write!(value, "; {}", param.name()).unwrap(); - } + write!(value, "; {}", param.to_str().unwrap()).unwrap(); } - HeaderValue::from_bytes(&value).unwrap() + HeaderValue::from_bytes(&value).expect("joining HeaderValue should be valid") } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 1e9dd16..95ee2bc 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,81 +1,131 @@ //! WebSocket extensions. // Only `permessage-deflate` is supported at the moment. -use std::borrow::Cow; - mod compression; pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError}; +use http::HeaderValue; -/// Extension parameter. -#[derive(Clone, Debug, Eq, PartialEq)] -pub(crate) struct Param<'a> { - name: Cow<'a, str>, - value: Option>, +/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values. +pub(crate) fn iter_all<'a>( + values: impl Iterator, +) -> impl Iterator)>)> { + values + .filter_map(|h| h.to_str().ok()) + .map(|value_str| { + split_iter(value_str, ',').filter_map(|offer| { + // Parameters are separted by semicolons. + // The first element is the name of the extension. + let mut iter = split_iter(offer.trim(), ';').map(str::trim); + let name = iter.next()?; + let params = iter.filter_map(|kv| { + let mut it = kv.splitn(2, '='); + let key = it.next()?.trim(); + let val = it.next().map(|v| v.trim().trim_matches('"')); + Some((key, val)) + }); + Some((name, params)) + }) + }) + .flatten() } -impl<'a> Param<'a> { - /// Create a new parameter with name. - pub fn new(name: impl Into>) -> Self { - Param { name: name.into(), value: None } - } +fn split_iter(input: &str, sep: char) -> impl Iterator { + let mut in_quotes = false; + let mut prev = None; + input.split(move |c| { + if in_quotes { + if c == '"' && prev != Some('\\') { + in_quotes = false; + } + prev = Some(c); + false + } else if c == sep { + prev = Some(c); + true + } else { + if c == '"' { + in_quotes = true; + } + prev = Some(c); + false + } + }) +} - /// Consume itself to create a parameter with value. - pub fn with_value(mut self, value: impl Into>) -> Self { - self.value = Some(value.into()); - self - } +#[cfg(test)] +mod tests { + use http::{header::SEC_WEBSOCKET_EXTENSIONS, HeaderMap}; - /// Get the name of the parameter. - pub fn name(&self) -> &str { - &self.name - } + use super::*; + + // Make sure comma separated offers and multiple headers are equivalent + fn test_iteration<'a>( + mut iter: impl Iterator)>)>, + ) { + let (name, mut params) = iter.next().unwrap(); + assert_eq!(name, "permessage-deflate"); + assert_eq!(params.next(), Some(("client_max_window_bits", None))); + assert_eq!(params.next(), Some(("server_max_window_bits", Some("10")))); + assert!(params.next().is_none()); - /// Get the optional value of the parameter. - pub fn value(&self) -> Option<&str> { - self.value.as_ref().map(|v| v.as_ref()) + let (name, mut params) = iter.next().unwrap(); + assert_eq!(name, "permessage-deflate"); + assert_eq!(params.next(), Some(("client_max_window_bits", None))); + assert!(params.next().is_none()); + + assert!(iter.next().is_none()); } -} -// NOTE This doesn't support quoted values -/// Parse `Sec-WebSocket-Extensions` offer/response. -pub(crate) fn parse_header(exts: &str) -> Vec<(Cow<'_, str>, Vec>)> { - let mut collected = Vec::new(); - // ext-name; a; b=c, ext-name; x, y=z - for ext in exts.split(',') { - let mut parts = ext.split(';'); - if let Some(name) = parts.next().map(str::trim) { - let mut params = Vec::new(); - for p in parts { - let mut kv = p.splitn(2, '='); - if let Some(key) = kv.next().map(str::trim) { - let param = if let Some(value) = kv.next().map(str::trim) { - Param::new(key).with_value(value) - } else { - Param::new(key) - }; - params.push(param); - } - } - collected.push((Cow::from(name), params)); - } + #[test] + fn iter_single() { + let mut hm = HeaderMap::new(); + hm.append( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static( + "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits", + ), + ); + test_iteration(iter_all(std::iter::once(hm.get(SEC_WEBSOCKET_EXTENSIONS).unwrap()))); } - collected -} -#[test] -fn test_parse_extensions() { - let extensions = "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits"; - assert_eq!( - parse_header(extensions), - vec![ - ( - Cow::from("permessage-deflate"), - vec![ - Param::new("client_max_window_bits"), - Param::new("server_max_window_bits").with_value("10") - ] + #[test] + fn iter_multiple() { + let mut hm = HeaderMap::new(); + hm.append( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static( + "permessage-deflate; client_max_window_bits; server_max_window_bits=10", ), - (Cow::from("permessage-deflate"), vec![Param::new("client_max_window_bits")]) - ] - ); + ); + hm.append( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-deflate; client_max_window_bits"), + ); + test_iteration(iter_all(hm.get_all(SEC_WEBSOCKET_EXTENSIONS).iter())); + } } + +// TODO More strict parsing +// https://datatracker.ietf.org/doc/html/rfc6455#section-4.3 +// Sec-WebSocket-Extensions = extension-list +// extension-list = 1#extension +// extension = extension-token *( ";" extension-param ) +// extension-token = registered-token +// registered-token = token +// extension-param = token [ "=" (token | quoted-string) ] +// ;When using the quoted-string syntax variant, the value +// ;after quoted-string unescaping MUST conform to the +// ;'token' ABNF. +// +// token = 1* +// CHAR = +// CTL = +// separators = "(" | ")" | "<" | ">" | "@" +// | "," | ";" | ":" | "\" | <"> +// | "/" | "[" | "]" | "?" | "=" +// | "{" | "}" | SP | HT +// SP = +// HT = +// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) +// qdtext = > +// quoted-pair = "\" CHAR diff --git a/src/handshake/client.rs b/src/handshake/client.rs index af838ee..d428ca8 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -208,11 +208,13 @@ impl VerifyData { // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if let Some(exts) = headers - .get("Sec-WebSocket-Extensions") - .and_then(|h| h.to_str().ok()) - .map(extensions::parse_header) - { + let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter(); + if let Some(value) = extensions.next() { + if extensions.next().is_some() { + return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse)); + } + + let mut exts = extensions::iter_all(std::iter::once(value)); if let Some(compression) = &config.and_then(|c| c.compression) { for (name, params) in exts { if name != compression.name() { @@ -227,10 +229,9 @@ impl VerifyData { name.to_string(), ))); } - - pmce = Some(compression.accept_response(¶ms)?); + pmce = Some(compression.accept_response(params)?); } - } else if let Some((name, _)) = exts.get(0) { + } else if let Some((name, _)) = exts.next() { // The client didn't request anything, but got something return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string()))); } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 85b2c55..4451441 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -246,17 +246,10 @@ impl HandshakeRole for ServerHandshake { let mut response = create_response(&result)?; if let Some(compression) = &self.config.and_then(|c| c.compression) { - for extensions in result - .headers() - .get_all("Sec-WebSocket-Extensions") - .iter() - .filter_map(|h| h.to_str().ok()) - { - if let Some((agreed, pmce)) = compression.negotiation_response(extensions) { - self.pmce = Some(pmce); - response.headers_mut().insert("Sec-WebSocket-Extensions", agreed); - break; - } + let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter(); + if let Some((agreed, pmce)) = compression.negotiation_response(extensions) { + self.pmce = Some(pmce); + response.headers_mut().insert("Sec-WebSocket-Extensions", agreed); } }