Refactors deflate extension handling

pull/144/head
SirCipher 5 years ago
parent e24c00db32
commit 0127b55160
  1. 30
      examples/autobahn-client.rs
  2. 27
      src/client.rs
  3. 590
      src/ext/deflate.rs
  4. 33
      src/ext/mod.rs
  5. 96
      src/ext/uncompressed.rs
  6. 6
      src/extensions/compression.rs
  7. 40
      src/extensions/deflate.rs
  8. 2
      src/extensions/mod.rs
  9. 63
      src/handshake/client.rs
  10. 30
      src/handshake/server.rs
  11. 2
      src/lib.rs
  12. 159
      src/protocol/mod.rs
  13. 27
      src/server.rs
  14. 45
      tests/connection_reset.rs

@ -2,21 +2,17 @@ use log::*;
use url::Url;
use tungstenite::client::connect_with_config;
use tungstenite::extensions::compression::CompressionConfig;
use tungstenite::ext::deflate::DeflateExtension;
use tungstenite::ext::uncompressed::UncompressedExt;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::{connect, Error, Message, Result};
use tungstenite::{connect, Error, Message, Result, WebSocket};
const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect_with_config(
let (mut socket, _): (WebSocket<_, UncompressedExt>, _) = connect_with_config(
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
Some(WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
compression_config: CompressionConfig::deflate(),
}),
None,
)?;
let msg = socket.read_message()?;
socket.close(None)?;
@ -48,7 +44,7 @@ fn run_test(case: u32) -> Result<()> {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
compression_config: CompressionConfig::deflate(),
encoder: DeflateExtension::default(),
}),
)?;
@ -67,16 +63,16 @@ fn main() {
env_logger::init();
let total = get_case_count().unwrap();
let _total = get_case_count().unwrap();
for case in 1..=total {
if let Err(e) = run_test(case) {
match e {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
err => error!("test: {}", err),
}
// for case in 1..=total {
if let Err(e) = run_test(334) {
match e {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
err => error!("test: {}", err),
}
}
// }
update_reports().unwrap();
}

@ -66,6 +66,8 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::ext::uncompressed::UncompressedExt;
use crate::ext::WebSocketExtension;
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
@ -86,10 +88,13 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req: IntoClientRequest>(
pub fn connect_with_config<Req: IntoClientRequest, E>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
config: Option<WebSocketConfig<E>>,
) -> Result<(WebSocket<AutoStream, E>, Response)>
where
E: WebSocketExtension,
{
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
@ -122,7 +127,9 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream, UncompressedExt>, Response)> {
connect_with_config(request, None)
}
@ -159,14 +166,15 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<Stream, Req>(
pub fn client_with_config<Stream, Req, E>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
config: Option<WebSocketConfig<E>>,
) -> StdResult<(WebSocket<Stream, E>, Response), HandshakeError<ClientHandshake<Stream, E>>>
where
Stream: Read + Write,
Req: IntoClientRequest,
E: WebSocketExtension,
{
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}
@ -179,7 +187,10 @@ where
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<
(WebSocket<Stream, UncompressedExt>, Response),
HandshakeError<ClientHandshake<Stream, UncompressedExt>>,
>
where
Stream: Read + Write,
Req: IntoClientRequest,

@ -0,0 +1,590 @@
//! Permessage-deflate extension
use std::fmt::{Display, Formatter};
use crate::ext::uncompressed::UncompressedExt;
use crate::ext::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::{Error, Message};
use flate2::{
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress,
FlushDecompress, Status,
};
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
use std::mem::replace;
use std::slice;
pub struct DeflateExtension {
enabled: bool,
config: DeflateConfig,
fragments: Vec<Frame>,
inflator: Inflator,
deflator: Deflator,
uncompressed_extension: UncompressedExt,
}
impl Clone for DeflateExtension {
fn clone(&self) -> Self {
DeflateExtension {
enabled: self.enabled,
config: self.config,
fragments: vec![],
inflator: Inflator::new(),
deflator: Deflator::new(self.config.compression_level),
uncompressed_extension: UncompressedExt::new(self.config.max_message_size),
}
}
}
impl Default for DeflateExtension {
fn default() -> Self {
DeflateExtension::new(Default::default())
}
}
impl DeflateExtension {
pub fn new(config: DeflateConfig) -> DeflateExtension {
DeflateExtension {
enabled: false,
config,
fragments: vec![],
inflator: Inflator::new(),
deflator: Deflator::new(Compression::fast()),
uncompressed_extension: UncompressedExt::new(config.max_message_size),
}
}
fn complete_message(&self, data: Vec<u8>, opcode: OpCode) -> Result<Message, Error> {
let message_type = match opcode {
OpCode::Data(Data::Text) => IncompleteMessageType::Text,
OpCode::Data(Data::Binary) => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut incomplete_message = IncompleteMessage::new(message_type);
incomplete_message.extend(data, self.config.max_message_size)?;
incomplete_message.complete()
}
}
#[derive(Clone, Copy, Debug)]
pub struct DeflateConfig {
pub max_message_size: Option<usize>,
pub max_window_bits: u8,
pub request_no_context_takeover: bool,
pub accept_no_context_takeover: bool,
pub fragments_capacity: usize,
pub fragments_grow: bool,
pub compress_reset: bool,
pub decompress_reset: bool,
pub compression_level: Compression,
}
impl Default for DeflateConfig {
fn default() -> Self {
DeflateConfig {
max_message_size: Some(MAX_MESSAGE_SIZE),
max_window_bits: 15,
request_no_context_takeover: false,
accept_no_context_takeover: true,
fragments_capacity: 10,
fragments_grow: true,
compress_reset: false,
decompress_reset: false,
compression_level: Compression::best(),
}
}
}
#[derive(Debug, Clone)]
pub enum DeflateExtensionError {
DeflateError(String),
InflateError(String),
}
impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DeflateExtensionError::DeflateError(m) => write!(f, "{}", m),
DeflateExtensionError::InflateError(m) => write!(f, "{}", m),
}
}
}
impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Box::new(e))
}
}
const EXT_NAME: &str = "permessage-deflate";
impl WebSocketExtension for DeflateExtension {
type Error = DeflateExtensionError;
fn enabled(&self) -> bool {
self.enabled
}
fn rsv1(&self) -> bool {
true
}
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_NAME);
let DeflateConfig {
max_window_bits,
request_no_context_takeover,
..
} = self.config;
if max_window_bits < 15 {
header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits
))
} else {
header_value.push_str("; client_max_window_bits")
}
if request_no_context_takeover {
header_value.push_str("; server_no_context_takeover")
}
request.headers_mut().append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&header_value).unwrap(),
);
request
}
fn on_response<T>(&mut self, response: &Response<T>) {
let mut iter = response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter();
let self_header = HeaderValue::from_static(EXT_NAME);
match iter.next() {
Some(hv) if hv == self_header => {
self.enabled = true;
}
_ => {
self.enabled = false;
}
}
}
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> {
if self.enabled {
if let OpCode::Data(_) = frame.header().opcode {
frame.header_mut().rsv1 = true;
let mut compressed = Vec::with_capacity(frame.payload().len());
self.deflator.compress(frame.payload(), &mut compressed)?;
let len = compressed.len();
compressed.truncate(len - 4);
*frame.payload_mut() = compressed;
if self.config.compress_reset {
self.deflator.reset();
}
}
}
Ok(frame)
}
fn on_receive_frame(&mut self, mut frame: Frame) -> Result<Option<Message>, Self::Error> {
match frame.header().opcode {
OpCode::Control(_) => unreachable!(),
_ => {
if self.enabled && (!self.fragments.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragments.push(frame);
return Ok(None);
} else {
let message = if let OpCode::Data(Data::Continue) = frame.header().opcode {
if !self.config.fragments_grow
&& self.config.fragments_capacity == self.fragments.len()
{
return Err(DeflateExtensionError::DeflateError(
"Exceeded max fragments.".into(),
));
} else {
self.fragments.push(frame);
}
let opcode = self.fragments.first().unwrap().header().opcode;
let size = self
.fragments
.iter()
.fold(0, |len, frame| len + frame.payload().len());
let mut compressed = Vec::with_capacity(size);
let mut decompressed = Vec::with_capacity(size * 2);
replace(
&mut self.fragments,
Vec::with_capacity(self.config.fragments_capacity),
)
.into_iter()
.for_each(|f| {
compressed.extend(f.into_data());
});
compressed.extend(&[0, 0, 255, 255]);
self.inflator.decompress(&compressed, &mut decompressed)?;
self.complete_message(decompressed, opcode)
} else {
frame.payload_mut().extend(&[0, 0, 255, 255]);
let mut decompress_output =
Vec::with_capacity(frame.payload().len() * 2);
self.inflator
.decompress(frame.payload(), &mut decompress_output)?;
self.complete_message(decompress_output, frame.header().opcode)
};
if self.config.decompress_reset {
self.inflator.reset(false);
}
match message {
Ok(message) => Ok(Some(message)),
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())),
}
}
} else {
self.uncompressed_extension
.on_receive_frame(frame)
.map_err(|e| DeflateExtensionError::DeflateError(e.to_string()))
}
}
}
}
}
impl From<DecompressError> for DeflateExtensionError {
fn from(e: DecompressError) -> Self {
DeflateExtensionError::InflateError(e.to_string())
}
}
impl From<CompressError> for DeflateExtensionError {
fn from(e: CompressError) -> Self {
DeflateExtensionError::DeflateError(e.to_string())
}
}
struct Deflator {
compress: Compress,
}
impl Deflator {
pub fn new(compresion: Compression) -> Deflator {
Deflator {
compress: Compress::new(compresion, false),
}
}
fn reset(&mut self) {
self.compress.reset()
}
pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<usize, CompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
loop {
output_size = output.len();
if output_size == output.capacity() {
output.reserve(input.len());
}
let before_out = self.compress.total_out();
let before_in = self.compress.total_in();
let status = self
.compress
.compress_vec(&read_buff, output, FlushCompress::Sync)?;
let consumed = (self.compress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
let new_size = (self.compress.total_out() - before_out) as usize + output_size;
unsafe {
output.set_len(new_size);
}
match status {
Status::Ok | Status::BufError => {
if before_out == self.compress.total_out() && read_buff.is_empty() {
return Ok(consumed);
}
}
s => panic!(s),
}
}
}
}
struct Inflator {
decompress: Decompress,
}
impl Inflator {
pub fn new() -> Inflator {
Inflator {
decompress: Decompress::new(false),
}
}
fn reset(&mut self, zlib_header: bool) {
self.decompress.reset(zlib_header)
}
pub fn decompress(
&mut self,
input: &[u8],
output: &mut Vec<u8>,
) -> Result<usize, DecompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
loop {
output_size = output.len();
if output_size == output.capacity() {
output.reserve(input.len());
}
let before_out = self.decompress.total_out();
let before_in = self.decompress.total_in();
let out_slice = unsafe {
slice::from_raw_parts_mut(
output.as_mut_ptr().offset(output_size as isize),
output.capacity() - output_size,
)
};
let status =
self.decompress
.decompress(&read_buff, out_slice, FlushDecompress::Sync)?;
let consumed = (self.decompress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
unsafe {
output.set_len((self.decompress.total_out() - before_out) as usize + output_size);
}
match status {
Status::Ok | Status::BufError => {
if before_out == self.decompress.total_out() && read_buff.is_empty() {
return Ok(consumed);
} else {
continue;
}
}
s => panic!(s),
}
}
}
}
#[test]
fn t() {
let v1 = vec![
37, 80, 68, 70, 45, 49, 46, 50, 10, 37, 199, 236, 143, 162, 10, 54, 32, 48, 32, 111, 98,
106, 10, 60, 60, 47, 76, 101, 110, 103, 116, 104, 32, 55, 32, 48, 32, 82, 47, 70, 105, 108,
116, 101, 114, 32, 47, 70, 108, 97, 116, 101, 68, 101, 99, 111, 100, 101, 62, 62, 10, 115,
116, 114, 101, 97, 109, 10, 120, 156, 125, 83, 193, 110, 212, 48, 16, 189, 231, 192, 55,
248, 88, 36, 118, 106, 123, 198, 118, 124, 44, 168, 160, 10, 1, 106, 27, 46, 220, 178, 219,
236, 174, 209, 38, 105, 55, 217, 138, 254, 61, 99, 39, 54, 55, 148, 67, 148, 55, 227, 241,
123, 111, 94, 36, 40, 33, 227, 179, 190, 119, 125, 245, 82, 93, 63, 144, 56, 76, 213, 139,
80, 169, 148, 95, 187, 94, 124, 108, 184, 200, 159, 202, 129, 86, 198, 136, 102, 95, 45,
231, 148, 80, 74, 129, 247, 86, 88, 167, 193, 213, 74, 52, 125, 117, 117, 251, 110, 23,
186, 97, 126, 223, 252, 174, 156, 132, 218, 58, 238, 108, 158, 170, 171, 38, 34, 220, 47,
9, 253, 10, 157, 219, 48, 132, 225, 16, 11, 198, 131, 38, 107, 214, 194, 184, 143, 152, 86,
96, 12, 230, 243, 63, 182, 83, 119, 126, 109, 183, 167, 46, 214, 188, 6, 35, 109, 157, 107,
207, 169, 31, 161, 70, 149, 251, 187, 115, 59, 143, 17, 38, 214, 41, 203, 232, 115, 132,
54, 154, 9, 123, 227, 197, 70, 27, 208, 104, 83, 229, 91, 234, 214, 22, 156, 215, 121, 240,
83, 119, 154, 34, 138, 22, 208, 35, 173, 232, 101, 90, 89, 147, 5, 203, 221, 43, 252, 105,
28, 230, 238, 79, 82, 110, 107, 168, 165, 202, 58, 191, 156, 219, 231, 99, 154, 195, 70,
162, 80, 4, 72, 86, 71, 35, 55, 74, 19, 72, 75, 98, 227, 16, 8, 213, 63, 167, 106, 240, 78,
187, 98, 200, 54, 180, 11, 17, 3, 158, 36, 230, 193, 144, 76, 245, 96, 139, 73, 43, 32, 61,
81, 49, 98, 154, 187, 48, 36, 221, 150, 117, 96, 188, 142, 164, 101, 179, 86, 225, 45, 55,
36, 95, 136, 87, 76, 70, 229, 141, 29, 187, 41, 100, 218, 150, 19, 1, 94, 122, 181, 208,
38, 15, 198, 90, 158, 131, 96, 105, 161, 125, 55, 76, 115, 152, 47, 115, 218, 15, 17, 215,
41, 27, 176, 31, 211, 120, 197, 118, 213, 42, 147, 191, 185, 204, 227, 48, 246, 227, 37,
93, 97, 17, 60, 22, 189, 11, 221, 58, 134, 74, 231, 133, 206, 221, 233, 20, 14, 93, 170,
32, 113, 22, 139, 13, 75, 216, 56, 107, 117, 238, 125, 124, 99, 73, 253, 244, 33, 81, 225,
21, 27, 157, 151, 244, 57, 93, 6, 90, 162, 163, 18, 195, 203, 112, 28, 247, 171, 5, 158,
55, 237, 242, 156, 155, 240, 184, 228, 197, 32, 24, 66, 177, 81, 53, 40, 52, 171, 109, 243,
177, 235, 219, 57, 236, 218, 83, 26, 202, 39, 149, 207, 155, 248, 143, 27, 215, 41, 192,
32, 157, 183, 217, 234, 95, 55, 119, 95, 19, 89, 100, 27, 56, 173, 43, 252, 115, 8, 175,
41, 148, 252, 123, 169, 34, 129, 23, 26, 146, 100, 45, 1, 109, 249, 193, 222, 150, 177,
200, 185, 118, 37, 145, 167, 241, 48, 36, 10, 156, 52, 13, 202, 241, 202, 120, 11, 44, 61,
150, 191, 167, 196, 43, 4, 148, 117, 118, 62, 221, 103, 128, 76, 137, 79, 191, 54, 89, 93,
20, 108, 23, 19, 209, 74, 83, 56, 165, 46, 230, 105, 40, 139, 210, 82, 234, 136, 222, 54,
226, 190, 186, 175, 254, 2, 247, 54, 15, 175, 101, 110, 100, 115, 116, 114, 101, 97, 109,
10, 101, 110, 100, 111, 98, 106, 10, 55, 32, 48, 32, 111, 98, 106, 10, 53, 55, 56, 10, 101,
110, 100, 111, 98, 106, 10, 49, 55, 32, 48, 32, 111, 98, 106, 10, 60, 60, 47, 82, 52, 10,
52, 32, 48, 32, 82, 62, 62, 10, 101, 110, 100, 111, 98, 106, 10, 49, 56, 32, 48, 32, 111,
98, 106, 10, 60, 60, 47, 82, 49, 54, 10, 49, 54, 32, 48, 32, 82, 47, 82, 49, 51, 10, 49,
51, 32, 48, 32, 82, 47, 82, 49, 48, 10, 49, 48, 32, 48, 32, 82, 62, 62, 10, 101, 110, 100,
111, 98, 106, 10, 50, 51, 32, 48, 32, 111, 98, 106, 10, 60, 60, 47, 76, 101, 110, 103, 116,
104, 32, 50, 52, 32, 48, 32, 82, 47, 70, 105, 108, 116, 101, 114, 32, 47, 70, 108, 97, 116,
101, 68, 101, 99, 111, 100, 101, 62, 62, 10, 115, 116, 114, 101, 97, 109, 10, 120, 156,
173, 86, 75, 143, 27, 69, 16, 190, 91, 252, 8, 31, 55, 82, 92, 116, 87, 191, 143, 139, 128,
16, 16, 129, 100, 205, 5, 229, 50, 182, 219, 235, 129, 241, 12, 59, 211, 78, 216, 83, 254,
58, 53, 253, 178, 29, 22, 177, 72, 200, 7, 75, 213, 221, 85, 213, 223, 163, 122, 24, 240,
37, 155, 127, 249, 127, 123, 92, 60, 44, 30, 150, 60, 198, 202, 223, 246, 184, 252, 106,
189, 248, 242, 29, 114, 138, 128, 99, 142, 47, 215, 251, 69, 58, 192, 151, 232, 12, 48, 52,
75, 173, 20, 232, 229, 250, 184, 184, 185, 221, 189, 88, 255, 182, 224, 10, 164, 148, 130,
246, 172, 119, 139, 155, 15, 237, 52, 140, 211, 28, 167, 60, 102, 233, 192, 105, 212, 114,
206, 179, 226, 22, 193, 162, 94, 174, 148, 5, 235, 116, 220, 206, 33, 166, 64, 224, 218,
230, 12, 183, 187, 152, 99, 142, 175, 242, 194, 138, 107, 48, 50, 174, 254, 60, 14, 251,
120, 6, 21, 37, 81, 58, 31, 250, 122, 76, 137, 44, 104, 167, 100, 14, 190, 107, 218, 222,
199, 68, 2, 1, 37,
];
let v2 = vec![
170, 28, 191, 219, 206, 49, 7, 218, 136, 90, 245, 48, 54, 187, 180, 119, 101, 53, 40, 102,
48, 85, 21, 60, 46, 255, 234, 251, 124, 87, 163, 152, 203, 103, 194, 120, 58, 198, 86, 36,
56, 44, 137, 246, 177, 30, 8, 110, 84, 9, 125, 138, 105, 25, 160, 149, 90, 150, 230, 78,
99, 190, 185, 145, 186, 220, 226, 182, 191, 247, 31, 99, 70, 7, 232, 24, 207, 225, 166,
223, 5, 31, 195, 6, 180, 174, 45, 191, 238, 247, 195, 120, 108, 66, 251, 251, 188, 38, 29,
40, 94, 75, 254, 16, 107, 26, 234, 95, 137, 66, 205, 208, 245, 9, 84, 195, 40, 141, 185,
68, 245, 151, 190, 253, 16, 11, 48, 48, 136, 229, 128, 31, 167, 54, 124, 202, 184, 90, 2,
171, 244, 19, 10, 105, 86, 62, 167, 160, 154, 91, 35, 222, 133, 0, 105, 18, 158, 248, 95,
121, 63, 19, 108, 88, 197, 240, 59, 63, 110, 98, 219, 28, 148, 115, 181, 235, 144, 169, 34,
60, 74, 43, 223, 55, 254, 62, 211, 75, 242, 181, 220, 94, 93, 255, 117, 63, 133, 54, 156,
226, 57, 161, 1, 141, 82, 255, 7, 155, 167, 48, 244, 195, 49, 50, 39, 45, 8, 235, 10, 161,
109, 196, 197, 2, 99, 206, 149, 12, 193, 119, 93, 123, 159, 100, 38, 56, 32, 199, 42, 179,
152, 129, 115, 208, 66, 151, 216, 221, 227, 20, 124, 74, 189, 226, 138, 131, 96, 242, 74,
175, 223, 206, 43, 228, 83, 99, 116, 41, 58, 54, 167, 254, 48, 236, 19, 10, 82, 1, 39, 71,
148, 78, 219, 187, 236, 216, 107, 231, 51, 13, 179, 164, 86, 142, 236, 163, 92, 226, 193,
111, 187, 102, 108, 98, 14, 13, 66, 10, 83, 186, 108, 135, 254, 73, 219, 35, 19, 169, 189,
11, 219, 175, 15, 109, 156, 17, 72, 19, 129, 155, 210, 70, 56, 248, 41, 199, 13, 88, 81,
161, 73, 49, 106, 140, 217, 10, 192, 166, 153, 252, 46, 111, 53, 115, 218, 44, 185, 100,
83, 1, 142, 243, 18, 139, 142, 210, 116, 43, 201, 10, 135, 195, 24, 61, 195, 13, 40, 89,
46, 176, 109, 198, 177, 77, 57, 73, 168, 2, 109, 221, 156, 132, 65, 155, 141, 49, 252, 115,
15, 160, 170, 10, 166, 254, 243, 70, 68, 135, 213, 165, 73, 91, 73, 6, 228, 47, 83, 149,
186, 79, 106, 39, 165, 10, 212, 226, 90, 53, 195, 41, 94, 154, 48, 147, 6, 245, 19, 186,
225, 255, 174, 155, 228, 72, 193, 207, 131, 239, 113, 90, 69, 201, 208, 220, 3, 205, 240,
210, 3, 36, 167, 233, 101, 4, 84, 3, 141, 248, 130, 202, 115, 132, 164, 221, 133, 144, 94,
102, 67, 90, 39, 75, 219, 175, 60, 141, 169, 212, 160, 132, 106, 143, 199, 120, 28, 36,
205, 155, 146, 25, 114, 199, 146, 85, 56, 222, 12, 25, 34, 203, 108, 137, 253, 209, 36,
139, 83, 21, 77, 30, 44, 52, 237, 179, 77, 132, 149, 21, 130, 172, 51, 238, 64, 158, 153,
190, 210, 153, 113, 213, 196, 135, 102, 202, 252, 89, 86, 135, 252, 38, 1, 64, 229, 207,
115, 49, 63, 6, 116, 88, 98, 233, 125, 58, 109, 142, 109, 8, 73, 64, 210, 64, 137, 251,
110, 242, 31, 15, 126, 76, 252, 75, 16, 26, 241, 154, 255, 149, 48, 114, 30, 187, 87, 6,
78, 120, 145, 230, 17, 249, 21, 98, 22, 184, 19, 149, 227, 129, 238, 50, 230, 231, 80, 11,
83, 194, 59, 127, 63, 250, 88, 81, 208, 172, 231, 12, 171, 238, 51, 70, 150, 179, 130, 198,
195, 169, 233, 218, 47, 182, 77, 49, 176, 162, 1, 104, 47, 158, 158, 60, 118, 157, 169, 69,
155, 174, 203, 46, 147, 66, 60, 203, 101, 181, 216, 19, 187, 50, 232, 124, 118, 124, 73,
182, 27, 250, 244, 224, 145, 17, 233, 225, 185, 166, 2, 213, 252, 25, 242, 143, 144, 28,
255, 62, 91, 31, 39, 223, 237, 243, 188, 33, 113, 21, 52, 78, 125, 231, 167, 41, 191, 183,
12, 207, 2, 247, 36, 110, 223, 111, 51, 151, 212, 174, 172, 158, 13, 67, 238, 246, 194,
199, 217, 247, 196, 227, 60, 206, 244, 165, 171, 182, 121, 32, 145, 238, 101, 113, 97, 24,
155, 241, 49, 115, 134, 88, 57, 75, 230, 158, 199, 60, 183, 159, 101, 38, 4, 185, 168, 38,
11, 254, 207, 0, 233, 41, 35, 114, 221, 60, 89, 221, 60, 134, 210, 100, 77, 0, 11, 165,
206, 52, 108, 218, 132, 49, 213, 147, 231, 9, 250, 10, 202, 168, 100, 188, 32, 252, 211,
38, 199, 148, 172, 229, 232, 43, 32, 248, 212, 28, 125, 70, 113, 85, 129, 126, 127, 243,
99, 19, 210, 179, 76, 11, 78, 212, 239, 149, 55, 245, 173, 86, 213, 157, 2, 37, 231, 136,
239, 95, 204, 43, 223, 172, 151, 111, 23, 111, 23, 127, 1, 75, 50, 131, 211, 101, 110, 100,
115, 116, 114, 101, 97, 109, 10, 101, 110, 100, 111, 98, 106, 10, 50, 52, 32, 48, 32, 111,
98, 106, 10, 49, 48, 54, 56, 10, 101, 110, 100, 111, 98, 106, 10, 50, 56, 32, 48, 32, 111,
98, 106, 10, 60, 60, 47, 82, 50, 55, 10, 50, 55, 32, 48, 32, 82, 47, 82, 50, 49, 10, 50,
49, 32, 48, 32, 82, 62, 62, 10, 101, 110, 100, 111, 98, 106, 10, 51, 51, 32, 48, 32, 111,
98, 106, 10, 60, 60, 47, 76, 101, 110, 103, 116, 104, 32, 51, 52, 32, 48, 32, 82, 47, 70,
105, 108, 116, 101, 114, 32, 47, 70, 108, 97, 116, 101, 68, 101, 99, 111, 100, 101, 62, 62,
10, 115, 116, 114, 101, 97, 109, 10, 120, 156, 205, 90, 77, 111, 28, 199, 17, 189, 211, 70,
126, 3, 143, 10, 64, 78, 250, 187, 123, 114, 147,
];
let v3 = vec![
45, 195, 49, 18, 66, 182, 41, 36, 128, 161, 203, 112, 183, 197, 29, 115, 119, 134, 158,
153, 149, 44, 93, 244, 215, 83, 61, 93, 85, 61, 187, 92, 146, 138, 147, 67, 160, 131, 128,
222, 222, 238, 234, 170, 87, 175, 94, 213, 82, 84, 242, 92, 164, 127, 248, 255, 106, 119,
246, 219, 217, 111, 231, 114, 94, 163, 255, 86, 187, 243, 111, 222, 156, 253, 229, 103, 45,
207, 93, 85, 123, 29, 236, 249, 155, 119, 103, 249, 11, 242, 220, 56, 89, 185, 90, 158,
123, 171, 170, 90, 195, 71, 187, 179, 23, 109, 219, 254, 249, 205, 175, 240, 21, 37, 225,
144, 170, 22, 240, 57, 124, 229, 82, 73, 83, 57, 239, 207, 47, 157, 171, 172, 128, 181,
245, 217, 139, 151, 55, 227, 52, 52, 105, 187, 174, 43, 39, 67, 128, 83, 211, 250, 106,
194, 35, 252, 121, 93, 213, 78, 57, 51, 31, 33, 131, 175, 148, 134, 35, 180, 174, 132, 208,
243, 214, 215, 55, 99, 28, 222, 207, 71, 136, 74, 41, 235, 240, 136, 230, 102, 27, 211,
170, 2, 19, 106, 235, 113, 245, 245, 125, 90, 147, 186, 178, 166, 86, 184, 22, 135, 102,
234, 135, 121, 111, 93, 105, 175, 52, 174, 95, 245, 243, 94, 3, 107, 78, 226, 218, 58, 110,
199, 121, 21, 204, 178, 129, 238, 122, 251, 226, 245, 235, 171, 241, 237, 159, 211, 39, 70,
84, 78, 208, 209, 205, 48, 219, 32, 125, 229, 116, 160, 51, 110, 99, 151, 174, 108, 223,
227, 187, 131, 18, 116, 208, 188, 59, 84, 222, 215, 154, 44, 30, 63, 142, 83, 220, 205,
151, 106, 95, 213, 134, 142, 254, 176, 105, 87, 104, 137, 147, 224, 159, 188, 186, 73, 75,
112, 164, 119, 202, 146, 47, 155, 14, 247, 41, 79, 38, 236, 122, 114, 131, 174, 117, 121,
26, 218, 42, 44, 47, 142, 83, 127, 153, 86, 47, 181, 23, 149, 181, 245, 249, 165, 116, 149,
55, 249, 224, 249, 185, 149, 52, 198, 208, 185, 155, 102, 156, 178, 89, 16, 12, 169, 29,
217, 48, 181, 187, 120, 9, 113, 106, 227, 252, 16, 11, 62, 242, 130, 108, 94, 55, 211, 12,
1, 165, 33, 82, 53, 125, 165, 233, 214, 24, 62, 23, 84, 77, 230, 196, 223, 246, 177, 91,
197, 177, 154, 143, 1, 20, 0, 36, 240, 179, 55, 105, 201, 85, 218, 90, 65, 48, 26, 154,
182, 107, 187, 219, 217, 30, 91, 137, 224, 233, 156, 20, 174, 121, 85, 87, 210, 209, 238,
102, 215, 239, 59, 132, 129, 144, 146, 28, 61, 229, 128, 171, 10, 192, 78, 91, 39, 242,
159, 9, 53, 173, 69, 120, 250, 14, 194, 154, 239, 131, 39, 10, 45, 232, 190, 109, 219, 197,
102, 200, 142, 116, 170, 82, 128, 199, 217, 145, 58, 167, 65, 159, 81, 41, 82, 224, 221, 1,
42, 201, 74, 45, 28, 125, 240, 110, 232, 119, 232, 46, 107, 180, 58, 237, 97, 3, 72, 182,
140, 138, 177, 217, 221, 231, 108, 128, 163, 148, 177, 245, 194, 243, 179, 39, 21, 164,
164, 175, 41, 138, 111, 54, 148, 57, 198, 27, 114, 239, 253, 208, 67, 70, 205, 55, 107, 64,
168, 181, 116, 115, 59, 162, 241, 66, 26, 58, 120, 218, 52, 115, 6, 43, 153, 206, 37, 203,
247, 227, 190, 217, 110, 63, 206, 71, 152, 74, 120, 69, 174, 187, 153, 61, 86, 5, 17, 56,
249, 250, 105, 131, 104, 244, 194, 74, 62, 150, 210, 73, 90, 175, 158, 9, 91, 90, 2, 186,
80, 202, 208, 253, 253, 59, 132, 114, 128, 148, 81, 75, 40, 55, 25, 202, 174, 86, 154, 54,
191, 207, 75, 94, 72, 70, 95, 211, 110, 137, 85, 52, 240, 158, 208, 100, 236, 180, 4, 26,
88, 172, 30, 98, 91, 129, 203, 252, 17, 182, 33, 37, 141, 97, 140, 244, 251, 1, 31, 23, 4,
3, 117, 213, 239, 238, 247, 19, 160, 170, 239, 154, 57, 61, 29, 184, 169, 228, 195, 16, 71,
248, 218, 10, 99, 174, 128, 74, 216, 87, 68, 61, 240, 70, 41, 101, 217, 63, 13, 237, 106,
138, 235, 57, 236, 38, 164, 8, 135, 7, 97, 183, 198, 81, 216, 119, 240, 178, 121, 213, 0,
15, 51, 101, 174, 250, 14, 185, 209, 25, 206, 241, 161, 37, 170, 128, 199, 250, 176, 244,
239, 205, 62, 189, 96, 196, 200, 187, 84, 69, 150, 33, 1, 240, 88, 29, 10, 120, 16, 82, 16,
38, 193, 233, 13, 177, 31, 243, 58, 128, 21, 8, 93, 151, 119, 254, 21, 205, 134, 74, 67,
187, 223, 190, 144, 153, 142, 165, 133, 138, 33, 121, 243, 186, 185, 207, 222, 196, 36,
177, 146, 243, 247, 164, 41, 125, 230, 83, 41, 109, 160, 183, 247, 88, 1, 128, 249, 130,
165, 131, 57, 41, 225, 233, 154, 147, 178, 237, 222, 245, 195, 142, 239, 3, 190, 146, 198,
233, 242, 160, 126, 248, 136, 92, 9, 97, 80, 7, 119, 74, 89, 1, 125, 50, 53, 65, 244, 63,
126, 138, 232, 191, 224, 56, 60, 64, 100, 204, 42, 2, 158, 185, 100, 21, 34, 36, 126, 172,
176, 252, 176, 5, 137, 132, 192, 169, 252, 117, 215, 78, 25, 3, 182, 242, 53, 87, 139, 4,
226, 11, 92, 14, 158, 151, 223, 190, 80, 232, 99, 96, 221, 90, 149, 58, 242, 30, 253, 83,
43, 74, 131, 184, 237, 239, 119, 177, 67, 34, 118, 11, 39, 228, 44, 181, 214, 50, 119, 96,
28, 146, 176, 224, 220, 104, 114, 93, 116, 86, 177, 111, 187, 248, 1, 195, 30, 36, 187, 99,
81, 74, 132, 226, 116, 3, 204, 239, 87, 211, 126, 136, 23, 89, 85, 232, 3, 85, 97, 224, 81,
53, 25, 138, 53, 13, 42, 32, 23, 247, 190, 155, 226, 239, 19, 198, 9, 184, 131, 214, 111,
135, 108, 148, 178,
];
let mut compressor = Deflator::new(Compression::best());
let mut f = |v: Vec<_>| {
let mut compressed = Vec::with_capacity(v.len());
let r = compressor.compress(&v, &mut compressed);
println!("{:?}", r);
let len = compressed.len();
compressed.truncate(len - 4);
println!("Output capacity: {}", compressed.capacity());
println!("Compressed to: {:?}", compressed.len());
};
f(v1);
f(v2);
f(v3);
}

@ -0,0 +1,33 @@
//! WebSocket extensions
use http::{Request, Response};
use crate::protocol::frame::Frame;
use crate::Message;
pub mod deflate;
pub mod uncompressed;
pub trait WebSocketExtension: Default + Clone {
type Error: Into<crate::Error>;
fn enabled(&self) -> bool {
false
}
fn rsv1(&self) -> bool {
false
}
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {
request
}
fn on_response<T>(&mut self, _response: &Response<T>) {}
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> {
Ok(frame)
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error>;
}

@ -0,0 +1,96 @@
use crate::ext::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::{Error, Message};
#[derive(Debug)]
pub struct UncompressedExt {
incomplete: Option<IncompleteMessage>,
max_message_size: Option<usize>,
}
impl UncompressedExt {
pub fn new(max_message_size: Option<usize>) -> UncompressedExt {
UncompressedExt {
incomplete: None,
max_message_size,
}
}
}
impl Clone for UncompressedExt {
fn clone(&self) -> Self {
Self::default()
}
}
impl Default for UncompressedExt {
fn default() -> Self {
UncompressedExt {
incomplete: None,
max_message_size: Some(MAX_MESSAGE_SIZE),
}
}
}
impl WebSocketExtension for UncompressedExt {
type Error = Error;
fn enabled(&self) -> bool {
true
}
fn rsv1(&self) -> bool {
false
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
let fin = frame.header().is_final;
match frame.header().opcode {
OpCode::Data(data) => match data {
Data::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.max_message_size)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
Data::Text | Data::Binary => {
let msg = {
let message_type = match data {
Data::Text => IncompleteMessageType::Text,
Data::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
Data::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
},
_ => unreachable!(),
}
}
}

@ -3,7 +3,7 @@
use std::fmt::{Debug, Display, Formatter};
use crate::extensions::deflate::{DeflateConfig, DeflateExtension};
use crate::extensions::WebSocketExtension;
use crate::extensions::WebSocketExtensionOld;
use crate::protocol::frame::Frame;
use http::header::SEC_WEBSOCKET_EXTENSIONS;
use http::{HeaderValue, Request, Response};
@ -69,7 +69,7 @@ impl CompressionStrategy {
}
}
impl WebSocketExtension for CompressionStrategy {
impl WebSocketExtensionOld for CompressionStrategy {
type Error = CompressionExtensionError;
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {
@ -140,7 +140,7 @@ impl From<CompressionSelectorError> for crate::Error {
}
}
impl WebSocketExtension for CompressionConfig {
impl WebSocketExtensionOld for CompressionConfig {
type Error = CompressionSelectorError;
fn on_request<T>(&mut self, mut request: Request<T>) -> Request<T> {

@ -2,7 +2,7 @@
use std::fmt::{Display, Formatter};
use crate::extensions::WebSocketExtension;
use crate::extensions::WebSocketExtensionOld;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use flate2::{
@ -95,23 +95,19 @@ impl From<DeflateExtensionError> for crate::Error {
}
}
impl WebSocketExtension for DeflateExtension {
impl WebSocketExtensionOld for DeflateExtension {
type Error = DeflateExtensionError;
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> {
if let OpCode::Data(_) = frame.header().opcode {
frame.header_mut().rsv1 = true;
// println!("Compressing: {:?}", frame.payload());
let mut compressed = Vec::with_capacity(frame.payload().len() * 2);
let mut compressed = Vec::with_capacity(frame.payload().len());
self.deflator.compress(frame.payload(), &mut compressed)?;
let len = compressed.len();
compressed.truncate(len - 4);
println!("Compressed to: {:?}", compressed.len());
*frame.payload_mut() = compressed;
if self.config.compress_reset {
@ -216,21 +212,6 @@ impl Deflator {
self.compress.reset()
}
// pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<usize, CompressError> {
// loop {
// let before_in = self.compress.total_in();
// output.reserve(256);
// let status = self
// .compress
// .compress_vec(input, output, flate2::FlushCompress::Sync)?;
// let written = (self.compress.total_in() - before_in) as usize;
//
// if written != 0 || status == flate2::Status::StreamEnd {
// return Ok(written);
// }
// }
// }
pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<usize, CompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
@ -518,18 +499,3 @@ fn t() {
f(v2);
f(v3);
}
// #[test]
// fn t() {
// let mut decompressor = Inflator::new();
//
//
//
// let mut buffer = Vec::with_capacity(v2.len() * 2);
//
// let r = decompressor.decompress(&v2, &mut buffer);
//
// println!("String: {:?}", String::from_utf8(buffer.to_vec()));
//
// println!("{:?}", r);
// }

@ -7,7 +7,7 @@ use crate::protocol::frame::Frame;
pub mod compression;
pub mod deflate;
pub trait WebSocketExtension {
pub trait WebSocketExtensionOld {
type Error: Into<crate::Error>;
fn on_request<T>(&mut self, request: Request<T>) -> Request<T> {

@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::ext::WebSocketExtension;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type.
@ -22,18 +22,24 @@ pub type Response = HttpResponse<()>;
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S> {
pub struct ClientHandshake<S, E>
where
E: WebSocketExtension,
{
verify_data: VerifyData,
config: Option<WebSocketConfig>,
config: Option<WebSocketConfig<E>>,
_marker: PhantomData<S>,
}
impl<S: Read + Write> ClientHandshake<S> {
impl<S: Read + Write, E> ClientHandshake<S, E>
where
E: WebSocketExtension,
{
/// Initiate a client handshake.
pub fn start(
stream: S,
request: Request,
mut config: Option<WebSocketConfig>,
mut config: Option<WebSocketConfig<E>>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
@ -74,10 +80,14 @@ impl<S: Read + Write> ClientHandshake<S> {
}
}
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
impl<S: Read + Write, E> HandshakeRole for ClientHandshake<S, E>
where
E: WebSocketExtension,
{
type IncomingData = Response;
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
type FinalResult = (WebSocket<S, E>, Response);
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -95,7 +105,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
.verify_response(&result, &mut self.config)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
WebSocket::from_partially_read(stream, tail, Role::Client, self.config.clone());
ProcessingResult::Done((websocket, result))
}
})
@ -103,13 +113,16 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Generate client request.
fn generate_request(
fn generate_request<E>(
request: Request,
key: &str,
config: &mut Option<WebSocketConfig>,
) -> Result<Vec<u8>> {
let request = match &config {
Some(mut config) => config.compression_config.on_request(request),
config: &mut Option<WebSocketConfig<E>>,
) -> Result<Vec<u8>>
where
E: WebSocketExtension,
{
let request = match config {
Some(ref mut config) => config.encoder.on_request(request),
None => request,
};
let mut req = Vec::new();
@ -168,11 +181,14 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(
pub fn verify_response<E>(
&self,
response: &Response,
config: &mut Option<WebSocketConfig>,
) -> Result<()> {
config: &mut Option<WebSocketConfig<E>>,
) -> Result<()>
where
E: WebSocketExtension,
{
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -229,7 +245,7 @@ impl VerifyData {
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if let Some(config) = config {
config.compression_config.on_response(response);
config.encoder.on_response(response);
}
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
@ -288,6 +304,7 @@ mod tests {
use super::super::machine::TryParse;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use crate::ext::uncompressed::UncompressedExt;
#[test]
fn random_keys() {
@ -317,7 +334,9 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -336,7 +355,9 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -355,7 +376,9 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -12,6 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::ext::WebSocketExtension;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Server request type.
@ -39,7 +40,10 @@ pub fn create_response(request: &Request) -> Result<Response> {
.headers()
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("Upgrade"))
})
.unwrap_or(false)
{
return Err(Error::Protocol(
@ -188,25 +192,31 @@ impl Callback for NoCallback {
/// Server handshake role.
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C> {
pub struct ServerHandshake<S, C, E>
where
E: WebSocketExtension,
{
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
config: Option<WebSocketConfig<E>>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
impl<S: Read + Write, C: Callback, E> ServerHandshake<S, C, E>
where
E: WebSocketExtension,
{
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig<E>>) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
@ -220,10 +230,13 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
}
}
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
impl<S: Read + Write, C: Callback, E> HandshakeRole for ServerHandshake<S, C, E>
where
E: WebSocketExtension,
{
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S>;
type FinalResult = WebSocket<S, E>;
fn stage_finished(
&mut self,
@ -278,7 +291,8 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket =
WebSocket::from_raw_socket(stream, Role::Server, self.config.clone());
ProcessingResult::Done(websocket)
}
}

@ -23,6 +23,8 @@ pub mod server;
pub mod stream;
pub mod util;
pub mod ext;
pub use crate::client::{client, connect};
pub use crate::error::{Error, Result};
pub use crate::handshake::client::ClientHandshake;

@ -2,7 +2,7 @@
pub mod frame;
mod message;
pub(crate) mod message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
@ -14,12 +14,14 @@ use std::mem::replace;
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::message::IncompleteMessage;
use crate::error::{Error, Result};
use crate::extensions::compression::{CompressionConfig, CompressionStrategy};
use crate::extensions::WebSocketExtension;
use crate::ext::uncompressed::UncompressedExt;
use crate::ext::WebSocketExtension;
use crate::util::NonBlockingResult;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
@ -31,7 +33,10 @@ pub enum Role {
/// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig {
pub struct WebSocketConfig<E = UncompressedExt>
where
E: WebSocketExtension,
{
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
@ -45,17 +50,20 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// Permessage compression strategy.
pub compression_config: CompressionConfig,
/// Per-message compression strategy.
pub encoder: E,
}
impl Default for WebSocketConfig {
impl<E> Default for WebSocketConfig<E>
where
E: WebSocketExtension,
{
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_message_size: Some(MAX_MESSAGE_SIZE),
max_frame_size: Some(16 << 20),
compression_config: CompressionConfig::Uncompressed,
encoder: Default::default(),
}
}
}
@ -65,20 +73,26 @@ impl Default for WebSocketConfig {
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)]
pub struct WebSocket<Stream> {
pub struct WebSocket<Stream, E>
where
E: WebSocketExtension,
{
/// The underlying socket.
socket: Stream,
/// The context for managing a WebSocket.
context: WebSocketContext,
context: WebSocketContext<E>,
}
impl<Stream> WebSocket<Stream> {
impl<Stream, E> WebSocket<Stream, E>
where
E: WebSocketExtension,
{
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig<E>>) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
@ -94,7 +108,7 @@ impl<Stream> WebSocket<Stream> {
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
config: Option<WebSocketConfig<E>>,
) -> Self {
WebSocket {
socket: stream,
@ -113,12 +127,12 @@ impl<Stream> WebSocket<Stream> {
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<E>)) {
self.context.set_config(set_func)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig {
pub fn get_config(&self) -> &WebSocketConfig<E> {
self.context.get_config()
}
@ -138,7 +152,10 @@ impl<Stream> WebSocket<Stream> {
}
}
impl<Stream: Read + Write> WebSocket<Stream> {
impl<Stream: Read + Write, E> WebSocket<Stream, E>
where
E: WebSocketExtension,
{
/// Read a message from stream, if possible.
///
/// This will queue responses to ping and close messages to be sent. It will call
@ -221,7 +238,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// A context for managing WebSocket stream.
#[derive(Debug)]
pub struct WebSocketContext {
pub struct WebSocketContext<E = UncompressedExt>
where
E: WebSocketExtension,
{
/// Server or client?
role: Role,
/// encoder/decoder of frame.
@ -235,16 +255,16 @@ pub struct WebSocketContext {
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
/// WebSocket compression strategy.
compressor: CompressionStrategy,
config: WebSocketConfig<E>,
}
impl WebSocketContext {
impl<E> WebSocketContext<E>
where
E: WebSocketExtension,
{
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
let config = config.unwrap_or_else(WebSocketConfig::default);
let compressor = config.compression_config.into_strategy();
pub fn new(role: Role, config: Option<WebSocketConfig<E>>) -> Self {
let config = config.unwrap_or_else(|| Default::default());
WebSocketContext {
role,
@ -254,12 +274,15 @@ impl WebSocketContext {
send_queue: VecDeque::new(),
pong: None,
config,
compressor,
}
}
/// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
pub fn from_partially_read(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<E>>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
@ -267,12 +290,12 @@ impl WebSocketContext {
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<E>)) {
set_func(&mut self.config)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig {
pub fn get_config(&self) -> &WebSocketConfig<E> {
&self.config
}
@ -442,7 +465,7 @@ impl WebSocketContext {
{
let hdr = frame.header();
if !self.compressor.is_enabled() && hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
if !self.get_config().encoder.rsv1() && hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(
"Reserved bits are non-zero and no WebSocket extensions are enabled".into(),
));
@ -500,64 +523,10 @@ impl WebSocketContext {
}
}
OpCode::Data(data) => {
let fin = frame.header().is_final;
let frame = match self.compressor.on_receive_frame(frame)? {
Some(frame) => frame,
None => return Ok(None),
};
match data {
OpData::Continue => {
if self.compressor.is_enabled() {
let message_type = match frame.header().opcode {
OpCode::Data(OpData::Text) => IncompleteMessageType::Text,
OpCode::Data(OpData::Binary) => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
self.incomplete = Some(IncompleteMessage::new(message_type));
}
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
OpData::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
}
}
_ => match self.config.encoder.on_receive_frame(frame) {
Ok(r) => Ok(r),
Err(e) => Err(e.into()),
},
} // match opcode
} else {
// Connection closed by peer
@ -627,7 +596,7 @@ impl WebSocketContext {
}
}
let frame = self.compressor.on_send_frame(frame)?;
// let frame = self.config.encoder.on_send_frame(frame)?;
trace!("Sending frame: {:?}", frame);
self.frame
@ -703,6 +672,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::ext::uncompressed::UncompressedExt;
use std::io;
use std::io::Cursor;
@ -730,7 +700,8 @@ mod tests {
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(
@ -753,7 +724,8 @@ mod tests {
max_message_size: Some(10),
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 7 + 6 > 10"
@ -767,7 +739,8 @@ mod tests {
max_message_size: Some(2),
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2"

@ -7,6 +7,8 @@ use crate::handshake::HandshakeError;
use crate::protocol::{WebSocket, WebSocketConfig};
use crate::ext::uncompressed::UncompressedExt;
use crate::ext::WebSocketExtension;
use std::io::{Read, Write};
/// Accept the given Stream as a WebSocket.
@ -18,10 +20,13 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write>(
pub fn accept_with_config<S: Read + Write, E>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
config: Option<WebSocketConfig<E>>,
) -> Result<WebSocket<S, E>, HandshakeError<ServerHandshake<S, NoCallback, E>>>
where
E: WebSocketExtension,
{
accept_hdr_with_config(stream, NoCallback, config)
}
@ -33,7 +38,10 @@ pub fn accept_with_config<S: Read + Write>(
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
) -> Result<
WebSocket<S, UncompressedExt>,
HandshakeError<ServerHandshake<S, NoCallback, UncompressedExt>>,
> {
accept_with_config(stream, None)
}
@ -45,11 +53,14 @@ pub fn accept<S: Read + Write>(
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
pub fn accept_hdr_with_config<S: Read + Write, C: Callback, E>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
config: Option<WebSocketConfig<E>>,
) -> Result<WebSocket<S, E>, HandshakeError<ServerHandshake<S, C, E>>>
where
E: WebSocketExtension,
{
ServerHandshake::start(stream, callback, config).handshake()
}
@ -61,6 +72,6 @@ pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
) -> Result<WebSocket<S, UncompressedExt>, HandshakeError<ServerHandshake<S, C, UncompressedExt>>> {
accept_hdr_with_config(stream, callback, None)
}

@ -1,22 +1,23 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpStream, TcpListener};
use std::net::{TcpListener, TcpStream};
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream};
use native_tls::TlsStream;
use url::Url;
use net2::TcpStreamExt;
use tungstenite::ext::uncompressed::UncompressedExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
type Sock<E> = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>, E>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where
CT: FnOnce(Sock) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream>),
CT: FnOnce(Sock<UncompressedExt>) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream, UncompressedExt>),
{
env_logger::try_init().ok();
@ -26,8 +27,8 @@ where
exit(1);
});
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is port already in use?");
let server =
TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?");
let client_thread = spawn(move || {
let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap())
@ -46,7 +47,8 @@ where
#[test]
fn test_server_close() {
do_test(3012,
do_test(
3012,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -75,12 +77,14 @@ fn test_server_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}
#[test]
fn test_evil_server_close() {
do_test(3013,
do_test(
3013,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -106,14 +110,19 @@ fn test_evil_server_close() {
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
srv_sock
.get_mut()
.set_linger(Some(Duration::from_secs(0)))
.unwrap();
drop(srv_sock);
});
},
);
}
#[test]
fn test_client_close() {
do_test(3014,
do_test(
3014,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -137,7 +146,9 @@ fn test_client_close() {
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.write_message(Message::Text("From Server".into())).unwrap();
srv_sock
.write_message(Message::Text("From Server".into()))
.unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close());
@ -147,6 +158,6 @@ fn test_client_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}

Loading…
Cancel
Save