Make `DeflateContext` private and add `Extensions` container

pull/235/head
kazk 4 years ago
parent 78322fed68
commit c62eccc8df
  1. 10
      src/extensions/mod.rs
  2. 28
      src/handshake/client.rs
  3. 18
      src/handshake/server.rs
  4. 67
      src/protocol/mod.rs

@ -2,9 +2,17 @@
// Only `permessage-deflate` is supported at the moment.
mod compression;
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError};
use compression::deflate::DeflateContext;
pub use compression::deflate::{DeflateConfig, DeflateError};
use http::HeaderValue;
/// Container for configured extensions.
#[derive(Debug, Default)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
pub(crate) compression: Option<DeflateContext>,
}
/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
pub(crate) fn iter_all<'a>(
values: impl Iterator<Item = &'a HeaderValue>,

@ -17,7 +17,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
extensions::{self, DeflateContext},
extensions::{self, Extensions},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -83,14 +83,15 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let (result, pmce) = self.verify_data.verify_response(result, &self.config)?;
let (result, extensions) =
self.verify_data.verify_response(result, &self.config)?;
debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read_with_compression(
let websocket = WebSocket::from_partially_read_with_extensions(
stream,
tail,
Role::Client,
self.config,
pmce,
extensions,
);
ProcessingResult::Done((websocket, result))
}
@ -161,7 +162,7 @@ impl VerifyData {
&self,
response: Response,
config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<DeflateContext>)> {
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -201,15 +202,15 @@ impl VerifyData {
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
}
let mut pmce = None;
let mut extensions = None;
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions.next() {
if extensions.next().is_some() {
let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
if let Some(value) = extensions_values.next() {
if extensions_values.next().is_some() {
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
}
@ -223,12 +224,15 @@ impl VerifyData {
}
// Already had PMCE configured
if pmce.is_some() {
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
name.to_string(),
)));
}
pmce = Some(compression.accept_response(params)?);
extensions = Some(Extensions {
compression: Some(compression.accept_response(params)?),
});
}
} else if let Some((name, _)) = exts.next() {
// The client didn't request anything, but got something
@ -243,7 +247,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok((response, pmce))
Ok((response, extensions))
}
}

@ -20,7 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions,
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -203,8 +203,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated Per-Message Compression Extension context for server.
pmce: Option<extensions::DeflateContext>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -222,7 +222,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
pmce: None,
extensions: None,
_marker: PhantomData,
},
}
@ -246,10 +246,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let mut response = create_response(&result)?;
if let Some(config) = &self.config {
let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter();
if let Some((agreed, pmce)) = config.accept_offers(extensions) {
self.pmce = Some(pmce);
let values = result.headers().get_all("Sec-WebSocket-Extensions").iter();
if let Some((agreed, extensions)) = config.accept_offers(values) {
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
self.extensions = Some(extensions);
}
}
@ -292,11 +292,11 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket_with_compression(
let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
self.pmce.take(),
self.extensions.take(),
);
ProcessingResult::Done(websocket)
}

@ -23,7 +23,7 @@ use self::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::{self, DeflateContext},
extensions::{self, Extensions},
util::NonBlockingResult,
};
@ -81,15 +81,14 @@ impl WebSocketConfig {
self.compression.map(|c| c.generate_offer())
}
// TODO Replace `DeflateContext` with something more general
// This can be used with `WebSocket::from_raw_socket_with_compression` for integration.
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers<'a>(
&'a self,
extensions: impl Iterator<Item = &'a HeaderValue>,
) -> Option<(HeaderValue, DeflateContext)> {
) -> Option<(HeaderValue, Extensions)> {
if let Some(compression) = &self.compression {
let extensions = crate::extensions::iter_all(extensions);
let extensions = extensions::iter_all(extensions);
let offers =
extensions.filter_map(
|(k, v)| {
@ -100,7 +99,12 @@ impl WebSocketConfig {
}
},
);
compression.accept_offer(offers)
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
compression
.accept_offer(offers)
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
} else {
None
}
@ -130,14 +134,14 @@ impl<Stream> WebSocket<Stream> {
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_compression(
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.pmce = pmce;
context.extensions = extensions;
WebSocket { socket: stream, context }
}
@ -158,17 +162,17 @@ impl<Stream> WebSocket<Stream> {
}
}
pub(crate) fn from_partially_read_with_compression(
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_compression(
part, role, config, pmce,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
@ -306,8 +310,8 @@ pub struct WebSocketContext {
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
/// Per-Message Compression Extension. Only deflate is supported at the moment.
pub(crate) pmce: Option<extensions::DeflateContext>,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
}
impl WebSocketContext {
@ -321,7 +325,7 @@ impl WebSocketContext {
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_else(WebSocketConfig::default),
pmce: None,
extensions: None,
}
}
@ -333,15 +337,15 @@ impl WebSocketContext {
}
}
pub(crate) fn from_partially_read_with_compression(
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
pmce: Option<DeflateContext>,
extensions: Option<Extensions>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
pmce,
extensions,
..WebSocketContext::new(role, config)
}
}
@ -447,11 +451,12 @@ impl WebSocketContext {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
let frame = if let Some(pmce) = self.pmce.as_mut() {
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
} else {
Frame::message(data, opcode, is_final)
};
let frame =
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
} else {
Frame::message(data, opcode, is_final)
};
Ok(frame)
}
@ -533,7 +538,7 @@ impl WebSocketContext {
// Connection_.
let is_compressed = {
let hdr = frame.header();
if (hdr.rsv1 && self.pmce.is_none()) || hdr.rsv2 || hdr.rsv3 {
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
@ -606,8 +611,9 @@ impl WebSocketContext {
if let Some(ref mut msg) = self.incomplete {
let data = if msg.compressed() {
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap
self.pmce
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(frame.into_data(), fin)?
} else {
@ -637,8 +643,9 @@ impl WebSocketContext {
};
let mut m = IncompleteMessage::new(message_type, is_compressed);
let data = if is_compressed {
self.pmce
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(frame.into_data(), fin)?
} else {
@ -729,6 +736,10 @@ impl WebSocketContext {
trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
fn has_compression(&self) -> bool {
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
}
/// The current connection state.

Loading…
Cancel
Save