Compare commits

..

No commits in common. 'master' and 'permessage-deflate' have entirely different histories.

  1. 70
      .github/workflows/ci.yml
  2. 3
      .gitignore
  3. 15
      .travis.yml
  4. 22
      CHANGELOG.md
  5. 58
      Cargo.toml
  6. 8
      README.md
  7. 576
      autobahn/expected-results.json
  8. 3
      benches/buffer.rs
  9. 75
      benches/write.rs
  10. 20
      examples/autobahn-client.rs
  11. 18
      examples/autobahn-server.rs
  12. 4
      examples/client.rs
  13. 4
      examples/server.rs
  14. 8
      examples/srv_accept_unmasked_frames.rs
  15. 2
      fuzz/fuzz_targets/read_message_client.rs
  16. 2
      fuzz/fuzz_targets/read_message_server.rs
  17. 2
      scripts/autobahn-client.sh
  18. 2
      scripts/autobahn-server.sh
  19. 1
      src/client.rs
  20. 29
      src/error.rs
  21. 442
      src/extensions/compression/deflate.rs
  22. 4
      src/extensions/compression/mod.rs
  23. 18
      src/extensions/mod.rs
  24. 135
      src/handshake/client.rs
  25. 26
      src/handshake/server.rs
  26. 5
      src/lib.rs
  27. 12
      src/protocol/frame/frame.rs
  28. 122
      src/protocol/frame/mod.rs
  29. 24
      src/protocol/message.rs
  30. 602
      src/protocol/mod.rs
  31. 12
      src/tls.rs
  32. 41
      tests/connection_reset.rs
  33. 9
      tests/no_send_after_close.rs
  34. 15
      tests/receive_after_init_close.rs
  35. 68
      tests/write.rs

@ -1,70 +0,0 @@
name: CI
on: [push, pull_request]
jobs:
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@nightly
with:
components: rustfmt
- run: cargo fmt --all --check
test:
name: Test
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Install cargo-hack
uses: taiki-e/install-action@cargo-hack
- name: Check
run: cargo hack check --feature-powerset --all-targets
- name: Test
run: cargo test --release
autobahn:
name: Autobahn tests
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
- beta
- nightly
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Running Autobahn TestSuite for client
run: ./scripts/autobahn-client.sh
- name: Running Autobahn TestSuite for server
run: ./scripts/autobahn-server.sh

3
.gitignore vendored

@ -1,3 +1,4 @@
target
Cargo.lock
.vscode
autobahn/client/
autobahn/server/

@ -0,0 +1,15 @@
language: rust
rust:
- stable
services:
- docker
before_script:
- export PATH="$PATH:$HOME/.cargo/bin"
script:
- cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -1,25 +1,3 @@
# Unreleased (0.20.0)
- Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing.
This allows writing a batch of messages and flushing once, improving performance.
- Add `WebSocket::read`, `write`, `send`, `flush`. Deprecate `read_message`, `write_message`, `write_pending`.
- Add `FrameSocket::read`, `write`, `send`, `flush`. Remove `read_frame`, `write_frame`, `write_pending`.
Note: Previous use of `write_frame` may be replaced with `send`.
- Add `WebSocketContext::read`, `write`, `flush`. Remove `read_message`, `write_message`, `write_pending`.
Note: Previous use of `write_message` may be replaced with `write` + `flush`.
- Remove `send_queue`, replaced with using the frame write buffer to achieve similar results.
* Add `WebSocketConfig::max_write_buffer_size`. Deprecate `max_send_queue`.
* Add `Error::WriteBufferFull`. Remove `Error::SendQueueFull`.
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
# 0.19.0
- Update TLS dependencies.
- Exchanging `base64` for `data-encoding`.
# 0.18.0
- Make handshake dependencies optional with a new `handshake` feature (now a default one!).

@ -7,9 +7,9 @@ authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0"
readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.19.0"
documentation = "https://docs.rs/tungstenite/0.18.0"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.19.0"
version = "0.18.0"
edition = "2018"
rust-version = "1.51"
include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
@ -24,7 +24,16 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls"]
__rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies]
data-encoding = { version = "2", optional = true }
@ -38,6 +47,11 @@ sha1 = { version = "0.10", optional = true }
thiserror = "1.0.23"
url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5"
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate]
optional = true
@ -46,18 +60,22 @@ version = "0.2.3"
[dependencies.rustls]
optional = true
version = "0.21.0"
version = "0.20.0"
[dependencies.rustls-native-certs]
optional = true
version = "0.6.0"
[dependencies.webpki]
optional = true
version = "0.22"
[dependencies.webpki-roots]
optional = true
version = "0.23"
version = "0.22"
[dev-dependencies]
criterion = "0.5.0"
criterion = "0.4.0"
env_logger = "0.10.0"
input_buffer = "0.5.0"
net2 = "0.2.37"
@ -66,31 +84,3 @@ rand = "0.8.4"
[[bench]]
name = "buffer"
harness = false
[[bench]]
name = "write"
harness = false
[[example]]
name = "client"
required-features = ["handshake"]
[[example]]
name = "server"
required-features = ["handshake"]
[[example]]
name = "autobahn-client"
required-features = ["handshake"]
[[example]]
name = "autobahn-server"
required-features = ["handshake"]
[[example]]
name = "callback-error"
required-features = ["handshake"]
[[example]]
name = "srv_accept_unmasked_frames"
required-features = ["handshake"]

@ -14,11 +14,11 @@ fn main () {
spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap();
loop {
let msg = websocket.read().unwrap();
let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap();
websocket.write_message(msg).unwrap();
}
}
});
@ -36,7 +36,7 @@ take a look at [`tokio-tungstenite`](https://github.com/snapview/tokio-tungsteni
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)
[![Build Status](https://github.com/snapview/tungstenite-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/snapview/tungstenite-rs/actions)
[![Build Status](https://travis-ci.org/snapview/tungstenite-rs.svg?branch=master)](https://travis-ci.org/snapview/tungstenite-rs)
[Documentation](https://docs.rs/tungstenite)
@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing
-------

File diff suppressed because it is too large Load Diff

@ -1,4 +1,5 @@
use std::io::{Cursor, Read, Result as IoResult};
use std::io::Result as IoResult;
use std::io::{Cursor, Read};
use bytes::Buf;
use criterion::*;

@ -1,75 +0,0 @@
//! Benchmarks for write performance.
use criterion::{BatchSize, Criterion};
use std::{
hint,
io::{self, Read, Write},
time::{Duration, Instant},
};
use tungstenite::{Message, WebSocket};
const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// `Write` impl that simulates slowish writes and slow flushes.
///
/// Each `write` can buffer up to 8 MiB before flushing but takes an additional **~80ns**
/// to simulate stuff going on in the underlying stream.
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
}
// simulate io
spin(Duration::from_nanos(80));
self.0.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.0.is_empty() {
// simulate io
spin(Duration::from_micros(8));
self.0.clear();
}
Ok(())
}
}
fn spin(duration: Duration) {
let a = Instant::now();
while a.elapsed() < duration {
hint::spin_loop();
}
}
fn benchmark(c: &mut Criterion) {
// Writes 100k small json text messages then flushes
c.bench_function("write 100k small texts then flush", |b| {
let mut ws = WebSocket::from_raw_socket(
MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)),
tungstenite::protocol::Role::Server,
None,
);
b.iter_batched(
|| (0..100_000).map(|i| Message::Text(format!("{{\"id\":{i}}}"))),
|batch| {
for msg in batch {
ws.write(msg).unwrap();
}
ws.flush().unwrap();
},
BatchSize::SmallInput,
)
});
}
criterion::criterion_group!(write_benches, benchmark);
criterion::criterion_main!(write_benches);

@ -1,13 +1,16 @@
use log::*;
use url::Url;
use tungstenite::{connect, Error, Message, Result};
use tungstenite::{
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read()?;
let msg = socket.read_message()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
}
@ -24,11 +27,18 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?;
let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop {
match socket.read()? {
match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?;
socket.write_message(msg)?;
}
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
}

@ -4,7 +4,10 @@ use std::{
};
use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::{
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -14,12 +17,19 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test");
loop {
match socket.read()? {
match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?;
socket.write_message(msg)?;
}
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
}

@ -14,9 +14,9 @@ fn main() {
println!("* {}", header);
}
socket.send(Message::Text("Hello WebSocket".into())).unwrap();
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop {
let msg = socket.read().expect("Error reading message");
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);
}
// socket.close(None);

@ -28,9 +28,9 @@ fn main() {
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();
loop {
let msg = websocket.read().unwrap();
let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap();
websocket.write_message(msg).unwrap();
}
}
});

@ -27,18 +27,22 @@ fn main() {
};
let config = Some(WebSocketConfig {
max_send_queue: None,
max_message_size: None,
max_frame_size: None,
// This setting allows to accept client frames which are not masked
// This is not in compliance with RFC 6455 but might be handy in some
// rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames
accept_unmasked_frames: true,
..<_>::default()
#[cfg(feature = "deflate")]
compression: None,
});
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
loop {
let msg = websocket.read().unwrap();
let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() {
println!("received message {}", msg);
}

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
socket.read().ok();
socket.read_message().ok();
});

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
socket.read().ok();
socket.read_message().ok();
});

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3
cargo run --release --example autobahn-client
cargo run --release --example autobahn-client --features=deflate
test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi
}
cargo run --release --example autobahn-server & WSSERVER_PID=$!
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3
docker run --rm \

@ -54,7 +54,6 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,

@ -53,9 +53,9 @@ pub enum Error {
/// Protocol violation.
#[error("WebSocket protocol error: {0}")]
Protocol(#[from] ProtocolError),
/// Message write buffer is full.
#[error("Write buffer is full")]
WriteBufferFull(Message),
/// Message send queue full.
#[error("Send queue is full")]
SendQueueFull(Message),
/// UTF coding error.
#[error("UTF-8 encoding error")]
Utf8,
@ -70,6 +70,10 @@ pub enum Error {
#[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] crate::extensions::DeflateError),
}
impl From<str::Utf8Error> for Error {
@ -206,6 +210,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
@ -218,6 +225,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
@ -230,6 +240,15 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
InvalidExtensionsHeader,
}
/// Indicates the specific type/cause of URL error.
@ -271,6 +290,10 @@ pub enum TlsError {
#[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")]
Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error.
#[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name")]

@ -0,0 +1,442 @@
use std::convert::TryFrom;
use bytes::BytesMut;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use headers::WebsocketExtension;
use http::HeaderValue;
use thiserror::Error;
use crate::protocol::Role;
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Errors from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress(#[source] std::io::Error),
/// Decompress failed
#[error("Failed to decompress")]
Decompress(#[source] std::io::Error),
/// Extension negotiation failed.
#[error("Extension negotiation failed")]
Negotiation(#[source] NegotiationError),
}
/// Errors from `permessage-deflate` extension negotiation.
#[derive(Debug, Error)]
pub enum NegotiationError {
/// Unknown parameter in a negotiation response.
#[error("Unknown parameter in a negotiation response: {0}")]
UnknownParameter(String),
/// Duplicate parameter in a negotiation response.
#[error("Duplicate parameter in a negotiation response: {0}")]
DuplicateParameter(String),
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
UnexpectedClientMaxWindowBits,
/// Received unsupported `server_max_window_bits` in a negotiation response.
#[error("Received unsupported server_max_window_bits in a negotiation response")]
ServerMaxWindowBitsNotSupported,
/// Invalid `client_max_window_bits` value in a negotiation response.
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
InvalidClientMaxWindowBitsValue(String),
/// Invalid `server_max_window_bits` value in a negotiation response.
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
InvalidServerMaxWindowBitsValue(String),
/// Missing `server_max_window_bits` value in a negotiation response.
#[error("Missing server_max_window_bits value in a negotiation response")]
MissingServerMaxWindowBitsValue,
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug, Default)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
// > a client informs the peer server of a hint that even if the server doesn't include the
// > "client_no_context_takeover" extension parameter in the corresponding
// > extension negotiation response to the offer, the client is not going
// > to use context takeover.
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
if self.client_no_context_takeover {
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub(crate) fn accept_offer(
&self,
offers: &headers::SecWebsocketExtensions,
) -> Option<(WebsocketExtension, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
// 2. The negotiation offer contains an extension parameter with an invalid value.
// 3. The negotiation offer contains multiple extension parameters with the same name.
// 4. The server doesn't support the offered configuration.
offers.iter().find_map(|extension| {
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for (key, val) in params {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
return None;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
return None;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `server_max_window_bits` requires a value in range [8, 15].
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
} else {
return None;
}
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
return None;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `client_max_window_bits` requires a value in range [8, 15] or no value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
}
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
return None;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
return None;
}
}
}
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
} else {
None
}
})
}
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = DeflateConfig {
compression: self.compression,
// If this was hinted in the offer, the client won't use context takeover
// even if the response doesn't include it.
// See `generate_offer`.
client_no_context_takeover: self.client_no_context_takeover,
..DeflateConfig::default()
};
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
// A client MUST _Fail the WebSocket Connection_ if the peer server
// accepted an extension negotiation offer for this extension with an
// extension negotiation response meeting any of the following
// conditions:
// 1. The negotiation response contains an extension parameter not defined for use in a response.
// 2. The negotiation response contains an extension parameter with an invalid value.
// 3. The negotiation response contains multiple extension parameters with the same name.
// 4. The client does not support the configuration that the response represents.
for (key, val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_server_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_server_no_context_takeover = true;
// A server MAY include the "server_no_context_takeover" extension
// parameter in an extension negotiation response even if the extension
// negotiation offer being accepted by the extension negotiation
// response didn't include the "server_no_context_takeover" extension
// parameter.
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_client_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_client_no_context_takeover = true;
// The server may include this parameter in the response and the client MUST support it.
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
));
}
} else {
return Err(DeflateError::Negotiation(
NegotiationError::MissingServerMaxWindowBitsValue,
));
}
// A server may include the "server_max_window_bits" extension parameter
// in an extension negotiation response even if the extension
// negotiation offer being accepted by the response didn't include the
// "server_max_window_bits" extension parameter.
//
// However, but we need to fail the connection because we don't support it (condition 4).
return Err(DeflateError::Negotiation(
NegotiationError::ServerMaxWindowBitsNotSupported,
));
}
CLIENT_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
));
}
}
// Fail the connection because the parameter is invalid when the client didn't offer.
//
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
return Err(DeflateError::Negotiation(
NegotiationError::UnexpectedClientMaxWindowBits,
));
}
// Response with unknown parameter MUST fail the WebSocket connection.
_ => {
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
key.to_owned(),
)));
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
fn is_valid_max_window_bits(bits: &str) -> bool {
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
}
#[cfg(test)]
mod tests {
use super::is_valid_max_window_bits;
#[test]
fn valid_max_window_bits() {
for bits in 8..=15 {
assert!(is_valid_max_window_bits(&bits.to_string()));
}
}
#[test]
fn invalid_max_window_bits() {
assert!(!is_valid_max_window_bits(""));
assert!(!is_valid_max_window_bits("0"));
assert!(!is_valid_max_window_bits("08"));
assert!(!is_valid_max_window_bits("+8"));
assert!(!is_valid_max_window_bits("-8"));
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
for param in params {
buf.extend_from_slice(b"; ");
buf.extend_from_slice(param.as_bytes());
}
let header = HeaderValue::from_maybe_shared(buf.freeze())
.expect("semicolon separated HeaderValue is valid");
WebsocketExtension::try_from(header).expect("valid extension")
}

@ -0,0 +1,4 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -0,0 +1,18 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -5,6 +5,7 @@ use std::{
marker::PhantomData,
};
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -19,6 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -56,7 +58,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let (request, key) = generate_request(request, &config)?;
let machine = HandshakeMachine::start_write(stream, request);
@ -83,18 +85,24 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) {
Ok(r) => r,
Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail);
return Err(Error::Http(e));
}
Err(e) => return Err(e),
};
let (result, extensions) =
match self.verify_data.verify_response(result, &self.config) {
Ok(r) => r,
Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail);
return Err(Error::Http(e));
}
Err(e) => return Err(e),
};
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read_with_extensions(
stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result))
}
})
@ -102,7 +110,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
pub fn generate_request(
mut request: Request,
config: &Option<WebSocketConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new();
write!(
req,
@ -173,6 +184,9 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
}
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
}
writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok((req, key))
@ -186,7 +200,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> {
pub fn verify_response(
&self,
response: Response,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -231,7 +249,14 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
let extensions = if let Some(agreed) = headers
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
{
verify_extensions(&agreed, _config)?
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -240,10 +265,49 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(response)
Ok((response, extensions))
}
}
fn verify_extensions(
agreed_extensions: &headers::SecWebsocketExtensions,
_config: &Option<WebSocketConfig>,
) -> Result<Option<Extensions>> {
#[cfg(feature = "deflate")]
{
if let Some(compression) = _config.and_then(|c| c.compression) {
let mut extensions = None;
for extension in agreed_extensions.iter() {
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
// > the client MUST _Fail the WebSocket Connection_.
if extension.name() != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
extension.name().to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
extension.name().to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(extension.params())?),
});
}
return Ok(extensions);
}
}
if let Some(extension) = agreed_extensions.iter().next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
}
Ok(None)
}
impl TryParse for Response {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -258,7 +322,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -286,6 +350,8 @@ pub fn generate_key() -> String {
mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test]
fn random_keys() {
@ -322,7 +388,7 @@ mod tests {
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -330,7 +396,7 @@ mod tests {
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -338,11 +404,40 @@ mod tests {
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n",
host = "localhost",
key = key
)
.into_bytes();
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -354,6 +449,6 @@ mod tests {
#[test]
fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err());
assert!(generate_request(request, &None).is_err());
}
}

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult,
};
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -20,6 +21,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
extensions: None,
_marker: PhantomData,
},
}
@ -240,7 +245,19 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let mut response = create_response(&result)?;
if let Some(config) = &self.config {
if let Some((agreed, extensions)) = result
.headers()
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
.and_then(|values| config.accept_offers(&values))
{
response.headers_mut().typed_insert(agreed);
self.extensions = Some(extensions);
}
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
@ -283,7 +300,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body)));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
self.extensions.take(),
);
ProcessingResult::Done(websocket)
}
}

@ -19,13 +19,14 @@ pub mod buffer;
#[cfg(feature = "handshake")]
pub mod client;
pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")]
pub mod handshake;
pub mod protocol;
#[cfg(feature = "handshake")]
mod server;
pub mod stream;
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls;
pub mod util;
@ -44,5 +45,5 @@ pub use crate::{
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
};
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_tls, client_tls_with_config, Connector};

@ -311,6 +311,18 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
}
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {

@ -6,14 +6,15 @@ pub mod coding;
mod frame;
mod mask;
use crate::{
error::{CapacityError, Error, Result},
Message, ReadBuffer,
};
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
use log::*;
pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::{
error::{CapacityError, Error, Result},
ReadBuffer,
};
/// A reader and writer for WebSocket frames.
#[derive(Debug)]
@ -56,7 +57,7 @@ where
Stream: Read,
{
/// Read a frame from stream.
pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
self.codec.read_frame(&mut self.stream, max_size)
}
}
@ -65,28 +66,18 @@ impl<Stream> FrameSocket<Stream>
where
Stream: Write,
{
/// Writes and immediately flushes a frame.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, frame: Frame) -> Result<()> {
self.write(frame)?;
self.flush()
}
/// Write a frame to stream.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
///
/// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
/// is returned.
/// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
pub fn write(&mut self, frame: Frame) -> Result<()> {
self.codec.buffer_frame(&mut self.stream, frame)
/// This function guarantees that the frame is queued regardless of any errors.
/// There is no need to resend the frame. In order to handle WouldBlock or Incomplete,
/// call write_pending() afterwards.
pub fn write_frame(&mut self, frame: Frame) -> Result<()> {
self.codec.write_frame(&mut self.stream, frame)
}
/// Flush writes.
pub fn flush(&mut self) -> Result<()> {
self.codec.write_out_buffer(&mut self.stream)?;
Ok(self.stream.flush()?)
/// Complete pending write, if any.
pub fn write_pending(&mut self) -> Result<()> {
self.codec.write_pending(&mut self.stream)
}
}
@ -97,14 +88,6 @@ pub(super) struct FrameCodec {
in_buffer: ReadBuffer,
/// Buffer to send packets to the network.
out_buffer: Vec<u8>,
/// Capacity limit for `out_buffer`.
max_out_buffer_len: usize,
/// Buffer target length to reach before writing to the stream
/// on calls to `buffer_frame`.
///
/// Setting this to non-zero will buffer small writes from hitting
/// the stream.
out_buffer_write_len: usize,
/// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>,
}
@ -112,13 +95,7 @@ pub(super) struct FrameCodec {
impl FrameCodec {
/// Create a new frame codec.
pub(super) fn new() -> Self {
Self {
in_buffer: ReadBuffer::new(),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None,
}
Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None }
}
/// Create a new frame codec from partially read data.
@ -126,23 +103,10 @@ impl FrameCodec {
Self {
in_buffer: ReadBuffer::from_partially_read(part),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None,
}
}
/// Sets a maximum size for the out buffer.
pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
self.max_out_buffer_len = max;
}
/// Sets [`Self::buffer_frame`] buffer target length to reach before
/// writing to the stream.
pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
self.out_buffer_write_len = len;
}
/// Read a frame from the provided stream.
pub(super) fn read_frame<Stream>(
&mut self,
@ -201,37 +165,19 @@ impl FrameCodec {
Ok(Some(frame))
}
/// Writes a frame into the `out_buffer`.
/// If the out buffer size is over the `out_buffer_write_len` will also write
/// the out buffer into the provided `stream`.
///
/// To ensure buffered frames are written call [`Self::write_out_buffer`].
///
/// May write to the stream, will **not** flush.
pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
/// Write a frame to the provided stream.
pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
where
Stream: Write,
{
if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
return Err(Error::WriteBufferFull(Message::Frame(frame)));
}
trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
if self.out_buffer.len() > self.out_buffer_write_len {
self.write_out_buffer(stream)
} else {
Ok(())
}
self.write_pending(stream)
}
/// Writes the out_buffer to the provided stream.
///
/// Does **not** flush.
pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
/// Complete pending write, if any.
pub(super) fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
Stream: Write,
{
@ -247,11 +193,19 @@ impl FrameCodec {
}
self.out_buffer.drain(0..len);
}
stream.flush()?;
Ok(())
}
}
#[cfg(test)]
impl FrameCodec {
/// Returns the size of the output buffer.
pub(super) fn output_buffer_len(&self) -> usize {
self.out_buffer.len()
}
}
#[cfg(test)]
mod tests {
@ -270,11 +224,11 @@ mod tests {
let mut sock = FrameSocket::new(raw);
assert_eq!(
sock.read(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read(None).unwrap().is_none());
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]);
@ -285,7 +239,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(
sock.read(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
@ -295,10 +249,10 @@ mod tests {
let mut sock = FrameSocket::new(Vec::new());
let frame = Frame::ping(vec![0x04, 0x05]);
sock.send(frame).unwrap();
sock.write_frame(frame).unwrap();
let frame = Frame::pong(vec![0x01]);
sock.send(frame).unwrap();
sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
@ -310,7 +264,7 @@ mod tests {
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let mut sock = FrameSocket::new(raw);
let _ = sock.read(None); // should not crash
let _ = sock.read_frame(None); // should not crash
}
#[test]
@ -318,7 +272,7 @@ mod tests {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw);
assert!(matches!(
sock.read(Some(5)),
sock.read_frame(Some(5)),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
));
}

@ -84,6 +84,8 @@ use self::string_collect::StringCollector;
#[derive(Debug)]
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
}
#[derive(Debug)]
@ -94,6 +96,7 @@ enum IncompleteMessageCollector {
impl IncompleteMessage {
/// Create new.
#[cfg(not(feature = "deflate"))]
pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
@ -105,6 +108,25 @@ impl IncompleteMessage {
}
}
/// Create new.
#[cfg(feature = "deflate")]
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer.
pub fn len(&self) -> usize {
match self.collector {
@ -185,7 +207,7 @@ impl Message {
Message::Text(string.into())
}
/// Create a new binary WebSocket message by converting to `Vec<u8>`.
/// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message
where
B: Into<Vec<u8>>,

@ -6,6 +6,13 @@ mod message;
pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{
frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
@ -15,13 +22,9 @@ use self::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::Extensions,
util::NonBlockingResult,
};
use log::*;
use std::{
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -35,21 +38,10 @@ pub enum Role {
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
/// Does nothing, instead use `max_write_buffer_size`.
#[deprecated]
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>,
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
/// The default value is 128 KiB.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited.
///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size).
pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user.
@ -65,18 +57,76 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
#[allow(deprecated)]
WebSocketConfig {
max_send_queue: None,
write_buffer_size: 128 * 1024,
max_write_buffer_size: usize::MAX,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
}
impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
#[cfg(feature = "deflate")]
{
let mut offers = Vec::new();
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
offers.push(compression);
}
if offers.is_empty() {
None
} else {
Some(headers::SecWebsocketExtensions::new(offers))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers(
&self,
_offers: &headers::SecWebsocketExtensions,
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
#[cfg(feature = "deflate")]
{
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
let mut agreed_extensions = Vec::new();
let mut extensions = Extensions::default();
if let Some(compression) = &self.compression {
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
agreed_extensions.push(agreed);
extensions.compression = Some(compression);
}
}
if agreed_extensions.is_empty() {
None
} else {
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
}
@ -85,8 +135,6 @@ impl Default for WebSocketConfig {
///
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
///
/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
#[derive(Debug)]
pub struct WebSocket<Stream> {
/// The underlying socket.
@ -105,6 +153,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
@ -122,6 +182,21 @@ impl<Stream> WebSocket<Stream> {
}
}
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.socket
@ -160,116 +235,82 @@ impl<Stream> WebSocket<Stream> {
impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible.
///
/// This will also queue responses to ping and close messages. These responses
/// will be written and flushed on the next call to [`read`](Self::read),
/// [`write`](Self::write) or [`flush`](Self::flush).
/// This will queue responses to ping and close messages to be sent. It will call
/// `write_pending` before trying to read in order to make sure that those responses
/// make progress even if you never call `write_pending`. That does mean that they
/// get sent out earliest on the next call to `read_message`, `write_message` or `write_pending`.
///
/// # Closing the connection
/// ## Closing the connection
/// When the remote endpoint decides to close the connection this will return
/// the close message with an optional close frame.
///
/// You should continue calling [`read`](Self::read), [`write`](Self::write) or
/// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`]
/// is returned. Once that happens it is safe to drop the underlying connection.
pub fn read(&mut self) -> Result<Message> {
self.context.read(&mut self.socket)
}
/// Writes and immediately flushes a message.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, message: Message) -> Result<()> {
self.write(message)?;
self.flush()
/// You should continue calling `read_message`, `write_message` or `write_pending` to drive
/// the reply to the close frame until [Error::ConnectionClosed] is returned. Once that happens
/// it is safe to drop the underlying connection.
pub fn read_message(&mut self) -> Result<Message> {
self.context.read_message(&mut self.socket)
}
/// Write a message to the provided stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// Send a message to stream, if possible.
///
/// In the event of stream write failure the message frame will be stored
/// in the write buffer and will try again on the next call to [`write`](Self::write)
/// or [`flush`](Self::flush).
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// requests. A Pong reply will jump the queue because the
/// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent
/// as soon as is practical.
///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
///
/// This call will generally not flush. However, if there are queued automatic messages
/// they will be written and eagerly flushed.
///
/// For example, upon receiving ping messages tungstenite queues pong replies automatically.
/// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
/// will write & flush the pong reply. This means you should not respond to ping frames manually.
/// Note that upon receiving a ping message, tungstenite cues a pong reply automatically.
/// When you call either `read_message`, `write_message` or `write_pending` next it will try to send
/// that pong out if the underlying connection can take more data. This means you should not
/// respond to ping frames manually.
///
/// You can however send pong frames manually in order to indicate a unidirectional heartbeat
/// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
/// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
/// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
/// ping will not be sent as it will be replaced by your custom pong message.
///
/// # Errors
/// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
/// along with the equivalent passed message frame.
/// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
/// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
/// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
/// error on your part.
/// - [`Error::Io`] is returned if the underlying connection returns an error
/// if `read_message` returns a ping, you should call `write_pending` until it doesn't return
/// WouldBlock before passing a pong to `write_message`, otherwise the response to the
/// ping will not be sent, but rather replaced by your custom pong message.
///
/// ## Errors
/// - If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed].
/// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`,
/// [Error::AlreadyClosed] will be returned. This indicates a program error on your part.
/// - [Error::Io] is returned if the underlying connection returns an error
/// (consider these fatal except for WouldBlock).
/// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
pub fn write(&mut self, message: Message) -> Result<()> {
self.context.write(&mut self.socket, message)
/// - [Error::Capacity] if your message size is bigger than the configured max message size.
pub fn write_message(&mut self, message: Message) -> Result<()> {
self.context.write_message(&mut self.socket, message)
}
/// Flush writes.
///
/// Ensures all messages previously passed to [`write`](Self::write) and automatic
/// queued pong responses are written & flushed into the underlying stream.
pub fn flush(&mut self) -> Result<()> {
self.context.flush(&mut self.socket)
/// Flush the pending send queue.
pub fn write_pending(&mut self) -> Result<()> {
self.context.write_pending(&mut self.socket)
}
/// Close the connection.
///
/// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is
/// the same as calling `write(Message::Close(..))`.
/// the same as calling `write_message(Message::Close(..))`.
///
/// After queuing the close frame you should continue calling [`read`](Self::read) or
/// [`flush`](Self::flush) to drive the close handshake to completion.
/// After queuing the close frame you should continue calling `read_message` or
/// `write_pending` to drive the close handshake to completion.
///
/// The websocket RFC defines that the underlying connection should be closed
/// by the server. Tungstenite takes care of this asymmetry for you.
///
/// When the close handshake is finished (we have both sent and received
/// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
/// a close message), `read_message` or `write_pending` will return
/// [Error::ConnectionClosed] if this endpoint is the server.
///
/// If this endpoint is a client, [Error::ConnectionClosed] will only be
/// returned after the server has closed the underlying connection.
///
/// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
/// is returned from [`read`](Self::read) or [`flush`](Self::flush).
/// is returned from `read_message` or `write_pending`.
pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
self.context.close(&mut self.socket, code)
}
/// Old name for [`read`](Self::read).
#[deprecated(note = "Use `read`")]
pub fn read_message(&mut self) -> Result<Message> {
self.read()
}
/// Old name for [`send`](Self::send).
#[deprecated(note = "Use `send`")]
pub fn write_message(&mut self, message: Message) -> Result<()> {
self.send(message)
}
/// Old name for [`flush`](Self::flush).
#[deprecated(note = "Use `flush`")]
pub fn write_pending(&mut self) -> Result<()> {
self.flush()
}
}
/// A context for managing WebSocket stream.
@ -283,41 +324,55 @@ pub struct WebSocketContext {
state: WebSocketState,
/// Receive: an incomplete message being processed.
incomplete: Option<IncompleteMessage>,
/// Send in addition to regular messages E.g. "pong" or "close".
additional_send: Option<Frame>,
/// Send: a data send queue.
send_queue: VecDeque<Frame>,
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
}
impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::new(), config.unwrap_or_default())
WebSocketContext {
role,
frame: FrameCodec::new(),
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_default(),
extensions: None,
}
}
/// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default())
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
}
}
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
frame.set_max_out_buffer_len(config.max_write_buffer_size);
frame.set_out_buffer_write_len(config.write_buffer_size);
Self {
role,
frame,
state: WebSocketState::Active,
incomplete: None,
additional_send: None,
config,
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
extensions,
..WebSocketContext::new(role, config)
}
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config);
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
set_func(&mut self.config)
}
/// Read the configuration.
@ -344,23 +399,17 @@ impl WebSocketContext {
///
/// This function sends pong and close responses automatically.
/// However, it never blocks on write.
pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
where
Stream: Read + Write,
{
// Do not read from already closed connections.
self.state.check_not_terminated()?;
self.state.check_active()?;
loop {
if self.additional_send.is_some() {
// Since we may get ping or close, we need to reply to the messages even during read.
// Thus we flush but ignore its blocking.
self.flush(stream).no_block()?;
} else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed);
}
// Since we may get ping or close, we need to reply to the messages even during read.
// Thus we call write_pending() but ignore its blocking.
self.write_pending(stream).no_block()?;
// If we get here, either write blocks or we have nothing to write.
// Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? {
@ -370,94 +419,89 @@ impl WebSocketContext {
}
}
/// Write a message to the provided stream.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// Send a message to the provided stream, if possible.
///
/// In the event of stream write failure the message frame will be stored
/// in the write buffer and will try again on the next call to [`write`](Self::write)
/// or [`flush`](Self::flush).
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
pub fn write<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
/// Note that only the last pong frame is stored to be sent, and only the
/// most recent pong frame is sent if multiple pong frames are queued.
pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where
Stream: Read + Write,
{
// When terminated, return AlreadyClosed.
self.state.check_not_terminated()?;
self.state.check_active()?;
// Do not write after sending a close frame.
if !self.state.is_active() {
return Err(Error::Protocol(ProtocolError::SendAfterClosing));
}
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
// Try to make some room for the new message.
// Do not return here if write would block, ignore WouldBlock silently
// since we must queue the message anyway.
self.write_pending(stream).no_block()?;
}
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.set_additional(Frame::pong(data));
// Note: user pongs can be user flushed so no need to flush here
return self._write(stream, None).map(|_| ());
self.pong = Some(Frame::pong(data));
return self.write_pending(stream);
}
Message::Close(code) => return self.close(stream, code),
Message::Frame(f) => f,
};
let should_flush = self._write(stream, Some(frame))?;
if should_flush {
self.flush(stream)?;
}
Ok(())
self.send_queue.push_back(frame);
self.write_pending(stream)
}
/// Flush writes.
///
/// Ensures all messages previously passed to [`write`](Self::write) and automatically
/// queued pong responses are written & flushed into the `stream`.
#[inline]
pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
Stream: Read + Write,
{
self._write(stream, None)?;
self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?)
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
#[cfg(feature = "deflate")]
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
}
Ok(Frame::message(data, opcode, is_final))
}
/// Writes any data in the out_buffer, `additional_send` and given `data`.
///
/// Does **not** flush.
///
/// Returns true if the write contents indicate we should flush immediately.
fn _write<Stream>(&mut self, stream: &mut Stream, data: Option<Frame>) -> Result<bool>
/// Flush the pending send queue.
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
Stream: Read + Write,
{
if let Some(data) = data {
self.buffer_frame(stream, data)?;
}
// First, make sure we have no pending frame sending.
self.frame.write_pending(stream)?;
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
// response, unless it already received a Close frame. It SHOULD
// respond with Pong frame as soon as is practical. (RFC 6455)
let should_flush = if let Some(msg) = self.additional_send.take() {
trace!("Sending pong/close");
match self.buffer_frame(stream, msg) {
Err(Error::WriteBufferFull(Message::Frame(msg))) => {
// if an system message would exceed the buffer put it back in
// `additional_send` for retry. Otherwise returning this error
// may not make sense to the user, e.g. calling `flush`.
self.set_additional(msg);
false
}
Err(err) => return Err(err),
Ok(_) => true,
}
} else {
false
};
if let Some(pong) = self.pong.take() {
trace!("Sending pong reply");
self.send_one_frame(stream, pong)?;
}
// If we have any unsent frames, send them.
trace!("Frames still in queue: {}", self.send_queue.len());
while let Some(data) = self.send_queue.pop_front() {
self.send_one_frame(stream, data)?;
}
// If we get to this point, the send queue is empty and the underlying socket is still
// willing to take more data.
// If we're closing and there is nothing to send anymore, we should close the connection.
if self.role == Role::Server && !self.state.can_read() {
@ -467,11 +511,10 @@ impl WebSocketContext {
// maximum segment lifetimes (2MSL), while there is no corresponding
// server impact as a TIME_WAIT connection is immediately reopened upon
// a new SYN with a higher seq number). (RFC 6455)
self.frame.write_out_buffer(stream)?;
self.state = WebSocketState::Terminated;
Err(Error::ConnectionClosed)
} else {
Ok(should_flush)
Ok(())
}
}
@ -479,7 +522,7 @@ impl WebSocketContext {
///
/// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is
/// the same as calling `send(Message::Close(..))`.
/// the same as calling `write(Message::Close(..))`.
pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()>
where
Stream: Read + Write,
@ -487,9 +530,11 @@ impl WebSocketContext {
if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(code);
self._write(stream, Some(frame))?;
self.send_queue.push_back(frame);
} else {
// Already closed, nothing to do.
}
self.flush(stream)
self.write_pending(stream)
}
/// Try to decode one message frame. May return None.
@ -510,12 +555,14 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
{
let is_compressed = {
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
}
hdr.rsv1
};
match self.role {
Role::Server => {
@ -550,6 +597,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
}
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -558,7 +609,7 @@ impl WebSocketContext {
let data = frame.into_data();
// No ping processing after we sent a close frame.
if self.state.is_active() {
self.set_additional(Frame::pong(data.clone()));
self.pong = Some(Frame::pong(data.clone()));
}
Ok(Some(Message::Ping(data)))
}
@ -570,39 +621,34 @@ impl WebSocketContext {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
if self.incomplete.is_some() && is_compressed {
return Err(Error::Protocol(
ProtocolError::UnexpectedContinueFrame,
ProtocolError::CompressedContinueFrame,
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
let msg = self
.incomplete
.take()
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
self.extend_incomplete(msg, frame.into_data(), fin)
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
#[cfg(feature = "deflate")]
let msg = IncompleteMessage::new(message_type, is_compressed);
#[cfg(not(feature = "deflate"))]
let msg = IncompleteMessage::new(message_type);
self.extend_incomplete(msg, frame.into_data(), fin)
}
OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
@ -621,6 +667,32 @@ impl WebSocketContext {
}
}
fn extend_incomplete(
&mut self,
mut msg: IncompleteMessage,
data: Vec<u8>,
is_final: bool,
) -> Result<Option<Message>> {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(data, is_final)?
} else {
data
};
msg.extend(data, self.config.max_message_size)?;
if is_final {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
/// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
@ -642,7 +714,7 @@ impl WebSocketContext {
let reply = Frame::close(close.clone());
debug!("Replying to close with {:?}", reply);
self.set_additional(reply);
self.send_queue.push_back(reply);
Some(close)
}
@ -659,8 +731,8 @@ impl WebSocketContext {
}
}
/// Write a single frame into the write-buffer.
fn buffer_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
/// Send a single pending frame.
fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
where
Stream: Read + Write,
{
@ -674,17 +746,17 @@ impl WebSocketContext {
}
trace!("Sending frame: {:?}", frame);
self.frame.buffer_frame(stream, frame).check_connection_reset(self.state)
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
/// Replace `additional_send` if it is currently a `Pong` message.
fn set_additional(&mut self, add: Frame) {
let empty_or_pong = self
.additional_send
.as_ref()
.map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong));
if empty_or_pong {
self.additional_send.replace(add);
fn has_compression(&self) -> bool {
#[cfg(feature = "deflate")]
{
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
}
#[cfg(not(feature = "deflate"))]
{
false
}
}
}
@ -718,7 +790,7 @@ impl WebSocketState {
}
/// Check if the state is active, return error if not.
fn check_not_terminated(self) -> Result<()> {
fn check_active(self) -> Result<()> {
match self {
WebSocketState::Terminated => Err(Error::AlreadyClosed),
_ => Ok(()),
@ -770,6 +842,64 @@ mod tests {
}
}
struct WouldBlockStreamMoc;
impl io::Write for WouldBlockStreamMoc {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
impl io::Read for WouldBlockStreamMoc {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
#[test]
fn queue_logic() {
// Create a socket with the queue size of 1.
let mut socket = WebSocket::from_raw_socket(
WouldBlockStreamMoc,
Role::Client,
Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }),
);
// Test message that we're going to send.
let message = Message::Binary(vec![0xFF; 1024]);
// Helper to check the error.
let assert_would_block = |error| {
if let Error::Io(io_error) = error {
assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock);
} else {
panic!("Expected WouldBlock error");
}
};
// The first attempt of writing must not fail, since the queue is empty at start.
// But since the underlying mock object always returns `WouldBlock`, so is the result.
assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err()));
// Any subsequent attempts must return an error telling that the queue is full.
for _i in 0..100 {
assert!(matches!(
socket.write_message(message.clone()).unwrap_err(),
Error::SendQueueFull(..)
));
}
// The size of the output buffer must not be bigger than the size of that message
// that we managed to write to the output buffer at first. Since we could not make
// any progress (because of the logic of the moc buffer), the size remains unchanged.
if socket.context.frame.output_buffer_len() > message.len() {
panic!("Too many frames in the queue");
}
}
#[test]
fn receive_messages() {
let incoming = Cursor::new(vec![
@ -778,10 +908,10 @@ mod tests {
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
}
#[test]
@ -794,7 +924,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!(
socket.read(),
socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
));
}
@ -806,7 +936,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!(
socket.read(),
socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
));
}

@ -104,13 +104,11 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")]
{
let native_certs = rustls_native_certs::load_native_certs()?;
let der_certs: Vec<Vec<u8>> =
native_certs.into_iter().map(|cert| cert.0).collect();
let total_number = der_certs.len();
let (number_added, number_ignored) =
root_store.add_parsable_certificates(&der_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
for cert in rustls_native_certs::load_native_certs()? {
root_store
.add(&rustls::Certificate(cert.0))
.map_err(TlsError::Webpki)?;
}
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{

@ -1,7 +1,6 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closed from the server's point of view and drop the underlying tcp socket.
#![cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
use std::{
net::{TcpListener, TcpStream},
@ -52,27 +51,27 @@ fn test_server_close() {
do_test(
3012,
|mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive close from server
let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
},
|mut srv_sock| {
let message = srv_sock.read().unwrap();
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed
let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
@ -86,26 +85,26 @@ fn test_evil_server_close() {
do_test(
3013,
|mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1));
let message = cli_sock.read().unwrap(); // receive close from server
let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
},
|mut srv_sock| {
let message = srv_sock.read().unwrap();
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
@ -119,32 +118,32 @@ fn test_client_close() {
do_test(
3014,
|mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive answer from server
let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server");
cli_sock.close(None).unwrap(); // send close to server
let message = cli_sock.read().unwrap(); // receive acknowledgement from server
let message = cli_sock.read_message().unwrap(); // receive acknowledgement from server
assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
},
|mut srv_sock| {
let message = srv_sock.read().unwrap();
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.send(Message::Text("From Server".into())).unwrap();
srv_sock.write_message(Message::Text("From Server".into())).unwrap();
let message = srv_sock.read().unwrap(); // receive close from client
let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed
let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{
net::TcpListener,
process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration,
};
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url;
@ -29,10 +28,10 @@ fn test_no_send_after_close() {
let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
let message = client.read().unwrap(); // receive close from server
let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
@ -44,7 +43,7 @@ fn test_no_send_after_close() {
client_handler.close(None).unwrap(); // send close to client
let err = client_handler.send(Message::Text("Hello WebSocket".into()));
let err = client_handler.write_message(Message::Text("Hello WebSocket".into()));
assert!(err.is_err());

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{
net::TcpListener,
process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration,
};
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, Error, Message};
use url::Url;
@ -29,12 +28,12 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server
let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
@ -47,12 +46,12 @@ fn test_receive_after_init_close() {
client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap();
let message = client_handler.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
assert!(client_handler.read_message().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),

@ -1,68 +0,0 @@
use std::io::{self, Read, Write};
use tungstenite::{protocol::WebSocketConfig, Message, WebSocket};
/// `Write` impl that records call stats and drops the data.
#[derive(Debug, Default)]
struct MockWrite {
written_bytes: usize,
write_count: usize,
flush_count: usize,
}
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written_bytes += buf.len();
self.write_count += 1;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_count += 1;
Ok(())
}
}
/// Test for write buffering and flushing behaviour.
#[test]
fn write_flush_behaviour() {
const SEND_ME_LEN: usize = 10;
const BATCH_ME_LEN: usize = 11;
const WRITE_BUFFER_SIZE: usize = 600;
let mut ws = WebSocket::from_raw_socket(
MockWrite::default(),
tungstenite::protocol::Role::Server,
Some(WebSocketConfig { write_buffer_size: WRITE_BUFFER_SIZE, ..<_>::default() }),
);
assert_eq!(ws.get_ref().written_bytes, 0);
assert_eq!(ws.get_ref().write_count, 0);
assert_eq!(ws.get_ref().flush_count, 0);
// `send` writes & flushes immediately
ws.send(Message::Text("Send me!".into())).unwrap();
assert_eq!(ws.get_ref().written_bytes, SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 1);
assert_eq!(ws.get_ref().flush_count, 1);
// send a batch of messages
for msg in (0..100).map(|_| Message::Text("Batch me!".into())) {
ws.write(msg).unwrap();
}
// after 55 writes the out_buffer will exceed write_buffer_size=600
// and so do a single underlying write (not flushing).
assert_eq!(ws.get_ref().written_bytes, 55 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 2);
assert_eq!(ws.get_ref().flush_count, 1);
// flushing will perform a single write for the remaining out_buffer & flush.
ws.flush().unwrap();
assert_eq!(ws.get_ref().written_bytes, 100 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 3);
assert_eq!(ws.get_ref().flush_count, 2);
}
Loading…
Cancel
Save