Make `DeflateContext` private and add `Extensions` container

pull/235/head
kazk 4 years ago
parent 78322fed68
commit c62eccc8df
  1. 10
      src/extensions/mod.rs
  2. 28
      src/handshake/client.rs
  3. 18
      src/handshake/server.rs
  4. 67
      src/protocol/mod.rs

@ -2,9 +2,17 @@
// Only `permessage-deflate` is supported at the moment. // Only `permessage-deflate` is supported at the moment.
mod compression; mod compression;
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError}; use compression::deflate::DeflateContext;
pub use compression::deflate::{DeflateConfig, DeflateError};
use http::HeaderValue; use http::HeaderValue;
/// Container for configured extensions.
#[derive(Debug, Default)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
pub(crate) compression: Option<DeflateContext>,
}
/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values. /// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
pub(crate) fn iter_all<'a>( pub(crate) fn iter_all<'a>(
values: impl Iterator<Item = &'a HeaderValue>, values: impl Iterator<Item = &'a HeaderValue>,

@ -17,7 +17,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result, UrlError}, error::{Error, ProtocolError, Result, UrlError},
extensions::{self, DeflateContext}, extensions::{self, Extensions},
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -83,14 +83,15 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
let (result, pmce) = self.verify_data.verify_response(result, &self.config)?; let (result, extensions) =
self.verify_data.verify_response(result, &self.config)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read_with_compression( let websocket = WebSocket::from_partially_read_with_extensions(
stream, stream,
tail, tail,
Role::Client, Role::Client,
self.config, self.config,
pmce, extensions,
); );
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
@ -161,7 +162,7 @@ impl VerifyData {
&self, &self,
response: Response, response: Response,
config: &Option<WebSocketConfig>, config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<DeflateContext>)> { ) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -201,15 +202,15 @@ impl VerifyData {
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
} }
let mut pmce = None; let mut extensions = None;
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter(); let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions.next() { if let Some(value) = extensions_values.next() {
if extensions.next().is_some() { if extensions_values.next().is_some() {
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse)); return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
} }
@ -223,12 +224,15 @@ impl VerifyData {
} }
// Already had PMCE configured // Already had PMCE configured
if pmce.is_some() { if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict( return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(), name.to_string(),
))); )));
} }
pmce = Some(compression.accept_response(params)?);
extensions = Some(Extensions {
compression: Some(compression.accept_response(params)?),
});
} }
} else if let Some((name, _)) = exts.next() { } else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something // The client didn't request anything, but got something
@ -243,7 +247,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok((response, pmce)) Ok((response, extensions))
} }
} }

@ -20,7 +20,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions, extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -203,8 +203,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>, error_response: Option<ErrorResponse>,
// Negotiated Per-Message Compression Extension context for server. // Negotiated extension context for server.
pmce: Option<extensions::DeflateContext>, extensions: Option<Extensions>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -222,7 +222,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback), callback: Some(callback),
config, config,
error_response: None, error_response: None,
pmce: None, extensions: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -246,10 +246,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let mut response = create_response(&result)?; let mut response = create_response(&result)?;
if let Some(config) = &self.config { if let Some(config) = &self.config {
let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter(); let values = result.headers().get_all("Sec-WebSocket-Extensions").iter();
if let Some((agreed, pmce)) = config.accept_offers(extensions) { if let Some((agreed, extensions)) = config.accept_offers(values) {
self.pmce = Some(pmce);
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed); response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
self.extensions = Some(extensions);
} }
} }
@ -292,11 +292,11 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(err)); return Err(Error::Http(err));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket_with_compression( let websocket = WebSocket::from_raw_socket_with_extensions(
stream, stream,
Role::Server, Role::Server,
self.config, self.config,
self.pmce.take(), self.extensions.take(),
); );
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }

@ -23,7 +23,7 @@ use self::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::{self, DeflateContext}, extensions::{self, Extensions},
util::NonBlockingResult, util::NonBlockingResult,
}; };
@ -81,15 +81,14 @@ impl WebSocketConfig {
self.compression.map(|c| c.generate_offer()) self.compression.map(|c| c.generate_offer())
} }
// TODO Replace `DeflateContext` with something more general // This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
// This can be used with `WebSocket::from_raw_socket_with_compression` for integration. /// Returns negotiation response based on offers and `Extensions` to manage extensions.
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub fn accept_offers<'a>( pub fn accept_offers<'a>(
&'a self, &'a self,
extensions: impl Iterator<Item = &'a HeaderValue>, extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, DeflateContext)> { ) -> Option<(HeaderValue, Extensions)> {
if let Some(compression) = &self.compression { if let Some(compression) = &self.compression {
let extensions = crate::extensions::iter_all(extensions); let extensions = extensions::iter_all(extensions);
let offers = let offers =
extensions.filter_map( extensions.filter_map(
|(k, v)| { |(k, v)| {
@ -100,7 +99,12 @@ impl WebSocketConfig {
} }
}, },
); );
compression.accept_offer(offers)
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
compression
.accept_offer(offers)
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
} else { } else {
None None
} }
@ -130,14 +134,14 @@ impl<Stream> WebSocket<Stream> {
} }
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_compression( pub fn from_raw_socket_with_extensions(
stream: Stream, stream: Stream,
role: Role, role: Role,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>, extensions: Option<Extensions>,
) -> Self { ) -> Self {
let mut context = WebSocketContext::new(role, config); let mut context = WebSocketContext::new(role, config);
context.pmce = pmce; context.extensions = extensions;
WebSocket { socket: stream, context } WebSocket { socket: stream, context }
} }
@ -158,17 +162,17 @@ impl<Stream> WebSocket<Stream> {
} }
} }
pub(crate) fn from_partially_read_with_compression( pub(crate) fn from_partially_read_with_extensions(
stream: Stream, stream: Stream,
part: Vec<u8>, part: Vec<u8>,
role: Role, role: Role,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>, extensions: Option<Extensions>,
) -> Self { ) -> Self {
WebSocket { WebSocket {
socket: stream, socket: stream,
context: WebSocketContext::from_partially_read_with_compression( context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, pmce, part, role, config, extensions,
), ),
} }
} }
@ -306,8 +310,8 @@ pub struct WebSocketContext {
pong: Option<Frame>, pong: Option<Frame>,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
/// Per-Message Compression Extension. Only deflate is supported at the moment. // Container for extensions.
pub(crate) pmce: Option<extensions::DeflateContext>, pub(crate) extensions: Option<Extensions>,
} }
impl WebSocketContext { impl WebSocketContext {
@ -321,7 +325,7 @@ impl WebSocketContext {
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
pong: None, pong: None,
config: config.unwrap_or_else(WebSocketConfig::default), config: config.unwrap_or_else(WebSocketConfig::default),
pmce: None, extensions: None,
} }
} }
@ -333,15 +337,15 @@ impl WebSocketContext {
} }
} }
pub(crate) fn from_partially_read_with_compression( pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>, part: Vec<u8>,
role: Role, role: Role,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>, extensions: Option<Extensions>,
) -> Self { ) -> Self {
WebSocketContext { WebSocketContext {
frame: FrameCodec::from_partially_read(part), frame: FrameCodec::from_partially_read(part),
pmce, extensions,
..WebSocketContext::new(role, config) ..WebSocketContext::new(role, config)
} }
} }
@ -447,11 +451,12 @@ impl WebSocketContext {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind"); debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata); let opcode = OpCode::Data(opdata);
let is_final = true; let is_final = true;
let frame = if let Some(pmce) = self.pmce.as_mut() { let frame =
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final) if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
} else { Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
Frame::message(data, opcode, is_final) } else {
}; Frame::message(data, opcode, is_final)
};
Ok(frame) Ok(frame)
} }
@ -533,7 +538,7 @@ impl WebSocketContext {
// Connection_. // Connection_.
let is_compressed = { let is_compressed = {
let hdr = frame.header(); let hdr = frame.header();
if (hdr.rsv1 && self.pmce.is_none()) || hdr.rsv2 || hdr.rsv3 { if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
@ -606,8 +611,9 @@ impl WebSocketContext {
if let Some(ref mut msg) = self.incomplete { if let Some(ref mut msg) = self.incomplete {
let data = if msg.compressed() { let data = if msg.compressed() {
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap // `msg.compressed` is only set when compression is enabled so it's safe to unwrap
self.pmce self.extensions
.as_mut() .as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap() .unwrap()
.decompress(frame.into_data(), fin)? .decompress(frame.into_data(), fin)?
} else { } else {
@ -637,8 +643,9 @@ impl WebSocketContext {
}; };
let mut m = IncompleteMessage::new(message_type, is_compressed); let mut m = IncompleteMessage::new(message_type, is_compressed);
let data = if is_compressed { let data = if is_compressed {
self.pmce self.extensions
.as_mut() .as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap() .unwrap()
.decompress(frame.into_data(), fin)? .decompress(frame.into_data(), fin)?
} else { } else {
@ -729,6 +736,10 @@ impl WebSocketContext {
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state) self.frame.write_frame(stream, frame).check_connection_reset(self.state)
} }
fn has_compression(&self) -> bool {
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
} }
/// The current connection state. /// The current connection state.

Loading…
Cancel
Save