diff --git a/.gitignore b/.gitignore index a9d37c5..c8e9e48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.vscode diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aa0a41..5a60687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +# Unreleased (0.20.0) +- Remove many implicit flushing behaviours. In general reading and writing messages will no + longer flush until calling `flush`. An exception is automatic responses (e.g. pongs) + which will continue to be written and flushed when reading and writing. + This allows writing a batch of messages and flushing once. +- Add `WebSocket::read`, `write`, `send`, `flush`. Deprecate `read_message`, `write_message`, `write_pending`. +- Add `FrameSocket::read`, `write`, `send`, `flush`. Remove `read_frame`, `write_frame`, `write_pending`. + Note: Previous use of `write_frame` may be replaced with `send`. +- Add `WebSocketContext::read`, `write`, `flush`. Remove `read_message`, `write_message`, `write_pending`. + Note: Previous use of `write_message` may be replaced with `write` + `flush`. +- Remove `send_queue`, replaced with using the frame write buffer to achieve similar results. + * Add `WebSocketConfig::max_write_buffer_size`. Deprecate `max_send_queue`. + * Add `Error::WriteBufferFull`. Remove `Error::SendQueueFull`. + Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`. + # 0.19.0 - Update TLS dependencies. diff --git a/Cargo.toml b/Cargo.toml index 86d45aa..72cf2e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,10 @@ rand = "0.8.4" name = "buffer" harness = false +[[bench]] +name = "write" +harness = false + [[example]] name = "client" required-features = ["handshake"] diff --git a/README.md b/README.md index 302ac97..3522fb9 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ fn main () { spawn (move || { let mut websocket = accept(stream.unwrap()).unwrap(); loop { - let msg = websocket.read_message().unwrap(); + let msg = websocket.read().unwrap(); // We do not want to send back ping/pong messages. if msg.is_binary() || msg.is_text() { - websocket.write_message(msg).unwrap(); + websocket.send(msg).unwrap(); } } }); diff --git a/benches/write.rs b/benches/write.rs new file mode 100644 index 0000000..8a19874 --- /dev/null +++ b/benches/write.rs @@ -0,0 +1,67 @@ +//! Benchmarks for write performance. +use criterion::{BatchSize, Criterion}; +use std::{ + hint, + io::{self, Read, Write}, + time::{Duration, Instant}, +}; +use tungstenite::{Message, WebSocket}; + +const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024; + +/// `Write` impl that simulates fast writes and slow flushes. +/// +/// Buffers up to 8 MiB fast on `write`. Each `flush` takes ~100ns. +struct MockSlowFlushWrite(Vec); + +impl Read for MockSlowFlushWrite { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported")) + } +} +impl Write for MockSlowFlushWrite { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.0.len() + buf.len() > MOCK_WRITE_LEN { + self.flush()?; + } + self.0.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + if !self.0.is_empty() { + // simulate 100ns io + let a = Instant::now(); + while a.elapsed() < Duration::from_nanos(100) { + hint::spin_loop(); + } + self.0.clear(); + } + Ok(()) + } +} + +fn benchmark(c: &mut Criterion) { + // Writes 100k small json text messages then flushes + c.bench_function("write 100k small texts then flush", |b| { + let mut ws = WebSocket::from_raw_socket( + MockSlowFlushWrite(Vec::with_capacity(MOCK_WRITE_LEN)), + tungstenite::protocol::Role::Server, + None, + ); + + b.iter_batched( + || (0..100_000).map(|i| Message::Text(format!("{{\"id\":{i}}}"))), + |batch| { + for msg in batch { + ws.write(msg).unwrap(); + } + ws.flush().unwrap(); + }, + BatchSize::SmallInput, + ) + }); +} + +criterion::criterion_group!(write_benches, benchmark); +criterion::criterion_main!(write_benches); diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 2538d86..ac7a7d1 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -7,7 +7,7 @@ const AGENT: &str = "Tungstenite"; fn get_case_count() -> Result { let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; - let msg = socket.read_message()?; + let msg = socket.read()?; socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) } @@ -26,9 +26,9 @@ fn run_test(case: u32) -> Result<()> { Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap(); let (mut socket, _) = connect(case_url)?; loop { - match socket.read_message()? { + match socket.read()? { msg @ Message::Text(_) | msg @ Message::Binary(_) => { - socket.write_message(msg)?; + socket.send(msg)?; } Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index f002efa..dafe37b 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -17,9 +17,9 @@ fn handle_client(stream: TcpStream) -> Result<()> { let mut socket = accept(stream).map_err(must_not_block)?; info!("Running test"); loop { - match socket.read_message()? { + match socket.read()? { msg @ Message::Text(_) | msg @ Message::Binary(_) => { - socket.write_message(msg)?; + socket.send(msg)?; } Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => {} } diff --git a/examples/client.rs b/examples/client.rs index def6a3c..a24f316 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -14,9 +14,9 @@ fn main() { println!("* {}", header); } - socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + socket.send(Message::Text("Hello WebSocket".into())).unwrap(); loop { - let msg = socket.read_message().expect("Error reading message"); + let msg = socket.read().expect("Error reading message"); println!("Received: {}", msg); } // socket.close(None); diff --git a/examples/server.rs b/examples/server.rs index 420e5db..2183b96 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -28,9 +28,9 @@ fn main() { let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); loop { - let msg = websocket.read_message().unwrap(); + let msg = websocket.read().unwrap(); if msg.is_binary() || msg.is_text() { - websocket.write_message(msg).unwrap(); + websocket.send(msg).unwrap(); } } }); diff --git a/examples/srv_accept_unmasked_frames.rs b/examples/srv_accept_unmasked_frames.rs index b280fba..b65e4f7 100644 --- a/examples/srv_accept_unmasked_frames.rs +++ b/examples/srv_accept_unmasked_frames.rs @@ -27,20 +27,18 @@ fn main() { }; let config = Some(WebSocketConfig { - max_send_queue: None, - max_message_size: None, - max_frame_size: None, // This setting allows to accept client frames which are not masked // This is not in compliance with RFC 6455 but might be handy in some // rare cases where it is necessary to integrate with existing/legacy // clients which are sending unmasked frames accept_unmasked_frames: true, + ..<_>::default() }); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); loop { - let msg = websocket.read_message().unwrap(); + let msg = websocket.read().unwrap(); if msg.is_binary() || msg.is_text() { println!("received message {}", msg); } diff --git a/fuzz/fuzz_targets/read_message_client.rs b/fuzz/fuzz_targets/read_message_client.rs index 1c0708b..8e53512 100644 --- a/fuzz/fuzz_targets/read_message_client.rs +++ b/fuzz/fuzz_targets/read_message_client.rs @@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| { //let vector: Vec = data.into(); let cursor = Cursor::new(data); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Client, None); - socket.read_message().ok(); + socket.read().ok(); }); diff --git a/fuzz/fuzz_targets/read_message_server.rs b/fuzz/fuzz_targets/read_message_server.rs index d96db96..7f0e7ff 100644 --- a/fuzz/fuzz_targets/read_message_server.rs +++ b/fuzz/fuzz_targets/read_message_server.rs @@ -33,5 +33,5 @@ fuzz_target!(|data: &[u8]| { //let vector: Vec = data.into(); let cursor = Cursor::new(data); let mut socket = WebSocket::from_raw_socket(WriteMoc(cursor), Role::Server, None); - socket.read_message().ok(); + socket.read().ok(); }); diff --git a/src/error.rs b/src/error.rs index b9a957f..a7b3354 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,9 +53,9 @@ pub enum Error { /// Protocol violation. #[error("WebSocket protocol error: {0}")] Protocol(#[from] ProtocolError), - /// Message send queue full. - #[error("Send queue is full")] - SendQueueFull(Message), + /// Message write buffer is full. + #[error("Write buffer is full")] + WriteBufferFull(Message), /// UTF coding error. #[error("UTF-8 encoding error")] Utf8, diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 39066be..bb72e5a 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,15 +6,14 @@ pub mod coding; mod frame; mod mask; -use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; - -use log::*; - -pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::{ error::{CapacityError, Error, Result}, - ReadBuffer, + Message, ReadBuffer, }; +use log::*; +use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; + +pub use self::frame::{CloseFrame, Frame, FrameHeader}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -57,7 +56,7 @@ where Stream: Read, { /// Read a frame from stream. - pub fn read_frame(&mut self, max_size: Option) -> Result> { + pub fn read(&mut self, max_size: Option) -> Result> { self.codec.read_frame(&mut self.stream, max_size) } } @@ -66,18 +65,28 @@ impl FrameSocket where Stream: Write, { + /// Writes and immediately flushes a frame. + /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush). + pub fn send(&mut self, frame: Frame) -> Result<()> { + self.write(frame)?; + self.flush() + } + /// Write a frame to stream. /// - /// This function guarantees that the frame is queued regardless of any errors. - /// There is no need to resend the frame. In order to handle WouldBlock or Incomplete, - /// call write_pending() afterwards. - pub fn write_frame(&mut self, frame: Frame) -> Result<()> { + /// A subsequent call should be made to [`flush`](Self::flush) to flush writes. + /// + /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`] + /// is returned. + /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards. + pub fn write(&mut self, frame: Frame) -> Result<()> { self.codec.write_frame(&mut self.stream, frame) } - /// Complete pending write, if any. - pub fn write_pending(&mut self) -> Result<()> { - self.codec.write_pending(&mut self.stream) + /// Flush writes. + pub fn flush(&mut self) -> Result<()> { + self.codec.write_out_buffer(&mut self.stream)?; + Ok(self.stream.flush()?) } } @@ -88,6 +97,8 @@ pub(super) struct FrameCodec { in_buffer: ReadBuffer, /// Buffer to send packets to the network. out_buffer: Vec, + /// Capacity limit for `out_buffer`. + max_out_buffer_len: usize, /// Header and remaining size of the incoming packet being processed. header: Option<(FrameHeader, u64)>, } @@ -95,7 +106,12 @@ pub(super) struct FrameCodec { impl FrameCodec { /// Create a new frame codec. pub(super) fn new() -> Self { - Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None } + Self { + in_buffer: ReadBuffer::new(), + out_buffer: Vec::new(), + max_out_buffer_len: usize::MAX, + header: None, + } } /// Create a new frame codec from partially read data. @@ -103,10 +119,22 @@ impl FrameCodec { Self { in_buffer: ReadBuffer::from_partially_read(part), out_buffer: Vec::new(), + max_out_buffer_len: usize::MAX, header: None, } } + /// Sets a maximum size for the out buffer. + pub(super) fn with_max_out_buffer_len(mut self, max: usize) -> Self { + self.max_out_buffer_len = max; + self + } + + /// Sets a maximum size for the out buffer. + pub(super) fn set_max_out_buffer_len(&mut self, max: usize) { + self.max_out_buffer_len = max; + } + /// Read a frame from the provided stream. pub(super) fn read_frame( &mut self, @@ -166,18 +194,28 @@ impl FrameCodec { } /// Write a frame to the provided stream. + /// + /// Does **not** flush. pub(super) fn write_frame(&mut self, stream: &mut Stream, frame: Frame) -> Result<()> where Stream: Write, { + if frame.len() + self.out_buffer.len() > self.max_out_buffer_len { + return Err(Error::WriteBufferFull(Message::Frame(frame))); + } + trace!("writing frame {}", frame); + self.out_buffer.reserve(frame.len()); frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); - self.write_pending(stream) + + self.write_out_buffer(stream) } - /// Complete pending write, if any. - pub(super) fn write_pending(&mut self, stream: &mut Stream) -> Result<()> + /// Write any buffered frames to the provided stream. + /// + /// Does **not** flush. + pub(super) fn write_out_buffer(&mut self, stream: &mut Stream) -> Result<()> where Stream: Write, { @@ -193,16 +231,8 @@ impl FrameCodec { } self.out_buffer.drain(0..len); } - stream.flush()?; - Ok(()) - } -} -#[cfg(test)] -impl FrameCodec { - /// Returns the size of the output buffer. - pub(super) fn output_buffer_len(&self) -> usize { - self.out_buffer.len() + Ok(()) } } @@ -224,11 +254,11 @@ mod tests { let mut sock = FrameSocket::new(raw); assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), + sock.read(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); - assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); - assert!(sock.read_frame(None).unwrap().is_none()); + assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); + assert!(sock.read(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); assert_eq!(rest, vec![0x99]); @@ -239,7 +269,7 @@ mod tests { let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), + sock.read(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); } @@ -249,10 +279,10 @@ mod tests { let mut sock = FrameSocket::new(Vec::new()); let frame = Frame::ping(vec![0x04, 0x05]); - sock.write_frame(frame).unwrap(); + sock.send(frame).unwrap(); let frame = Frame::pong(vec![0x01]); - sock.write_frame(frame).unwrap(); + sock.send(frame).unwrap(); let (buf, _) = sock.into_inner(); assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]); @@ -264,7 +294,7 @@ mod tests { 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, ]); let mut sock = FrameSocket::new(raw); - let _ = sock.read_frame(None); // should not crash + let _ = sock.read(None); // should not crash } #[test] @@ -272,7 +302,7 @@ mod tests { let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); assert!(matches!( - sock.read_frame(Some(5)), + sock.read(Some(5)), Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) )); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index cdebabc..2b2ed0b 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -185,7 +185,7 @@ impl Message { Message::Text(string.into()) } - /// Create a new binary WebSocket message by converting to Vec. + /// Create a new binary WebSocket message by converting to `Vec`. pub fn binary(bin: B) -> Message where B: Into>, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 94397e9..1100a67 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -6,13 +6,6 @@ mod message; pub use self::{frame::CloseFrame, message::Message}; -use log::*; -use std::{ - collections::VecDeque, - io::{ErrorKind as IoErrorKind, Read, Write}, - mem::replace, -}; - use self::{ frame::{ coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, @@ -24,6 +17,11 @@ use crate::{ error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; +use log::*; +use std::{ + io::{ErrorKind as IoErrorKind, Read, Write}, + mem::replace, +}; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -37,10 +35,12 @@ pub enum Role { /// The configuration for WebSocket connection. #[derive(Debug, Clone, Copy)] pub struct WebSocketConfig { - /// The size of the send queue. You can use it to turn on/off the backpressure features. `None` - /// means here that the size of the queue is unlimited. The default value is the unlimited - /// queue. + /// Does nothing, instead use `max_write_buffer_size`. + #[deprecated] pub max_send_queue: Option, + /// The max size of the write buffer in bytes. Setting this can provide backpressure. + /// The default value is unlimited. + pub max_write_buffer_size: usize, /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// which should be reasonably big for all normal use-cases but small enough to prevent /// memory eating by a malicious user. @@ -60,8 +60,10 @@ pub struct WebSocketConfig { impl Default for WebSocketConfig { fn default() -> Self { + #[allow(deprecated)] WebSocketConfig { max_send_queue: None, + max_write_buffer_size: usize::MAX, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false, @@ -73,6 +75,8 @@ impl Default for WebSocketConfig { /// /// This is THE structure you want to create to be able to speak the WebSocket protocol. /// It may be created by calling `connect`, `accept` or `client` functions. +/// +/// Use [`WebSocket::read`], [`WebSocket::send`] to received and send messages. #[derive(Debug)] pub struct WebSocket { /// The underlying socket. @@ -146,82 +150,116 @@ impl WebSocket { impl WebSocket { /// Read a message from stream, if possible. /// - /// This will queue responses to ping and close messages to be sent. It will call - /// `write_pending` before trying to read in order to make sure that those responses - /// make progress even if you never call `write_pending`. That does mean that they - /// get sent out earliest on the next call to `read_message`, `write_message` or `write_pending`. + /// This will also queue responses to ping and close messages. These responses + /// will be written and flushed on the next call to [`read`](Self::read), + /// [`write`](Self::write) or [`flush`](Self::flush). /// - /// ## Closing the connection + /// # Closing the connection /// When the remote endpoint decides to close the connection this will return /// the close message with an optional close frame. /// - /// You should continue calling `read_message`, `write_message` or `write_pending` to drive - /// the reply to the close frame until [Error::ConnectionClosed] is returned. Once that happens - /// it is safe to drop the underlying connection. - pub fn read_message(&mut self) -> Result { - self.context.read_message(&mut self.socket) + /// You should continue calling [`read`](Self::read), [`write`](Self::write) or + /// [`flush`](Self::flush) to drive the reply to the close frame until [`Error::ConnectionClosed`] + /// is returned. Once that happens it is safe to drop the underlying connection. + pub fn read(&mut self) -> Result { + self.context.read(&mut self.socket) + } + + /// Writes and immediately flushes a message. + /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush). + pub fn send(&mut self, message: Message) -> Result<()> { + self.write(message)?; + self.flush() } - /// Send a message to stream, if possible. + /// Write a message to the provided stream, if possible. /// - /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping - /// requests. A Pong reply will jump the queue because the - /// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent - /// as soon as is practical. + /// A subsequent call should be made to [`flush`](Self::flush) to flush writes. /// - /// Note that upon receiving a ping message, tungstenite cues a pong reply automatically. - /// When you call either `read_message`, `write_message` or `write_pending` next it will try to send - /// that pong out if the underlying connection can take more data. This means you should not - /// respond to ping frames manually. + /// In the event of stream write failure the message frame will be stored + /// in the write buffer and will try again on the next call to [`write`](Self::write) + /// or [`flush`](Self::flush). + /// + /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] + /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. + /// + /// This call will generally not flush. However, if there are queued automatic messages + /// they will be written and eagerly flushed. + /// + /// For example, upon receiving ping messages tungstenite queues pong replies automatically. + /// The next call to [`read`](Self::read), [`write`](Self::write) or [`flush`](Self::flush) + /// will write & flush the pong reply. This means you should not respond to ping frames manually. /// /// You can however send pong frames manually in order to indicate a unidirectional heartbeat /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that - /// if `read_message` returns a ping, you should call `write_pending` until it doesn't return - /// WouldBlock before passing a pong to `write_message`, otherwise the response to the - /// ping will not be sent, but rather replaced by your custom pong message. + /// if [`read`](Self::read) returns a ping, you should [`flush`](Self::flush) before passing + /// a custom pong to [`write`](Self::write), otherwise the automatic queued response to the + /// ping will not be sent as it will be replaced by your custom pong message. /// - /// ## Errors - /// - If the WebSocket's send queue is full, `SendQueueFull` will be returned - /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. - /// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed]. - /// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`, - /// [Error::AlreadyClosed] will be returned. This indicates a program error on your part. - /// - [Error::Io] is returned if the underlying connection returns an error + /// # Errors + /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned + /// along with the equivalent passed message frame. + /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`]. + /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from + /// [`read`](Self::read), [`Error::AlreadyClosed`] will be returned. This indicates a program + /// error on your part. + /// - [`Error::Io`] is returned if the underlying connection returns an error /// (consider these fatal except for WouldBlock). - /// - [Error::Capacity] if your message size is bigger than the configured max message size. - pub fn write_message(&mut self, message: Message) -> Result<()> { - self.context.write_message(&mut self.socket, message) + /// - [`Error::Capacity`] if your message size is bigger than the configured max message size. + pub fn write(&mut self, message: Message) -> Result<()> { + self.context.write(&mut self.socket, message) } - /// Flush the pending send queue. - pub fn write_pending(&mut self) -> Result<()> { - self.context.write_pending(&mut self.socket) + /// Flush writes. + /// + /// Ensures all messages previously passed to [`write`](Self::write) and automatic + /// queued pong responses are written & flushed into the underlying stream. + pub fn flush(&mut self) -> Result<()> { + self.context.flush(&mut self.socket) } /// Close the connection. /// /// This function guarantees that the close frame will be queued. /// There is no need to call it again. Calling this function is - /// the same as calling `write_message(Message::Close(..))`. + /// the same as calling `write(Message::Close(..))`. /// - /// After queuing the close frame you should continue calling `read_message` or - /// `write_pending` to drive the close handshake to completion. + /// After queuing the close frame you should continue calling [`read`](Self::read) or + /// [`flush`](Self::flush) to drive the close handshake to completion. /// /// The websocket RFC defines that the underlying connection should be closed /// by the server. Tungstenite takes care of this asymmetry for you. /// /// When the close handshake is finished (we have both sent and received - /// a close message), `read_message` or `write_pending` will return + /// a close message), [`read`](Self::read) or [`flush`](Self::flush) will return /// [Error::ConnectionClosed] if this endpoint is the server. /// /// If this endpoint is a client, [Error::ConnectionClosed] will only be /// returned after the server has closed the underlying connection. /// /// It is thus safe to drop the underlying connection as soon as [Error::ConnectionClosed] - /// is returned from `read_message` or `write_pending`. + /// is returned from [`read`](Self::read) or [`flush`](Self::flush). pub fn close(&mut self, code: Option) -> Result<()> { self.context.close(&mut self.socket, code) } + + /// Old name for [`read`](Self::read). + #[deprecated(note = "Use `read`")] + pub fn read_message(&mut self) -> Result { + self.read() + } + + /// Old name for [`send`](Self::send). + #[deprecated(note = "Use `send`")] + pub fn write_message(&mut self, message: Message) -> Result<()> { + self.send(message) + } + + /// Old name for [`flush`](Self::flush). + #[deprecated(note = "Use `flush`")] + pub fn write_pending(&mut self) -> Result<()> { + self.flush() + } } /// A context for managing WebSocket stream. @@ -235,10 +273,8 @@ pub struct WebSocketContext { state: WebSocketState, /// Receive: an incomplete message being processed. incomplete: Option, - /// Send: a data send queue. - send_queue: VecDeque, - /// Send: an OOB pong message. - pong: Option, + /// Send in addition to regular messages E.g. "pong" or "close". + additional_send: Option, /// The configuration for the websocket session. config: WebSocketConfig, } @@ -246,28 +282,32 @@ pub struct WebSocketContext { impl WebSocketContext { /// Create a WebSocket context that manages a post-handshake stream. pub fn new(role: Role, config: Option) -> Self { + let config = config.unwrap_or_default(); + WebSocketContext { role, - frame: FrameCodec::new(), + frame: FrameCodec::new().with_max_out_buffer_len(config.max_write_buffer_size), state: WebSocketState::Active, incomplete: None, - send_queue: VecDeque::new(), - pong: None, - config: config.unwrap_or_default(), + additional_send: None, + config, } } /// Create a WebSocket context that manages an post-handshake stream. pub fn from_partially_read(part: Vec, role: Role, config: Option) -> Self { + let config = config.unwrap_or_default(); WebSocketContext { - frame: FrameCodec::from_partially_read(part), - ..WebSocketContext::new(role, config) + frame: FrameCodec::from_partially_read(part) + .with_max_out_buffer_len(config.max_write_buffer_size), + ..WebSocketContext::new(role, Some(config)) } } /// Change the configuration. pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { - set_func(&mut self.config) + set_func(&mut self.config); + self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size); } /// Read the configuration. @@ -294,17 +334,23 @@ impl WebSocketContext { /// /// This function sends pong and close responses automatically. /// However, it never blocks on write. - pub fn read_message(&mut self, stream: &mut Stream) -> Result + pub fn read(&mut self, stream: &mut Stream) -> Result where Stream: Read + Write, { // Do not read from already closed connections. - self.state.check_active()?; + self.state.check_not_terminated()?; loop { - // Since we may get ping or close, we need to reply to the messages even during read. - // Thus we call write_pending() but ignore its blocking. - self.write_pending(stream).no_block()?; + if self.additional_send.is_some() { + // Since we may get ping or close, we need to reply to the messages even during read. + // Thus we flush but ignore its blocking. + self.flush(stream).no_block()?; + } else if self.role == Role::Server && !self.state.can_read() { + self.state = WebSocketState::Terminated; + return Err(Error::ConnectionClosed); + } + // If we get here, either write blocks or we have nothing to write. // Thus if read blocks, just let it return WouldBlock. if let Some(message) = self.read_message_frame(stream)? { @@ -314,78 +360,94 @@ impl WebSocketContext { } } - /// Send a message to the provided stream, if possible. + /// Write a message to the provided stream. + /// + /// A subsequent call should be made to [`flush`](Self::flush) to flush writes. /// - /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping - /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned - /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. + /// In the event of stream write failure the message frame will be stored + /// in the write buffer and will try again on the next call to [`write`](Self::write) + /// or [`flush`](Self::flush). /// - /// Note that only the last pong frame is stored to be sent, and only the - /// most recent pong frame is sent if multiple pong frames are queued. - pub fn write_message(&mut self, stream: &mut Stream, message: Message) -> Result<()> + /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] + /// [`Err(WriteBufferFull(msg_frame))`](Error::WriteBufferFull) is returned. + pub fn write(&mut self, stream: &mut Stream, message: Message) -> Result<()> where Stream: Read + Write, { // When terminated, return AlreadyClosed. - self.state.check_active()?; + self.state.check_not_terminated()?; // Do not write after sending a close frame. if !self.state.is_active() { return Err(Error::Protocol(ProtocolError::SendAfterClosing)); } - if let Some(max_send_queue) = self.config.max_send_queue { - if self.send_queue.len() >= max_send_queue { - // Try to make some room for the new message. - // Do not return here if write would block, ignore WouldBlock silently - // since we must queue the message anyway. - self.write_pending(stream).no_block()?; - } - - if self.send_queue.len() >= max_send_queue { - return Err(Error::SendQueueFull(message)); - } - } - let frame = match message { Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { - self.pong = Some(Frame::pong(data)); - return self.write_pending(stream); + self.set_additional(Frame::pong(data)); + // Note: user pongs can be user flushed so no need to flush here + return self._write(stream, None).map(|_| ()); } Message::Close(code) => return self.close(stream, code), Message::Frame(f) => f, }; - self.send_queue.push_back(frame); - self.write_pending(stream) + let should_flush = self._write(stream, Some(frame))?; + if should_flush { + self.flush(stream)?; + } + Ok(()) + } + + /// Flush writes. + /// + /// Ensures all messages previously passed to [`write`](Self::write) and automatically + /// queued pong responses are written & flushed into the `stream`. + #[inline] + pub fn flush(&mut self, stream: &mut Stream) -> Result<()> + where + Stream: Read + Write, + { + self._write(stream, None)?; + Ok(stream.flush()?) } - /// Flush the pending send queue. - pub fn write_pending(&mut self, stream: &mut Stream) -> Result<()> + /// Writes any data in the out_buffer, `additional_send` and given `data`. + /// + /// Does **not** flush. + /// + /// Returns true if the write contents indicate we should flush immediately. + fn _write(&mut self, stream: &mut Stream, data: Option) -> Result where Stream: Read + Write, { - // First, make sure we have no pending frame sending. - self.frame.write_pending(stream)?; + match data { + Some(data) => self.write_one_frame(stream, data)?, + None => self.frame.write_out_buffer(stream)?, + } // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // response, unless it already received a Close frame. It SHOULD // respond with Pong frame as soon as is practical. (RFC 6455) - if let Some(pong) = self.pong.take() { - trace!("Sending pong reply"); - self.send_one_frame(stream, pong)?; - } - // If we have any unsent frames, send them. - trace!("Frames still in queue: {}", self.send_queue.len()); - while let Some(data) = self.send_queue.pop_front() { - self.send_one_frame(stream, data)?; - } - - // If we get to this point, the send queue is empty and the underlying socket is still - // willing to take more data. + let should_flush = if let Some(msg) = self.additional_send.take() { + trace!("Sending pong/close"); + match self.write_one_frame(stream, msg) { + Err(Error::WriteBufferFull(Message::Frame(msg))) => { + // if an system message would exceed the buffer put it back in + // `additional_send` for retry. Otherwise returning this error + // may not make sense to the user, e.g. calling `flush`. + self.set_additional(msg); + false + } + Err(err) => return Err(err), + Ok(_) => true, + } + } else { + false + }; // If we're closing and there is nothing to send anymore, we should close the connection. if self.role == Role::Server && !self.state.can_read() { @@ -398,7 +460,7 @@ impl WebSocketContext { self.state = WebSocketState::Terminated; Err(Error::ConnectionClosed) } else { - Ok(()) + Ok(should_flush) } } @@ -414,11 +476,11 @@ impl WebSocketContext { if let WebSocketState::Active = self.state { self.state = WebSocketState::ClosedByUs; let frame = Frame::close(code); - self.send_queue.push_back(frame); + self._write(stream, Some(frame))?; } else { // Already closed, nothing to do. } - self.write_pending(stream) + Ok(stream.flush()?) } /// Try to decode one message frame. May return None. @@ -487,7 +549,7 @@ impl WebSocketContext { let data = frame.into_data(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.pong = Some(Frame::pong(data.clone())); + self.set_additional(Frame::pong(data.clone())); } Ok(Some(Message::Ping(data))) } @@ -571,7 +633,7 @@ impl WebSocketContext { let reply = Frame::close(close.clone()); debug!("Replying to close with {:?}", reply); - self.send_queue.push_back(reply); + self.set_additional(reply); Some(close) } @@ -588,8 +650,8 @@ impl WebSocketContext { } } - /// Send a single pending frame. - fn send_one_frame(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> + /// Write a single frame into the stream via the write-buffer. + fn write_one_frame(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> where Stream: Read + Write, { @@ -605,6 +667,17 @@ impl WebSocketContext { trace!("Sending frame: {:?}", frame); self.frame.write_frame(stream, frame).check_connection_reset(self.state) } + + /// Replace `additional_send` if it is currently a `Pong` message. + fn set_additional(&mut self, add: Frame) { + let empty_or_pong = self + .additional_send + .as_ref() + .map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong)); + if empty_or_pong { + self.additional_send.replace(add); + } + } } /// The current connection state. @@ -636,7 +709,7 @@ impl WebSocketState { } /// Check if the state is active, return error if not. - fn check_active(self) -> Result<()> { + fn check_not_terminated(self) -> Result<()> { match self { WebSocketState::Terminated => Err(Error::AlreadyClosed), _ => Ok(()), @@ -688,64 +761,6 @@ mod tests { } } - struct WouldBlockStreamMoc; - - impl io::Write for WouldBlockStreamMoc { - fn write(&mut self, _: &[u8]) -> io::Result { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - fn flush(&mut self) -> io::Result<()> { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - } - - impl io::Read for WouldBlockStreamMoc { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - } - - #[test] - fn queue_logic() { - // Create a socket with the queue size of 1. - let mut socket = WebSocket::from_raw_socket( - WouldBlockStreamMoc, - Role::Client, - Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }), - ); - - // Test message that we're going to send. - let message = Message::Binary(vec![0xFF; 1024]); - - // Helper to check the error. - let assert_would_block = |error| { - if let Error::Io(io_error) = error { - assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock); - } else { - panic!("Expected WouldBlock error"); - } - }; - - // The first attempt of writing must not fail, since the queue is empty at start. - // But since the underlying mock object always returns `WouldBlock`, so is the result. - assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err())); - - // Any subsequent attempts must return an error telling that the queue is full. - for _i in 0..100 { - assert!(matches!( - socket.write_message(message.clone()).unwrap_err(), - Error::SendQueueFull(..) - )); - } - - // The size of the output buffer must not be bigger than the size of that message - // that we managed to write to the output buffer at first. Since we could not make - // any progress (because of the logic of the moc buffer), the size remains unchanged. - if socket.context.frame.output_buffer_len() > message.len() { - panic!("Too many frames in the queue"); - } - } - #[test] fn receive_messages() { let incoming = Cursor::new(vec![ @@ -754,10 +769,10 @@ mod tests { 0x03, ]); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); - assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); - assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2])); + assert_eq!(socket.read().unwrap(), Message::Pong(vec![3])); + assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into())); + assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); } #[test] @@ -770,7 +785,7 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert!(matches!( - socket.read_message(), + socket.read(), Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) )); } @@ -782,7 +797,7 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert!(matches!( - socket.read_message(), + socket.read(), Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) )); } diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 40b7469..015b4c1 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -52,27 +52,27 @@ fn test_server_close() { do_test( 3012, |mut cli_sock| { - cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); - let message = cli_sock.read_message().unwrap(); // receive close from server + let message = cli_sock.read().unwrap(); // receive close from server assert!(message.is_close()); - let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } }, |mut srv_sock| { - let message = srv_sock.read_message().unwrap(); + let message = srv_sock.read().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); srv_sock.close(None).unwrap(); // send close to client - let message = srv_sock.read_message().unwrap(); // receive acknowledgement + let message = srv_sock.read().unwrap(); // receive acknowledgement assert!(message.is_close()); - let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), @@ -86,26 +86,26 @@ fn test_evil_server_close() { do_test( 3013, |mut cli_sock| { - cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); sleep(Duration::from_secs(1)); - let message = cli_sock.read_message().unwrap(); // receive close from server + let message = cli_sock.read().unwrap(); // receive close from server assert!(message.is_close()); - let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } }, |mut srv_sock| { - let message = srv_sock.read_message().unwrap(); + let message = srv_sock.read().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); srv_sock.close(None).unwrap(); // send close to client - let message = srv_sock.read_message().unwrap(); // receive acknowledgement + let message = srv_sock.read().unwrap(); // receive acknowledgement assert!(message.is_close()); // and now just drop the connection without waiting for `ConnectionClosed` srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); @@ -119,32 +119,32 @@ fn test_client_close() { do_test( 3014, |mut cli_sock| { - cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); - let message = cli_sock.read_message().unwrap(); // receive answer from server + let message = cli_sock.read().unwrap(); // receive answer from server assert_eq!(message.into_data(), b"From Server"); cli_sock.close(None).unwrap(); // send close to server - let message = cli_sock.read_message().unwrap(); // receive acknowledgement from server + let message = cli_sock.read().unwrap(); // receive acknowledgement from server assert!(message.is_close()); - let err = cli_sock.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = cli_sock.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } }, |mut srv_sock| { - let message = srv_sock.read_message().unwrap(); + let message = srv_sock.read().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); - srv_sock.write_message(Message::Text("From Server".into())).unwrap(); + srv_sock.send(Message::Text("From Server".into())).unwrap(); - let message = srv_sock.read_message().unwrap(); // receive close from client + let message = srv_sock.read().unwrap(); // receive close from client assert!(message.is_close()); - let err = srv_sock.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = srv_sock.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index 2182b92..258d09a 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -29,10 +29,10 @@ fn test_no_send_after_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); - let message = client.read_message().unwrap(); // receive close from server + let message = client.read().unwrap(); // receive close from server assert!(message.is_close()); - let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = client.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), @@ -44,7 +44,7 @@ fn test_no_send_after_close() { client_handler.close(None).unwrap(); // send close to client - let err = client_handler.write_message(Message::Text("Hello WebSocket".into())); + let err = client_handler.send(Message::Text("Hello WebSocket".into())); assert!(err.is_err()); diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index c8661c8..af867a7 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -29,12 +29,12 @@ fn test_receive_after_init_close() { let client_thread = spawn(move || { let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); - client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); + client.send(Message::Text("Hello WebSocket".into())).unwrap(); - let message = client.read_message().unwrap(); // receive close from server + let message = client.read().unwrap(); // receive close from server assert!(message.is_close()); - let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = client.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), @@ -47,12 +47,12 @@ fn test_receive_after_init_close() { client_handler.close(None).unwrap(); // send close to client // This read should succeed even though we already initiated a close - let message = client_handler.read_message().unwrap(); + let message = client_handler.read().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); - assert!(client_handler.read_message().unwrap().is_close()); // receive acknowledgement + assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement - let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed + let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed match err { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err),