Adds initial deflate implementation

pull/144/head
SirCipher 5 years ago
parent 006cec72ea
commit e93f9a01e6
  1. 3
      Cargo.toml
  2. 3
      src/error.rs
  3. 155
      src/extensions/compression.rs
  4. 278
      src/extensions/deflate.rs
  5. 26
      src/extensions/mod.rs
  6. 56
      src/handshake/client.rs
  7. 7
      src/lib.rs
  8. 1
      src/protocol/frame/mod.rs
  9. 35
      src/protocol/mod.rs

@ -29,11 +29,14 @@ rand = "0.7.2"
sha-1 = "0.9"
url = "2.1.0"
utf-8 = "0.7.5"
flate2 = { version = "1.0", features = ["zlib"], default-features = false }
[dependencies.native-tls]
optional = true
version = "0.2.3"
[dev-dependencies]
env_logger = "0.7.1"
net2 = "0.2.33"

@ -67,6 +67,8 @@ pub enum Error {
Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
/// An error from a WebSocket extension.
ExtensionError(Box<dyn std::error::Error + Send + Sync>),
}
impl fmt::Display for Error {
@ -84,6 +86,7 @@ impl fmt::Display for Error {
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
Error::ExtensionError(ref e) => write!(f, "{}", e),
}
}
}

@ -0,0 +1,155 @@
//!
use std::fmt::{Debug, Display, Formatter};
use crate::extensions::deflate::{DeflateConfig, DeflateExtension};
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::Frame;
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
#[derive(Copy, Clone, Debug)]
pub enum CompressionConfig {
Uncompressed,
Deflate(DeflateConfig),
}
impl CompressionConfig {
pub fn into_strategy(self) -> CompressionStrategy {
match self {
Self::Uncompressed => CompressionStrategy::Uncompressed,
Self::Deflate(_config) => CompressionStrategy::Deflate(DeflateExtension::new()),
}
}
pub fn uncompressed() -> CompressionConfig {
CompressionConfig::Uncompressed
}
pub fn deflate() -> CompressionConfig {
CompressionConfig::Deflate(DeflateConfig::default())
}
}
pub enum CompressionStrategy {
Uncompressed,
Deflate(DeflateExtension),
}
#[derive(Debug, Clone)]
pub struct CompressionExtensionError(String);
impl std::error::Error for CompressionExtensionError {}
impl Display for CompressionExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<CompressionExtensionError> for crate::Error {
fn from(e: CompressionExtensionError) -> Self {
crate::Error::ExtensionError(Box::new(e))
}
}
impl WebSocketExtension for CompressionStrategy {
type Error = CompressionExtensionError;
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {
match self {
Self::Uncompressed => request,
Self::Deflate(de) => de.on_request(request),
}
}
fn on_response<T>(&mut self, response: &Response<T>) {
match self {
Self::Uncompressed => {}
Self::Deflate(de) => de.on_response(response),
}
}
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> {
match self {
Self::Uncompressed => Ok(frame),
Self::Deflate(de) => de
.on_send_frame(frame)
.map_err(|e| CompressionExtensionError(e.to_string())),
}
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Frame>, Self::Error> {
match self {
Self::Uncompressed => Ok(Some(frame)),
Self::Deflate(de) => de
.on_receive_frame(frame)
.map_err(|e| CompressionExtensionError(e.to_string())),
}
}
}
impl Debug for CompressionStrategy {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Uncompressed => f.debug_struct("Uncompressed").finish(),
Self::Deflate(_) => f.debug_struct("Deflate").finish(),
}
}
}
impl CompressionConfig {
fn as_header_value(&self) -> Option<HeaderValue> {
match self {
Self::Uncompressed => None,
Self::Deflate(_) => Some(HeaderValue::from_static("permessage-deflate")),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompressionSelectorError(&'static str);
impl std::error::Error for CompressionSelectorError {}
impl Display for CompressionSelectorError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<CompressionSelectorError> for crate::Error {
fn from(e: CompressionSelectorError) -> Self {
crate::Error::ExtensionError(Box::new(e))
}
}
impl WebSocketExtension for CompressionConfig {
type Error = CompressionSelectorError;
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
if let Some(header_value) = self.as_header_value() {
request
.headers_mut()
.append(SEC_WEBSOCKET_EXTENSIONS, header_value);
}
request
}
fn on_response<T>(&mut self, response: &Response<T>) {
let mut iter = response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter();
let self_header = match self.as_header_value() {
Some(hv) => hv,
None => return,
};
match iter.next() {
Some(hv) if hv == self_header => {}
_ => {
*self = CompressionConfig::Uncompressed;
}
}
}
}

@ -0,0 +1,278 @@
//! Permessage-deflate extension
use std::fmt::{Display, Formatter};
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use flate2::{Compress, CompressError, Compression, Decompress, DecompressError};
use std::mem::replace;
pub struct DeflateExtension {
pub(crate) config: DeflateConfig,
pub(crate) fragments: Vec<Frame>,
inflator: Inflator,
deflator: Deflator,
}
impl DeflateExtension {
pub fn new() -> DeflateExtension {
DeflateExtension {
config: Default::default(),
fragments: vec![],
inflator: Inflator::new(),
deflator: Deflator::new(Compression::best()),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct DeflateConfig {
/// The max size of the sliding window. If the other endpoint selects a smaller size, that size
/// will be used instead. This must be an integer between 9 and 15 inclusive.
/// Default: 15
pub max_window_bits: u8,
/// Indicates whether to ask the other endpoint to reset the sliding window for each message.
/// Default: false
pub request_no_context_takeover: bool,
/// Indicates whether this endpoint will agree to reset the sliding window for each message it
/// compresses. If this endpoint won't agree to reset the sliding window, then the handshake
/// will fail if this endpoint is a client and the server requests no context takeover.
/// Default: true
pub accept_no_context_takeover: bool,
/// The number of WebSocket frames to store when defragmenting an incoming fragmented
/// compressed message.
/// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to
/// allow for differences between compressed and uncompressed messages.
/// Default: 10
pub fragments_capacity: usize,
/// Indicates whether the extension handler will reallocate if the `fragments_capacity` is
/// exceeded. If this is not true, a capacity error will be triggered instead.
/// Default: true
pub fragments_grow: bool,
compress_reset: bool,
decompress_reset: bool,
}
impl Default for DeflateConfig {
fn default() -> Self {
DeflateConfig {
max_window_bits: 15,
request_no_context_takeover: false,
accept_no_context_takeover: true,
fragments_capacity: 10,
fragments_grow: true,
compress_reset: false,
decompress_reset: false,
}
}
}
#[derive(Debug, Clone)]
pub enum DeflateExtensionError {
DeflateError(String),
InflateError(String),
}
impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DeflateExtensionError::DeflateError(m) => write!(f, "{}", m),
DeflateExtensionError::InflateError(m) => write!(f, "{}", m),
}
}
}
impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Box::new(e))
}
}
impl WebSocketExtension for DeflateExtension {
type Error = DeflateExtensionError;
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> {
if let OpCode::Data(_) = frame.header().opcode {
frame.header_mut().rsv1 = true;
let mut compressed = Vec::with_capacity(frame.payload().len());
self.deflator.compress(frame.payload(), &mut compressed)?;
let len = compressed.len();
compressed.truncate(len - 4);
*frame.payload_mut() = compressed;
if self.config.compress_reset {
self.deflator.reset();
}
}
Ok(frame)
}
fn on_receive_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>, Self::Error> {
if frame.header().rsv1 {
frame.header_mut().rsv1 = false;
if !frame.header().is_final {
self.fragments.push(frame);
return Ok(None);
} else {
if let OpCode::Data(Data::Continue) = frame.header().opcode {
if !self.config.fragments_grow
&& self.config.fragments_capacity == self.fragments.len()
{
return Err(DeflateExtensionError::DeflateError(
"Exceeded max fragments.".into(),
));
} else {
self.fragments.push(frame);
}
let opcode = self.fragments.first().unwrap().header().opcode;
let size = self
.fragments
.iter()
.fold(0, |len, frame| len + frame.payload().len());
let mut compressed = Vec::with_capacity(size);
let decompressed = Vec::with_capacity(size * 2);
replace(
&mut self.fragments,
Vec::with_capacity(self.config.fragments_capacity),
)
.into_iter()
.for_each(|f| {
compressed.extend(f.into_data());
});
compressed.extend(&[0, 0, 255, 255]);
frame = Frame::message(decompressed, opcode, true);
} else {
frame.payload_mut().extend(&[0, 0, 255, 255]);
let mut decompress_output = Vec::with_capacity(frame.payload().len() * 2);
self.inflator
.decompress(frame.payload(), &mut decompress_output)?;
*frame.payload_mut() = decompress_output;
}
if self.config.decompress_reset {
self.inflator.reset(false);
}
}
}
Ok(Some(frame))
}
}
impl From<DecompressError> for DeflateExtensionError {
fn from(e: DecompressError) -> Self {
DeflateExtensionError::InflateError(e.to_string())
}
}
impl From<CompressError> for DeflateExtensionError {
fn from(e: CompressError) -> Self {
DeflateExtensionError::DeflateError(e.to_string())
}
}
struct Deflator {
compress: Compress,
}
impl Deflator {
pub fn new(compresion: Compression) -> Deflator {
Deflator {
compress: Compress::new(compresion, false),
}
}
fn reset(&mut self) {
self.compress.reset()
}
pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<usize, CompressError> {
loop {
let before_in = self.compress.total_in();
output.reserve(256);
let status = self
.compress
.compress_vec(input, output, flate2::FlushCompress::Sync)?;
let written = (self.compress.total_in() - before_in) as usize;
if written != 0 || status == flate2::Status::StreamEnd {
return Ok(written);
}
}
}
}
struct Inflator {
decompress: Decompress,
}
impl Inflator {
pub fn new() -> Inflator {
Inflator {
decompress: Decompress::new(false),
}
}
fn reset(&mut self, zlib_header: bool) {
self.decompress.reset(zlib_header)
}
pub fn decompress(
&mut self,
input: &[u8],
output: &mut Vec<u8>,
) -> Result<usize, DecompressError> {
let mut read_buff = Vec::from(input);
let mut eof = false;
loop {
if read_buff.is_empty() {
eof = true;
}
if !eof && output.is_empty() {
output.reserve(256);
unsafe {
output.set_len(output.capacity());
}
}
let before_out = self.decompress.total_out();
let before_in = self.decompress.total_in();
let decompression_strategy = if eof {
flate2::FlushDecompress::Finish
} else {
flate2::FlushDecompress::None
};
let status = self
.decompress
.decompress(&read_buff, output, decompression_strategy)?;
let consumed = (self.decompress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
let read = (self.decompress.total_out() - before_out) as usize;
if read != 0 || status == flate2::Status::StreamEnd {
output.truncate(read);
return Ok(read);
}
}
}
}

@ -0,0 +1,26 @@
//! WebSocket extensions
use http::{Request, Response};
use crate::protocol::frame::Frame;
pub mod compression;
pub mod deflate;
pub trait WebSocketExtension {
type Error: Into<crate::Error>;
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {
request
}
fn on_response<T>(&mut self, _response: &Response<T>) {}
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> {
Ok(frame)
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Frame>, Self::Error> {
Ok(Some(frame))
}
}

@ -11,6 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type.
@ -32,7 +33,7 @@ impl<S: Read + Write> ClientHandshake<S> {
pub fn start(
stream: S,
request: Request,
config: Option<WebSocketConfig>,
mut config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
@ -52,7 +53,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let key = generate_key();
let machine = {
let req = generate_request(request, &key)?;
let req = generate_request(request, &key, &mut config)?;
HandshakeMachine::start_write(stream, req)
};
@ -90,7 +91,8 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result,
tail,
} => {
self.verify_data.verify_response(&result)?;
self.verify_data
.verify_response(&result, &mut self.config)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
@ -101,20 +103,30 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Generate client request.
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
fn generate_request(
request: Request,
key: &str,
config: &mut Option<WebSocketConfig>,
) -> Result<Vec<u8>> {
let request = match &config {
Some(mut config) => config.compression_config.on_request(request),
None => request,
};
let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority()
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}
write!(
@ -138,7 +150,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
for (k, v) in request.headers() {
let mut k = k.as_str();
if k == "sec-websocket-protocol" {
if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol";
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
@ -156,7 +168,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
pub fn verify_response(
&self,
response: &Response,
config: &mut Option<WebSocketConfig>,
) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -205,12 +221,16 @@ impl VerifyData {
"Key mismatch in Sec-WebSocket-Accept".into(),
));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
if let Some(config) = config {
config.compression_config.on_response(response);
}
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -266,8 +286,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
#[test]
fn random_keys() {
@ -297,14 +317,16 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
@ -314,14 +336,16 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://user:pass@localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
@ -331,7 +355,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -1,8 +1,8 @@
//! Lightweight, flexible WebSockets for Rust.
#![deny(
missing_docs,
missing_copy_implementations,
missing_debug_implementations,
// missing_docs,
// missing_copy_implementations,
// missing_debug_implementations,
trivial_casts,
trivial_numeric_casts,
unstable_features,
@ -16,6 +16,7 @@ pub use http;
pub mod client;
pub mod error;
pub mod extensions;
pub mod handshake;
pub mod protocol;
pub mod server;

@ -187,6 +187,7 @@ impl FrameCodec {
frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending(stream)
}

@ -16,6 +16,8 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::{IncompleteMessage, IncompleteMessageType};
use crate::error::{Error, Result};
use crate::extensions::compression::{CompressionConfig, CompressionStrategy};
use crate::extensions::WebSocketExtension;
use crate::util::NonBlockingResult;
/// Indicates a Client or Server role of the websocket
@ -28,7 +30,7 @@ pub enum Role {
}
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
@ -43,6 +45,8 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// Permessage compression strategy.
pub compression_config: CompressionConfig,
}
impl Default for WebSocketConfig {
@ -51,6 +55,7 @@ impl Default for WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
compression_config: CompressionConfig::Uncompressed,
}
}
}
@ -101,6 +106,7 @@ impl<Stream> WebSocket<Stream> {
pub fn get_ref(&self) -> &Stream {
&self.socket
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.socket
@ -230,11 +236,16 @@ pub struct WebSocketContext {
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
/// WebSocket compression strategy.
compressor: CompressionStrategy,
}
impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
let config = config.unwrap_or_else(WebSocketConfig::default);
let compressor = config.compression_config.into_strategy();
WebSocketContext {
role,
frame: FrameCodec::new(),
@ -242,7 +253,8 @@ impl WebSocketContext {
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_else(WebSocketConfig::default),
config,
compressor,
}
}
@ -426,17 +438,6 @@ impl WebSocketContext {
"Remote sent frame after having sent a Close Frame".into(),
));
}
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()));
}
}
match self.role {
Role::Server => {
@ -491,6 +492,12 @@ impl WebSocketContext {
OpCode::Data(data) => {
let fin = frame.header().is_final;
let compressor = &mut self.compressor;
let frame = match compressor.on_receive_frame(frame)? {
Some(frame) => frame,
None => return Ok(None),
};
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
@ -601,6 +608,8 @@ impl WebSocketContext {
}
}
let frame = self.compressor.on_send_frame(frame)?;
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)

Loading…
Cancel
Save