Merge pull request #1 from SirCipher/master

Permessage-deflate
pull/144/head
Thomas Klapwijk 5 years ago committed by GitHub
commit cd43267e17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      .gitignore
  2. 15
      Cargo.toml
  3. 2
      README.md
  4. 1218
      autobahn/client-results.json
  5. 3238
      autobahn/server-results.json
  6. 17
      examples/autobahn-client.rs
  7. 31
      examples/autobahn-server.rs
  8. 6
      fuzz/fuzz_targets/read_message_client.rs
  9. 9
      fuzz/fuzz_targets/read_message_server.rs
  10. 4
      scripts/autobahn-client.sh
  11. 6
      scripts/autobahn-server.sh
  12. 28
      src/client.rs
  13. 3
      src/error.rs
  14. 922
      src/extensions/deflate.rs
  15. 56
      src/extensions/mod.rs
  16. 104
      src/extensions/uncompressed.rs
  17. 103
      src/handshake/client.rs
  18. 56
      src/handshake/server.rs
  19. 2
      src/lib.rs
  20. 181
      src/protocol/mod.rs
  21. 32
      src/server.rs
  22. 45
      tests/connection_reset.rs

3
.gitignore vendored

@ -1,2 +1,5 @@
target
Cargo.lock
autobahn/client/*
autobahn/server/*

@ -16,6 +16,7 @@ edition = "2018"
default = ["tls"]
tls = ["native-tls"]
tls-vendored = ["native-tls", "native-tls/vendored"]
deflate = ["flate2"]
[dependencies]
base64 = "0.12.0"
@ -30,6 +31,12 @@ sha-1 = "0.9"
url = "2.1.0"
utf-8 = "0.7.5"
[dependencies.flate2]
optional = true
version = "1.0"
default-features = false
features = ["zlib"]
[dependencies.native-tls]
optional = true
version = "0.2.3"
@ -37,3 +44,11 @@ version = "0.2.3"
[dev-dependencies]
env_logger = "0.7.1"
net2 = "0.2.33"
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]

@ -58,7 +58,7 @@ Features
Tungstenite provides a complete implementation of the WebSocket specification.
TLS is supported on all platforms using native-tls.
There is no support for permessage-deflate at the moment. It's planned.
Permessage-deflate.
Testing
-------

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,6 +1,9 @@
use log::*;
use url::Url;
use tungstenite::client::connect_with_config;
use tungstenite::extensions::deflate::{DeflateConfigBuilder, DeflateExt};
use tungstenite::protocol::WebSocketConfig;
use tungstenite::{connect, Error, Message, Result};
const AGENT: &str = "Tungstenite";
@ -31,7 +34,19 @@ fn run_test(case: u32) -> Result<()> {
case, AGENT
))
.unwrap();
let (mut socket, _) = connect(case_url)?;
let deflate_config = DeflateConfigBuilder::default()
.max_message_size(None)
.build();
let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: DeflateExt::new(deflate_config),
}),
)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {

@ -2,8 +2,11 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use log::*;
use tungstenite::extensions::deflate::{DeflateExt, DeflateConfigBuilder};
use tungstenite::handshake::HandshakeRole;
use tungstenite::{accept, Error, HandshakeError, Message, Result};
use tungstenite::protocol::WebSocketConfig;
use tungstenite::server::accept_with_config;
use tungstenite::{Error, HandshakeError, Message, Result};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -13,7 +16,19 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let deflate_config = DeflateConfigBuilder::default()
.max_message_size(None)
.build();
let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: DeflateExt::new(deflate_config),
}),
)
.map_err(must_not_block)?;
info!("Running test");
loop {
match socket.read_message()? {
@ -32,12 +47,14 @@ fn main() {
for stream in server.incoming() {
spawn(move || match stream {
Ok(stream) => if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
Ok(stream) => {
if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
}
}
},
}
Err(e) => error!("Error accepting stream: {}", e),
});
}

@ -6,6 +6,7 @@ use std::io;
use std::io::Cursor;
use tungstenite::WebSocket;
use tungstenite::protocol::Role;
use tungstenite::extensions::uncompressed::UncompressedExt;
//use std::result::Result;
// FIXME: copypasted from tungstenite's protocol/mod.rs
@ -32,6 +33,7 @@ impl<Stream: io::Read> io::Read for WriteMoc<Stream> {
fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
socket.read_message().ok();
});
});

@ -1,11 +1,13 @@
#![no_main]
#[macro_use] extern crate libfuzzer_sys;
#[macro_use]
extern crate libfuzzer_sys;
extern crate tungstenite;
use std::io;
use std::io::Cursor;
use tungstenite::WebSocket;
use tungstenite::protocol::Role;
use tungstenite::WebSocket;
use tungstenite::extensions::uncompressed::UncompressedExt;
//use std::result::Result;
// FIXME: copypasted from tungstenite's protocol/mod.rs
@ -32,6 +34,7 @@ impl<Stream: io::Read> io::Read for WriteMoc<Stream> {
fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
socket.read_message().ok();
});

@ -23,10 +23,10 @@ function test_diff() {
fi
}
cargo build --release --example autobahn-client
cargo build --release --example autobahn-client --features deflate
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' & FUZZINGSERVER_PID=$!
sleep 3
echo "Server PID: ${FUZZINGSERVER_PID}"
cargo run --release --example autobahn-client
cargo run --release --example autobahn-client --features deflate
test_diff

@ -14,7 +14,7 @@ trap cleanup TERM EXIT
function test_diff() {
if ! diff -q \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/client-results.json') \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server-results.json') \
<(jq -S 'del(."Tungstenite" | .. | .duration?)' 'autobahn/server/index.json')
then
echo Difference in results, either this is a regression or \
@ -23,8 +23,8 @@ function test_diff() {
fi
}
cargo build --release --example autobahn-server
cargo run --release --example autobahn-server & WSSERVER_PID=$!
cargo build --release --example autobahn-server --features deflate
cargo run --release --example autobahn-server --features deflate & WSSERVER_PID=$!
echo "Server PID: ${WSSERVER_PID}"
sleep 3
wstest -m fuzzingclient -s 'autobahn/fuzzingclient.json'

@ -66,6 +66,8 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
@ -86,10 +88,14 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req: IntoClientRequest>(
pub fn connect_with_config<Req, Ext>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
config: Option<WebSocketConfig<Ext>>,
) -> Result<(WebSocket<AutoStream, Ext>, Response)>
where
Req: IntoClientRequest,
Ext: WebSocketExtension,
{
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
@ -122,7 +128,9 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream, UncompressedExt>, Response)> {
connect_with_config(request, None)
}
@ -159,14 +167,15 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<Stream, Req>(
pub fn client_with_config<Stream, Req, Ext>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
config: Option<WebSocketConfig<Ext>>,
) -> StdResult<(WebSocket<Stream, Ext>, Response), HandshakeError<ClientHandshake<Stream, Ext>>>
where
Stream: Read + Write,
Req: IntoClientRequest,
Ext: WebSocketExtension,
{
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}
@ -179,7 +188,10 @@ where
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<
(WebSocket<Stream, UncompressedExt>, Response),
HandshakeError<ClientHandshake<Stream, UncompressedExt>>,
>
where
Stream: Read + Write,
Req: IntoClientRequest,

@ -67,6 +67,8 @@ pub enum Error {
Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
/// An error from a WebSocket extension.
ExtensionError(Cow<'static, str>),
}
impl fmt::Display for Error {
@ -84,6 +86,7 @@ impl fmt::Display for Error {
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
Error::ExtensionError(ref e) => write!(f, "Extension error: {}", e),
}
}
}

@ -0,0 +1,922 @@
//! Permessage-deflate extension
use std::fmt::{Display, Formatter};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::Message;
use bytes::BufMut;
use flate2::{
Compress, CompressError, Compression, Decompress, DecompressError, FlushCompress,
FlushDecompress, Status,
};
use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS};
use http::{HeaderValue, Request, Response};
use std::borrow::Cow;
use std::mem::replace;
use std::slice;
/// The WebSocket Extension Identifier as per the IANA registry.
const EXT_IDENT: &str = "permessage-deflate";
/// The minimum size of the LZ77 sliding window size.
const LZ77_MIN_WINDOW_SIZE: u8 = 8;
/// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter
/// indicates that the client can receive messages compressed using an LZ77 sliding window of up to
/// 32,768 bytes. RFC 7692 7.1.2.1.
const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)]
pub struct DeflateConfig {
/// The maximum size of a message. The default value is 64 MiB which should be reasonably big
/// for all normal use-cases but small enough to prevent memory eating by a malicious user.
max_message_size: usize,
/// The LZ77 sliding window size. Negotiated during the HTTP upgrade. In client mode, this
/// conforms to RFC 7692 7.1.2.1. In server mode, this conforms to RFC 7692 7.1.2.2. Must be in
/// range 8..15 inclusive.
max_window_bits: u8,
/// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1.
request_no_context_takeover: bool,
/// Whether to accept `no_context_takeover`.
accept_no_context_takeover: bool,
// Whether the compressor should be reset after usage.
compress_reset: bool,
// Whether the decompressor should be reset after usage.
decompress_reset: bool,
/// The active compression level. The integer here is typically on a scale of 0-9 where 0 means
/// "no compression" and 9 means "take as long as you'd like".
compression_level: Compression,
}
impl DeflateConfig {
/// Builds a new `DeflateConfig` using the `compression_level` and the defaults for all other
/// members.
pub fn with_compression_level(compression_level: Compression) -> DeflateConfig {
DeflateConfig {
compression_level,
..Default::default()
}
}
/// Returns the maximum message size permitted.
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
/// Returns the maximum LZ77 window size permitted.
pub fn max_window_bits(&self) -> u8 {
self.max_window_bits
}
/// Returns whether `no_context_takeover` has been requested.
pub fn request_no_context_takeover(&self) -> bool {
self.request_no_context_takeover
}
/// Returns whether this WebSocket will accept `no_context_takeover`.
pub fn accept_no_context_takeover(&self) -> bool {
self.accept_no_context_takeover
}
/// Returns whether or not the inner compressor is set to reset after completing a message.
pub fn compress_reset(&self) -> bool {
self.compress_reset
}
/// Returns whether or not the inner decompressor is set to reset after completing a message.
pub fn decompress_reset(&self) -> bool {
self.decompress_reset
}
/// Returns the active compression level.
pub fn compression_level(&self) -> Compression {
self.compression_level
}
/// Sets the maximum message size permitted.
pub fn set_max_message_size(&mut self, max_message_size: Option<usize>) {
self.max_message_size = max_message_size.unwrap_or_else(usize::max_value);
}
/// Sets the LZ77 sliding window size.
pub fn set_max_window_bits(&mut self, max_window_bits: u8) {
assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits));
self.max_window_bits = max_window_bits;
}
/// Sets the WebSocket to request `no_context_takeover` if `true`.
pub fn set_request_no_context_takeover(&mut self, request_no_context_takeover: bool) {
self.request_no_context_takeover = request_no_context_takeover;
}
/// Sets the WebSocket to accept `no_context_takeover` if `true`.
pub fn set_accept_no_context_takeover(&mut self, accept_no_context_takeover: bool) {
self.accept_no_context_takeover = accept_no_context_takeover;
}
}
impl Default for DeflateConfig {
fn default() -> Self {
DeflateConfig {
max_message_size: MAX_MESSAGE_SIZE,
max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
decompress_reset: false,
compression_level: Compression::best(),
}
}
}
/// A `DeflateConfig` builder.
#[derive(Debug, Copy, Clone)]
pub struct DeflateConfigBuilder {
max_message_size: Option<usize>,
max_window_bits: u8,
request_no_context_takeover: bool,
accept_no_context_takeover: bool,
fragments_grow: bool,
compression_level: Compression,
}
impl Default for DeflateConfigBuilder {
fn default() -> Self {
DeflateConfigBuilder {
max_message_size: Some(MAX_MESSAGE_SIZE),
max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
fragments_grow: true,
compression_level: Compression::fast(),
}
}
}
impl DeflateConfigBuilder {
/// Sets the maximum message size permitted.
pub fn max_message_size(mut self, max_message_size: Option<usize>) -> DeflateConfigBuilder {
self.max_message_size = max_message_size;
self
}
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `8..=15`.
pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder {
assert!(
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits),
"max window bits must be in range 8..=15"
);
self.max_window_bits = max_window_bits;
self
}
/// Sets the WebSocket to request `no_context_takeover`.
pub fn request_no_context_takeover(
mut self,
request_no_context_takeover: bool,
) -> DeflateConfigBuilder {
self.request_no_context_takeover = request_no_context_takeover;
self
}
/// Sets the WebSocket to accept `no_context_takeover`.
pub fn accept_no_context_takeover(
mut self,
accept_no_context_takeover: bool,
) -> DeflateConfigBuilder {
self.accept_no_context_takeover = accept_no_context_takeover;
self
}
/// Consumes the builder and produces a `DeflateConfig.`
pub fn build(self) -> DeflateConfig {
DeflateConfig {
max_message_size: self.max_message_size.unwrap_or_else(usize::max_value),
max_window_bits: self.max_window_bits,
request_no_context_takeover: self.request_no_context_takeover,
accept_no_context_takeover: self.accept_no_context_takeover,
compression_level: self.compression_level,
..Default::default()
}
}
}
/// A permessage-deflate encoding WebSocket extension.
#[derive(Debug)]
pub struct DeflateExt {
/// Defines whether the extension is enabled. Following a successful handshake, this will be
/// `true`.
enabled: bool,
/// The configuration for the extension.
config: DeflateConfig,
/// A stack of continuation frames awaiting `fin` and the total size of all of the fragments.
fragment_buffer: FragmentBuffer,
/// The deflate decompressor.
inflator: Inflator,
/// The deflate compressor.
deflator: Deflator,
/// If this deflate extension is not used, messages will be forwarded to this extension.
uncompressed_extension: UncompressedExt,
}
impl DeflateExt {
/// Creates a `DeflateExt` instance using the provided configuration.
pub fn new(config: DeflateConfig) -> DeflateExt {
DeflateExt {
enabled: false,
config,
fragment_buffer: FragmentBuffer::new(config.max_message_size),
inflator: Inflator::new(),
deflator: Deflator::new(Compression::fast()),
uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())),
}
}
fn parse_window_parameter<'a>(
&mut self,
mut param_iter: impl Iterator<Item = &'a str>,
) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() {
match window_bits_str.trim().parse() {
Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
if window_bits != self.config.max_window_bits() {
self.config.max_window_bits = window_bits;
Ok(Some(window_bits))
} else {
Ok(None)
}
} else {
Err(format!("Invalid window parameter: {}", window_bits))
}
}
Err(e) => Err(e.to_string()),
}
} else {
Ok(None)
}
}
fn decline<T>(&mut self, res: &mut Response<T>) {
self.enabled = false;
res.headers_mut().remove(EXT_IDENT);
}
}
/// A permessage-deflate extension error.
#[derive(Debug, Clone)]
pub enum DeflateExtensionError {
/// An error produced when deflating a message.
DeflateError(String),
/// An error produced when inflating a message.
InflateError(String),
/// An error produced during the WebSocket negotiation.
NegotiationError(String),
/// Produced when fragment buffer grew beyond the maximum configured size.
Capacity(Cow<'static, str>),
}
impl Display for DeflateExtensionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DeflateExtensionError::DeflateError(m) => {
write!(f, "An error was produced during decompression: {}", m)
}
DeflateExtensionError::InflateError(m) => {
write!(f, "An error was produced during compression: {}", m)
}
DeflateExtensionError::NegotiationError(m) => {
write!(f, "An upgrade error was encountered: {}", m)
}
DeflateExtensionError::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
}
}
}
impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
impl From<InvalidHeaderValue> for DeflateExtensionError {
fn from(e: InvalidHeaderValue) -> Self {
DeflateExtensionError::NegotiationError(e.to_string())
}
}
impl Default for DeflateExt {
fn default() -> Self {
DeflateExt::new(Default::default())
}
}
impl WebSocketExtension for DeflateExt {
type Error = DeflateExtensionError;
fn new(max_message_size: Option<usize>) -> Self {
DeflateExt::new(DeflateConfig {
max_message_size: max_message_size.unwrap_or_else(usize::max_value),
..Default::default()
})
}
fn enabled(&self) -> bool {
self.enabled
}
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_IDENT);
let DeflateConfig {
max_window_bits,
request_no_context_takeover,
..
} = self.config;
if max_window_bits < LZ77_MAX_WINDOW_SIZE {
header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits
))
} else {
header_value.push_str("; client_max_window_bits")
}
if request_no_context_takeover {
header_value.push_str("; server_no_context_takeover")
}
request.headers_mut().append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&header_value).unwrap(),
);
request
}
fn on_receive_request<T>(
&mut self,
request: &Request<T>,
response: &mut Response<T>,
) -> Result<(), Self::Error> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len());
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => response_str.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if server_takeover {
self.decline(response);
} else {
server_takeover = true;
if self.config.accept_no_context_takeover() {
self.config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
}
"client_no_context_takeover" => {
if client_takeover {
self.decline(response);
} else {
client_takeover = true;
self.config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
self.decline(response);
} else {
server_max_bits = true;
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
response_str.push_str("; ");
response_str.push_str(param)
}
Ok(None) => {}
Err(_) => {
self.decline(response);
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
self.decline(response);
} else {
client_max_bits = true;
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator::new_with_window_bits(bits);
response_str.push_str("; ");
response_str.push_str(param);
continue;
}
Ok(None) => {}
Err(_) => {
self.decline(response);
}
}
response_str.push_str("; ");
response_str.push_str(&format!(
"client_max_window_bits={}",
self.config.max_window_bits()
))
}
}
_ => {
self.decline(response);
}
}
}
if !response_str.contains("client_no_context_takeover")
&& self.config.request_no_context_takeover()
{
self.config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
self.config.max_window_bits()
))
}
if !response_str.contains("client_max_window_bits")
&& self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
continue;
}
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
self.enabled = true;
Ok(())
}
Err(e) => {
self.enabled = false;
Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}",
e,
)))
}
};
}
self.decline(response);
Ok(())
}
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> {
let mut extension_name = false;
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_window_bits = false;
let mut client_max_window_bits = false;
for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() {
match header.to_str() {
Ok(header) => {
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
if extension_name {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate"
)));
} else {
self.enabled = true;
extension_name = true;
}
}
"server_no_context_takeover" => {
if server_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_no_context_takeover"
)));
} else {
server_takeover = true;
self.config.decompress_reset = true;
}
}
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_no_context_takeover"
)));
} else {
client_takeover = true;
if self.config.accept_no_context_takeover() {
self.config.compress_reset = true;
} else {
return Err(DeflateExtensionError::NegotiationError(
format!("The client requires context takeover."),
));
}
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_max_window_bits"
)));
} else {
server_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator::new_with_window_bits(bits);
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"server_max_window_bits parameter error: {}",
e
),
))
}
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_max_window_bits"
)));
} else {
client_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"client_max_window_bits parameter error: {}",
e
),
))
}
}
}
}
p => {
return Err(DeflateExtensionError::NegotiationError(format!(
"Unknown permessage-deflate parameter: {}",
p
)));
}
}
}
}
Err(e) => {
self.enabled = false;
return Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse extension parameter: {}",
e
)));
}
}
}
Ok(())
}
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> {
if self.enabled {
if let OpCode::Data(_) = frame.header().opcode {
let mut compressed = Vec::with_capacity(frame.payload().len());
self.deflator.compress(frame.payload(), &mut compressed)?;
let len = compressed.len();
compressed.truncate(len - 4);
*frame.payload_mut() = compressed;
frame.header_mut().rsv1 = true;
if self.config.compress_reset() {
self.deflator.reset();
}
}
}
Ok(frame)
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragment_buffer
.try_push_frame(frame)
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?;
Ok(None)
} else {
let mut compressed = if self.fragment_buffer.is_empty() {
Vec::with_capacity(frame.payload().len())
} else {
Vec::with_capacity(self.fragment_buffer.len() + frame.payload().len())
};
let mut decompressed = Vec::with_capacity(frame.payload().len() * 2);
let opcode = match frame.header().opcode {
OpCode::Data(Data::Continue) => {
self.fragment_buffer
.try_push_frame(frame)
.map_err(|s| DeflateExtensionError::Capacity(s.into()))?;
let opcode = self.fragment_buffer.first().unwrap().header().opcode;
self.fragment_buffer.reset().into_iter().for_each(|f| {
compressed.extend(f.into_data());
});
opcode
}
_ => {
compressed.put_slice(frame.payload());
frame.header().opcode
}
};
compressed.extend(&[0, 0, 255, 255]);
self.inflator.decompress(&compressed, &mut decompressed)?;
if self.config.decompress_reset() {
self.inflator.reset(false);
}
self.uncompressed_extension.on_receive_frame(Frame::message(
decompressed,
opcode,
true,
))
}
} else {
self.uncompressed_extension.on_receive_frame(frame)
};
match r {
Ok(msg) => Ok(msg),
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())),
}
}
}
impl From<DecompressError> for DeflateExtensionError {
fn from(e: DecompressError) -> Self {
DeflateExtensionError::InflateError(e.to_string())
}
}
impl From<CompressError> for DeflateExtensionError {
fn from(e: CompressError) -> Self {
DeflateExtensionError::DeflateError(e.to_string())
}
}
#[derive(Debug)]
struct Deflator {
compress: Compress,
}
impl Deflator {
fn new(compresion: Compression) -> Deflator {
Deflator {
compress: Compress::new(compresion, false),
}
}
fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
}
Deflator {
compress: Compress::new_with_window_bits(compression, false, window_size),
}
}
fn reset(&mut self) {
self.compress.reset()
}
fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), CompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
loop {
output_size = output.len();
if output_size == output.capacity() {
output.reserve(input.len());
}
let before_out = self.compress.total_out();
let before_in = self.compress.total_in();
let out_slice = unsafe {
slice::from_raw_parts_mut(
output.as_mut_ptr().offset(output_size as isize),
output.capacity() - output_size,
)
};
let status = self
.compress
.compress(&read_buff, out_slice, FlushCompress::Sync)?;
let consumed = (self.compress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
unsafe {
output.set_len((self.compress.total_out() - before_out) as usize + output_size);
}
match status {
Status::Ok | Status::BufError => {
if before_out == self.compress.total_out() && read_buff.is_empty() {
return Ok(());
}
}
s => panic!("Compression error: {:?}", s),
}
}
}
}
#[derive(Debug)]
struct Inflator {
decompress: Decompress,
}
impl Inflator {
fn new() -> Inflator {
Inflator {
decompress: Decompress::new(false),
}
}
fn new_with_window_bits(mut window_size: u8) -> Inflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
}
Inflator {
decompress: Decompress::new_with_window_bits(false, window_size),
}
}
fn reset(&mut self, zlib_header: bool) {
self.decompress.reset(zlib_header)
}
fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<(), DecompressError> {
let mut read_buff = Vec::from(input);
let mut output_size;
loop {
output_size = output.len();
if output_size == output.capacity() {
output.reserve(input.len());
}
let before_out = self.decompress.total_out();
let before_in = self.decompress.total_in();
let out_slice = unsafe {
slice::from_raw_parts_mut(
output.as_mut_ptr().offset(output_size as isize),
output.capacity() - output_size,
)
};
let status =
self.decompress
.decompress(&read_buff, out_slice, FlushDecompress::Sync)?;
let consumed = (self.decompress.total_in() - before_in) as usize;
read_buff = read_buff.split_off(consumed);
unsafe {
output.set_len((self.decompress.total_out() - before_out) as usize + output_size);
}
match status {
Status::Ok | Status::BufError => {
if before_out == self.decompress.total_out() && read_buff.is_empty() {
return Ok(());
}
}
s => panic!("Decompression error: {:?}", s),
}
}
}
}
/// A buffer for holding continuation frames. Ensures that the total length of all of the frame's
/// payloads does not exceed `max_len`.
///
/// Defaults to an initial capacity of ten frames.
#[derive(Debug)]
struct FragmentBuffer {
fragments: Vec<Frame>,
fragments_len: usize,
max_len: usize,
}
impl FragmentBuffer {
/// Creates a new fragment buffer that will permit a maximum length of `max_len`.
fn new(max_len: usize) -> FragmentBuffer {
FragmentBuffer {
fragments: Vec::with_capacity(10),
fragments_len: 0,
max_len,
}
}
/// Attempts to push a frame into the buffer. This will fail if the new length of the buffer's
/// frames exceeds the maximum capacity of `max_len`.
fn try_push_frame(&mut self, frame: Frame) -> Result<(), String> {
let FragmentBuffer {
fragments,
fragments_len,
max_len,
} = self;
*fragments_len += frame.payload().len();
if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len {
return Err(format!(
"Message too big: {} + {} > {}",
fragments_len, fragments_len, max_len
)
.into());
} else {
fragments.push(frame);
Ok(())
}
}
/// Returns the total length of all of the frames that have been pushed into the buffer.
fn len(&self) -> usize {
self.fragments_len
}
/// Returns whether the buffer is empty.
fn is_empty(&self) -> bool {
self.fragments.is_empty()
}
/// Returns the first element of the fragments slice, or `None` if it is empty.
fn first(&self) -> Option<&Frame> {
self.fragments.first()
}
/// Drains the buffer and resets it to an initial capacity of 10 elements.
fn reset(&mut self) -> Vec<Frame> {
self.fragments_len = 0;
replace(&mut self.fragments, Vec::with_capacity(10))
}
}

@ -0,0 +1,56 @@
//! WebSocket extensions
use http::{Request, Response};
use crate::protocol::frame::Frame;
use crate::Message;
/// A permessage-deflate WebSocket extension (RFC 7692).
#[cfg(feature = "deflate")]
pub mod deflate;
/// An uncompressed message handler for a WebSocket.
pub mod uncompressed;
/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions
/// may be stacked by nesting them inside one another.
pub trait WebSocketExtension {
/// An error type that the extension produces.
type Error: Into<crate::Error>;
/// Constructs a new WebSocket extension that will permit messages of the provided size.
fn new(max_message_size: Option<usize>) -> Self;
/// Returns whether or not the extension is enabled.
fn enabled(&self) -> bool {
false
}
/// For WebSocket clients, this will be called when a `Request` is being constructed.
fn on_make_request<T>(&mut self, request: Request<T>) -> Request<T> {
request
}
/// For WebSocket server, this will be called when a `Request` has been received.
fn on_receive_request<T>(
&mut self,
_request: &Request<T>,
_response: &mut Response<T>,
) -> Result<(), Self::Error> {
Ok(())
}
/// For WebSocket clients, this will be called when a response from the server has been
/// received. If an error is produced, then subsequent calls to `rsv1()` should return `false`.
fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> {
Ok(())
}
/// Called when a frame is about to be sent.
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> {
Ok(frame)
}
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
/// type `OpCode::Data`.
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error>;
}

@ -0,0 +1,104 @@
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::{Error, Message};
use crate::protocol::MAX_MESSAGE_SIZE;
/// An uncompressed message handler for a WebSocket.
#[derive(Debug)]
pub struct UncompressedExt {
incomplete: Option<IncompleteMessage>,
max_message_size: Option<usize>,
}
impl Default for UncompressedExt {
fn default() -> Self {
UncompressedExt {
incomplete: None,
max_message_size: Some(MAX_MESSAGE_SIZE)
}
}
}
impl UncompressedExt {
/// Builds a new `UncompressedExt` that will permit a maximum message size of `max_message_size`
/// or will be unbounded if `None`.
pub fn new(max_message_size: Option<usize>) -> UncompressedExt {
UncompressedExt {
incomplete: None,
max_message_size,
}
}
}
impl WebSocketExtension for UncompressedExt {
type Error = Error;
fn new(max_message_size: Option<usize>) -> Self {
UncompressedExt {
incomplete: None,
max_message_size,
}
}
fn enabled(&self) -> bool {
true
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
let fin = frame.header().is_final;
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(
"Reserved bits are non-zero and no WebSocket extensions are enabled".into(),
));
}
match frame.header().opcode {
OpCode::Data(data) => match data {
Data::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.max_message_size)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
Data::Text | Data::Binary => {
let msg = {
let message_type = match data {
Data::Text => IncompleteMessageType::Text,
Data::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
Data::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
},
_ => unreachable!(),
}
}
}

@ -11,6 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type.
@ -21,18 +22,25 @@ pub type Response = HttpResponse<()>;
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S> {
pub struct ClientHandshake<S, Extension>
where
Extension: WebSocketExtension,
{
verify_data: VerifyData,
config: Option<WebSocketConfig>,
config: Option<Option<WebSocketConfig<Extension>>>,
_marker: PhantomData<S>,
}
impl<S: Read + Write> ClientHandshake<S> {
impl<Stream, Ext> ClientHandshake<Stream, Ext>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Initiate a client handshake.
pub fn start(
stream: S,
stream: Stream,
request: Request,
config: Option<WebSocketConfig>,
mut config: Option<WebSocketConfig<Ext>>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
@ -52,7 +60,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let key = generate_key();
let machine = {
let req = generate_request(request, &key)?;
let req = generate_request(request, &key, &mut config)?;
HandshakeMachine::start_write(stream, req)
};
@ -60,7 +68,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData { accept_key },
config,
config: Some(config),
_marker: PhantomData,
}
};
@ -73,10 +81,15 @@ impl<S: Read + Write> ClientHandshake<S> {
}
}
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
impl<Stream, Ext> HandshakeRole for ClientHandshake<Stream, Ext>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
type IncomingData = Response;
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
type InternalStream = Stream;
type FinalResult = (WebSocket<Stream, Ext>, Response);
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -90,10 +103,11 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result,
tail,
} => {
self.verify_data.verify_response(&result)?;
let mut config = self.config.take().unwrap();
self.verify_data.verify_response(&result, &mut config)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config);
ProcessingResult::Done((websocket, result))
}
})
@ -101,20 +115,33 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Generate client request.
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
fn generate_request<Ext>(
request: Request,
key: &str,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<Vec<u8>>
where
Ext: WebSocketExtension,
{
let request = match config {
Some(ref mut config) => config.encoder.on_make_request(request),
None => request,
};
let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority()
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}
write!(
@ -138,7 +165,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
for (k, v) in request.headers() {
let mut k = k.as_str();
if k == "sec-websocket-protocol" {
if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol";
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
@ -156,7 +183,14 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
pub fn verify_response<Ext>(
&self,
response: &Response,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<()>
where
Ext: WebSocketExtension,
{
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -205,12 +239,18 @@ impl VerifyData {
"Key mismatch in Sec-WebSocket-Accept".into(),
));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
if let Some(config) = config {
if let Err(e) = config.encoder.on_response(response) {
return Err(e.into());
}
}
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -266,8 +306,9 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use crate::extensions::uncompressed::UncompressedExt;
#[test]
fn random_keys() {
@ -297,14 +338,18 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
@ -314,14 +359,18 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://user:pass@localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
@ -331,7 +380,9 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -12,6 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Server request type.
@ -39,7 +40,10 @@ pub fn create_response(request: &Request) -> Result<Response> {
.headers()
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("Upgrade"))
})
.unwrap_or(false)
{
return Err(Error::Protocol(
@ -187,31 +191,43 @@ impl Callback for NoCallback {
/// Server handshake role.
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C> {
pub struct ServerHandshake<S, C, Ext>
where
Ext: WebSocketExtension,
{
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
config: Option<Option<WebSocketConfig<Ext>>>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
impl<S, C, Ext> ServerHandshake<S, C, Ext>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
pub fn start(
stream: S,
callback: C,
config: Option<WebSocketConfig<Ext>>,
) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake {
callback: Some(callback),
config,
config: Some(config),
error_code: None,
_marker: PhantomData,
},
@ -219,10 +235,15 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
}
}
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
impl<S, C, Ext> HandshakeRole for ServerHandshake<S, C, Ext>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S>;
type FinalResult = WebSocket<S, Ext>;
fn stage_finished(
&mut self,
@ -231,16 +252,23 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
Ok(match finish {
StageResult::DoneReading {
stream,
result,
result: request,
tail,
} => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()));
}
let response = create_response(&result)?;
let mut response = create_response(&request)?;
if let Some(ref mut config) = self.config.as_mut().unwrap() {
if let Err(e) = config.encoder.on_receive_request(&request, &mut response) {
return Err(e.into());
}
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
callback.on_request(&request, response)
} else {
Ok(response)
};
@ -277,7 +305,11 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.take().unwrap(),
);
ProcessingResult::Done(websocket)
}
}

@ -22,6 +22,8 @@ pub mod server;
pub mod stream;
pub mod util;
pub mod extensions;
pub use crate::client::{client, connect};
pub use crate::error::{Error, Result};
pub use crate::handshake::client::ClientHandshake;

@ -2,7 +2,7 @@
pub mod frame;
mod message;
pub(crate) mod message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
@ -14,10 +14,14 @@ use std::mem::replace;
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::message::IncompleteMessage;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::util::NonBlockingResult;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
@ -28,29 +32,48 @@ pub enum Role {
}
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
#[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig<E = UncompressedExt>
where
E: WebSocketExtension,
{
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user.
pub max_message_size: Option<usize>,
/// The maximum size of a single message frame. `None` means no size limit. The limit is for
/// frame payload NOT including the frame header. The default value is 16 MiB which should
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// Per-message compression strategy.
pub encoder: E,
}
impl Default for WebSocketConfig {
impl<E> Default for WebSocketConfig<E>
where
E: WebSocketExtension,
{
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
encoder: E::new(Some(MAX_MESSAGE_SIZE)),
}
}
}
impl<E> WebSocketConfig<E>
where
E: WebSocketExtension,
{
/// Creates a `WebSocketConfig` instance using the default configuration and the provided
/// encoder for new connections.
pub fn default_with_encoder(encoder: E) -> WebSocketConfig<E> {
WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder,
}
}
}
@ -60,20 +83,30 @@ impl Default for WebSocketConfig {
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)]
pub struct WebSocket<Stream> {
pub struct WebSocket<Stream, Ext>
where
Ext: WebSocketExtension,
{
/// The underlying socket.
socket: Stream,
/// The context for managing a WebSocket.
context: WebSocketContext,
context: WebSocketContext<Ext>,
}
impl<Stream> WebSocket<Stream> {
impl<Stream, Ext> WebSocket<Stream, Ext>
where
Ext: WebSocketExtension,
{
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
pub fn from_raw_socket(
stream: Stream,
role: Role,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
@ -89,7 +122,7 @@ impl<Stream> WebSocket<Stream> {
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocket {
socket: stream,
@ -101,18 +134,19 @@ impl<Stream> WebSocket<Stream> {
pub fn get_ref(&self) -> &Stream {
&self.socket
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.socket
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
self.context.set_config(set_func)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig {
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
self.context.get_config()
}
@ -132,7 +166,11 @@ impl<Stream> WebSocket<Stream> {
}
}
impl<Stream: Read + Write> WebSocket<Stream> {
impl<Stream, Ext> WebSocket<Stream, Ext>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Read a message from stream, if possible.
///
/// This will queue responses to ping and close messages to be sent. It will call
@ -215,7 +253,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// A context for managing WebSocket stream.
#[derive(Debug)]
pub struct WebSocketContext {
pub struct WebSocketContext<Ext = UncompressedExt>
where
Ext: WebSocketExtension,
{
/// Server or client?
role: Role,
/// encoder/decoder of frame.
@ -229,12 +270,17 @@ pub struct WebSocketContext {
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
config: WebSocketConfig<Ext>,
}
impl WebSocketContext {
impl<Ext> WebSocketContext<Ext>
where
Ext: WebSocketExtension,
{
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
pub fn new(role: Role, config: Option<WebSocketConfig<Ext>>) -> Self {
let config = config.unwrap_or_else(Default::default);
WebSocketContext {
role,
frame: FrameCodec::new(),
@ -242,12 +288,16 @@ impl WebSocketContext {
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_else(WebSocketConfig::default),
config,
}
}
/// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
pub fn from_partially_read(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
@ -255,12 +305,12 @@ impl WebSocketContext {
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
set_func(&mut self.config)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig {
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
&self.config
}
@ -426,17 +476,6 @@ impl WebSocketContext {
"Remote sent frame after having sent a Close Frame".into(),
));
}
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()));
}
}
match self.role {
Role::Server => {
@ -489,49 +528,10 @@ impl WebSocketContext {
}
}
OpCode::Data(data) => {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
OpData::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
}
}
_ => match self.config.encoder.on_receive_frame(frame) {
Ok(r) => Ok(r),
Err(e) => Err(e.into()),
},
} // match opcode
} else {
// Connection closed by peer
@ -601,6 +601,13 @@ impl WebSocketContext {
}
}
if frame.header().is_final {
frame = match self.config.encoder.on_send_frame(frame) {
Ok(frame) => frame,
Err(e) => return Err(e.into()),
};
}
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
@ -675,6 +682,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::UncompressedExt;
use std::io;
use std::io::Cursor;
@ -702,7 +710,8 @@ mod tests {
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(
@ -722,8 +731,9 @@ mod tests {
0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_message_size: Some(10),
..WebSocketConfig::default()
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: UncompressedExt::new(Some(10)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
@ -736,8 +746,9 @@ mod tests {
fn size_limiting_binary() {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig {
max_message_size: Some(2),
..WebSocketConfig::default()
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: UncompressedExt::new(Some(2)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(

@ -7,6 +7,8 @@ use crate::handshake::HandshakeError;
use crate::protocol::{WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use std::io::{Read, Write};
/// Accept the given Stream as a WebSocket.
@ -18,10 +20,14 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
pub fn accept_with_config<Stream, Ext>(
stream: Stream,
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<Stream, Ext>, HandshakeError<ServerHandshake<Stream, NoCallback, Ext>>>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
accept_hdr_with_config(stream, NoCallback, config)
}
@ -33,7 +39,10 @@ pub fn accept_with_config<S: Read + Write>(
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
) -> Result<
WebSocket<S, UncompressedExt>,
HandshakeError<ServerHandshake<S, NoCallback, UncompressedExt>>,
> {
accept_with_config(stream, None)
}
@ -45,11 +54,16 @@ pub fn accept<S: Read + Write>(
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
pub fn accept_hdr_with_config<S, C, Ext>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<S, Ext>, HandshakeError<ServerHandshake<S, C, Ext>>>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
ServerHandshake::start(stream, callback, config).handshake()
}
@ -61,6 +75,6 @@ pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
) -> Result<WebSocket<S, UncompressedExt>, HandshakeError<ServerHandshake<S, C, UncompressedExt>>> {
accept_hdr_with_config(stream, callback, None)
}

@ -1,22 +1,23 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpStream, TcpListener};
use std::net::{TcpListener, TcpStream};
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream};
use native_tls::TlsStream;
use url::Url;
use net2::TcpStreamExt;
use tungstenite::extensions::uncompressed::UncompressedExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
type Sock<Ext> = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>, Ext>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where
CT: FnOnce(Sock) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream>),
CT: FnOnce(Sock<UncompressedExt>) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream, UncompressedExt>),
{
env_logger::try_init().ok();
@ -26,8 +27,8 @@ where
exit(1);
});
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is port already in use?");
let server =
TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?");
let client_thread = spawn(move || {
let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap())
@ -46,7 +47,8 @@ where
#[test]
fn test_server_close() {
do_test(3012,
do_test(
3012,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -75,12 +77,14 @@ fn test_server_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}
#[test]
fn test_evil_server_close() {
do_test(3013,
do_test(
3013,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -106,14 +110,19 @@ fn test_evil_server_close() {
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
srv_sock
.get_mut()
.set_linger(Some(Duration::from_secs(0)))
.unwrap();
drop(srv_sock);
});
},
);
}
#[test]
fn test_client_close() {
do_test(3014,
do_test(
3014,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -137,7 +146,9 @@ fn test_client_close() {
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.write_message(Message::Text("From Server".into())).unwrap();
srv_sock
.write_message(Message::Text("From Server".into()))
.unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close());
@ -147,6 +158,6 @@ fn test_client_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}

Loading…
Cancel
Save