Compare commits

..

40 Commits

Author SHA1 Message Date
Daniel Abramov 57d9e23939
Merge pull request #360 from doylemark/master 2 years ago
Mark Doyle 9533c02280 return correct protocol error when missing http version 2 years ago
Daniel Abramov d7b559d724
Merge pull request #359 from mtbc/allow-IPv6-SocketAddr 2 years ago
Mark T. B. Carroll 8901dcc535
remove [] enclosing IPv6 host address 2 years ago
Alexey Galakhov 8f23e1765e
Merge pull request #358 from alexheretic/buffer-writes 2 years ago
Alex Butler f6a610f925 Add write_flush_behaviour test 2 years ago
Alex Butler dea67d6cca Fix doc typo 2 years ago
Alex Butler 41818166cf refactor WebSocketContext new 2 years ago
Alex Butler 0cada00fb5 Refactor write_one_frame -> buffer_frame 2 years ago
Alex Butler f33bb2cb97 Ensure out_buffer written when !can_read 2 years ago
Alex Butler 1b47964f18 split write and write_out_buffer internals 2 years ago
Alex Butler 2ef5b9a5e2 Buffer writes before writing to the underlying stream 2 years ago
Alex Butler 2cf7cfef04 Rework write 100k bench to have a slow writes & even slower flushes 2 years ago
Alexey Galakhov 5a3115c09b
Merge pull request #357 from alexheretic/flush-writes-less 2 years ago
Alex Butler 06e55a4ef2 Refactor additional_send writing 2 years ago
Alex Butler 84a54b76e6 Rename methods to `read`, `send`, `write` & `flush` 2 years ago
Daniel Abramov 79b39eb146
Merge pull request #356 from snapview/dependabot/cargo/criterion-0.5.0 2 years ago
Alex Butler 0203a1849b Remove send_queue, use out_buffer instead 2 years ago
Alex Butler 483d229707 Remove implicit write flushing 2 years ago
Alex Butler d298089bf3 Add write 100k micro-bench 2 years ago
dependabot[bot] 7242a22b91
Update criterion requirement from 0.4.0 to 0.5.0 2 years ago
Daniel Abramov 371f823044
Merge pull request #354 from CBenoit/fix-error-on-bad-root-cert 2 years ago
Benoît CORTIER ee3ffc9e9d
Gracefully handle invalid native root certificates 2 years ago
Daniel Abramov e5efe537b8
Merge pull request #351 from nickelc/deps/webpki 2 years ago
Constantin Nickel 8a436e7550 Remove unused `TlsError::Webpki` error variant 2 years ago
Daniel Abramov 314feea305
Merge pull request #348 from atouchet/trv 2 years ago
Alex Touchet 50d5a37bdc
Switch build status badge to GitHub Actions 2 years ago
Daniel Abramov 79fa37888f
Merge pull request #347 from snapview/github-actions 2 years ago
Daniel Abramov 746d938412 Use `cargo fmt` from nightly 2 years ago
Daniel Abramov a4863d3f10 Make code compile with any feature set 2 years ago
Daniel Abramov 7e4a15446d Properly activate features for examples and tests 2 years ago
Daniel Abramov 87e9f576af Make `cargo fmt` happy 2 years ago
Daniel Abramov e758f7dc2a Exchange Travis CI for GitHub Actions 2 years ago
Daniel Abramov 869a67ca0b Bump version 2 years ago
Daniel Abramov a873befaae
Merge pull request #345 from mlemesle/fix/webpki-error-variant 2 years ago
Martin Lemesle 1f6c62d301 Fix not compiling features rustls-tls-native-roots and rustls-tls-webpki-roots 2 years ago
dependabot[bot] 92d65e1104
Update webpki-roots requirement from 0.22 to 0.23 (#343) 2 years ago
Daniel Abramov 67e25fdd68
Merge pull request #341 from snapview/dependabot/cargo/rustls-0.21.0 2 years ago
dependabot[bot] 1422d47ec0
Update rustls requirement from 0.20.0 to 0.21.0 2 years ago
Daniel Abramov 42b8797e8b Revert "Add `permessage-deflate` support" 2 years ago
  1. 70
      .github/workflows/ci.yml
  2. 3
      .gitignore
  3. 15
      .travis.yml
  4. 22
      CHANGELOG.md
  5. 58
      Cargo.toml
  6. 8
      README.md
  7. 576
      autobahn/expected-results.json
  8. 3
      benches/buffer.rs
  9. 75
      benches/write.rs
  10. 20
      examples/autobahn-client.rs
  11. 18
      examples/autobahn-server.rs
  12. 4
      examples/client.rs
  13. 4
      examples/server.rs
  14. 8
      examples/srv_accept_unmasked_frames.rs
  15. 2
      fuzz/fuzz_targets/read_message_client.rs
  16. 2
      fuzz/fuzz_targets/read_message_server.rs
  17. 2
      scripts/autobahn-client.sh
  18. 2
      scripts/autobahn-server.sh
  19. 1
      src/client.rs
  20. 29
      src/error.rs
  21. 442
      src/extensions/compression/deflate.rs
  22. 4
      src/extensions/compression/mod.rs
  23. 18
      src/extensions/mod.rs
  24. 121
      src/handshake/client.rs
  25. 26
      src/handshake/server.rs
  26. 5
      src/lib.rs
  27. 12
      src/protocol/frame/frame.rs
  28. 122
      src/protocol/frame/mod.rs
  29. 24
      src/protocol/message.rs
  30. 588
      src/protocol/mod.rs
  31. 12
      src/tls.rs
  32. 41
      tests/connection_reset.rs
  33. 9
      tests/no_send_after_close.rs
  34. 15
      tests/receive_after_init_close.rs
  35. 68
      tests/write.rs

@ -0,0 +1,70 @@
name: CI
on: [push, pull_request]
jobs:
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@nightly
with:
components: rustfmt
- run: cargo fmt --all --check
test:
name: Test
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Install cargo-hack
uses: taiki-e/install-action@cargo-hack
- name: Check
run: cargo hack check --feature-powerset --all-targets
- name: Test
run: cargo test --release
autobahn:
name: Autobahn tests
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
- beta
- nightly
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Running Autobahn TestSuite for client
run: ./scripts/autobahn-client.sh
- name: Running Autobahn TestSuite for server
run: ./scripts/autobahn-server.sh

3
.gitignore vendored

@ -1,4 +1,3 @@
target target
Cargo.lock Cargo.lock
autobahn/client/ .vscode
autobahn/server/

@ -1,15 +0,0 @@
language: rust
rust:
- stable
services:
- docker
before_script:
- export PATH="$PATH:$HOME/.cargo/bin"
script:
- cargo test --release
- cargo test --release --features=deflate
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -1,3 +1,25 @@
# Unreleased (0.20.0)
- Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing.
This allows writing a batch of messages and flushing once, improving performance.
- Add `WebSocket::read`, `write`, `send`, `flush`. Deprecate `read_message`, `write_message`, `write_pending`.
- Add `FrameSocket::read`, `write`, `send`, `flush`. Remove `read_frame`, `write_frame`, `write_pending`.
Note: Previous use of `write_frame` may be replaced with `send`.
- Add `WebSocketContext::read`, `write`, `flush`. Remove `read_message`, `write_message`, `write_pending`.
Note: Previous use of `write_message` may be replaced with `write` + `flush`.
- Remove `send_queue`, replaced with using the frame write buffer to achieve similar results.
* Add `WebSocketConfig::max_write_buffer_size`. Deprecate `max_send_queue`.
* Add `Error::WriteBufferFull`. Remove `Error::SendQueueFull`.
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
# 0.19.0
- Update TLS dependencies.
- Exchanging `base64` for `data-encoding`.
# 0.18.0 # 0.18.0
- Make handshake dependencies optional with a new `handshake` feature (now a default one!). - Make handshake dependencies optional with a new `handshake` feature (now a default one!).

@ -7,9 +7,9 @@ authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.18.0" documentation = "https://docs.rs/tungstenite/0.19.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.18.0" version = "0.19.0"
edition = "2018" edition = "2018"
rust-version = "1.51" rust-version = "1.51"
include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"] include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
@ -24,16 +24,7 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"] __rustls-tls = ["rustls"]
deflate = ["flate2"]
[[example]]
name = "autobahn-client"
required-features = ["deflate"]
[[example]]
name = "autobahn-server"
required-features = ["deflate"]
[dependencies] [dependencies]
data-encoding = { version = "2", optional = true } data-encoding = { version = "2", optional = true }
@ -47,11 +38,6 @@ sha1 = { version = "0.10", optional = true }
thiserror = "1.0.23" thiserror = "1.0.23"
url = { version = "2.1.0", optional = true } url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5" utf-8 = "0.7.5"
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
[dependencies.flate2]
optional = true
version = "1.0"
[dependencies.native-tls-crate] [dependencies.native-tls-crate]
optional = true optional = true
@ -60,22 +46,18 @@ version = "0.2.3"
[dependencies.rustls] [dependencies.rustls]
optional = true optional = true
version = "0.20.0" version = "0.21.0"
[dependencies.rustls-native-certs] [dependencies.rustls-native-certs]
optional = true optional = true
version = "0.6.0" version = "0.6.0"
[dependencies.webpki]
optional = true
version = "0.22"
[dependencies.webpki-roots] [dependencies.webpki-roots]
optional = true optional = true
version = "0.22" version = "0.23"
[dev-dependencies] [dev-dependencies]
criterion = "0.4.0" criterion = "0.5.0"
env_logger = "0.10.0" env_logger = "0.10.0"
input_buffer = "0.5.0" input_buffer = "0.5.0"
net2 = "0.2.37" net2 = "0.2.37"
@ -84,3 +66,31 @@ rand = "0.8.4"
[[bench]] [[bench]]
name = "buffer" name = "buffer"
harness = false harness = false
[[bench]]
name = "write"
harness = false
[[example]]
name = "client"
required-features = ["handshake"]
[[example]]
name = "server"
required-features = ["handshake"]
[[example]]
name = "autobahn-client"
required-features = ["handshake"]
[[example]]
name = "autobahn-server"
required-features = ["handshake"]
[[example]]
name = "callback-error"
required-features = ["handshake"]
[[example]]
name = "srv_accept_unmasked_frames"
required-features = ["handshake"]

@ -14,11 +14,11 @@ fn main () {
spawn (move || { spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap(); let mut websocket = accept(stream.unwrap()).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read().unwrap();
// We do not want to send back ping/pong messages. // We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap(); websocket.send(msg).unwrap();
} }
} }
}); });
@ -36,7 +36,7 @@ take a look at [`tokio-tungstenite`](https://github.com/snapview/tokio-tungsteni
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)
[![Build Status](https://travis-ci.org/snapview/tungstenite-rs.svg?branch=master)](https://travis-ci.org/snapview/tungstenite-rs) [![Build Status](https://github.com/snapview/tungstenite-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/snapview/tungstenite-rs/actions)
[Documentation](https://docs.rs/tungstenite) [Documentation](https://docs.rs/tungstenite)
@ -72,6 +72,8 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features, By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints. otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing Testing
------- -------

File diff suppressed because it is too large Load Diff

@ -1,5 +1,4 @@
use std::io::Result as IoResult; use std::io::{Cursor, Read, Result as IoResult};
use std::io::{Cursor, Read};
use bytes::Buf; use bytes::Buf;
use criterion::*; use criterion::*;

@ -0,0 +1,75 @@
//! Benchmarks for write performance.
use criterion::{BatchSize, Criterion};
use std::{
hint,
io::{self, Read, Write},
time::{Duration, Instant},
};
use tungstenite::{Message, WebSocket};
const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// `Write` impl that simulates slowish writes and slow flushes.
///
/// Each `write` can buffer up to 8 MiB before flushing but takes an additional **~80ns**
/// to simulate stuff going on in the underlying stream.
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
}
// simulate io
spin(Duration::from_nanos(80));
self.0.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.0.is_empty() {
// simulate io
spin(Duration::from_micros(8));
self.0.clear();
}
Ok(())
}
}
fn spin(duration: Duration) {
let a = Instant::now();
while a.elapsed() < duration {
hint::spin_loop();
}
}
fn benchmark(c: &mut Criterion) {
// Writes 100k small json text messages then flushes
c.bench_function("write 100k small texts then flush", |b| {
let mut ws = WebSocket::from_raw_socket(
MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)),
tungstenite::protocol::Role::Server,
None,
);
b.iter_batched(
|| (0..100_000).map(|i| Message::Text(format!("{{\"id\":{i}}}"))),
|batch| {
for msg in batch {
ws.write(msg).unwrap();
}
ws.flush().unwrap();
},
BatchSize::SmallInput,
)
});
}
criterion::criterion_group!(write_benches, benchmark);
criterion::criterion_main!(write_benches);

@ -1,16 +1,13 @@
use log::*; use log::*;
use url::Url; use url::Url;
use tungstenite::{ use tungstenite::{connect, Error, Message, Result};
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
Error, Message, Result,
};
const AGENT: &str = "Tungstenite"; const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read_message()?; let msg = socket.read()?;
socket.close(None)?; socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
} }
@ -27,18 +24,11 @@ fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect_with_config( let (mut socket, _) = connect(case_url)?;
case_url,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
3,
)?;
loop { loop {
match socket.read_message()? { match socket.read()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?; socket.send(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -4,10 +4,7 @@ use std::{
}; };
use log::*; use log::*;
use tungstenite::{ use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
@ -17,19 +14,12 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept_with_config( let mut socket = accept(stream).map_err(must_not_block)?;
stream,
Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.map_err(must_not_block)?;
info!("Running test"); info!("Running test");
loop { loop {
match socket.read_message()? { match socket.read()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?; socket.send(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -14,9 +14,9 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); socket.send(Message::Text("Hello WebSocket".into())).unwrap();
loop { loop {
let msg = socket.read_message().expect("Error reading message"); let msg = socket.read().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);
} }
// socket.close(None); // socket.close(None);

@ -28,9 +28,9 @@ fn main() {
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap(); websocket.send(msg).unwrap();
} }
} }
}); });

@ -27,22 +27,18 @@ fn main() {
}; };
let config = Some(WebSocketConfig { let config = Some(WebSocketConfig {
max_send_queue: None,
max_message_size: None,
max_frame_size: None,
// This setting allows to accept client frames which are not masked // This setting allows to accept client frames which are not masked
// This is not in compliance with RFC 6455 but might be handy in some // This is not in compliance with RFC 6455 but might be handy in some
// rare cases where it is necessary to integrate with existing/legacy // rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames // clients which are sending unmasked frames
accept_unmasked_frames: true, accept_unmasked_frames: true,
#[cfg(feature = "deflate")] ..<_>::default()
compression: None,
}); });
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
println!("received message {}", msg); println!("received message {}", msg);
} }

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
socket.read_message().ok(); socket.read().ok();
}); });

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
socket.read_message().ok(); socket.read().ok();
}); });

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json' wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3 sleep 3
cargo run --release --example autobahn-client --features=deflate cargo run --release --example autobahn-client
test_diff test_diff

@ -22,7 +22,7 @@ function test_diff() {
fi fi
} }
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$! cargo run --release --example autobahn-server & WSSERVER_PID=$!
sleep 3 sleep 3
docker run --rm \ docker run --rm \

@ -54,6 +54,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,

@ -53,9 +53,9 @@ pub enum Error {
/// Protocol violation. /// Protocol violation.
#[error("WebSocket protocol error: {0}")] #[error("WebSocket protocol error: {0}")]
Protocol(#[from] ProtocolError), Protocol(#[from] ProtocolError),
/// Message send queue full. /// Message write buffer is full.
#[error("Send queue is full")] #[error("Write buffer is full")]
SendQueueFull(Message), WriteBufferFull(Message),
/// UTF coding error. /// UTF coding error.
#[error("UTF-8 encoding error")] #[error("UTF-8 encoding error")]
Utf8, Utf8,
@ -70,10 +70,6 @@ pub enum Error {
#[error("HTTP format error: {0}")] #[error("HTTP format error: {0}")]
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
HttpFormat(#[from] http::Error), HttpFormat(#[from] http::Error),
/// Error from `permessage-deflate` extension.
#[cfg(feature = "deflate")]
#[error("Deflate error: {0}")]
Deflate(#[from] crate::extensions::DeflateError),
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -210,9 +206,6 @@ pub enum ProtocolError {
/// Control frames must not be fragmented. /// Control frames must not be fragmented.
#[error("Fragmented control frame")] #[error("Fragmented control frame")]
FragmentedControlFrame, FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less. /// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")] #[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig, ControlFrameTooBig,
@ -225,9 +218,6 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue. /// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")] #[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame, UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments. /// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")] #[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data), ExpectedFragment(Data),
@ -240,15 +230,6 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid. /// The payload for the closing frame is invalid.
#[error("Invalid close sequence")] #[error("Invalid close sequence")]
InvalidCloseSequence, InvalidCloseSequence,
/// The negotiation response included an extension not offered.
#[error("Extension negotiation response had invalid extension: {0}")]
InvalidExtension(String),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(String),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
InvalidExtensionsHeader,
} }
/// Indicates the specific type/cause of URL error. /// Indicates the specific type/cause of URL error.
@ -290,10 +271,6 @@ pub enum TlsError {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")] #[error("rustls error: {0}")]
Rustls(#[from] rustls::Error), Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error. /// DNS name resolution error.
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name")] #[error("Invalid DNS name")]

@ -1,442 +0,0 @@
use std::convert::TryFrom;
use bytes::BytesMut;
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
use headers::WebsocketExtension;
use http::HeaderValue;
use thiserror::Error;
use crate::protocol::Role;
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
/// Errors from `permessage-deflate` extension.
#[derive(Debug, Error)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress(#[source] std::io::Error),
/// Decompress failed
#[error("Failed to decompress")]
Decompress(#[source] std::io::Error),
/// Extension negotiation failed.
#[error("Extension negotiation failed")]
Negotiation(#[source] NegotiationError),
}
/// Errors from `permessage-deflate` extension negotiation.
#[derive(Debug, Error)]
pub enum NegotiationError {
/// Unknown parameter in a negotiation response.
#[error("Unknown parameter in a negotiation response: {0}")]
UnknownParameter(String),
/// Duplicate parameter in a negotiation response.
#[error("Duplicate parameter in a negotiation response: {0}")]
DuplicateParameter(String),
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
UnexpectedClientMaxWindowBits,
/// Received unsupported `server_max_window_bits` in a negotiation response.
#[error("Received unsupported server_max_window_bits in a negotiation response")]
ServerMaxWindowBitsNotSupported,
/// Invalid `client_max_window_bits` value in a negotiation response.
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
InvalidClientMaxWindowBitsValue(String),
/// Invalid `server_max_window_bits` value in a negotiation response.
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
InvalidServerMaxWindowBitsValue(String),
/// Missing `server_max_window_bits` value in a negotiation response.
#[error("Missing server_max_window_bits value in a negotiation response")]
MissingServerMaxWindowBitsValue,
}
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
// because custom window size requires `flate2/zlib` feature.
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
#[derive(Clone, Copy, Debug, Default)]
pub struct DeflateConfig {
/// Compression level.
pub compression: Compression,
/// Request the peer server not to use context takeover.
pub server_no_context_takeover: bool,
/// Hint that context takeover is not used.
pub client_no_context_takeover: bool,
}
impl DeflateConfig {
pub(crate) fn name(&self) -> &str {
PER_MESSAGE_DEFLATE
}
/// Value for `Sec-WebSocket-Extensions` request header.
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
let mut offers = Vec::new();
if self.server_no_context_takeover {
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
// > a client informs the peer server of a hint that even if the server doesn't include the
// > "client_no_context_takeover" extension parameter in the corresponding
// > extension negotiation response to the offer, the client is not going
// > to use context takeover.
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
if self.client_no_context_takeover {
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
to_header_value(&offers)
}
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
pub(crate) fn accept_offer(
&self,
offers: &headers::SecWebsocketExtensions,
) -> Option<(WebsocketExtension, DeflateContext)> {
// Accept the first valid offer for `permessage-deflate`.
// A server MUST decline an extension negotiation offer for this
// extension if any of the following conditions are met:
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
// 2. The negotiation offer contains an extension parameter with an invalid value.
// 3. The negotiation offer contains multiple extension parameters with the same name.
// 4. The server doesn't support the offered configuration.
offers.iter().find_map(|extension| {
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
let mut config =
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
let mut agreed = Vec::new();
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
let mut seen_client_max_window_bits = false;
for (key, val) in params {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_server_no_context_takeover {
return None;
}
seen_server_no_context_takeover = true;
config.server_no_context_takeover = true;
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Invalid offer with multiple params with same name is declined.
if seen_client_no_context_takeover {
return None;
}
seen_client_no_context_takeover = true;
config.client_no_context_takeover = true;
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
}
// Max window bits are not supported at the moment.
SERVER_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `server_max_window_bits` requires a value in range [8, 15].
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
} else {
return None;
}
// A server declines an extension negotiation offer with this parameter
// if the server doesn't support it.
return None;
}
// Not supported, but server may ignore and accept the offer.
CLIENT_MAX_WINDOW_BITS => {
// Decline offer with invalid parameter value.
// `client_max_window_bits` requires a value in range [8, 15] or no value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return None;
}
}
// Invalid offer with multiple params with same name is declined.
if seen_client_max_window_bits {
return None;
}
seen_client_max_window_bits = true;
}
// Offer with unknown parameter MUST be declined.
_ => {
return None;
}
}
}
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
} else {
None
}
})
}
pub(crate) fn accept_response<'a>(
&'a self,
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
) -> Result<DeflateContext, DeflateError> {
let mut config = DeflateConfig {
compression: self.compression,
// If this was hinted in the offer, the client won't use context takeover
// even if the response doesn't include it.
// See `generate_offer`.
client_no_context_takeover: self.client_no_context_takeover,
..DeflateConfig::default()
};
let mut seen_server_no_context_takeover = false;
let mut seen_client_no_context_takeover = false;
// A client MUST _Fail the WebSocket Connection_ if the peer server
// accepted an extension negotiation offer for this extension with an
// extension negotiation response meeting any of the following
// conditions:
// 1. The negotiation response contains an extension parameter not defined for use in a response.
// 2. The negotiation response contains an extension parameter with an invalid value.
// 3. The negotiation response contains multiple extension parameters with the same name.
// 4. The client does not support the configuration that the response represents.
for (key, val) in agreed {
match key {
SERVER_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_server_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_server_no_context_takeover = true;
// A server MAY include the "server_no_context_takeover" extension
// parameter in an extension negotiation response even if the extension
// negotiation offer being accepted by the extension negotiation
// response didn't include the "server_no_context_takeover" extension
// parameter.
config.server_no_context_takeover = true;
}
CLIENT_NO_CONTEXT_TAKEOVER => {
// Fail the connection when the response contains multiple parameters with the same name.
if seen_client_no_context_takeover {
return Err(DeflateError::Negotiation(
NegotiationError::DuplicateParameter(key.to_owned()),
));
}
seen_client_no_context_takeover = true;
// The server may include this parameter in the response and the client MUST support it.
config.client_no_context_takeover = true;
}
SERVER_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
));
}
} else {
return Err(DeflateError::Negotiation(
NegotiationError::MissingServerMaxWindowBitsValue,
));
}
// A server may include the "server_max_window_bits" extension parameter
// in an extension negotiation response even if the extension
// negotiation offer being accepted by the response didn't include the
// "server_max_window_bits" extension parameter.
//
// However, but we need to fail the connection because we don't support it (condition 4).
return Err(DeflateError::Negotiation(
NegotiationError::ServerMaxWindowBitsNotSupported,
));
}
CLIENT_MAX_WINDOW_BITS => {
// Fail the connection when the response contains a parameter with invalid value.
if let Some(bits) = val {
if !is_valid_max_window_bits(bits) {
return Err(DeflateError::Negotiation(
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
));
}
}
// Fail the connection because the parameter is invalid when the client didn't offer.
//
// If a received extension negotiation offer doesn't have the
// "client_max_window_bits" extension parameter, the corresponding
// extension negotiation response to the offer MUST NOT include the
// "client_max_window_bits" extension parameter.
return Err(DeflateError::Negotiation(
NegotiationError::UnexpectedClientMaxWindowBits,
));
}
// Response with unknown parameter MUST fail the WebSocket connection.
_ => {
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
key.to_owned(),
)));
}
}
}
Ok(DeflateContext::new(Role::Client, config))
}
}
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
fn is_valid_max_window_bits(bits: &str) -> bool {
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
}
#[cfg(test)]
mod tests {
use super::is_valid_max_window_bits;
#[test]
fn valid_max_window_bits() {
for bits in 8..=15 {
assert!(is_valid_max_window_bits(&bits.to_string()));
}
}
#[test]
fn invalid_max_window_bits() {
assert!(!is_valid_max_window_bits(""));
assert!(!is_valid_max_window_bits("0"));
assert!(!is_valid_max_window_bits("08"));
assert!(!is_valid_max_window_bits("+8"));
assert!(!is_valid_max_window_bits("-8"));
}
}
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
role: Role,
config: DeflateConfig,
compressor: Compress,
decompressor: Decompress,
}
impl DeflateContext {
fn new(role: Role, config: DeflateConfig) -> Self {
DeflateContext {
role,
config,
compressor: Compress::new(config.compression, false),
decompressor: Decompress::new(false),
}
}
fn own_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.server_no_context_takeover,
Role::Client => !self.config.client_no_context_takeover,
}
}
fn peer_context_takeover(&self) -> bool {
match self.role {
Role::Server => !self.config.client_no_context_takeover,
Role::Client => !self.config.server_no_context_takeover,
}
}
// Compress the data of message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
// 1. Compress all the octets of the payload of the message using DEFLATE.
let mut output = Vec::with_capacity(data.len());
let before_in = self.compressor.total_in() as usize;
while (self.compressor.total_in() as usize) - before_in < data.len() {
let offset = (self.compressor.total_in() as usize) - before_in;
match self
.compressor
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok => continue,
Status::BufError => output.reserve(4096),
Status::StreamEnd => break,
}
}
// 2. If the resulting data does not end with an empty DEFLATE block
// with no compression (the "BTYPE" bits are set to 00), append an
// empty DEFLATE block with no compression to the tail end.
while !output.ends_with(&TRAILER) {
output.reserve(5);
match self
.compressor
.compress_vec(&[], &mut output, FlushCompress::Sync)
.map_err(|e| DeflateError::Compress(e.into()))?
{
Status::Ok | Status::BufError => continue,
Status::StreamEnd => break,
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
// After this step, the last octet of the compressed data contains
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
// set to 00.
output.truncate(output.len() - 4);
if !self.own_context_takeover() {
self.compressor.reset();
}
Ok(output)
}
pub(crate) fn decompress(
&mut self,
mut data: Vec<u8>,
is_final: bool,
) -> Result<Vec<u8>, DeflateError> {
if is_final {
data.extend_from_slice(&TRAILER);
}
let before_in = self.decompressor.total_in() as usize;
let mut output = Vec::with_capacity(2 * data.len());
loop {
let offset = (self.decompressor.total_in() as usize) - before_in;
match self
.decompressor
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
.map_err(|e| DeflateError::Decompress(e.into()))?
{
Status::Ok => output.reserve(2 * output.len()),
Status::BufError | Status::StreamEnd => break,
}
}
if is_final && !self.peer_context_takeover() {
self.decompressor.reset(false);
}
Ok(output)
}
}
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
for param in params {
buf.extend_from_slice(b"; ");
buf.extend_from_slice(param.as_bytes());
}
let header = HeaderValue::from_maybe_shared(buf.freeze())
.expect("semicolon separated HeaderValue is valid");
WebsocketExtension::try_from(header).expect("valid extension")
}

@ -1,4 +0,0 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
pub mod deflate;

@ -1,18 +0,0 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#[cfg(feature = "deflate")]
mod compression;
#[cfg(feature = "deflate")]
use compression::deflate::DeflateContext;
#[cfg(feature = "deflate")]
pub use compression::deflate::{DeflateConfig, DeflateError};
/// Container for configured extensions.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
// Per-Message Compression. Only `permessage-deflate` is supported.
#[cfg(feature = "deflate")]
pub(crate) compression: Option<DeflateContext>,
}

@ -5,7 +5,6 @@ use std::{
marker::PhantomData, marker::PhantomData,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -20,7 +19,6 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result, UrlError}, error::{Error, ProtocolError, Result, UrlError},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -58,7 +56,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC. // Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request). // Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request, &config)?; let (request, key) = generate_request(request)?;
let machine = HandshakeMachine::start_write(stream, request); let machine = HandshakeMachine::start_write(stream, request);
@ -85,8 +83,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
let (result, extensions) = let result = match self.verify_data.verify_response(result) {
match self.verify_data.verify_response(result, &self.config) {
Ok(r) => r, Ok(r) => r,
Err(Error::Http(mut e)) => { Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail); *e.body_mut() = Some(tail);
@ -96,13 +93,8 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}; };
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read_with_extensions( let websocket =
stream, WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
}) })
@ -110,10 +102,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
} }
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it. /// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request( pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
mut request: Request,
config: &Option<WebSocketConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new(); let mut req = Vec::new();
write!( write!(
req, req,
@ -184,9 +173,6 @@ pub fn generate_request(
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap(); writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
} }
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
}
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req)); trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok((req, key)) Ok((req, key))
@ -200,11 +186,7 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response( pub fn verify_response(&self, response: Response) -> Result<Response> {
&self,
response: Response,
_config: &Option<WebSocketConfig>,
) -> Result<(Response, Option<Extensions>)> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -249,14 +231,7 @@ impl VerifyData {
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
let extensions = if let Some(agreed) = headers // TODO
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
{
verify_extensions(&agreed, _config)?
} else {
None
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field // 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was // and this header field indicates the use of a subprotocol that was
@ -265,47 +240,8 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok((response, extensions)) Ok(response)
}
}
fn verify_extensions(
agreed_extensions: &headers::SecWebsocketExtensions,
_config: &Option<WebSocketConfig>,
) -> Result<Option<Extensions>> {
#[cfg(feature = "deflate")]
{
if let Some(compression) = _config.and_then(|c| c.compression) {
let mut extensions = None;
for extension in agreed_extensions.iter() {
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
// > the client MUST _Fail the WebSocket Connection_.
if extension.name() != compression.name() {
return Err(Error::Protocol(ProtocolError::InvalidExtension(
extension.name().to_string(),
)));
}
// Already had PMCE configured
if extensions.is_some() {
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
extension.name().to_string(),
)));
}
extensions = Some(Extensions {
compression: Some(compression.accept_response(extension.params())?),
});
}
return Ok(extensions);
}
}
if let Some(extension) = agreed_extensions.iter().next() {
// The client didn't request anything, but got something
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
} }
Ok(None)
} }
impl TryParse for Response { impl TryParse for Response {
@ -322,7 +258,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -350,8 +286,6 @@ pub fn generate_key() -> String {
mod tests { mod tests {
use super::{super::machine::TryParse, generate_key, generate_request, Response}; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest; use crate::client::IntoClientRequest;
#[cfg(feature = "deflate")]
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -388,7 +322,7 @@ mod tests {
#[test] #[test]
fn request_formatting() { fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request, &None).unwrap(); let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost", &key); let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -396,7 +330,7 @@ mod tests {
#[test] #[test]
fn request_formatting_with_host() { fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request, &None).unwrap(); let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
@ -404,40 +338,11 @@ mod tests {
#[test] #[test]
fn request_formatting_with_at() { fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request, &None).unwrap(); let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost:9001", &key); let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
&Some(WebSocketConfig {
compression: Some(DeflateConfig::default()),
..WebSocketConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
\r\n",
host = "localhost",
key = key
)
.into_bytes();
assert_eq!(&request[..], &correct[..]);
}
#[test] #[test]
fn response_parsing() { fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -449,6 +354,6 @@ mod tests {
#[test] #[test]
fn invalid_custom_request() { fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap(); let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request, &None).is_err()); assert!(generate_request(request).is_err());
} }
} }

@ -6,7 +6,6 @@ use std::{
result::Result as StdResult, result::Result as StdResult,
}; };
use headers::{HeaderMapExt, SecWebsocketExtensions};
use http::{ use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
}; };
@ -21,7 +20,6 @@ use super::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
protocol::{Role, WebSocket, WebSocketConfig}, protocol::{Role, WebSocket, WebSocketConfig},
}; };
@ -204,8 +202,6 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>, error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Option<Extensions>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -223,7 +219,6 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback), callback: Some(callback),
config, config,
error_response: None, error_response: None,
extensions: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -245,19 +240,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let mut response = create_response(&result)?; let response = create_response(&result)?;
if let Some(config) = &self.config {
if let Some((agreed, extensions)) = result
.headers()
.typed_try_get::<SecWebsocketExtensions>()
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
.and_then(|values| config.accept_offers(&values))
{
response.headers_mut().typed_insert(agreed);
self.extensions = Some(extensions);
}
}
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result, response)
} else { } else {
@ -300,12 +283,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body))); return Err(Error::Http(http::Response::from_parts(parts, body)));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket_with_extensions( let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
stream,
Role::Server,
self.config,
self.extensions.take(),
);
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }
} }

@ -19,14 +19,13 @@ pub mod buffer;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod client; pub mod client;
pub mod error; pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub mod handshake; pub mod handshake;
pub mod protocol; pub mod protocol;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
mod server; mod server;
pub mod stream; pub mod stream;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] #[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
mod tls; mod tls;
pub mod util; pub mod util;
@ -45,5 +44,5 @@ pub use crate::{
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
}; };
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] #[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
pub use tls::{client_tls, client_tls_with_config, Connector}; pub use tls::{client_tls, client_tls_with_config, Connector};

@ -311,18 +311,6 @@ impl Frame {
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
} }
/// Create a new compressed data frame.
#[inline]
#[cfg(feature = "deflate")]
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
payload: data,
}
}
/// Create a new Pong control frame. /// Create a new Pong control frame.
#[inline] #[inline]
pub fn pong(data: Vec<u8>) -> Frame { pub fn pong(data: Vec<u8>) -> Frame {

@ -6,15 +6,14 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
use log::*;
pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::{ use crate::{
error::{CapacityError, Error, Result}, error::{CapacityError, Error, Result},
ReadBuffer, Message, ReadBuffer,
}; };
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
pub use self::frame::{CloseFrame, Frame, FrameHeader};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
@ -57,7 +56,7 @@ where
Stream: Read, Stream: Read,
{ {
/// Read a frame from stream. /// Read a frame from stream.
pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> { pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
self.codec.read_frame(&mut self.stream, max_size) self.codec.read_frame(&mut self.stream, max_size)
} }
} }
@ -66,18 +65,28 @@ impl<Stream> FrameSocket<Stream>
where where
Stream: Write, Stream: Write,
{ {
/// Writes and immediately flushes a frame.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, frame: Frame) -> Result<()> {
self.write(frame)?;
self.flush()
}
/// Write a frame to stream. /// Write a frame to stream.
/// ///
/// This function guarantees that the frame is queued regardless of any errors. /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// There is no need to resend the frame. In order to handle WouldBlock or Incomplete, ///
/// call write_pending() afterwards. /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
pub fn write_frame(&mut self, frame: Frame) -> Result<()> { /// is returned.
self.codec.write_frame(&mut self.stream, frame) /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
pub fn write(&mut self, frame: Frame) -> Result<()> {
self.codec.buffer_frame(&mut self.stream, frame)
} }
/// Complete pending write, if any. /// Flush writes.
pub fn write_pending(&mut self) -> Result<()> { pub fn flush(&mut self) -> Result<()> {
self.codec.write_pending(&mut self.stream) self.codec.write_out_buffer(&mut self.stream)?;
Ok(self.stream.flush()?)
} }
} }
@ -88,6 +97,14 @@ pub(super) struct FrameCodec {
in_buffer: ReadBuffer, in_buffer: ReadBuffer,
/// Buffer to send packets to the network. /// Buffer to send packets to the network.
out_buffer: Vec<u8>, out_buffer: Vec<u8>,
/// Capacity limit for `out_buffer`.
max_out_buffer_len: usize,
/// Buffer target length to reach before writing to the stream
/// on calls to `buffer_frame`.
///
/// Setting this to non-zero will buffer small writes from hitting
/// the stream.
out_buffer_write_len: usize,
/// Header and remaining size of the incoming packet being processed. /// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>, header: Option<(FrameHeader, u64)>,
} }
@ -95,7 +112,13 @@ pub(super) struct FrameCodec {
impl FrameCodec { impl FrameCodec {
/// Create a new frame codec. /// Create a new frame codec.
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None } Self {
in_buffer: ReadBuffer::new(),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None,
}
} }
/// Create a new frame codec from partially read data. /// Create a new frame codec from partially read data.
@ -103,10 +126,23 @@ impl FrameCodec {
Self { Self {
in_buffer: ReadBuffer::from_partially_read(part), in_buffer: ReadBuffer::from_partially_read(part),
out_buffer: Vec::new(), out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None, header: None,
} }
} }
/// Sets a maximum size for the out buffer.
pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
self.max_out_buffer_len = max;
}
/// Sets [`Self::buffer_frame`] buffer target length to reach before
/// writing to the stream.
pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
self.out_buffer_write_len = len;
}
/// Read a frame from the provided stream. /// Read a frame from the provided stream.
pub(super) fn read_frame<Stream>( pub(super) fn read_frame<Stream>(
&mut self, &mut self,
@ -165,19 +201,37 @@ impl FrameCodec {
Ok(Some(frame)) Ok(Some(frame))
} }
/// Write a frame to the provided stream. /// Writes a frame into the `out_buffer`.
pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()> /// If the out buffer size is over the `out_buffer_write_len` will also write
/// the out buffer into the provided `stream`.
///
/// To ensure buffered frames are written call [`Self::write_out_buffer`].
///
/// May write to the stream, will **not** flush.
pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
return Err(Error::WriteBufferFull(Message::Frame(frame)));
}
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
self.write_pending(stream)
if self.out_buffer.len() > self.out_buffer_write_len {
self.write_out_buffer(stream)
} else {
Ok(())
}
} }
/// Complete pending write, if any. /// Writes the out_buffer to the provided stream.
pub(super) fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()> ///
/// Does **not** flush.
pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
@ -193,16 +247,8 @@ impl FrameCodec {
} }
self.out_buffer.drain(0..len); self.out_buffer.drain(0..len);
} }
stream.flush()?;
Ok(())
}
}
#[cfg(test)] Ok(())
impl FrameCodec {
/// Returns the size of the output buffer.
pub(super) fn output_buffer_len(&self) -> usize {
self.out_buffer.len()
} }
} }
@ -224,11 +270,11 @@ mod tests {
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!( assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(), sock.read(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame(None).unwrap().is_none()); assert!(sock.read(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]); assert_eq!(rest, vec![0x99]);
@ -239,7 +285,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!( assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(), sock.read(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
} }
@ -249,10 +295,10 @@ mod tests {
let mut sock = FrameSocket::new(Vec::new()); let mut sock = FrameSocket::new(Vec::new());
let frame = Frame::ping(vec![0x04, 0x05]); let frame = Frame::ping(vec![0x04, 0x05]);
sock.write_frame(frame).unwrap(); sock.send(frame).unwrap();
let frame = Frame::pong(vec![0x01]); let frame = Frame::pong(vec![0x01]);
sock.write_frame(frame).unwrap(); sock.send(frame).unwrap();
let (buf, _) = sock.into_inner(); let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
@ -264,7 +310,7 @@ mod tests {
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
let _ = sock.read_frame(None); // should not crash let _ = sock.read(None); // should not crash
} }
#[test] #[test]
@ -272,7 +318,7 @@ mod tests {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert!(matches!( assert!(matches!(
sock.read_frame(Some(5)), sock.read(Some(5)),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
)); ));
} }

@ -84,8 +84,6 @@ use self::string_collect::StringCollector;
#[derive(Debug)] #[derive(Debug)]
pub struct IncompleteMessage { pub struct IncompleteMessage {
collector: IncompleteMessageCollector, collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -96,7 +94,6 @@ enum IncompleteMessageCollector {
impl IncompleteMessage { impl IncompleteMessage {
/// Create new. /// Create new.
#[cfg(not(feature = "deflate"))]
pub fn new(message_type: IncompleteMessageType) -> Self { pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage { IncompleteMessage {
collector: match message_type { collector: match message_type {
@ -108,25 +105,6 @@ impl IncompleteMessage {
} }
} }
/// Create new.
#[cfg(feature = "deflate")]
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed,
}
}
#[cfg(feature = "deflate")]
pub fn compressed(&self) -> bool {
self.compressed
}
/// Get the current filled size of the buffer. /// Get the current filled size of the buffer.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self.collector { match self.collector {
@ -207,7 +185,7 @@ impl Message {
Message::Text(string.into()) Message::Text(string.into())
} }
/// Create a new binary WebSocket message by converting to Vec<u8>. /// Create a new binary WebSocket message by converting to `Vec<u8>`.
pub fn binary<B>(bin: B) -> Message pub fn binary<B>(bin: B) -> Message
where where
B: Into<Vec<u8>>, B: Into<Vec<u8>>,

@ -6,13 +6,6 @@ mod message;
pub use self::{frame::CloseFrame, message::Message}; pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{ use self::{
frame::{ frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
@ -22,9 +15,13 @@ use self::{
}; };
use crate::{ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
extensions::Extensions,
util::NonBlockingResult, util::NonBlockingResult,
}; };
use log::*;
use std::{
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -38,10 +35,21 @@ pub enum Role {
/// The configuration for WebSocket connection. /// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig { pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None` /// Does nothing, instead use `max_write_buffer_size`.
/// means here that the size of the queue is unlimited. The default value is the unlimited #[deprecated]
/// queue.
pub max_send_queue: Option<usize>, pub max_send_queue: Option<usize>,
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
/// The default value is 128 KiB.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited.
///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size).
pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent /// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user. /// memory eating by a malicious user.
@ -57,76 +65,18 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC. /// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455. /// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool, pub accept_unmasked_frames: bool,
/// Optional configuration for Per-Message Compression Extension.
#[cfg(feature = "deflate")]
pub compression: Option<crate::extensions::DeflateConfig>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
fn default() -> Self { fn default() -> Self {
#[allow(deprecated)]
WebSocketConfig { WebSocketConfig {
max_send_queue: None, max_send_queue: None,
write_buffer_size: 128 * 1024,
max_write_buffer_size: usize::MAX,
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
accept_unmasked_frames: false, accept_unmasked_frames: false,
#[cfg(feature = "deflate")]
compression: None,
}
}
}
impl WebSocketConfig {
// Generate extension negotiation offers for configured extensions.
// Only `permessage-deflate` is supported at the moment.
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
#[cfg(feature = "deflate")]
{
let mut offers = Vec::new();
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
offers.push(compression);
}
if offers.is_empty() {
None
} else {
Some(headers::SecWebsocketExtensions::new(offers))
}
}
#[cfg(not(feature = "deflate"))]
{
None
}
}
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
pub fn accept_offers(
&self,
_offers: &headers::SecWebsocketExtensions,
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
#[cfg(feature = "deflate")]
{
// To support more extensions, store extension context in `Extensions` and
// concatenate negotiation responses from each extension.
let mut agreed_extensions = Vec::new();
let mut extensions = Extensions::default();
if let Some(compression) = &self.compression {
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
agreed_extensions.push(agreed);
extensions.compression = Some(compression);
}
}
if agreed_extensions.is_empty() {
None
} else {
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
}
}
#[cfg(not(feature = "deflate"))]
{
None
} }
} }
} }
@ -135,6 +85,8 @@ impl WebSocketConfig {
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions. /// It may be created by calling `connect`, `accept` or `client` functions.
///
/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocket<Stream> { pub struct WebSocket<Stream> {
/// The underlying socket. /// The underlying socket.
@ -153,18 +105,6 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) } WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
} }
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
/// ///
/// Call this function if you're using Tungstenite as a part of a web framework /// Call this function if you're using Tungstenite as a part of a web framework
@ -182,21 +122,6 @@ impl<Stream> WebSocket<Stream> {
} }
} }
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Option<Extensions>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
&self.socket &self.socket
@ -235,82 +160,116 @@ impl<Stream> WebSocket<Stream> {
impl<Stream: Read + Write> WebSocket<Stream> { impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
/// This will queue responses to ping and close messages to be sent. It will call /// This will also queue responses to ping and close messages. These responses
/// `write_pending` before trying to read in order to make sure that those responses /// will be written and flushed on the next call to [`read`](Self::read),
/// make progress even if you never call `write_pending`. That does mean that they /// [`write`](Self::write) or [`flush`](Self::flush).
/// get sent out earliest on the next call to `read_message`, `write_message` or `write_pending`.
/// ///
/// ## Closing the connection /// # Closing the connection
/// When the remote endpoint decides to close the connection this will return /// When the remote endpoint decides to close the connection this will return
/// the close message with an optional close frame. /// the close message with an optional close frame.
/// ///
/// You should continue calling `read_message`, `write_message` or `write_pending` to drive /// You should continue calling [`read`](Self::read), [`write`](Self::write) or
/// the reply to the close frame until [Error::ConnectionClosed] is returned. Once that happens /// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`]
/// it is safe to drop the underlying connection. /// is returned. Once that happens it is safe to drop the underlying connection.
pub fn read_message(&mut self) -> Result<Message> { pub fn read(&mut self) -> Result<Message> {
self.context.read_message(&mut self.socket) self.context.read(&mut self.socket)
} }
/// Send a message to stream, if possible. /// Writes and immediately flushes a message.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, message: Message) -> Result<()> {
self.write(message)?;
self.flush()
}
/// Write a message to the provided stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
///
/// In the event of stream write failure the message frame will be stored
/// in the write buffer and will try again on the next call to [`write`](Self::write)
/// or [`flush`](Self::flush).
///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
/// ///
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping /// This call will generally not flush. However, if there are queued automatic messages
/// requests. A Pong reply will jump the queue because the /// they will be written and eagerly flushed.
/// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent
/// as soon as is practical.
/// ///
/// Note that upon receiving a ping message, tungstenite cues a pong reply automatically. /// For example, upon receiving ping messages tungstenite queues pong replies automatically.
/// When you call either `read_message`, `write_message` or `write_pending` next it will try to send /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
/// that pong out if the underlying connection can take more data. This means you should not /// will write & flush the pong reply. This means you should not respond to ping frames manually.
/// respond to ping frames manually.
/// ///
/// You can however send pong frames manually in order to indicate a unidirectional heartbeat /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
/// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
/// if `read_message` returns a ping, you should call `write_pending` until it doesn't return /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing
/// WouldBlock before passing a pong to `write_message`, otherwise the response to the /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the
/// ping will not be sent, but rather replaced by your custom pong message. /// ping will not be sent as it will be replaced by your custom pong message.
/// ///
/// ## Errors /// # Errors
/// - If the WebSocket's send queue is full, `SendQueueFull` will be returned /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. /// along with the equivalent passed message frame.
/// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed]. /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`].
/// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`, /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from
/// [Error::AlreadyClosed] will be returned. This indicates a program error on your part. /// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program
/// - [Error::Io] is returned if the underlying connection returns an error /// error on your part.
/// - [`Error::Io`] is returned if the underlying connection returns an error
/// (consider these fatal except for WouldBlock). /// (consider these fatal except for WouldBlock).
/// - [Error::Capacity] if your message size is bigger than the configured max message size. /// - [`Error::Capacity`] if your message size is bigger than the configured max message size.
pub fn write_message(&mut self, message: Message) -> Result<()> { pub fn write(&mut self, message: Message) -> Result<()> {
self.context.write_message(&mut self.socket, message) self.context.write(&mut self.socket, message)
} }
/// Flush the pending send queue. /// Flush writes.
pub fn write_pending(&mut self) -> Result<()> { ///
self.context.write_pending(&mut self.socket) /// Ensures all messages previously passed to [`write`](Self::write) and automatic
/// queued pong responses are written & flushed into the underlying stream.
pub fn flush(&mut self) -> Result<()> {
self.context.flush(&mut self.socket)
} }
/// Close the connection. /// Close the connection.
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `write_message(Message::Close(..))`. /// the same as calling `write(Message::Close(..))`.
/// ///
/// After queuing the close frame you should continue calling `read_message` or /// After queuing the close frame you should continue calling [`read`](Self::read) or
/// `write_pending` to drive the close handshake to completion. /// [`flush`](Self::flush) to drive the close handshake to completion.
/// ///
/// The websocket RFC defines that the underlying connection should be closed /// The websocket RFC defines that the underlying connection should be closed
/// by the server. Tungstenite takes care of this asymmetry for you. /// by the server. Tungstenite takes care of this asymmetry for you.
/// ///
/// When the close handshake is finished (we have both sent and received /// When the close handshake is finished (we have both sent and received
/// a close message), `read_message` or `write_pending` will return /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return
/// [Error::ConnectionClosed] if this endpoint is the server. /// [Error::ConnectionClosed] if this endpoint is the server.
/// ///
/// If this endpoint is a client, [Error::ConnectionClosed] will only be /// If this endpoint is a client, [Error::ConnectionClosed] will only be
/// returned after the server has closed the underlying connection. /// returned after the server has closed the underlying connection.
/// ///
/// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed] /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
/// is returned from `read_message` or `write_pending`. /// is returned from [`read`](Self::read) or [`flush`](Self::flush).
pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> { pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
self.context.close(&mut self.socket, code) self.context.close(&mut self.socket, code)
} }
/// Old name for [`read`](Self::read).
#[deprecated(note = "Use `read`")]
pub fn read_message(&mut self) -> Result<Message> {
self.read()
}
/// Old name for [`send`](Self::send).
#[deprecated(note = "Use `send`")]
pub fn write_message(&mut self, message: Message) -> Result<()> {
self.send(message)
}
/// Old name for [`flush`](Self::flush).
#[deprecated(note = "Use `flush`")]
pub fn write_pending(&mut self) -> Result<()> {
self.flush()
}
} }
/// A context for managing WebSocket stream. /// A context for managing WebSocket stream.
@ -324,55 +283,41 @@ pub struct WebSocketContext {
state: WebSocketState, state: WebSocketState,
/// Receive: an incomplete message being processed. /// Receive: an incomplete message being processed.
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send: a data send queue. /// Send in addition to regular messages E.g. "pong" or "close".
send_queue: VecDeque<Frame>, additional_send: Option<Frame>,
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
// Container for extensions.
pub(crate) extensions: Option<Extensions>,
} }
impl WebSocketContext { impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream. /// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self { pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocketContext { Self::_new(role, FrameCodec::new(), config.unwrap_or_default())
role,
frame: FrameCodec::new(),
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_default(),
extensions: None,
}
} }
/// Create a WebSocket context that manages an post-handshake stream. /// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocketContext { Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default())
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
}
} }
pub(crate) fn from_partially_read_with_extensions( fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
part: Vec<u8>, frame.set_max_out_buffer_len(config.max_write_buffer_size);
role: Role, frame.set_out_buffer_write_len(config.write_buffer_size);
config: Option<WebSocketConfig>, Self {
extensions: Option<Extensions>, role,
) -> Self { frame,
WebSocketContext { state: WebSocketState::Active,
frame: FrameCodec::from_partially_read(part), incomplete: None,
extensions, additional_send: None,
..WebSocketContext::new(role, config) config,
} }
} }
/// Change the configuration. /// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config) set_func(&mut self.config);
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
} }
/// Read the configuration. /// Read the configuration.
@ -399,17 +344,23 @@ impl WebSocketContext {
/// ///
/// This function sends pong and close responses automatically. /// This function sends pong and close responses automatically.
/// However, it never blocks on write. /// However, it never blocks on write.
pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message> pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// Do not read from already closed connections. // Do not read from already closed connections.
self.state.check_active()?; self.state.check_not_terminated()?;
loop { loop {
if self.additional_send.is_some() {
// Since we may get ping or close, we need to reply to the messages even during read. // Since we may get ping or close, we need to reply to the messages even during read.
// Thus we call write_pending() but ignore its blocking. // Thus we flush but ignore its blocking.
self.write_pending(stream).no_block()?; self.flush(stream).no_block()?;
} else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed);
}
// If we get here, either write blocks or we have nothing to write. // If we get here, either write blocks or we have nothing to write.
// Thus if read blocks, just let it return WouldBlock. // Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? { if let Some(message) = self.read_message_frame(stream)? {
@ -419,89 +370,94 @@ impl WebSocketContext {
} }
} }
/// Send a message to the provided stream, if possible. /// Write a message to the provided stream.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// ///
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping /// In the event of stream write failure the message frame will be stored
/// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned /// in the write buffer and will try again on the next call to [`write`](Self::write)
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. /// or [`flush`](Self::flush).
/// ///
/// Note that only the last pong frame is stored to be sent, and only the /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`]
/// most recent pong frame is sent if multiple pong frames are queued. /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned.
pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()> pub fn write<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// When terminated, return AlreadyClosed. // When terminated, return AlreadyClosed.
self.state.check_active()?; self.state.check_not_terminated()?;
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol(ProtocolError::SendAfterClosing)); return Err(Error::Protocol(ProtocolError::SendAfterClosing));
} }
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
// Try to make some room for the new message.
// Do not return here if write would block, ignore WouldBlock silently
// since we must queue the message anyway.
self.write_pending(stream).no_block()?;
}
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message { let frame = match message {
Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?, Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?, Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
self.pong = Some(Frame::pong(data)); self.set_additional(Frame::pong(data));
return self.write_pending(stream); // Note: user pongs can be user flushed so no need to flush here
return self._write(stream, None).map(|_| ());
} }
Message::Close(code) => return self.close(stream, code), Message::Close(code) => return self.close(stream, code),
Message::Frame(f) => f, Message::Frame(f) => f,
}; };
self.send_queue.push_back(frame); let should_flush = self._write(stream, Some(frame))?;
self.write_pending(stream) if should_flush {
self.flush(stream)?;
} }
Ok(())
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
let opcode = OpCode::Data(opdata);
let is_final = true;
#[cfg(feature = "deflate")]
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
} }
Ok(Frame::message(data, opcode, is_final))
/// Flush writes.
///
/// Ensures all messages previously passed to [`write`](Self::write) and automatically
/// queued pong responses are written & flushed into the `stream`.
#[inline]
pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
Stream: Read + Write,
{
self._write(stream, None)?;
self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?)
} }
/// Flush the pending send queue. /// Writes any data in the out_buffer, `additional_send` and given `data`.
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()> ///
/// Does **not** flush.
///
/// Returns true if the write contents indicate we should flush immediately.
fn _write<Stream>(&mut self, stream: &mut Stream, data: Option<Frame>) -> Result<bool>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// First, make sure we have no pending frame sending. if let Some(data) = data {
self.frame.write_pending(stream)?; self.buffer_frame(stream, data)?;
}
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
// response, unless it already received a Close frame. It SHOULD // response, unless it already received a Close frame. It SHOULD
// respond with Pong frame as soon as is practical. (RFC 6455) // respond with Pong frame as soon as is practical. (RFC 6455)
if let Some(pong) = self.pong.take() { let should_flush = if let Some(msg) = self.additional_send.take() {
trace!("Sending pong reply"); trace!("Sending pong/close");
self.send_one_frame(stream, pong)?; match self.buffer_frame(stream, msg) {
Err(Error::WriteBufferFull(Message::Frame(msg))) => {
// if an system message would exceed the buffer put it back in
// `additional_send` for retry. Otherwise returning this error
// may not make sense to the user, e.g. calling `flush`.
self.set_additional(msg);
false
} }
// If we have any unsent frames, send them. Err(err) => return Err(err),
trace!("Frames still in queue: {}", self.send_queue.len()); Ok(_) => true,
while let Some(data) = self.send_queue.pop_front() {
self.send_one_frame(stream, data)?;
} }
} else {
// If we get to this point, the send queue is empty and the underlying socket is still false
// willing to take more data. };
// If we're closing and there is nothing to send anymore, we should close the connection. // If we're closing and there is nothing to send anymore, we should close the connection.
if self.role == Role::Server && !self.state.can_read() { if self.role == Role::Server && !self.state.can_read() {
@ -511,10 +467,11 @@ impl WebSocketContext {
// maximum segment lifetimes (2MSL), while there is no corresponding // maximum segment lifetimes (2MSL), while there is no corresponding
// server impact as a TIME_WAIT connection is immediately reopened upon // server impact as a TIME_WAIT connection is immediately reopened upon
// a new SYN with a higher seq number). (RFC 6455) // a new SYN with a higher seq number). (RFC 6455)
self.frame.write_out_buffer(stream)?;
self.state = WebSocketState::Terminated; self.state = WebSocketState::Terminated;
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} else { } else {
Ok(()) Ok(should_flush)
} }
} }
@ -522,7 +479,7 @@ impl WebSocketContext {
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `write(Message::Close(..))`. /// the same as calling `send(Message::Close(..))`.
pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()> pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
@ -530,11 +487,9 @@ impl WebSocketContext {
if let WebSocketState::Active = self.state { if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs; self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(code); let frame = Frame::close(code);
self.send_queue.push_back(frame); self._write(stream, Some(frame))?;
} else {
// Already closed, nothing to do.
} }
self.write_pending(stream) self.flush(stream)
} }
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
@ -555,14 +510,12 @@ impl WebSocketContext {
// the negotiated extensions defines the meaning of such a nonzero // the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket // value, the receiving endpoint MUST _Fail the WebSocket
// Connection_. // Connection_.
let is_compressed = { {
let hdr = frame.header(); let hdr = frame.header();
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 { if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
} }
}
hdr.rsv1
};
match self.role { match self.role {
Role::Server => { Role::Server => {
@ -597,10 +550,6 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
} }
// Control frames must not have compress bit.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => { OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -609,7 +558,7 @@ impl WebSocketContext {
let data = frame.into_data(); let data = frame.into_data();
// No ping processing after we sent a close frame. // No ping processing after we sent a close frame.
if self.state.is_active() { if self.state.is_active() {
self.pong = Some(Frame::pong(data.clone())); self.set_additional(Frame::pong(data.clone()));
} }
Ok(Some(Message::Ping(data))) Ok(Some(Message::Ping(data)))
} }
@ -621,34 +570,39 @@ impl WebSocketContext {
let fin = frame.header().is_final; let fin = frame.header().is_final;
match data { match data {
OpData::Continue => { OpData::Continue => {
if self.incomplete.is_some() && is_compressed { if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol( return Err(Error::Protocol(
ProtocolError::CompressedContinueFrame, ProtocolError::UnexpectedContinueFrame,
)); ));
} }
if fin {
let msg = self Ok(Some(self.incomplete.take().unwrap().complete()?))
.incomplete } else {
.take() Ok(None)
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?; }
self.extend_incomplete(msg, frame.into_data(), fin)
} }
c if self.incomplete.is_some() => { c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
} }
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data { let message_type = match data {
OpData::Text => IncompleteMessageType::Text, OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary, OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"), _ => panic!("Bug: message is not text nor binary"),
}; };
#[cfg(feature = "deflate")] let mut m = IncompleteMessage::new(message_type);
let msg = IncompleteMessage::new(message_type, is_compressed); m.extend(frame.into_data(), self.config.max_message_size)?;
#[cfg(not(feature = "deflate"))] m
let msg = IncompleteMessage::new(message_type); };
self.extend_incomplete(msg, frame.into_data(), fin) if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
} }
OpData::Reserved(i) => { OpData::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
@ -667,32 +621,6 @@ impl WebSocketContext {
} }
} }
fn extend_incomplete(
&mut self,
mut msg: IncompleteMessage,
data: Vec<u8>,
is_final: bool,
) -> Result<Option<Message>> {
#[cfg(feature = "deflate")]
let data = if msg.compressed() {
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
self.extensions
.as_mut()
.and_then(|x| x.compression.as_mut())
.unwrap()
.decompress(data, is_final)?
} else {
data
};
msg.extend(data, self.config.max_message_size)?;
if is_final {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
/// Received a close frame. Tells if we need to return a close frame to the user. /// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)] #[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> { fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
@ -714,7 +642,7 @@ impl WebSocketContext {
let reply = Frame::close(close.clone()); let reply = Frame::close(close.clone());
debug!("Replying to close with {:?}", reply); debug!("Replying to close with {:?}", reply);
self.send_queue.push_back(reply); self.set_additional(reply);
Some(close) Some(close)
} }
@ -731,8 +659,8 @@ impl WebSocketContext {
} }
} }
/// Send a single pending frame. /// Write a single frame into the write-buffer.
fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> fn buffer_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
@ -746,17 +674,17 @@ impl WebSocketContext {
} }
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state) self.frame.buffer_frame(stream, frame).check_connection_reset(self.state)
} }
fn has_compression(&self) -> bool { /// Replace `additional_send` if it is currently a `Pong` message.
#[cfg(feature = "deflate")] fn set_additional(&mut self, add: Frame) {
{ let empty_or_pong = self
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some() .additional_send
} .as_ref()
#[cfg(not(feature = "deflate"))] .map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong));
{ if empty_or_pong {
false self.additional_send.replace(add);
} }
} }
} }
@ -790,7 +718,7 @@ impl WebSocketState {
} }
/// Check if the state is active, return error if not. /// Check if the state is active, return error if not.
fn check_active(self) -> Result<()> { fn check_not_terminated(self) -> Result<()> {
match self { match self {
WebSocketState::Terminated => Err(Error::AlreadyClosed), WebSocketState::Terminated => Err(Error::AlreadyClosed),
_ => Ok(()), _ => Ok(()),
@ -842,64 +770,6 @@ mod tests {
} }
} }
struct WouldBlockStreamMoc;
impl io::Write for WouldBlockStreamMoc {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
impl io::Read for WouldBlockStreamMoc {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
#[test]
fn queue_logic() {
// Create a socket with the queue size of 1.
let mut socket = WebSocket::from_raw_socket(
WouldBlockStreamMoc,
Role::Client,
Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }),
);
// Test message that we're going to send.
let message = Message::Binary(vec![0xFF; 1024]);
// Helper to check the error.
let assert_would_block = |error| {
if let Error::Io(io_error) = error {
assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock);
} else {
panic!("Expected WouldBlock error");
}
};
// The first attempt of writing must not fail, since the queue is empty at start.
// But since the underlying mock object always returns `WouldBlock`, so is the result.
assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err()));
// Any subsequent attempts must return an error telling that the queue is full.
for _i in 0..100 {
assert!(matches!(
socket.write_message(message.clone()).unwrap_err(),
Error::SendQueueFull(..)
));
}
// The size of the output buffer must not be bigger than the size of that message
// that we managed to write to the output buffer at first. Since we could not make
// any progress (because of the logic of the moc buffer), the size remains unchanged.
if socket.context.frame.output_buffer_len() > message.len() {
panic!("Too many frames in the queue");
}
}
#[test] #[test]
fn receive_messages() { fn receive_messages() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![
@ -908,10 +778,10 @@ mod tests {
0x03, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
} }
#[test] #[test]
@ -924,7 +794,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read_message(), socket.read(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
)); ));
} }
@ -936,7 +806,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read_message(), socket.read(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
)); ));
} }

@ -104,11 +104,13 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")] #[cfg(feature = "rustls-tls-native-roots")]
{ {
for cert in rustls_native_certs::load_native_certs()? { let native_certs = rustls_native_certs::load_native_certs()?;
root_store let der_certs: Vec<Vec<u8>> =
.add(&rustls::Certificate(cert.0)) native_certs.into_iter().map(|cert| cert.0).collect();
.map_err(TlsError::Webpki)?; let total_number = der_certs.len();
} let (number_added, number_ignored) =
root_store.add_parsable_certificates(&der_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
} }
#[cfg(feature = "rustls-tls-webpki-roots")] #[cfg(feature = "rustls-tls-webpki-roots")]
{ {

@ -1,6 +1,7 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closed from the server's point of view and drop the underlying tcp socket. //! is closed from the server's point of view and drop the underlying tcp socket.
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
#![cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
use std::{ use std::{
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
@ -51,27 +52,27 @@ fn test_server_close() {
do_test( do_test(
3012, 3012,
|mut cli_sock| { |mut cli_sock| {
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive close from server let message = cli_sock.read().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read_message().unwrap(); let message = srv_sock.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read_message().unwrap(); // receive acknowledgement let message = srv_sock.read().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -85,26 +86,26 @@ fn test_evil_server_close() {
do_test( do_test(
3013, 3013,
|mut cli_sock| { |mut cli_sock| {
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
let message = cli_sock.read_message().unwrap(); // receive close from server let message = cli_sock.read().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read_message().unwrap(); let message = srv_sock.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read_message().unwrap(); // receive acknowledgement let message = srv_sock.read().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed` // and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
@ -118,32 +119,32 @@ fn test_client_close() {
do_test( do_test(
3014, 3014,
|mut cli_sock| { |mut cli_sock| {
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive answer from server let message = cli_sock.read().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server"); assert_eq!(message.into_data(), b"From Server");
cli_sock.close(None).unwrap(); // send close to server cli_sock.close(None).unwrap(); // send close to server
let message = cli_sock.read_message().unwrap(); // receive acknowledgement from server let message = cli_sock.read().unwrap(); // receive acknowledgement from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read_message().unwrap(); let message = srv_sock.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.write_message(Message::Text("From Server".into())).unwrap(); srv_sock.send(Message::Text("From Server".into())).unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client let message = srv_sock.read().unwrap(); // receive close from client
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -1,6 +1,8 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -8,7 +10,6 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, error::ProtocolError, Error, Message}; use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url; use url::Url;
@ -28,10 +29,10 @@ fn test_no_send_after_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
let message = client.read_message().unwrap(); // receive close from server let message = client.read().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -43,7 +44,7 @@ fn test_no_send_after_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
let err = client_handler.write_message(Message::Text("Hello WebSocket".into())); let err = client_handler.send(Message::Text("Hello WebSocket".into()));
assert!(err.is_err()); assert!(err.is_err());

@ -1,6 +1,8 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -8,7 +10,6 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;
@ -28,12 +29,12 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); client.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read_message().unwrap(); // receive close from server let message = client.read().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -46,12 +47,12 @@ fn test_receive_after_init_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close // This read should succeed even though we already initiated a close
let message = client_handler.read_message().unwrap(); let message = client_handler.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read_message().unwrap().is_close()); // receive acknowledgement assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -0,0 +1,68 @@
use std::io::{self, Read, Write};
use tungstenite::{protocol::WebSocketConfig, Message, WebSocket};
/// `Write` impl that records call stats and drops the data.
#[derive(Debug, Default)]
struct MockWrite {
written_bytes: usize,
write_count: usize,
flush_count: usize,
}
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written_bytes += buf.len();
self.write_count += 1;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_count += 1;
Ok(())
}
}
/// Test for write buffering and flushing behaviour.
#[test]
fn write_flush_behaviour() {
const SEND_ME_LEN: usize = 10;
const BATCH_ME_LEN: usize = 11;
const WRITE_BUFFER_SIZE: usize = 600;
let mut ws = WebSocket::from_raw_socket(
MockWrite::default(),
tungstenite::protocol::Role::Server,
Some(WebSocketConfig { write_buffer_size: WRITE_BUFFER_SIZE, ..<_>::default() }),
);
assert_eq!(ws.get_ref().written_bytes, 0);
assert_eq!(ws.get_ref().write_count, 0);
assert_eq!(ws.get_ref().flush_count, 0);
// `send` writes & flushes immediately
ws.send(Message::Text("Send me!".into())).unwrap();
assert_eq!(ws.get_ref().written_bytes, SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 1);
assert_eq!(ws.get_ref().flush_count, 1);
// send a batch of messages
for msg in (0..100).map(|_| Message::Text("Batch me!".into())) {
ws.write(msg).unwrap();
}
// after 55 writes the out_buffer will exceed write_buffer_size=600
// and so do a single underlying write (not flushing).
assert_eq!(ws.get_ref().written_bytes, 55 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 2);
assert_eq!(ws.get_ref().flush_count, 1);
// flushing will perform a single write for the remaining out_buffer & flush.
ws.flush().unwrap();
assert_eq!(ws.get_ref().written_bytes, 100 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 3);
assert_eq!(ws.get_ref().flush_count, 2);
}
Loading…
Cancel
Save