Compare commits

..

No commits in common. 'master' and 'permessage-deflate' have entirely different histories.

  1. 70
      .github/workflows/ci.yml
  2. 3
      .gitignore
  3. 15
      .travis.yml
  4. 22
      CHANGELOG.md
  5. 58
      Cargo.toml
  6. 8
      README.md
  7. 576
      autobahn/expected-results.json
  8. 3
      benches/buffer.rs
  9. 75
      benches/write.rs
  10. 20
      examples/autobahn-client.rs
  11. 18
      examples/autobahn-server.rs
  12. 4
      examples/client.rs
  13. 4
      examples/server.rs
  14. 8
      examples/srv_accept_unmasked_frames.rs
  15. 2
      fuzz/fuzz_targets/read_message_client.rs
  16. 2
      fuzz/fuzz_targets/read_message_server.rs
  17. 2
      scripts/autobahn-client.sh
  18. 2
      scripts/autobahn-server.sh
  19. 1
      src/client.rs
  20. 29
      src/error.rs
  21. 442
      src/extensions/compression/deflate.rs
  22. 4
      src/extensions/compression/mod.rs
  23. 18
      src/extensions/mod.rs
  24. 135
      src/handshake/client.rs
  25. 26
      src/handshake/server.rs
  26. 5
      src/lib.rs
  27. 12
      src/protocol/frame/frame.rs
  28. 122
      src/protocol/frame/mod.rs
  29. 24
      src/protocol/message.rs
  30. 602
      src/protocol/mod.rs
  31. 12
      src/tls.rs
  32. 41
      tests/connection_reset.rs
  33. 9
      tests/no_send_after_close.rs
  34. 15
      tests/receive_after_init_close.rs
  35. 68
      tests/write.rs

@ -1,70 +0,0 @@
name: CI
on: [push, pull_request]
jobs:
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@nightly
with:
components: rustfmt
- run: cargo fmt --all --check
test:
name: Test
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Install cargo-hack
uses: taiki-e/install-action@cargo-hack
- name: Check
run: cargo hack check --feature-powerset --all-targets
- name: Test
run: cargo test --release
autobahn:
name: Autobahn tests
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
- beta
- nightly
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Running Autobahn TestSuite for client
run: ./scripts/autobahn-client.sh
- name: Running Autobahn TestSuite for server
run: ./scripts/autobahn-server.sh

3
.gitignore vendored

@ -1,3 +1,4 @@
target target
Cargo.lock Cargo.lock
.vscode autobahn/client/
autobahn/server/

@ -0,0 +1,15 @@
language: rust
rust:
- stable
services:
- docker
before_script:
- export PATH="$PATH:$HOME/.cargo/bin"
script:
- cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -1,25 +1,3 @@
# Unreleased (0.20.0)
- Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing.
This allows writing a batch of messages and flushing once, improving performance.
- Add `WebSocket::read`, `write`, `send`, `flush`. Deprecate `read_message`, `write_message`, `write_pending`.
- Add `FrameSocket::read`, `write`, `send`, `flush`. Remove `read_frame`, `write_frame`, `write_pending`.
Note: Previous use of `write_frame` may be replaced with `send`.
- Add `WebSocketContext::read`, `write`, `flush`. Remove `read_message`, `write_message`, `write_pending`.
Note: Previous use of `write_message` may be replaced with `write` + `flush`.
- Remove `send_queue`, replaced with using the frame write buffer to achieve similar results.
* Add `WebSocketConfig::max_write_buffer_size`. Deprecate `max_send_queue`.
* Add `Error::WriteBufferFull`. Remove `Error::SendQueueFull`.
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
# 0.19.0
- Update TLS dependencies.
- Exchanging `base64` for `data-encoding`.
# 0.18.0 # 0.18.0
- Make handshake dependencies optional with a new `handshake` feature (now a default one!). - Make handshake dependencies optional with a new `handshake` feature (now a default one!).

@ -7,9 +7,9 @@ authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.19.0" documentation = "https://docs.rs/tungstenite/0.18.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.19.0" version = "0.18.0"
edition = "2018" edition = "2018"
rust-version = "1.51" rust-version = "1.51"
include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"] include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
@ -24,7 +24,16 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls"] __rustls-tls = ["rustls", "webpki"]
deflate = ["flate2"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies] [dependencies]
data-encoding = { version = "2", optional = true } data-encoding = { version = "2", optional = true }
@ -38,6 +47,11 @@ sha1 = { version = "0.10", optional = true }
thiserror = "1.0.23" thiserror = "1.0.23"
url = { version = "2.1.0", optional = true } url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5" utf-8 = "0.7.5"
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate] [dependencies.native-tls-crate]
optional = true optional = true
@ -46,18 +60,22 @@ version = "0.2.3"
[dependencies.rustls] [dependencies.rustls]
optional = true optional = true
version = "0.21.0" version = "0.20.0"
[dependencies.rustls-native-certs] [dependencies.rustls-native-certs]
optional = true optional = true
version = "0.6.0" version = "0.6.0"
[dependencies.webpki]
optional = true
version = "0.22"
[dependencies.webpki-roots] [dependencies.webpki-roots]
optional = true optional = true
version = "0.23" version = "0.22"
[dev-dependencies] [dev-dependencies]
criterion = "0.5.0" criterion = "0.4.0"
env_logger = "0.10.0" env_logger = "0.10.0"
input_buffer = "0.5.0" input_buffer = "0.5.0"
net2 = "0.2.37" net2 = "0.2.37"
@ -66,31 +84,3 @@ rand = "0.8.4"
[[bench]] [[bench]]
name = "buffer" name = "buffer"
harness = false harness = false
[[bench]]
name = "write"
harness = false
[[example]]
name = "client"
required-features = ["handshake"]
[[example]]
name = "server"
required-features = ["handshake"]
[[example]]
name = "autobahn-client"
required-features = ["handshake"]
[[example]]
name = "autobahn-server"
required-features = ["handshake"]
[[example]]
name = "callback-error"
required-features = ["handshake"]
[[example]]
name = "srv_accept_unmasked_frames"
required-features = ["handshake"]

@ -14,11 +14,11 @@ fn main () {
spawn (move || { spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap(); let mut websocket = accept(stream.unwrap()).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages. // We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap(); websocket.write_message(msg).unwrap();
} }
} }
}); });
@ -36,7 +36,7 @@ take a look at [`tokio-tungstenite`](https://github.com/snapview/tokio-tungsteni
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)
[![Build Status](https://github.com/snapview/tungstenite-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/snapview/tungstenite-rs/actions) [![Build Status](https://travis-ci.org/snapview/tungstenite-rs.svg?branch=master)](https://travis-ci.org/snapview/tungstenite-rs)
[Documentation](https://docs.rs/tungstenite) [Documentation](https://docs.rs/tungstenite)
@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features, By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints. otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing Testing
------- -------

File diff suppressed because it is too large Load Diff

@ -1,4 +1,5 @@
use std::io::{Cursor, Read, Result as IoResult}; use std::io::Result as IoResult;
use std::io::{Cursor, Read};
use bytes::Buf; use bytes::Buf;
use criterion::*; use criterion::*;

@ -1,75 +0,0 @@
//! Benchmarks for write performance.
use criterion::{BatchSize, Criterion};
use std::{
hint,
io::{self, Read, Write},
time::{Duration, Instant},
};
use tungstenite::{Message, WebSocket};
const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// `Write` impl that simulates slowish writes and slow flushes.
///
/// Each `write` can buffer up to 8 MiB before flushing but takes an additional **~80ns**
/// to simulate stuff going on in the underlying stream.
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
}
// simulate io
spin(Duration::from_nanos(80));
self.0.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.0.is_empty() {
// simulate io
spin(Duration::from_micros(8));
self.0.clear();
}
Ok(())
}
}
fn spin(duration: Duration) {
let a = Instant::now();
while a.elapsed() < duration {
hint::spin_loop();
}
}
fn benchmark(c: &mut Criterion) {
// Writes 100k small json text messages then flushes
c.bench_function("write 100k small texts then flush", |b| {
let mut ws = WebSocket::from_raw_socket(
MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)),
tungstenite::protocol::Role::Server,
None,
);
b.iter_batched(
|| (0..100_000).map(|i| Message::Text(format!("{{\"id\":{i}}}"))),
|batch| {
for msg in batch {
ws.write(msg).unwrap();
}
ws.flush().unwrap();
},
BatchSize::SmallInput,
)
});
}
criterion::criterion_group!(write_benches, benchmark);
criterion::criterion_main!(write_benches);

@ -1,13 +1,16 @@
use log::*; use log::*;
use url::Url; use url::Url;
use tungstenite::{connect, Error, Message, Result}; use tungstenite::{
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite"; const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
} }
@ -24,11 +27,18 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect_with_config(
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop { loop {
match socket.read()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -4,7 +4,10 @@ use std::{
}; };
use log::*; use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result}; use tungstenite::{
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
@ -14,12 +17,19 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept_with_config(
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test"); info!("Running test");
loop { loop {
match socket.read()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -14,9 +14,9 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket.send(Message::Text("Hello WebSocket".into())).unwrap(); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop { loop {
let msg = socket.read().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);
} }
// socket.close(None); // socket.close(None);

@ -28,9 +28,9 @@ fn main() {
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap(); websocket.write_message(msg).unwrap();
} }
} }
}); });

@ -27,18 +27,22 @@ fn main() {
}; };
let config = Some(WebSocketConfig { let config = Some(WebSocketConfig {
max_send_queue: None,
max_message_size: None,
max_frame_size: None,
// This setting allows to accept client frames which are not masked // This setting allows to accept client frames which are not masked
// This is not in compliance with RFC 6455 but might be handy in some // This is not in compliance with RFC 6455 but might be handy in some
// rare cases where it is necessary to integrate with existing/legacy // rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames // clients which are sending unmasked frames
accept_unmasked_frames: true, accept_unmasked_frames: true,
..<_>::default() #[cfg(feature = "deflate")]
compression: None,
}); });
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
println!("received message {}", msg); println!("received message {}", msg);
} }

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
socket.read().ok(); socket.read_message().ok();
}); });

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
socket.read().ok(); socket.read_message().ok();
}); });

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3 sleep 3
cargo run --release --example autobahn-client cargo run --release --example autobahn-client --features=deflate
test_diff test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi fi
} }
cargo run --release --example autobahn-server & WSSERVER_PID=$! cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3 sleep 3
docker run --rm \ docker run --rm \

@ -54,7 +54,6 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,

@ -53,9 +53,9 @@ pub enum Error {
/// Protocol violation. /// Protocol violation.
#[error("WebSocket protocol error: {0}")] #[error("WebSocket protocol error: {0}")]
Protocol(#[from] ProtocolError), Protocol(#[from] ProtocolError),
/// Message write buffer is full. /// Message send queue full.
#[error("Write buffer is full")] #[error("Send queue is full")]
WriteBufferFull(Message), SendQueueFull(Message),
/// UTF coding error. /// UTF coding error.
#[error("UTF-8 encoding error")] #[error("UTF-8 encoding error")]
Utf8, Utf8,
@ -70,6 +70,10 @@ pub enum Error {
#[error("HTTP format error: {0}")] #[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error), HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] crate::extensions::DeflateError),
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -206,6 +210,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented. /// Control frames must not be fragmented.
#[error("Fragmented control frame")] #[error("Fragmented control frame")]
FragmentedControlFrame, FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less. /// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")] #[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig, ControlFrameTooBig,
@ -218,6 +225,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue. /// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")] #[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame, UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments. /// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")] #[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data), ExpectedFragment(Data),
@ -230,6 +240,15 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid. /// The payload for the closing frame is invalid.
#[error("Invalid close sequence")] #[error("Invalid close sequence")]
InvalidCloseSequence, InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
InvalidExtensionsHeader,
} }
/// Indicates the specific type/cause of URL error. /// Indicates the specific type/cause of URL error.
@ -271,6 +290,10 @@ pub enum TlsError {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")] #[error("rustls error: {0}")]
Rustls(#[from] rustls::Error), Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error. /// DNS name resolution error.
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name")] #[error("Invalid DNS name")]

@ -0,0 +1,442 @@
use std::convert::TryFrom;
use bytes::BytesMut;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use headers::WebsocketExtension;
use http::HeaderValue;
use thiserror::Error;
use crate::protocol::Role;
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Errors from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress(#[source] std::io::Error),
/// Decompress failed
#[error("Failed to decompress")]
Decompress(#[source] std::io::Error),
/// Extension negotiation failed.
#[error("Extension negotiation failed")]
Negotiation(#[source] NegotiationError),
}
/// Errors from `permessage-deflate` extension negotiation.
#[derive(Debug, Error)]
pub enum NegotiationError {
/// Unknown parameter in a negotiation response.
#[error("Unknown parameter in a negotiation response: {0}")]
UnknownParameter(String),
/// Duplicate parameter in a negotiation response.
#[error("Duplicate parameter in a negotiation response: {0}")]
DuplicateParameter(String),
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
UnexpectedClientMaxWindowBits,
/// Received unsupported `server_max_window_bits` in a negotiation response.
#[error("Received unsupported server_max_window_bits in a negotiation response")]
ServerMaxWindowBitsNotSupported,
/// Invalid `client_max_window_bits` value in a negotiation response.
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
InvalidClientMaxWindowBitsValue(String),
/// Invalid `server_max_window_bits` value in a negotiation response.
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
InvalidServerMaxWindowBitsValue(String),
/// Missing `server_max_window_bits` value in a negotiation response.
#[error("Missing server_max_window_bits value in a negotiation response")]
MissingServerMaxWindowBitsValue,
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug, Default)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
// > a client informs the peer server of a hint that even if the server doesn't include the
// > "client_no_context_takeover" extension parameter in the corresponding
// > extension negotiation response to the offer, the client is not going
// > to use context takeover.
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
if self.client_no_context_takeover {
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub(crate) fn accept_offer(
&self,
offers: &headers::SecWebsocketExtensions,
) -> Option<(WebsocketExtension, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
// 2. The negotiation offer contains an extension parameter with an invalid value.
// 3. The negotiation offer contains multiple extension parameters with the same name.
// 4. The server doesn't support the offered configuration.
offers.iter().find_map(|extension| {
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for (key, val) in params {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
return None;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
return None;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `server_max_window_bits` requires a value in range [8, 15].
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
} else {
return None;
}
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
return None;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `client_max_window_bits` requires a value in range [8, 15] or no value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
}
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
return None;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
return None;
}
}
}
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
} else {
None
}
})
}
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = DeflateConfig {
compression: self.compression,
// If this was hinted in the offer, the client won't use context takeover
// even if the response doesn't include it.
// See `generate_offer`.
client_no_context_takeover: self.client_no_context_takeover,
..DeflateConfig::default()
};
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
// A client MUST _Fail the WebSocket Connection_ if the peer server
// accepted an extension negotiation offer for this extension with an
// extension negotiation response meeting any of the following
// conditions:
// 1. The negotiation response contains an extension parameter not defined for use in a response.
// 2. The negotiation response contains an extension parameter with an invalid value.
// 3. The negotiation response contains multiple extension parameters with the same name.
// 4. The client does not support the configuration that the response represents.
for (key, val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_server_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_server_no_context_takeover = true;
// A server MAY include the "server_no_context_takeover" extension
// parameter in an extension negotiation response even if the extension
// negotiation offer being accepted by the extension negotiation
// response didn't include the "server_no_context_takeover" extension
// parameter.
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_client_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_client_no_context_takeover = true;
// The server may include this parameter in the response and the client MUST support it.
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
));
}
} else {
return Err(DeflateError::Negotiation(
NegotiationError::MissingServerMaxWindowBitsValue,
));
}
// A server may include the "server_max_window_bits" extension parameter
// in an extension negotiation response even if the extension
// negotiation offer being accepted by the response didn't include the
// "server_max_window_bits" extension parameter.
//
// However, but we need to fail the connection because we don't support it (condition 4).
return Err(DeflateError::Negotiation(
NegotiationError::ServerMaxWindowBitsNotSupported,
));
}
CLIENT_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
));
}
}
// Fail the connection because the parameter is invalid when the client didn't offer.
//
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
return Err(DeflateError::Negotiation(
NegotiationError::UnexpectedClientMaxWindowBits,
));
}
// Response with unknown parameter MUST fail the WebSocket connection.
_ => {
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
key.to_owned(),
)));
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
fn is_valid_max_window_bits(bits: &str) -> bool {
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
}
#[cfg(test)]
mod tests {
use super::is_valid_max_window_bits;
#[test]
fn valid_max_window_bits() {
for bits in 8..=15 {
assert!(is_valid_max_window_bits(&bits.to_string()));
}
}
#[test]
fn invalid_max_window_bits() {
assert!(!is_valid_max_window_bits(""));
assert!(!is_valid_max_window_bits("0"));
assert!(!is_valid_max_window_bits("08"));
assert!(!is_valid_max_window_bits("+8"));
assert!(!is_valid_max_window_bits("-8"));
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
for param in params {
buf.extend_from_slice(b"; ");
buf.extend_from_slice(param.as_bytes());
}
let header = HeaderValue::from_maybe_shared(buf.freeze())
.expect("semicolon separated HeaderValue is valid");
WebsocketExtension::try_from(header).expect("valid extension")
}

@ -0,0 +1,4 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -0,0 +1,18 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -5,6 +5,7 @@ use std::{
marker::PhantomData, marker::PhantomData,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -19,6 +20,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result, UrlError}, error::{Error, ProtocolError, Result, UrlError},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -56,7 +58,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC. // Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request). // Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?; let (request, key) = generate_request(request, &config)?;
let machine = HandshakeMachine::start_write(stream, request); let machine = HandshakeMachine::start_write(stream, request);
@ -83,18 +85,24 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) { let (result, extensions) =
Ok(r) => r, match self.verify_data.verify_response(result, &self.config) {
Err(Error::Http(mut e)) => { Ok(r) => r,
*e.body_mut() = Some(tail); Err(Error::Http(mut e)) => {
return Err(Error::Http(e)); *e.body_mut() = Some(tail);
} return Err(Error::Http(e));
Err(e) => return Err(e), }
}; Err(e) => return Err(e),
};
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket = WebSocket::from_partially_read_with_extensions(
WebSocket::from_partially_read(stream, tail, Role::Client, self.config); stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
}) })
@ -102,7 +110,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
} }
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it. /// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> { pub fn generate_request(
mut request: Request,
config: &Option<WebSocketConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new(); let mut req = Vec::new();
write!( write!(
req, req,
@ -173,6 +184,9 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap(); writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
} }
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
}
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req)); trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok((req, key)) Ok((req, key))
@ -186,7 +200,11 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> { pub fn verify_response(
&self,
response: Response,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -231,7 +249,14 @@ impl VerifyData {
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO let extensions = if let Some(agreed) = headers
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
{
verify_extensions(&agreed, _config)?
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field // 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was // and this header field indicates the use of a subprotocol that was
@ -240,10 +265,49 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok(response) Ok((response, extensions))
} }
} }
fn verify_extensions(
agreed_extensions: &headers::SecWebsocketExtensions,
_config: &Option<WebSocketConfig>,
) -> Result<Option<Extensions>> {
#[cfg(feature = "deflate")]
{
if let Some(compression) = _config.and_then(|c| c.compression) {
let mut extensions = None;
for extension in agreed_extensions.iter() {
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
// > the client MUST _Fail the WebSocket Connection_.
if extension.name() != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
extension.name().to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
extension.name().to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(extension.params())?),
});
}
return Ok(extensions);
}
}
if let Some(extension) = agreed_extensions.iter().next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
}
Ok(None)
}
impl TryParse for Response { impl TryParse for Response {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -258,7 +322,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -286,6 +350,8 @@ pub fn generate_key() -> String {
mod tests { mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response}; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest; use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -322,7 +388,7 @@ mod tests {
#[test] #[test]
fn request_formatting() { fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost", &key); let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -330,7 +396,7 @@ mod tests {
#[test] #[test]
fn request_formatting_with_host() { fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -338,11 +404,40 @@ mod tests {
#[test] #[test]
fn request_formatting_with_at() { fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap(); let (request, key) = generate_request(request, &None).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n",
host = "localhost",
key = key
)
.into_bytes();
assert_eq!(&request[..], &correct[..]);
}
#[test] #[test]
fn response_parsing() { fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -354,6 +449,6 @@ mod tests {
#[test] #[test]
fn invalid_custom_request() { fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap(); let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err()); assert!(generate_request(request, &None).is_err());
} }
} }

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult, result::Result as StdResult,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -20,6 +21,7 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>, error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback), callback: Some(callback),
config, config,
error_response: None, error_response: None,
extensions: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -240,7 +245,19 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?; let mut response = create_response(&result)?;
if let Some(config) = &self.config {
if let Some((agreed, extensions)) = result
.headers()
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
.and_then(|values| config.accept_offers(&values))
{
response.headers_mut().typed_insert(agreed);
self.extensions = Some(extensions);
}
}
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result, response)
} else { } else {
@ -283,7 +300,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body))); return Err(Error::Http(http::Response::from_parts(parts, body)));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
self.extensions.take(),
);
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }
} }

@ -19,13 +19,14 @@ pub mod buffer;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod client; pub mod client;
pub mod error; pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod handshake; pub mod handshake;
pub mod protocol; pub mod protocol;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
mod server; mod server;
pub mod stream; pub mod stream;
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))] #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls; mod tls;
pub mod util; pub mod util;
@ -44,5 +45,5 @@ pub use crate::{
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
}; };
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))] #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_tls, client_tls_with_config, Connector}; pub use tls::{client_tls, client_tls_with_config, Connector};

@ -311,6 +311,18 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
} }
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame. /// Create a new Pong control frame.
#[inline] #[inline]
pub fn pong(data: Vec<u8>) -> Frame { pub fn pong(data: Vec<u8>) -> Frame {

@ -6,14 +6,15 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
use crate::{
error::{CapacityError, Error, Result},
Message, ReadBuffer,
};
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
use log::*;
pub use self::frame::{CloseFrame, Frame, FrameHeader}; pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::{
error::{CapacityError, Error, Result},
ReadBuffer,
};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
@ -56,7 +57,7 @@ where
Stream: Read, Stream: Read,
{ {
/// Read a frame from stream. /// Read a frame from stream.
pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> { pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
self.codec.read_frame(&mut self.stream, max_size) self.codec.read_frame(&mut self.stream, max_size)
} }
} }
@ -65,28 +66,18 @@ impl<Stream> FrameSocket<Stream>
where where
Stream: Write, Stream: Write,
{ {
/// Writes and immediately flushes a frame.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, frame: Frame) -> Result<()> {
self.write(frame)?;
self.flush()
}
/// Write a frame to stream. /// Write a frame to stream.
/// ///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes. /// This function guarantees that the frame is queued regardless of any errors.
/// /// There is no need to resend the frame. In order to handle WouldBlock or Incomplete,
/// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`] /// call write_pending() afterwards.
/// is returned. pub fn write_frame(&mut self, frame: Frame) -> Result<()> {
/// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards. self.codec.write_frame(&mut self.stream, frame)
pub fn write(&mut self, frame: Frame) -> Result<()> {
self.codec.buffer_frame(&mut self.stream, frame)
} }
/// Flush writes. /// Complete pending write, if any.
pub fn flush(&mut self) -> Result<()> { pub fn write_pending(&mut self) -> Result<()> {
self.codec.write_out_buffer(&mut self.stream)?; self.codec.write_pending(&mut self.stream)
Ok(self.stream.flush()?)
} }
} }
@ -97,14 +88,6 @@ pub(super) struct FrameCodec {
in_buffer: ReadBuffer, in_buffer: ReadBuffer,
/// Buffer to send packets to the network. /// Buffer to send packets to the network.
out_buffer: Vec<u8>, out_buffer: Vec<u8>,
/// Capacity limit for `out_buffer`.
max_out_buffer_len: usize,
/// Buffer target length to reach before writing to the stream
/// on calls to `buffer_frame`.
///
/// Setting this to non-zero will buffer small writes from hitting
/// the stream.
out_buffer_write_len: usize,
/// Header and remaining size of the incoming packet being processed. /// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>, header: Option<(FrameHeader, u64)>,
} }
@ -112,13 +95,7 @@ pub(super) struct FrameCodec {
impl FrameCodec { impl FrameCodec {
/// Create a new frame codec. /// Create a new frame codec.
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None }
in_buffer: ReadBuffer::new(),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None,
}
} }
/// Create a new frame codec from partially read data. /// Create a new frame codec from partially read data.
@ -126,23 +103,10 @@ impl FrameCodec {
Self { Self {
in_buffer: ReadBuffer::from_partially_read(part), in_buffer: ReadBuffer::from_partially_read(part),
out_buffer: Vec::new(), out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None, header: None,
} }
} }
/// Sets a maximum size for the out buffer.
pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
self.max_out_buffer_len = max;
}
/// Sets [`Self::buffer_frame`] buffer target length to reach before
/// writing to the stream.
pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
self.out_buffer_write_len = len;
}
/// Read a frame from the provided stream. /// Read a frame from the provided stream.
pub(super) fn read_frame<Stream>( pub(super) fn read_frame<Stream>(
&mut self, &mut self,
@ -201,37 +165,19 @@ impl FrameCodec {
Ok(Some(frame)) Ok(Some(frame))
} }
/// Writes a frame into the `out_buffer`. /// Write a frame to the provided stream.
/// If the out buffer size is over the `out_buffer_write_len` will also write pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
/// the out buffer into the provided `stream`.
///
/// To ensure buffered frames are written call [`Self::write_out_buffer`].
///
/// May write to the stream, will **not** flush.
pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
return Err(Error::WriteBufferFull(Message::Frame(frame)));
}
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
self.write_pending(stream)
if self.out_buffer.len() > self.out_buffer_write_len {
self.write_out_buffer(stream)
} else {
Ok(())
}
} }
/// Writes the out_buffer to the provided stream. /// Complete pending write, if any.
/// pub(super) fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
/// Does **not** flush.
pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
@ -247,11 +193,19 @@ impl FrameCodec {
} }
self.out_buffer.drain(0..len); self.out_buffer.drain(0..len);
} }
stream.flush()?;
Ok(()) Ok(())
} }
} }
#[cfg(test)]
impl FrameCodec {
/// Returns the size of the output buffer.
pub(super) fn output_buffer_len(&self) -> usize {
self.out_buffer.len()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -270,11 +224,11 @@ mod tests {
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!( assert_eq!(
sock.read(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read(None).unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]); assert_eq!(rest, vec![0x99]);
@ -285,7 +239,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!( assert_eq!(
sock.read(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
} }
@ -295,10 +249,10 @@ mod tests {
let mut sock = FrameSocket::new(Vec::new()); let mut sock = FrameSocket::new(Vec::new());
let frame = Frame::ping(vec![0x04, 0x05]); let frame = Frame::ping(vec![0x04, 0x05]);
sock.send(frame).unwrap(); sock.write_frame(frame).unwrap();
let frame = Frame::pong(vec![0x01]); let frame = Frame::pong(vec![0x01]);
sock.send(frame).unwrap(); sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner(); let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
@ -310,7 +264,7 @@ mod tests {
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
let _ = sock.read(None); // should not crash let _ = sock.read_frame(None); // should not crash
} }
#[test] #[test]
@ -318,7 +272,7 @@ mod tests {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert!(matches!( assert!(matches!(
sock.read(Some(5)), sock.read_frame(Some(5)),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
)); ));
} }

@ -84,6 +84,8 @@ use self::string_collect::StringCollector;
#[derive(Debug)] #[derive(Debug)]
pub struct IncompleteMessage { pub struct IncompleteMessage {
collector: IncompleteMessageCollector, collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -94,6 +96,7 @@ enum IncompleteMessageCollector {
impl IncompleteMessage { impl IncompleteMessage {
/// Create new. /// Create new.
#[cfg(not(feature = "deflate"))]
pub fn new(message_type: IncompleteMessageType) -> Self { pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage { IncompleteMessage {
collector: match message_type { collector: match message_type {
@ -105,6 +108,25 @@ impl IncompleteMessage {
} }
} }
/// Create new.
#[cfg(feature = "deflate")]
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer. /// Get the current filled size of the buffer.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self.collector { match self.collector {
@ -185,7 +207,7 @@ impl Message {
Message::Text(string.into()) Message::Text(string.into())
} }
/// Create a new binary WebSocket message by converting to `Vec<u8>`. /// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message pub fn binary<B>(bin: B) -> Message
where where
B: Into<Vec<u8>>, B: Into<Vec<u8>>,

@ -6,6 +6,13 @@ mod message;
pub use self::{frame::CloseFrame, message::Message}; pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{ use self::{
frame::{ frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
@ -15,13 +22,9 @@ use self::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
util::NonBlockingResult, util::NonBlockingResult,
}; };
use log::*;
use std::{
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -35,21 +38,10 @@ pub enum Role {
/// The configuration for WebSocket connection. /// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig { pub struct WebSocketConfig {
/// Does nothing, instead use `max_write_buffer_size`. /// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
#[deprecated] /// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>, pub max_send_queue: Option<usize>,
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
/// The default value is 128 KiB.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited.
///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size).
pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent /// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user. /// memory eating by a malicious user.
@ -65,18 +57,76 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC. /// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455. /// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool, pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
fn default() -> Self { fn default() -> Self {
#[allow(deprecated)]
WebSocketConfig { WebSocketConfig {
max_send_queue: None, max_send_queue: None,
write_buffer_size: 128 * 1024,
max_write_buffer_size: usize::MAX,
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
accept_unmasked_frames: false, accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
}
impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
#[cfg(feature = "deflate")]
{
let mut offers = Vec::new();
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
offers.push(compression);
}
if offers.is_empty() {
None
} else {
Some(headers::SecWebsocketExtensions::new(offers))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers(
&self,
_offers: &headers::SecWebsocketExtensions,
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
#[cfg(feature = "deflate")]
{
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
let mut agreed_extensions = Vec::new();
let mut extensions = Extensions::default();
if let Some(compression) = &self.compression {
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
agreed_extensions.push(agreed);
extensions.compression = Some(compression);
}
}
if agreed_extensions.is_empty() {
None
} else {
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
}
}
#[cfg(not(feature = "deflate"))]
{
None
} }
} }
} }
@ -85,8 +135,6 @@ impl Default for WebSocketConfig {
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions. /// It may be created by calling `connect`, `accept` or `client` functions.
///
/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocket<Stream> { pub struct WebSocket<Stream> {
/// The underlying socket. /// The underlying socket.
@ -105,6 +153,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) } WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
} }
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
/// ///
/// Call this function if you're using Tungstenite as a part of a web framework /// Call this function if you're using Tungstenite as a part of a web framework
@ -122,6 +182,21 @@ impl<Stream> WebSocket<Stream> {
} }
} }
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
&self.socket &self.socket
@ -160,116 +235,82 @@ impl<Stream> WebSocket<Stream> {
impl<Stream: Read + Write> WebSocket<Stream> { impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
/// This will also queue responses to ping and close messages. These responses /// This will queue responses to ping and close messages to be sent. It will call
/// will be written and flushed on the next call to [`read`](Self::read), /// `write_pending` before trying to read in order to make sure that those responses
/// [`write`](Self::write) or [`flush`](Self::flush). /// make progress even if you never call `write_pending`. That does mean that they
/// get sent out earliest on the next call to `read_message`, `write_message` or `write_pending`.
/// ///
/// # Closing the connection /// ## Closing the connection
/// When the remote endpoint decides to close the connection this will return /// When the remote endpoint decides to close the connection this will return
/// the close message with an optional close frame. /// the close message with an optional close frame.
/// ///
/// You should continue calling [`read`](Self::read), [`write`](Self::write) or /// You should continue calling `read_message`, `write_message` or `write_pending` to drive
/// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`] /// the reply to the close frame until [Error::ConnectionClosed] is returned. Once that happens
/// is returned. Once that happens it is safe to drop the underlying connection. /// it is safe to drop the underlying connection.
pub fn read(&mut self) -> Result<Message> { pub fn read_message(&mut self) -> Result<Message> {
self.context.read(&mut self.socket) self.context.read_message(&mut self.socket)
}
/// Writes and immediately flushes a message.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, message: Message) -> Result<()> {
self.write(message)?;
self.flush()
} }
/// Write a message to the provided stream, if possible. /// Send a message to stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// ///
/// In the event of stream write failure the message frame will be stored /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// in the write buffer and will try again on the next call to [`write`](Self::write) /// requests. A Pong reply will jump the queue because the
/// or [`flush`](Self::flush). /// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent
/// as soon as is practical.
/// ///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] /// Note that upon receiving a ping message, tungstenite cues a pong reply automatically.
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. /// When you call either `read_message`, `write_message` or `write_pending` next it will try to send
/// /// that pong out if the underlying connection can take more data. This means you should not
/// This call will generally not flush. However, if there are queued automatic messages /// respond to ping frames manually.
/// they will be written and eagerly flushed.
///
/// For example, upon receiving ping messages tungstenite queues pong replies automatically.
/// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
/// will write & flush the pong reply. This means you should not respond to ping frames manually.
/// ///
/// You can however send pong frames manually in order to indicate a unidirectional heartbeat /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
/// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
/// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing /// if `read_message` returns a ping, you should call `write_pending` until it doesn't return
/// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the /// WouldBlock before passing a pong to `write_message`, otherwise the response to the
/// ping will not be sent as it will be replaced by your custom pong message. /// ping will not be sent, but rather replaced by your custom pong message.
/// ///
/// # Errors /// ## Errors
/// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned /// - If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the equivalent passed message frame. /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`]. /// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed].
/// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from /// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`,
/// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program /// [Error::AlreadyClosed] will be returned. This indicates a program error on your part.
/// error on your part. /// - [Error::Io] is returned if the underlying connection returns an error
/// - [`Error::Io`] is returned if the underlying connection returns an error
/// (consider these fatal except for WouldBlock). /// (consider these fatal except for WouldBlock).
/// - [`Error::Capacity`] if your message size is bigger than the configured max message size. /// - [Error::Capacity] if your message size is bigger than the configured max message size.
pub fn write(&mut self, message: Message) -> Result<()> { pub fn write_message(&mut self, message: Message) -> Result<()> {
self.context.write(&mut self.socket, message) self.context.write_message(&mut self.socket, message)
} }
/// Flush writes. /// Flush the pending send queue.
/// pub fn write_pending(&mut self) -> Result<()> {
/// Ensures all messages previously passed to [`write`](Self::write) and automatic self.context.write_pending(&mut self.socket)
/// queued pong responses are written & flushed into the underlying stream.
pub fn flush(&mut self) -> Result<()> {
self.context.flush(&mut self.socket)
} }
/// Close the connection. /// Close the connection.
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `write(Message::Close(..))`. /// the same as calling `write_message(Message::Close(..))`.
/// ///
/// After queuing the close frame you should continue calling [`read`](Self::read) or /// After queuing the close frame you should continue calling `read_message` or
/// [`flush`](Self::flush) to drive the close handshake to completion. /// `write_pending` to drive the close handshake to completion.
/// ///
/// The websocket RFC defines that the underlying connection should be closed /// The websocket RFC defines that the underlying connection should be closed
/// by the server. Tungstenite takes care of this asymmetry for you. /// by the server. Tungstenite takes care of this asymmetry for you.
/// ///
/// When the close handshake is finished (we have both sent and received /// When the close handshake is finished (we have both sent and received
/// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return /// a close message), `read_message` or `write_pending` will return
/// [Error::ConnectionClosed] if this endpoint is the server. /// [Error::ConnectionClosed] if this endpoint is the server.
/// ///
/// If this endpoint is a client, [Error::ConnectionClosed] will only be /// If this endpoint is a client, [Error::ConnectionClosed] will only be
/// returned after the server has closed the underlying connection. /// returned after the server has closed the underlying connection.
/// ///
/// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed] /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
/// is returned from [`read`](Self::read) or [`flush`](Self::flush). /// is returned from `read_message` or `write_pending`.
pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> { pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
self.context.close(&mut self.socket, code) self.context.close(&mut self.socket, code)
} }
/// Old name for [`read`](Self::read).
#[deprecated(note = "Use `read`")]
pub fn read_message(&mut self) -> Result<Message> {
self.read()
}
/// Old name for [`send`](Self::send).
#[deprecated(note = "Use `send`")]
pub fn write_message(&mut self, message: Message) -> Result<()> {
self.send(message)
}
/// Old name for [`flush`](Self::flush).
#[deprecated(note = "Use `flush`")]
pub fn write_pending(&mut self) -> Result<()> {
self.flush()
}
} }
/// A context for managing WebSocket stream. /// A context for managing WebSocket stream.
@ -283,41 +324,55 @@ pub struct WebSocketContext {
state: WebSocketState, state: WebSocketState,
/// Receive: an incomplete message being processed. /// Receive: an incomplete message being processed.
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send in addition to regular messages E.g. "pong" or "close". /// Send: a data send queue.
additional_send: Option<Frame>, send_queue: VecDeque<Frame>,
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
} }
impl WebSocketContext { impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream. /// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self { pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::new(), config.unwrap_or_default()) WebSocketContext {
role,
frame: FrameCodec::new(),
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_default(),
extensions: None,
}
} }
/// Create a WebSocket context that manages an post-handshake stream. /// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default()) WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
}
} }
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self { pub(crate) fn from_partially_read_with_extensions(
frame.set_max_out_buffer_len(config.max_write_buffer_size); part: Vec<u8>,
frame.set_out_buffer_write_len(config.write_buffer_size); role: Role,
Self { config: Option<WebSocketConfig>,
role, extensions: Option<Extensions>,
frame, ) -> Self {
state: WebSocketState::Active, WebSocketContext {
incomplete: None, frame: FrameCodec::from_partially_read(part),
additional_send: None, extensions,
config, ..WebSocketContext::new(role, config)
} }
} }
/// Change the configuration. /// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config); set_func(&mut self.config)
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
} }
/// Read the configuration. /// Read the configuration.
@ -344,23 +399,17 @@ impl WebSocketContext {
/// ///
/// This function sends pong and close responses automatically. /// This function sends pong and close responses automatically.
/// However, it never blocks on write. /// However, it never blocks on write.
pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message> pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// Do not read from already closed connections. // Do not read from already closed connections.
self.state.check_not_terminated()?; self.state.check_active()?;
loop { loop {
if self.additional_send.is_some() { // Since we may get ping or close, we need to reply to the messages even during read.
// Since we may get ping or close, we need to reply to the messages even during read. // Thus we call write_pending() but ignore its blocking.
// Thus we flush but ignore its blocking. self.write_pending(stream).no_block()?;
self.flush(stream).no_block()?;
} else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed);
}
// If we get here, either write blocks or we have nothing to write. // If we get here, either write blocks or we have nothing to write.
// Thus if read blocks, just let it return WouldBlock. // Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? { if let Some(message) = self.read_message_frame(stream)? {
@ -370,94 +419,89 @@ impl WebSocketContext {
} }
} }
/// Write a message to the provided stream. /// Send a message to the provided stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// ///
/// In the event of stream write failure the message frame will be stored /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// in the write buffer and will try again on the next call to [`write`](Self::write) /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// or [`flush`](Self::flush). /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// ///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] /// Note that only the last pong frame is stored to be sent, and only the
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. /// most recent pong frame is sent if multiple pong frames are queued.
pub fn write<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()> pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// When terminated, return AlreadyClosed. // When terminated, return AlreadyClosed.
self.state.check_not_terminated()?; self.state.check_active()?;
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol(ProtocolError::SendAfterClosing)); return Err(Error::Protocol(ProtocolError::SendAfterClosing));
} }
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
// Try to make some room for the new message.
// Do not return here if write would block, ignore WouldBlock silently
// since we must queue the message anyway.
self.write_pending(stream).no_block()?;
}
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message { let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
self.set_additional(Frame::pong(data)); self.pong = Some(Frame::pong(data));
// Note: user pongs can be user flushed so no need to flush here return self.write_pending(stream);
return self._write(stream, None).map(|_| ());
} }
Message::Close(code) => return self.close(stream, code), Message::Close(code) => return self.close(stream, code),
Message::Frame(f) => f, Message::Frame(f) => f,
}; };
let should_flush = self._write(stream, Some(frame))?; self.send_queue.push_back(frame);
if should_flush { self.write_pending(stream)
self.flush(stream)?;
}
Ok(())
} }
/// Flush writes. fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
/// debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
/// Ensures all messages previously passed to [`write`](Self::write) and automatically let opcode = OpCode::Data(opdata);
/// queued pong responses are written & flushed into the `stream`. let is_final = true;
#[inline] #[cfg(feature = "deflate")]
pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<()> if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
where return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
Stream: Read + Write, }
{ Ok(Frame::message(data, opcode, is_final))
self._write(stream, None)?;
self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?)
} }
/// Writes any data in the out_buffer, `additional_send` and given `data`. /// Flush the pending send queue.
/// pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
/// Does **not** flush.
///
/// Returns true if the write contents indicate we should flush immediately.
fn _write<Stream>(&mut self, stream: &mut Stream, data: Option<Frame>) -> Result<bool>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
if let Some(data) = data { // First, make sure we have no pending frame sending.
self.buffer_frame(stream, data)?; self.frame.write_pending(stream)?;
}
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
// response, unless it already received a Close frame. It SHOULD // response, unless it already received a Close frame. It SHOULD
// respond with Pong frame as soon as is practical. (RFC 6455) // respond with Pong frame as soon as is practical. (RFC 6455)
let should_flush = if let Some(msg) = self.additional_send.take() { if let Some(pong) = self.pong.take() {
trace!("Sending pong/close"); trace!("Sending pong reply");
match self.buffer_frame(stream, msg) { self.send_one_frame(stream, pong)?;
Err(Error::WriteBufferFull(Message::Frame(msg))) => { }
// if an system message would exceed the buffer put it back in // If we have any unsent frames, send them.
// `additional_send` for retry. Otherwise returning this error trace!("Frames still in queue: {}", self.send_queue.len());
// may not make sense to the user, e.g. calling `flush`. while let Some(data) = self.send_queue.pop_front() {
self.set_additional(msg); self.send_one_frame(stream, data)?;
false }
}
Err(err) => return Err(err), // If we get to this point, the send queue is empty and the underlying socket is still
Ok(_) => true, // willing to take more data.
}
} else {
false
};
// If we're closing and there is nothing to send anymore, we should close the connection. // If we're closing and there is nothing to send anymore, we should close the connection.
if self.role == Role::Server && !self.state.can_read() { if self.role == Role::Server && !self.state.can_read() {
@ -467,11 +511,10 @@ impl WebSocketContext {
// maximum segment lifetimes (2MSL), while there is no corresponding // maximum segment lifetimes (2MSL), while there is no corresponding
// server impact as a TIME_WAIT connection is immediately reopened upon // server impact as a TIME_WAIT connection is immediately reopened upon
// a new SYN with a higher seq number). (RFC 6455) // a new SYN with a higher seq number). (RFC 6455)
self.frame.write_out_buffer(stream)?;
self.state = WebSocketState::Terminated; self.state = WebSocketState::Terminated;
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} else { } else {
Ok(should_flush) Ok(())
} }
} }
@ -479,7 +522,7 @@ impl WebSocketContext {
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `send(Message::Close(..))`. /// the same as calling `write(Message::Close(..))`.
pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()> pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
@ -487,9 +530,11 @@ impl WebSocketContext {
if let WebSocketState::Active = self.state { if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs; self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(code); let frame = Frame::close(code);
self._write(stream, Some(frame))?; self.send_queue.push_back(frame);
} else {
// Already closed, nothing to do.
} }
self.flush(stream) self.write_pending(stream)
} }
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
@ -510,12 +555,14 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero // the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket // value, the receiving endpoint MUST _Fail the WebSocket
// Connection_. // Connection_.
{ let is_compressed = {
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
}
hdr.rsv1
};
match self.role { match self.role {
Role::Server => { Role::Server => {
@ -550,6 +597,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
} }
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => { OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -558,7 +609,7 @@ impl WebSocketContext {
let data = frame.into_data(); let data = frame.into_data();
// No ping processing after we sent a close frame. // No ping processing after we sent a close frame.
if self.state.is_active() { if self.state.is_active() {
self.set_additional(Frame::pong(data.clone())); self.pong = Some(Frame::pong(data.clone()));
} }
Ok(Some(Message::Ping(data))) Ok(Some(Message::Ping(data)))
} }
@ -570,39 +621,34 @@ impl WebSocketContext {
let fin = frame.header().is_final; let fin = frame.header().is_final;
match data { match data {
OpData::Continue => { OpData::Continue => {
if let Some(ref mut msg) = self.incomplete { if self.incomplete.is_some() && is_compressed {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol( return Err(Error::Protocol(
ProtocolError::UnexpectedContinueFrame, ProtocolError::CompressedContinueFrame,
)); ));
} }
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?)) let msg = self
} else { .incomplete
Ok(None) .take()
} .ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
self.extend_incomplete(msg, frame.into_data(), fin)
} }
c if self.incomplete.is_some() => { c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
} }
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = { let message_type = match data {
let message_type = match data { OpData::Text => IncompleteMessageType::Text,
OpData::Text => IncompleteMessageType::Text, OpData::Binary => IncompleteMessageType::Binary,
OpData::Binary => IncompleteMessageType::Binary, _ => panic!("Bug: message is not text nor binary"),
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
}; };
if fin { #[cfg(feature = "deflate")]
Ok(Some(msg.complete()?)) let msg = IncompleteMessage::new(message_type, is_compressed);
} else { #[cfg(not(feature = "deflate"))]
self.incomplete = Some(msg); let msg = IncompleteMessage::new(message_type);
Ok(None) self.extend_incomplete(msg, frame.into_data(), fin)
}
} }
OpData::Reserved(i) => { OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
@ -621,6 +667,32 @@ impl WebSocketContext {
} }
} }
fn extend_incomplete(
&mut self,
mut msg: IncompleteMessage,
data: Vec<u8>,
is_final: bool,
) -> Result<Option<Message>> {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(data, is_final)?
} else {
data
};
msg.extend(data, self.config.max_message_size)?;
if is_final {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
/// Received a close frame. Tells if we need to return a close frame to the user. /// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)] #[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> { fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
@ -642,7 +714,7 @@ impl WebSocketContext {
let reply = Frame::close(close.clone()); let reply = Frame::close(close.clone());
debug!("Replying to close with {:?}", reply); debug!("Replying to close with {:?}", reply);
self.set_additional(reply); self.send_queue.push_back(reply);
Some(close) Some(close)
} }
@ -659,8 +731,8 @@ impl WebSocketContext {
} }
} }
/// Write a single frame into the write-buffer. /// Send a single pending frame.
fn buffer_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
@ -674,17 +746,17 @@ impl WebSocketContext {
} }
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.buffer_frame(stream, frame).check_connection_reset(self.state) self.frame.write_frame(stream, frame).check_connection_reset(self.state)
} }
/// Replace `additional_send` if it is currently a `Pong` message. fn has_compression(&self) -> bool {
fn set_additional(&mut self, add: Frame) { #[cfg(feature = "deflate")]
let empty_or_pong = self {
.additional_send self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
.as_ref() }
.map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong)); #[cfg(not(feature = "deflate"))]
if empty_or_pong { {
self.additional_send.replace(add); false
} }
} }
} }
@ -718,7 +790,7 @@ impl WebSocketState {
} }
/// Check if the state is active, return error if not. /// Check if the state is active, return error if not.
fn check_not_terminated(self) -> Result<()> { fn check_active(self) -> Result<()> {
match self { match self {
WebSocketState::Terminated => Err(Error::AlreadyClosed), WebSocketState::Terminated => Err(Error::AlreadyClosed),
_ => Ok(()), _ => Ok(()),
@ -770,6 +842,64 @@ mod tests {
} }
} }
struct WouldBlockStreamMoc;
impl io::Write for WouldBlockStreamMoc {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
impl io::Read for WouldBlockStreamMoc {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
#[test]
fn queue_logic() {
// Create a socket with the queue size of 1.
let mut socket = WebSocket::from_raw_socket(
WouldBlockStreamMoc,
Role::Client,
Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }),
);
// Test message that we're going to send.
let message = Message::Binary(vec![0xFF; 1024]);
// Helper to check the error.
let assert_would_block = |error| {
if let Error::Io(io_error) = error {
assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock);
} else {
panic!("Expected WouldBlock error");
}
};
// The first attempt of writing must not fail, since the queue is empty at start.
// But since the underlying mock object always returns `WouldBlock`, so is the result.
assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err()));
// Any subsequent attempts must return an error telling that the queue is full.
for _i in 0..100 {
assert!(matches!(
socket.write_message(message.clone()).unwrap_err(),
Error::SendQueueFull(..)
));
}
// The size of the output buffer must not be bigger than the size of that message
// that we managed to write to the output buffer at first. Since we could not make
// any progress (because of the logic of the moc buffer), the size remains unchanged.
if socket.context.frame.output_buffer_len() > message.len() {
panic!("Too many frames in the queue");
}
}
#[test] #[test]
fn receive_messages() { fn receive_messages() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![
@ -778,10 +908,10 @@ mod tests {
0x03, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
} }
#[test] #[test]
@ -794,7 +924,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read(), socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
)); ));
} }
@ -806,7 +936,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read(), socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
)); ));
} }

@ -104,13 +104,11 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")] #[cfg(feature = "rustls-tls-native-roots")]
{ {
let native_certs = rustls_native_certs::load_native_certs()?; for cert in rustls_native_certs::load_native_certs()? {
let der_certs: Vec<Vec<u8>> = root_store
native_certs.into_iter().map(|cert| cert.0).collect(); .add(&rustls::Certificate(cert.0))
let total_number = der_certs.len(); .map_err(TlsError::Webpki)?;
let (number_added, number_ignored) = }
root_store.add_parsable_certificates(&der_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
} }
#[cfg(feature = "rustls-tls-webpki-roots")] #[cfg(feature = "rustls-tls-webpki-roots")]
{ {

@ -1,7 +1,6 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closed from the server's point of view and drop the underlying tcp socket. //! is closed from the server's point of view and drop the underlying tcp socket.
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
#![cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
use std::{ use std::{
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
@ -52,27 +51,27 @@ fn test_server_close() {
do_test( do_test(
3012, 3012,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -86,26 +85,26 @@ fn test_evil_server_close() {
do_test( do_test(
3013, 3013,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
let message = cli_sock.read().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed` // and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
@ -119,32 +118,32 @@ fn test_client_close() {
do_test( do_test(
3014, 3014,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive answer from server let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server"); assert_eq!(message.into_data(), b"From Server");
cli_sock.close(None).unwrap(); // send close to server cli_sock.close(None).unwrap(); // send close to server
let message = cli_sock.read().unwrap(); // receive acknowledgement from server let message = cli_sock.read_message().unwrap(); // receive acknowledgement from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.send(Message::Text("From Server".into())).unwrap(); srv_sock.write_message(Message::Text("From Server".into())).unwrap();
let message = srv_sock.read().unwrap(); // receive close from client let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, error::ProtocolError, Error, Message}; use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url; use url::Url;
@ -29,10 +28,10 @@ fn test_no_send_after_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
let message = client.read().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -44,7 +43,7 @@ fn test_no_send_after_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
let err = client_handler.send(Message::Text("Hello WebSocket".into())); let err = client_handler.write_message(Message::Text("Hello WebSocket".into()));
assert!(err.is_err()); assert!(err.is_err());

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;
@ -29,12 +28,12 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap(); client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -47,12 +46,12 @@ fn test_receive_after_init_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close // This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap(); let message = client_handler.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement assert!(client_handler.read_message().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -1,68 +0,0 @@
use std::io::{self, Read, Write};
use tungstenite::{protocol::WebSocketConfig, Message, WebSocket};
/// `Write` impl that records call stats and drops the data.
#[derive(Debug, Default)]
struct MockWrite {
written_bytes: usize,
write_count: usize,
flush_count: usize,
}
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written_bytes += buf.len();
self.write_count += 1;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_count += 1;
Ok(())
}
}
/// Test for write buffering and flushing behaviour.
#[test]
fn write_flush_behaviour() {
const SEND_ME_LEN: usize = 10;
const BATCH_ME_LEN: usize = 11;
const WRITE_BUFFER_SIZE: usize = 600;
let mut ws = WebSocket::from_raw_socket(
MockWrite::default(),
tungstenite::protocol::Role::Server,
Some(WebSocketConfig { write_buffer_size: WRITE_BUFFER_SIZE, ..<_>::default() }),
);
assert_eq!(ws.get_ref().written_bytes, 0);
assert_eq!(ws.get_ref().write_count, 0);
assert_eq!(ws.get_ref().flush_count, 0);
// `send` writes & flushes immediately
ws.send(Message::Text("Send me!".into())).unwrap();
assert_eq!(ws.get_ref().written_bytes, SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 1);
assert_eq!(ws.get_ref().flush_count, 1);
// send a batch of messages
for msg in (0..100).map(|_| Message::Text("Batch me!".into())) {
ws.write(msg).unwrap();
}
// after 55 writes the out_buffer will exceed write_buffer_size=600
// and so do a single underlying write (not flushing).
assert_eq!(ws.get_ref().written_bytes, 55 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 2);
assert_eq!(ws.get_ref().flush_count, 1);
// flushing will perform a single write for the remaining out_buffer & flush.
ws.flush().unwrap();
assert_eq!(ws.get_ref().written_bytes, 100 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 3);
assert_eq!(ws.get_ref().flush_count, 2);
}
Loading…
Cancel
Save