Compare commits

..

5 Commits

  1. 70
      .github/workflows/ci.yml
  2. 1
      .gitignore
  3. 14
      .travis.yml
  4. 17
      CHANGELOG.md
  5. 49
      Cargo.toml
  6. 34
      README.md
  7. 3
      benches/buffer.rs
  8. 75
      benches/write.rs
  9. 6
      examples/autobahn-client.rs
  10. 4
      examples/autobahn-server.rs
  11. 4
      examples/client.rs
  12. 4
      examples/server.rs
  13. 6
      examples/srv_accept_unmasked_frames.rs
  14. 2
      fuzz/fuzz_targets/read_message_client.rs
  15. 2
      fuzz/fuzz_targets/read_message_server.rs
  16. 1
      src/client.rs
  17. 10
      src/error.rs
  18. 6
      src/handshake/client.rs
  19. 43
      src/handshake/server.rs
  20. 4
      src/lib.rs
  21. 122
      src/protocol/frame/mod.rs
  22. 2
      src/protocol/message.rs
  23. 400
      src/protocol/mod.rs
  24. 12
      src/tls.rs
  25. 41
      tests/connection_reset.rs
  26. 9
      tests/no_send_after_close.rs
  27. 15
      tests/receive_after_init_close.rs
  28. 68
      tests/write.rs

@ -1,70 +0,0 @@
name: CI
on: [push, pull_request]
jobs:
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@nightly
with:
components: rustfmt
- run: cargo fmt --all --check
test:
name: Test
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Install cargo-hack
uses: taiki-e/install-action@cargo-hack
- name: Check
run: cargo hack check --feature-powerset --all-targets
- name: Test
run: cargo test --release
autobahn:
name: Autobahn tests
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- stable
- beta
- nightly
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Running Autobahn TestSuite for client
run: ./scripts/autobahn-client.sh
- name: Running Autobahn TestSuite for server
run: ./scripts/autobahn-server.sh

1
.gitignore vendored

@ -1,3 +1,2 @@
target target
Cargo.lock Cargo.lock
.vscode

@ -0,0 +1,14 @@
language: rust
rust:
- stable
services:
- docker
before_script:
- export PATH="$PATH:$HOME/.cargo/bin"
script:
- cargo test --release
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

@ -1,20 +1,3 @@
# Unreleased (0.20.0)
- Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing.
This allows writing a batch of messages and flushing once, improving performance.
- Add `WebSocket::read`, `write`, `send`, `flush`. Deprecate `read_message`, `write_message`, `write_pending`.
- Add `FrameSocket::read`, `write`, `send`, `flush`. Remove `read_frame`, `write_frame`, `write_pending`.
Note: Previous use of `write_frame` may be replaced with `send`.
- Add `WebSocketContext::read`, `write`, `flush`. Remove `read_message`, `write_message`, `write_pending`.
Note: Previous use of `write_message` may be replaced with `write` + `flush`.
- Remove `send_queue`, replaced with using the frame write buffer to achieve similar results.
* Add `WebSocketConfig::max_write_buffer_size`. Deprecate `max_send_queue`.
* Add `Error::WriteBufferFull`. Remove `Error::SendQueueFull`.
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
# 0.19.0 # 0.19.0
- Update TLS dependencies. - Update TLS dependencies.

@ -1,14 +1,14 @@
[package] [package]
name = "tungstenite" name = "ng-tungstenite"
description = "Lightweight stream-based WebSocket implementation" description = "fork of tungstenite for Nextgraph.org"
categories = ["web-programming::websocket", "network-programming"] categories = []
keywords = ["websocket", "io", "web"] keywords = ["websocket", "io", "web"]
authors = ["Alexey Galakhov", "Daniel Abramov"] authors = ["Alexey Galakhov", "Daniel Abramov", "Niko PLP <niko@nextgraph.org>"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://git.nextgraph.org/NextGraph/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.19.0" documentation = "https://docs.rs/tungstenite/0.19.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://git.nextgraph.org/NextGraph/tungstenite-rs"
version = "0.19.0" version = "0.19.0"
edition = "2018" edition = "2018"
rust-version = "1.51" rust-version = "1.51"
@ -24,7 +24,7 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls"] __rustls-tls = ["rustls", "webpki"]
[dependencies] [dependencies]
data-encoding = { version = "2", optional = true } data-encoding = { version = "2", optional = true }
@ -52,12 +52,17 @@ version = "0.21.0"
optional = true optional = true
version = "0.6.0" version = "0.6.0"
[dependencies.webpki]
optional = true
version = "0.22"
features = ["std"]
[dependencies.webpki-roots] [dependencies.webpki-roots]
optional = true optional = true
version = "0.23" version = "0.23"
[dev-dependencies] [dev-dependencies]
criterion = "0.5.0" criterion = "0.4.0"
env_logger = "0.10.0" env_logger = "0.10.0"
input_buffer = "0.5.0" input_buffer = "0.5.0"
net2 = "0.2.37" net2 = "0.2.37"
@ -66,31 +71,3 @@ rand = "0.8.4"
[[bench]] [[bench]]
name = "buffer" name = "buffer"
harness = false harness = false
[[bench]]
name = "write"
harness = false
[[example]]
name = "client"
required-features = ["handshake"]
[[example]]
name = "server"
required-features = ["handshake"]
[[example]]
name = "autobahn-client"
required-features = ["handshake"]
[[example]]
name = "autobahn-server"
required-features = ["handshake"]
[[example]]
name = "callback-error"
required-features = ["handshake"]
[[example]]
name = "srv_accept_unmasked_frames"
required-features = ["handshake"]

@ -1,4 +1,6 @@
# Tungstenite # ng-tungstenite
fork of https://github.com/snapview/tungstenite-rs for the needs of NextGraph.org
Lightweight stream-based WebSocket implementation for [Rust](https://www.rust-lang.org/). Lightweight stream-based WebSocket implementation for [Rust](https://www.rust-lang.org/).
@ -14,11 +16,11 @@ fn main () {
spawn (move || { spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap(); let mut websocket = accept(stream.unwrap()).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages. // We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap(); websocket.write_message(msg).unwrap();
} }
} }
}); });
@ -36,12 +38,12 @@ take a look at [`tokio-tungstenite`](https://github.com/snapview/tokio-tungsteni
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)
[![Build Status](https://github.com/snapview/tungstenite-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/snapview/tungstenite-rs/actions) [![Build Status](https://travis-ci.org/snapview/tungstenite-rs.svg?branch=master)](https://travis-ci.org/snapview/tungstenite-rs)
[Documentation](https://docs.rs/tungstenite) [Documentation](https://docs.rs/tungstenite)
Introduction ## Introduction
------------
This library provides an implementation of WebSockets, This library provides an implementation of WebSockets,
[RFC6455](https://tools.ietf.org/html/rfc6455). It allows for both synchronous (like TcpStream) [RFC6455](https://tools.ietf.org/html/rfc6455). It allows for both synchronous (like TcpStream)
and asynchronous usage and is easy to integrate into any third-party event loops including and asynchronous usage and is easy to integrate into any third-party event loops including
@ -49,23 +51,21 @@ and asynchronous usage and is easy to integrate into any third-party event loops
WebSocket protocol but still makes them accessible for those who wants full control over the WebSocket protocol but still makes them accessible for those who wants full control over the
network. network.
Why Tungstenite? ## Why Tungstenite?
----------------
It's formerly WS2, the 2nd implementation of WS. WS2 is the chemical formula of It's formerly WS2, the 2nd implementation of WS. WS2 is the chemical formula of
tungsten disulfide, the tungstenite mineral. tungsten disulfide, the tungstenite mineral.
Features ## Features
--------
Tungstenite provides a complete implementation of the WebSocket specification. Tungstenite provides a complete implementation of the WebSocket specification.
TLS is supported on all platforms using `native-tls` or `rustls`. The following TLS is supported on all platforms using `native-tls` or `rustls`. The following
features are available: features are available:
* `native-tls` - `native-tls`
* `native-tls-vendored` - `native-tls-vendored`
* `rustls-tls-native-roots` - `rustls-tls-native-roots`
* `rustls-tls-webpki-roots` - `rustls-tls-webpki-roots`
Choose the one that is appropriate for your needs. Choose the one that is appropriate for your needs.
@ -74,13 +74,11 @@ otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink: There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing ## Testing
-------
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for
WebSockets. It is also covered by internal unit tests as well as possible. WebSockets. It is also covered by internal unit tests as well as possible.
Contributing ## Contributing
------------
Please report bugs and make feature requests [here](https://github.com/snapview/tungstenite-rs/issues). Please report bugs and make feature requests [here](https://github.com/snapview/tungstenite-rs/issues).

@ -1,4 +1,5 @@
use std::io::{Cursor, Read, Result as IoResult}; use std::io::Result as IoResult;
use std::io::{Cursor, Read};
use bytes::Buf; use bytes::Buf;
use criterion::*; use criterion::*;

@ -1,75 +0,0 @@
//! Benchmarks for write performance.
use criterion::{BatchSize, Criterion};
use std::{
hint,
io::{self, Read, Write},
time::{Duration, Instant},
};
use tungstenite::{Message, WebSocket};
const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// `Write` impl that simulates slowish writes and slow flushes.
///
/// Each `write` can buffer up to 8 MiB before flushing but takes an additional **~80ns**
/// to simulate stuff going on in the underlying stream.
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
}
// simulate io
spin(Duration::from_nanos(80));
self.0.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.0.is_empty() {
// simulate io
spin(Duration::from_micros(8));
self.0.clear();
}
Ok(())
}
}
fn spin(duration: Duration) {
let a = Instant::now();
while a.elapsed() < duration {
hint::spin_loop();
}
}
fn benchmark(c: &mut Criterion) {
// Writes 100k small json text messages then flushes
c.bench_function("write 100k small texts then flush", |b| {
let mut ws = WebSocket::from_raw_socket(
MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)),
tungstenite::protocol::Role::Server,
None,
);
b.iter_batched(
|| (0..100_000).map(|i| Message::Text(format!("{{\"id\":{i}}}"))),
|batch| {
for msg in batch {
ws.write(msg).unwrap();
}
ws.flush().unwrap();
},
BatchSize::SmallInput,
)
});
}
criterion::criterion_group!(write_benches, benchmark);
criterion::criterion_main!(write_benches);

@ -7,7 +7,7 @@ const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
} }
@ -26,9 +26,9 @@ fn run_test(case: u32) -> Result<()> {
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect(case_url)?;
loop { loop {
match socket.read()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -17,9 +17,9 @@ fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept(stream).map_err(must_not_block)?;
info!("Running test"); info!("Running test");
loop { loop {
match socket.read()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => { msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.send(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {}
} }

@ -14,9 +14,9 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket.send(Message::Text("Hello WebSocket".into())).unwrap(); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop { loop {
let msg = socket.read().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);
} }
// socket.close(None); // socket.close(None);

@ -28,9 +28,9 @@ fn main() {
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
websocket.send(msg).unwrap(); websocket.write_message(msg).unwrap();
} }
} }
}); });

@ -27,18 +27,20 @@ fn main() {
}; };
let config = Some(WebSocketConfig { let config = Some(WebSocketConfig {
max_send_queue: None,
max_message_size: None,
max_frame_size: None,
// This setting allows to accept client frames which are not masked // This setting allows to accept client frames which are not masked
// This is not in compliance with RFC 6455 but might be handy in some // This is not in compliance with RFC 6455 but might be handy in some
// rare cases where it is necessary to integrate with existing/legacy // rare cases where it is necessary to integrate with existing/legacy
// clients which are sending unmasked frames // clients which are sending unmasked frames
accept_unmasked_frames: true, accept_unmasked_frames: true,
..<_>::default()
}); });
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
loop { loop {
let msg = websocket.read().unwrap(); let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() { if msg.is_binary() || msg.is_text() {
println!("received message {}", msg); println!("received message {}", msg);
} }

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None);
socket.read().ok(); socket.read_message().ok();
}); });

@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| {
//let vector: Vec<u8> = data.into(); //let vector: Vec<u8> = data.into();
let cursor = Cursor::new(data); let cursor = Cursor::new(data);
let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None);
socket.read().ok(); socket.read_message().ok();
}); });

@ -54,7 +54,6 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,

@ -53,9 +53,9 @@ pub enum Error {
/// Protocol violation. /// Protocol violation.
#[error("WebSocket protocol error: {0}")] #[error("WebSocket protocol error: {0}")]
Protocol(#[from] ProtocolError), Protocol(#[from] ProtocolError),
/// Message write buffer is full. /// Message send queue full.
#[error("Write buffer is full")] #[error("Send queue is full")]
WriteBufferFull(Message), SendQueueFull(Message),
/// UTF coding error. /// UTF coding error.
#[error("UTF-8 encoding error")] #[error("UTF-8 encoding error")]
Utf8, Utf8,
@ -271,6 +271,10 @@ pub enum TlsError {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")] #[error("rustls error: {0}")]
Rustls(#[from] rustls::Error), Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error. /// DNS name resolution error.
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name")] #[error("Invalid DNS name")]

@ -87,8 +87,8 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
Ok(r) => r, Ok(r) => r,
Err(Error::Http(mut e)) => { Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail); *e.body_mut() = Some(tail);
return Err(Error::Http(e)); return Err(Error::Http(e))
} },
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
@ -258,7 +258,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;

@ -30,7 +30,7 @@ pub type Request = HttpRequest<()>;
pub type Response = HttpResponse<()>; pub type Response = HttpResponse<()>;
/// Server error response type. /// Server error response type.
pub type ErrorResponse = HttpResponse<Option<String>>; pub type ErrorResponse = HttpResponse<Option<Vec<u8>>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
@ -156,23 +156,15 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse>;
self,
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>, F: FnOnce(&Request) -> StdResult<(), ErrorResponse>,
{ {
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse> {
self, self(request)
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
self(request, response)
} }
} }
@ -181,12 +173,8 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request( fn on_request(self, _request: &Request) -> StdResult<(), ErrorResponse> {
self, Ok(())
_request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
Ok(response)
} }
} }
@ -240,24 +228,24 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result)
} else { } else {
Ok(response) Ok(())
}; };
match callback_result { match callback_result {
Ok(response) => { Ok(_) => {
let response = create_response(&result)?;
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &response)?; write_response(&mut output, &response)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
Err(resp) => { Err(resp) => {
if resp.status().is_success() { // if resp.status().is_success() {
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); // return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
} // }
self.error_response = Some(resp); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap(); let resp = self.error_response.as_ref().unwrap();
@ -266,7 +254,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
write_response(&mut output, resp)?; write_response(&mut output, resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend(body);
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
@ -279,7 +267,6 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
debug!("Server handshake failed."); debug!("Server handshake failed.");
let (parts, body) = err.into_parts(); let (parts, body) = err.into_parts();
let body = body.map(|b| b.as_bytes().to_vec());
return Err(Error::Http(http::Response::from_parts(parts, body))); return Err(Error::Http(http::Response::from_parts(parts, body)));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");

@ -25,7 +25,7 @@ pub mod protocol;
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
mod server; mod server;
pub mod stream; pub mod stream;
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))] #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls; mod tls;
pub mod util; pub mod util;
@ -44,5 +44,5 @@ pub use crate::{
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
}; };
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))] #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_tls, client_tls_with_config, Connector}; pub use tls::{client_tls, client_tls_with_config, Connector};

@ -6,14 +6,15 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
use crate::{
error::{CapacityError, Error, Result},
Message, ReadBuffer,
};
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
use log::*;
pub use self::frame::{CloseFrame, Frame, FrameHeader}; pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::{
error::{CapacityError, Error, Result},
ReadBuffer,
};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
@ -56,7 +57,7 @@ where
Stream: Read, Stream: Read,
{ {
/// Read a frame from stream. /// Read a frame from stream.
pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> { pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
self.codec.read_frame(&mut self.stream, max_size) self.codec.read_frame(&mut self.stream, max_size)
} }
} }
@ -65,28 +66,18 @@ impl<Stream> FrameSocket<Stream>
where where
Stream: Write, Stream: Write,
{ {
/// Writes and immediately flushes a frame.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, frame: Frame) -> Result<()> {
self.write(frame)?;
self.flush()
}
/// Write a frame to stream. /// Write a frame to stream.
/// ///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes. /// This function guarantees that the frame is queued regardless of any errors.
/// /// There is no need to resend the frame. In order to handle WouldBlock or Incomplete,
/// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`] /// call write_pending() afterwards.
/// is returned. pub fn write_frame(&mut self, frame: Frame) -> Result<()> {
/// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards. self.codec.write_frame(&mut self.stream, frame)
pub fn write(&mut self, frame: Frame) -> Result<()> {
self.codec.buffer_frame(&mut self.stream, frame)
} }
/// Flush writes. /// Complete pending write, if any.
pub fn flush(&mut self) -> Result<()> { pub fn write_pending(&mut self) -> Result<()> {
self.codec.write_out_buffer(&mut self.stream)?; self.codec.write_pending(&mut self.stream)
Ok(self.stream.flush()?)
} }
} }
@ -97,14 +88,6 @@ pub(super) struct FrameCodec {
in_buffer: ReadBuffer, in_buffer: ReadBuffer,
/// Buffer to send packets to the network. /// Buffer to send packets to the network.
out_buffer: Vec<u8>, out_buffer: Vec<u8>,
/// Capacity limit for `out_buffer`.
max_out_buffer_len: usize,
/// Buffer target length to reach before writing to the stream
/// on calls to `buffer_frame`.
///
/// Setting this to non-zero will buffer small writes from hitting
/// the stream.
out_buffer_write_len: usize,
/// Header and remaining size of the incoming packet being processed. /// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>, header: Option<(FrameHeader, u64)>,
} }
@ -112,13 +95,7 @@ pub(super) struct FrameCodec {
impl FrameCodec { impl FrameCodec {
/// Create a new frame codec. /// Create a new frame codec.
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None }
in_buffer: ReadBuffer::new(),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None,
}
} }
/// Create a new frame codec from partially read data. /// Create a new frame codec from partially read data.
@ -126,23 +103,10 @@ impl FrameCodec {
Self { Self {
in_buffer: ReadBuffer::from_partially_read(part), in_buffer: ReadBuffer::from_partially_read(part),
out_buffer: Vec::new(), out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
header: None, header: None,
} }
} }
/// Sets a maximum size for the out buffer.
pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
self.max_out_buffer_len = max;
}
/// Sets [`Self::buffer_frame`] buffer target length to reach before
/// writing to the stream.
pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
self.out_buffer_write_len = len;
}
/// Read a frame from the provided stream. /// Read a frame from the provided stream.
pub(super) fn read_frame<Stream>( pub(super) fn read_frame<Stream>(
&mut self, &mut self,
@ -201,37 +165,19 @@ impl FrameCodec {
Ok(Some(frame)) Ok(Some(frame))
} }
/// Writes a frame into the `out_buffer`. /// Write a frame to the provided stream.
/// If the out buffer size is over the `out_buffer_write_len` will also write pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
/// the out buffer into the provided `stream`.
///
/// To ensure buffered frames are written call [`Self::write_out_buffer`].
///
/// May write to the stream, will **not** flush.
pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
return Err(Error::WriteBufferFull(Message::Frame(frame)));
}
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
self.write_pending(stream)
if self.out_buffer.len() > self.out_buffer_write_len {
self.write_out_buffer(stream)
} else {
Ok(())
}
} }
/// Writes the out_buffer to the provided stream. /// Complete pending write, if any.
/// pub(super) fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
/// Does **not** flush.
pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
@ -247,11 +193,19 @@ impl FrameCodec {
} }
self.out_buffer.drain(0..len); self.out_buffer.drain(0..len);
} }
stream.flush()?;
Ok(()) Ok(())
} }
} }
#[cfg(test)]
impl FrameCodec {
/// Returns the size of the output buffer.
pub(super) fn output_buffer_len(&self) -> usize {
self.out_buffer.len()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -270,11 +224,11 @@ mod tests {
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!( assert_eq!(
sock.read(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read(None).unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]); assert_eq!(rest, vec![0x99]);
@ -285,7 +239,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!( assert_eq!(
sock.read(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
} }
@ -295,10 +249,10 @@ mod tests {
let mut sock = FrameSocket::new(Vec::new()); let mut sock = FrameSocket::new(Vec::new());
let frame = Frame::ping(vec![0x04, 0x05]); let frame = Frame::ping(vec![0x04, 0x05]);
sock.send(frame).unwrap(); sock.write_frame(frame).unwrap();
let frame = Frame::pong(vec![0x01]); let frame = Frame::pong(vec![0x01]);
sock.send(frame).unwrap(); sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner(); let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
@ -310,7 +264,7 @@ mod tests {
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
let _ = sock.read(None); // should not crash let _ = sock.read_frame(None); // should not crash
} }
#[test] #[test]
@ -318,7 +272,7 @@ mod tests {
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert!(matches!( assert!(matches!(
sock.read(Some(5)), sock.read_frame(Some(5)),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
)); ));
} }

@ -185,7 +185,7 @@ impl Message {
Message::Text(string.into()) Message::Text(string.into())
} }
/// Create a new binary WebSocket message by converting to `Vec<u8>`. /// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message pub fn binary<B>(bin: B) -> Message
where where
B: Into<Vec<u8>>, B: Into<Vec<u8>>,

@ -6,6 +6,13 @@ mod message;
pub use self::{frame::CloseFrame, message::Message}; pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{ use self::{
frame::{ frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
@ -17,11 +24,6 @@ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
use log::*;
use std::{
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -35,21 +37,10 @@ pub enum Role {
/// The configuration for WebSocket connection. /// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig { pub struct WebSocketConfig {
/// Does nothing, instead use `max_write_buffer_size`. /// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
#[deprecated] /// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>, pub max_send_queue: Option<usize>,
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
/// The default value is 128 KiB.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited.
///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size).
pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent /// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user. /// memory eating by a malicious user.
@ -69,11 +60,8 @@ pub struct WebSocketConfig {
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
fn default() -> Self { fn default() -> Self {
#[allow(deprecated)]
WebSocketConfig { WebSocketConfig {
max_send_queue: None, max_send_queue: None,
write_buffer_size: 128 * 1024,
max_write_buffer_size: usize::MAX,
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
accept_unmasked_frames: false, accept_unmasked_frames: false,
@ -85,8 +73,6 @@ impl Default for WebSocketConfig {
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions. /// It may be created by calling `connect`, `accept` or `client` functions.
///
/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocket<Stream> { pub struct WebSocket<Stream> {
/// The underlying socket. /// The underlying socket.
@ -160,116 +146,82 @@ impl<Stream> WebSocket<Stream> {
impl<Stream: Read + Write> WebSocket<Stream> { impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
/// This will also queue responses to ping and close messages. These responses /// This will queue responses to ping and close messages to be sent. It will call
/// will be written and flushed on the next call to [`read`](Self::read), /// `write_pending` before trying to read in order to make sure that those responses
/// [`write`](Self::write) or [`flush`](Self::flush). /// make progress even if you never call `write_pending`. That does mean that they
/// get sent out earliest on the next call to `read_message`, `write_message` or `write_pending`.
/// ///
/// # Closing the connection /// ## Closing the connection
/// When the remote endpoint decides to close the connection this will return /// When the remote endpoint decides to close the connection this will return
/// the close message with an optional close frame. /// the close message with an optional close frame.
/// ///
/// You should continue calling [`read`](Self::read), [`write`](Self::write) or /// You should continue calling `read_message`, `write_message` or `write_pending` to drive
/// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`] /// the reply to the close frame until [Error::ConnectionClosed] is returned. Once that happens
/// is returned. Once that happens it is safe to drop the underlying connection. /// it is safe to drop the underlying connection.
pub fn read(&mut self) -> Result<Message> { pub fn read_message(&mut self) -> Result<Message> {
self.context.read(&mut self.socket) self.context.read_message(&mut self.socket)
}
/// Writes and immediately flushes a message.
/// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
pub fn send(&mut self, message: Message) -> Result<()> {
self.write(message)?;
self.flush()
} }
/// Write a message to the provided stream, if possible. /// Send a message to stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// ///
/// In the event of stream write failure the message frame will be stored /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// in the write buffer and will try again on the next call to [`write`](Self::write) /// requests. A Pong reply will jump the queue because the
/// or [`flush`](Self::flush). /// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent
/// as soon as is practical.
/// ///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] /// Note that upon receiving a ping message, tungstenite cues a pong reply automatically.
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. /// When you call either `read_message`, `write_message` or `write_pending` next it will try to send
/// /// that pong out if the underlying connection can take more data. This means you should not
/// This call will generally not flush. However, if there are queued automatic messages /// respond to ping frames manually.
/// they will be written and eagerly flushed.
///
/// For example, upon receiving ping messages tungstenite queues pong replies automatically.
/// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush)
/// will write & flush the pong reply. This means you should not respond to ping frames manually.
/// ///
/// You can however send pong frames manually in order to indicate a unidirectional heartbeat /// You can however send pong frames manually in order to indicate a unidirectional heartbeat
/// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that
/// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing /// if `read_message` returns a ping, you should call `write_pending` until it doesn't return
/// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the /// WouldBlock before passing a pong to `write_message`, otherwise the response to the
/// ping will not be sent as it will be replaced by your custom pong message. /// ping will not be sent, but rather replaced by your custom pong message.
/// ///
/// # Errors /// ## Errors
/// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned /// - If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the equivalent passed message frame. /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`]. /// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed].
/// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from /// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`,
/// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program /// [Error::AlreadyClosed] will be returned. This indicates a program error on your part.
/// error on your part. /// - [Error::Io] is returned if the underlying connection returns an error
/// - [`Error::Io`] is returned if the underlying connection returns an error
/// (consider these fatal except for WouldBlock). /// (consider these fatal except for WouldBlock).
/// - [`Error::Capacity`] if your message size is bigger than the configured max message size. /// - [Error::Capacity] if your message size is bigger than the configured max message size.
pub fn write(&mut self, message: Message) -> Result<()> { pub fn write_message(&mut self, message: Message) -> Result<()> {
self.context.write(&mut self.socket, message) self.context.write_message(&mut self.socket, message)
} }
/// Flush writes. /// Flush the pending send queue.
/// pub fn write_pending(&mut self) -> Result<()> {
/// Ensures all messages previously passed to [`write`](Self::write) and automatic self.context.write_pending(&mut self.socket)
/// queued pong responses are written & flushed into the underlying stream.
pub fn flush(&mut self) -> Result<()> {
self.context.flush(&mut self.socket)
} }
/// Close the connection. /// Close the connection.
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `write(Message::Close(..))`. /// the same as calling `write_message(Message::Close(..))`.
/// ///
/// After queuing the close frame you should continue calling [`read`](Self::read) or /// After queuing the close frame you should continue calling `read_message` or
/// [`flush`](Self::flush) to drive the close handshake to completion. /// `write_pending` to drive the close handshake to completion.
/// ///
/// The websocket RFC defines that the underlying connection should be closed /// The websocket RFC defines that the underlying connection should be closed
/// by the server. Tungstenite takes care of this asymmetry for you. /// by the server. Tungstenite takes care of this asymmetry for you.
/// ///
/// When the close handshake is finished (we have both sent and received /// When the close handshake is finished (we have both sent and received
/// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return /// a close message), `read_message` or `write_pending` will return
/// [Error::ConnectionClosed] if this endpoint is the server. /// [Error::ConnectionClosed] if this endpoint is the server.
/// ///
/// If this endpoint is a client, [Error::ConnectionClosed] will only be /// If this endpoint is a client, [Error::ConnectionClosed] will only be
/// returned after the server has closed the underlying connection. /// returned after the server has closed the underlying connection.
/// ///
/// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed] /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed]
/// is returned from [`read`](Self::read) or [`flush`](Self::flush). /// is returned from `read_message` or `write_pending`.
pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> { pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
self.context.close(&mut self.socket, code) self.context.close(&mut self.socket, code)
} }
/// Old name for [`read`](Self::read).
#[deprecated(note = "Use `read`")]
pub fn read_message(&mut self) -> Result<Message> {
self.read()
}
/// Old name for [`send`](Self::send).
#[deprecated(note = "Use `send`")]
pub fn write_message(&mut self, message: Message) -> Result<()> {
self.send(message)
}
/// Old name for [`flush`](Self::flush).
#[deprecated(note = "Use `flush`")]
pub fn write_pending(&mut self) -> Result<()> {
self.flush()
}
} }
/// A context for managing WebSocket stream. /// A context for managing WebSocket stream.
@ -283,8 +235,10 @@ pub struct WebSocketContext {
state: WebSocketState, state: WebSocketState,
/// Receive: an incomplete message being processed. /// Receive: an incomplete message being processed.
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send in addition to regular messages E.g. "pong" or "close". /// Send: a data send queue.
additional_send: Option<Frame>, send_queue: VecDeque<Frame>,
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
} }
@ -292,32 +246,28 @@ pub struct WebSocketContext {
impl WebSocketContext { impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream. /// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self { pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::new(), config.unwrap_or_default()) WebSocketContext {
role,
frame: FrameCodec::new(),
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_default(),
}
} }
/// Create a WebSocket context that manages an post-handshake stream. /// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default()) WebSocketContext {
} frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
frame.set_max_out_buffer_len(config.max_write_buffer_size);
frame.set_out_buffer_write_len(config.write_buffer_size);
Self {
role,
frame,
state: WebSocketState::Active,
incomplete: None,
additional_send: None,
config,
} }
} }
/// Change the configuration. /// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config); set_func(&mut self.config)
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
} }
/// Read the configuration. /// Read the configuration.
@ -344,23 +294,17 @@ impl WebSocketContext {
/// ///
/// This function sends pong and close responses automatically. /// This function sends pong and close responses automatically.
/// However, it never blocks on write. /// However, it never blocks on write.
pub fn read<Stream>(&mut self, stream: &mut Stream) -> Result<Message> pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// Do not read from already closed connections. // Do not read from already closed connections.
self.state.check_not_terminated()?; self.state.check_active()?;
loop { loop {
if self.additional_send.is_some() {
// Since we may get ping or close, we need to reply to the messages even during read. // Since we may get ping or close, we need to reply to the messages even during read.
// Thus we flush but ignore its blocking. // Thus we call write_pending() but ignore its blocking.
self.flush(stream).no_block()?; self.write_pending(stream).no_block()?;
} else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed);
}
// If we get here, either write blocks or we have nothing to write. // If we get here, either write blocks or we have nothing to write.
// Thus if read blocks, just let it return WouldBlock. // Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? { if let Some(message) = self.read_message_frame(stream)? {
@ -370,94 +314,78 @@ impl WebSocketContext {
} }
} }
/// Write a message to the provided stream. /// Send a message to the provided stream, if possible.
///
/// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
/// ///
/// In the event of stream write failure the message frame will be stored /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// in the write buffer and will try again on the next call to [`write`](Self::write) /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// or [`flush`](Self::flush). /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// ///
/// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] /// Note that only the last pong frame is stored to be sent, and only the
/// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. /// most recent pong frame is sent if multiple pong frames are queued.
pub fn write<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()> pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// When terminated, return AlreadyClosed. // When terminated, return AlreadyClosed.
self.state.check_not_terminated()?; self.state.check_active()?;
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol(ProtocolError::SendAfterClosing)); return Err(Error::Protocol(ProtocolError::SendAfterClosing));
} }
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
// Try to make some room for the new message.
// Do not return here if write would block, ignore WouldBlock silently
// since we must queue the message anyway.
self.write_pending(stream).no_block()?;
}
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message { let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
self.set_additional(Frame::pong(data)); self.pong = Some(Frame::pong(data));
// Note: user pongs can be user flushed so no need to flush here return self.write_pending(stream);
return self._write(stream, None).map(|_| ());
} }
Message::Close(code) => return self.close(stream, code), Message::Close(code) => return self.close(stream, code),
Message::Frame(f) => f, Message::Frame(f) => f,
}; };
let should_flush = self._write(stream, Some(frame))?; self.send_queue.push_back(frame);
if should_flush { self.write_pending(stream)
self.flush(stream)?;
}
Ok(())
}
/// Flush writes.
///
/// Ensures all messages previously passed to [`write`](Self::write) and automatically
/// queued pong responses are written & flushed into the `stream`.
#[inline]
pub fn flush<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where
Stream: Read + Write,
{
self._write(stream, None)?;
self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?)
} }
/// Writes any data in the out_buffer, `additional_send` and given `data`. /// Flush the pending send queue.
/// pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
/// Does **not** flush.
///
/// Returns true if the write contents indicate we should flush immediately.
fn _write<Stream>(&mut self, stream: &mut Stream, data: Option<Frame>) -> Result<bool>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
if let Some(data) = data { // First, make sure we have no pending frame sending.
self.buffer_frame(stream, data)?; self.frame.write_pending(stream)?;
}
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
// response, unless it already received a Close frame. It SHOULD // response, unless it already received a Close frame. It SHOULD
// respond with Pong frame as soon as is practical. (RFC 6455) // respond with Pong frame as soon as is practical. (RFC 6455)
let should_flush = if let Some(msg) = self.additional_send.take() { if let Some(pong) = self.pong.take() {
trace!("Sending pong/close"); trace!("Sending pong reply");
match self.buffer_frame(stream, msg) { self.send_one_frame(stream, pong)?;
Err(Error::WriteBufferFull(Message::Frame(msg))) => { }
// if an system message would exceed the buffer put it back in // If we have any unsent frames, send them.
// `additional_send` for retry. Otherwise returning this error trace!("Frames still in queue: {}", self.send_queue.len());
// may not make sense to the user, e.g. calling `flush`. while let Some(data) = self.send_queue.pop_front() {
self.set_additional(msg); self.send_one_frame(stream, data)?;
false
}
Err(err) => return Err(err),
Ok(_) => true,
} }
} else {
false // If we get to this point, the send queue is empty and the underlying socket is still
}; // willing to take more data.
// If we're closing and there is nothing to send anymore, we should close the connection. // If we're closing and there is nothing to send anymore, we should close the connection.
if self.role == Role::Server && !self.state.can_read() { if self.role == Role::Server && !self.state.can_read() {
@ -467,11 +395,10 @@ impl WebSocketContext {
// maximum segment lifetimes (2MSL), while there is no corresponding // maximum segment lifetimes (2MSL), while there is no corresponding
// server impact as a TIME_WAIT connection is immediately reopened upon // server impact as a TIME_WAIT connection is immediately reopened upon
// a new SYN with a higher seq number). (RFC 6455) // a new SYN with a higher seq number). (RFC 6455)
self.frame.write_out_buffer(stream)?;
self.state = WebSocketState::Terminated; self.state = WebSocketState::Terminated;
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} else { } else {
Ok(should_flush) Ok(())
} }
} }
@ -479,7 +406,7 @@ impl WebSocketContext {
/// ///
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `send(Message::Close(..))`. /// the same as calling `write(Message::Close(..))`.
pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()> pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
@ -487,9 +414,11 @@ impl WebSocketContext {
if let WebSocketState::Active = self.state { if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs; self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(code); let frame = Frame::close(code);
self._write(stream, Some(frame))?; self.send_queue.push_back(frame);
} else {
// Already closed, nothing to do.
} }
self.flush(stream) self.write_pending(stream)
} }
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
@ -558,7 +487,7 @@ impl WebSocketContext {
let data = frame.into_data(); let data = frame.into_data();
// No ping processing after we sent a close frame. // No ping processing after we sent a close frame.
if self.state.is_active() { if self.state.is_active() {
self.set_additional(Frame::pong(data.clone())); self.pong = Some(Frame::pong(data.clone()));
} }
Ok(Some(Message::Ping(data))) Ok(Some(Message::Ping(data)))
} }
@ -642,7 +571,7 @@ impl WebSocketContext {
let reply = Frame::close(close.clone()); let reply = Frame::close(close.clone());
debug!("Replying to close with {:?}", reply); debug!("Replying to close with {:?}", reply);
self.set_additional(reply); self.send_queue.push_back(reply);
Some(close) Some(close)
} }
@ -659,8 +588,8 @@ impl WebSocketContext {
} }
} }
/// Write a single frame into the write-buffer. /// Send a single pending frame.
fn buffer_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
@ -674,18 +603,7 @@ impl WebSocketContext {
} }
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.buffer_frame(stream, frame).check_connection_reset(self.state) self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
/// Replace `additional_send` if it is currently a `Pong` message.
fn set_additional(&mut self, add: Frame) {
let empty_or_pong = self
.additional_send
.as_ref()
.map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong));
if empty_or_pong {
self.additional_send.replace(add);
}
} }
} }
@ -718,7 +636,7 @@ impl WebSocketState {
} }
/// Check if the state is active, return error if not. /// Check if the state is active, return error if not.
fn check_not_terminated(self) -> Result<()> { fn check_active(self) -> Result<()> {
match self { match self {
WebSocketState::Terminated => Err(Error::AlreadyClosed), WebSocketState::Terminated => Err(Error::AlreadyClosed),
_ => Ok(()), _ => Ok(()),
@ -770,6 +688,64 @@ mod tests {
} }
} }
struct WouldBlockStreamMoc;
impl io::Write for WouldBlockStreamMoc {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
impl io::Read for WouldBlockStreamMoc {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
#[test]
fn queue_logic() {
// Create a socket with the queue size of 1.
let mut socket = WebSocket::from_raw_socket(
WouldBlockStreamMoc,
Role::Client,
Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }),
);
// Test message that we're going to send.
let message = Message::Binary(vec![0xFF; 1024]);
// Helper to check the error.
let assert_would_block = |error| {
if let Error::Io(io_error) = error {
assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock);
} else {
panic!("Expected WouldBlock error");
}
};
// The first attempt of writing must not fail, since the queue is empty at start.
// But since the underlying mock object always returns `WouldBlock`, so is the result.
assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err()));
// Any subsequent attempts must return an error telling that the queue is full.
for _i in 0..100 {
assert!(matches!(
socket.write_message(message.clone()).unwrap_err(),
Error::SendQueueFull(..)
));
}
// The size of the output buffer must not be bigger than the size of that message
// that we managed to write to the output buffer at first. Since we could not make
// any progress (because of the logic of the moc buffer), the size remains unchanged.
if socket.context.frame.output_buffer_len() > message.len() {
panic!("Too many frames in the queue");
}
}
#[test] #[test]
fn receive_messages() { fn receive_messages() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![
@ -778,10 +754,10 @@ mod tests {
0x03, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
} }
#[test] #[test]
@ -794,7 +770,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read(), socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 }))
)); ));
} }
@ -806,7 +782,7 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert!(matches!( assert!(matches!(
socket.read(), socket.read_message(),
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
)); ));
} }

@ -104,13 +104,11 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")] #[cfg(feature = "rustls-tls-native-roots")]
{ {
let native_certs = rustls_native_certs::load_native_certs()?; for cert in rustls_native_certs::load_native_certs()? {
let der_certs: Vec<Vec<u8>> = root_store
native_certs.into_iter().map(|cert| cert.0).collect(); .add(&rustls::Certificate(cert.0))
let total_number = der_certs.len(); .map_err(TlsError::Rustls)?;
let (number_added, number_ignored) = }
root_store.add_parsable_certificates(&der_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
} }
#[cfg(feature = "rustls-tls-webpki-roots")] #[cfg(feature = "rustls-tls-webpki-roots")]
{ {

@ -1,7 +1,6 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closed from the server's point of view and drop the underlying tcp socket. //! is closed from the server's point of view and drop the underlying tcp socket.
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
#![cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "handshake"))]
use std::{ use std::{
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
@ -52,27 +51,27 @@ fn test_server_close() {
do_test( do_test(
3012, 3012,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -86,26 +85,26 @@ fn test_evil_server_close() {
do_test( do_test(
3013, 3013,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
let message = cli_sock.read().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.close(None).unwrap(); // send close to client srv_sock.close(None).unwrap(); // send close to client
let message = srv_sock.read().unwrap(); // receive acknowledgement let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed` // and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
@ -119,32 +118,32 @@ fn test_client_close() {
do_test( do_test(
3014, 3014,
|mut cli_sock| { |mut cli_sock| {
cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read().unwrap(); // receive answer from server let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server"); assert_eq!(message.into_data(), b"From Server");
cli_sock.close(None).unwrap(); // send close to server cli_sock.close(None).unwrap(); // send close to server
let message = cli_sock.read().unwrap(); // receive acknowledgement from server let message = cli_sock.read_message().unwrap(); // receive acknowledgement from server
assert!(message.is_close()); assert!(message.is_close());
let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}, },
|mut srv_sock| { |mut srv_sock| {
let message = srv_sock.read().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.send(Message::Text("From Server".into())).unwrap(); srv_sock.write_message(Message::Text("From Server".into())).unwrap();
let message = srv_sock.read().unwrap(); // receive close from client let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close()); assert!(message.is_close());
let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, error::ProtocolError, Error, Message}; use tungstenite::{accept, connect, error::ProtocolError, Error, Message};
use url::Url; use url::Url;
@ -29,10 +28,10 @@ fn test_no_send_after_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
let message = client.read().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -44,7 +43,7 @@ fn test_no_send_after_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
let err = client_handler.send(Message::Text("Hello WebSocket".into())); let err = client_handler.write_message(Message::Text("Hello WebSocket".into()));
assert!(err.is_err()); assert!(err.is_err());

@ -1,8 +1,6 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
#![cfg(feature = "handshake")]
use std::{ use std::{
net::TcpListener, net::TcpListener,
process::exit, process::exit,
@ -10,6 +8,7 @@ use std::{
time::Duration, time::Duration,
}; };
#[cfg(feature = "handshake")]
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;
@ -29,12 +28,12 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap(); client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
@ -47,12 +46,12 @@ fn test_receive_after_init_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close // This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap(); let message = client_handler.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement assert!(client_handler.read_message().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),

@ -1,68 +0,0 @@
use std::io::{self, Read, Write};
use tungstenite::{protocol::WebSocketConfig, Message, WebSocket};
/// `Write` impl that records call stats and drops the data.
#[derive(Debug, Default)]
struct MockWrite {
written_bytes: usize,
write_count: usize,
flush_count: usize,
}
impl Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written_bytes += buf.len();
self.write_count += 1;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_count += 1;
Ok(())
}
}
/// Test for write buffering and flushing behaviour.
#[test]
fn write_flush_behaviour() {
const SEND_ME_LEN: usize = 10;
const BATCH_ME_LEN: usize = 11;
const WRITE_BUFFER_SIZE: usize = 600;
let mut ws = WebSocket::from_raw_socket(
MockWrite::default(),
tungstenite::protocol::Role::Server,
Some(WebSocketConfig { write_buffer_size: WRITE_BUFFER_SIZE, ..<_>::default() }),
);
assert_eq!(ws.get_ref().written_bytes, 0);
assert_eq!(ws.get_ref().write_count, 0);
assert_eq!(ws.get_ref().flush_count, 0);
// `send` writes & flushes immediately
ws.send(Message::Text("Send me!".into())).unwrap();
assert_eq!(ws.get_ref().written_bytes, SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 1);
assert_eq!(ws.get_ref().flush_count, 1);
// send a batch of messages
for msg in (0..100).map(|_| Message::Text("Batch me!".into())) {
ws.write(msg).unwrap();
}
// after 55 writes the out_buffer will exceed write_buffer_size=600
// and so do a single underlying write (not flushing).
assert_eq!(ws.get_ref().written_bytes, 55 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 2);
assert_eq!(ws.get_ref().flush_count, 1);
// flushing will perform a single write for the remaining out_buffer & flush.
ws.flush().unwrap();
assert_eq!(ws.get_ref().written_bytes, 100 * BATCH_ME_LEN + SEND_ME_LEN);
assert_eq!(ws.get_ref().write_count, 3);
assert_eq!(ws.get_ref().flush_count, 2);
}
Loading…
Cancel
Save