diff --git a/examples/srv_accept_unmasked_frames.rs b/examples/srv_accept_unmasked_frames.rs index b280fba..0614214 100644 --- a/examples/srv_accept_unmasked_frames.rs +++ b/examples/srv_accept_unmasked_frames.rs @@ -27,14 +27,12 @@ fn main() { }; let config = Some(WebSocketConfig { - max_send_queue: None, - max_message_size: None, - max_frame_size: None, // This setting allows to accept client frames which are not masked // This is not in compliance with RFC 6455 but might be handy in some // rare cases where it is necessary to integrate with existing/legacy // clients which are sending unmasked frames accept_unmasked_frames: true, + ..<_>::default() }); let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap(); diff --git a/src/error.rs b/src/error.rs index b9a957f..a7b3354 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,9 +53,9 @@ pub enum Error { /// Protocol violation. #[error("WebSocket protocol error: {0}")] Protocol(#[from] ProtocolError), - /// Message send queue full. - #[error("Send queue is full")] - SendQueueFull(Message), + /// Message write buffer is full. + #[error("Write buffer is full")] + WriteBufferFull(Message), /// UTF coding error. #[error("UTF-8 encoding error")] Utf8, diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index c25837e..1cdf376 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,15 +6,14 @@ pub mod coding; mod frame; mod mask; -use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; - -use log::*; - -pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::{ error::{CapacityError, Error, Result}, - ReadBuffer, + Message, ReadBuffer, }; +use log::*; +use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; + +pub use self::frame::{CloseFrame, Frame, FrameHeader}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -89,6 +88,8 @@ pub(super) struct FrameCodec { in_buffer: ReadBuffer, /// Buffer to send packets to the network. out_buffer: Vec, + /// Capacity limit for `out_buffer`. + max_out_buffer_len: usize, /// Header and remaining size of the incoming packet being processed. header: Option<(FrameHeader, u64)>, } @@ -96,7 +97,12 @@ pub(super) struct FrameCodec { impl FrameCodec { /// Create a new frame codec. pub(super) fn new() -> Self { - Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None } + Self { + in_buffer: ReadBuffer::new(), + out_buffer: Vec::new(), + max_out_buffer_len: usize::MAX, + header: None, + } } /// Create a new frame codec from partially read data. @@ -104,10 +110,22 @@ impl FrameCodec { Self { in_buffer: ReadBuffer::from_partially_read(part), out_buffer: Vec::new(), + max_out_buffer_len: usize::MAX, header: None, } } + /// Sets a maximum size for the out buffer. + pub(super) fn with_max_out_buffer_len(mut self, max: usize) -> Self { + self.max_out_buffer_len = max; + self + } + + /// Sets a maximum size for the out buffer. + pub(super) fn set_max_out_buffer_len(&mut self, max: usize) { + self.max_out_buffer_len = max; + } + /// Read a frame from the provided stream. pub(super) fn read_frame( &mut self, @@ -173,10 +191,25 @@ impl FrameCodec { where Stream: Write, { + if frame.len() + self.out_buffer.len() > self.max_out_buffer_len { + return Err(Error::WriteBufferFull(Message::Frame(frame))); + } + trace!("writing frame {}", frame); + self.out_buffer.reserve(frame.len()); frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); + self.write_out_buffer(stream) + } + + /// Write any buffered frames to the provided stream. + /// + /// Does **not** flush. + pub(super) fn write_out_buffer(&mut self, stream: &mut Stream) -> Result<()> + where + Stream: Write, + { while !self.out_buffer.is_empty() { let len = stream.write(&self.out_buffer)?; if len == 0 { @@ -194,14 +227,6 @@ impl FrameCodec { } } -#[cfg(test)] -impl FrameCodec { - /// Returns the size of the output buffer. - pub(super) fn output_buffer_len(&self) -> usize { - self.out_buffer.len() - } -} - #[cfg(test)] mod tests { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 00f0b63..9839b31 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -6,13 +6,6 @@ mod message; pub use self::{frame::CloseFrame, message::Message}; -use log::*; -use std::{ - collections::VecDeque, - io::{ErrorKind as IoErrorKind, Read, Write}, - mem::replace, -}; - use self::{ frame::{ coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, @@ -24,6 +17,11 @@ use crate::{ error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; +use log::*; +use std::{ + io::{ErrorKind as IoErrorKind, Read, Write}, + mem::replace, +}; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -37,10 +35,12 @@ pub enum Role { /// The configuration for WebSocket connection. #[derive(Debug, Clone, Copy)] pub struct WebSocketConfig { - /// The size of the send queue. You can use it to turn on/off the backpressure features. `None` - /// means here that the size of the queue is unlimited. The default value is the unlimited - /// queue. + /// Does nothing, instead use `max_write_buffer_size`. + #[deprecated] pub max_send_queue: Option, + /// The max size of the write buffer in bytes. Setting this can provide backpressure. + /// The default value is unlimited. + pub max_write_buffer_size: usize, /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// which should be reasonably big for all normal use-cases but small enough to prevent /// memory eating by a malicious user. @@ -60,8 +60,10 @@ pub struct WebSocketConfig { impl Default for WebSocketConfig { fn default() -> Self { + #[allow(deprecated)] WebSocketConfig { max_send_queue: None, + max_write_buffer_size: usize::MAX, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false, @@ -162,17 +164,24 @@ impl WebSocket { self.context.read_message(&mut self.socket) } - /// Send a message to stream, if possible. + /// Write a message to the provided stream, if possible. + /// + /// A subsequent call should be made to [`Self::write_pending`] to flush writes. /// - /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping - /// requests. A Pong reply will jump the queue because the + /// In the event of stream write failure the message frame will be stored + /// in the write buffer and will try again on the next call to [`Self::write_message`] or [`Self::write_pending`]. + /// + /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] + /// `Err(WriteBufferFull(msg_frame))` is returned. + /// + /// This call will not flush, except to reply to Ping + /// requests. A Pong reply will flush early because the /// [websocket RFC](https://tools.ietf.org/html/rfc6455#section-5.5.2) specifies it should be sent /// as soon as is practical. /// /// Note that upon receiving a ping message, tungstenite cues a pong reply automatically. - /// When you call either `read_message`, `write_message` or `write_pending` next it will try to send - /// that pong out if the underlying connection can take more data. This means you should not - /// respond to ping frames manually. + /// When you call either `read_message`, `write_message` or `write_pending` next it will try to + /// write & flush the pong reply if possible. This means you should not respond to ping frames manually. /// /// You can however send pong frames manually in order to indicate a unidirectional heartbeat /// as described in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.5.3). Note that @@ -181,19 +190,19 @@ impl WebSocket { /// ping will not be sent, but rather replaced by your custom pong message. /// /// ## Errors - /// - If the WebSocket's send queue is full, `SendQueueFull` will be returned - /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. - /// - If the connection is closed and should be dropped, this will return [Error::ConnectionClosed]. - /// - If you try again after [Error::ConnectionClosed] was returned either from here or from `read_message`, - /// [Error::AlreadyClosed] will be returned. This indicates a program error on your part. - /// - [Error::Io] is returned if the underlying connection returns an error + /// - If the WebSocket's write buffer is full, [`Error::WriteBufferFull`] will be returned + /// along with the equivalent passed message frame. Otherwise, the message is queued and Ok(()) is returned. + /// - If the connection is closed and should be dropped, this will return [`Error::ConnectionClosed`]. + /// - If you try again after [`Error::ConnectionClosed`] was returned either from here or from `read_message`, + /// [`Error::AlreadyClosed`] will be returned. This indicates a program error on your part. + /// - [`Error::Io`] is returned if the underlying connection returns an error /// (consider these fatal except for WouldBlock). - /// - [Error::Capacity] if your message size is bigger than the configured max message size. + /// - [`Error::Capacity`] if your message size is bigger than the configured max message size. pub fn write_message(&mut self, message: Message) -> Result<()> { self.context.write_message(&mut self.socket, message) } - /// Flush the pending send queue. + /// Flush pending writes. pub fn write_pending(&mut self) -> Result<()> { self.context.write_pending(&mut self.socket) } @@ -235,10 +244,8 @@ pub struct WebSocketContext { state: WebSocketState, /// Receive: an incomplete message being processed. incomplete: Option, - /// Send: a data send queue. - send_queue: VecDeque, - /// Send: an OOB pong message. - pong: Option, + /// Send in addition to regular messages E.g. "pong" or "close". + additional_send: Option, /// The configuration for the websocket session. config: WebSocketConfig, } @@ -246,28 +253,32 @@ pub struct WebSocketContext { impl WebSocketContext { /// Create a WebSocket context that manages a post-handshake stream. pub fn new(role: Role, config: Option) -> Self { + let config = config.unwrap_or_default(); + WebSocketContext { role, - frame: FrameCodec::new(), + frame: FrameCodec::new().with_max_out_buffer_len(config.max_write_buffer_size), state: WebSocketState::Active, incomplete: None, - send_queue: VecDeque::new(), - pong: None, - config: config.unwrap_or_default(), + additional_send: None, + config, } } /// Create a WebSocket context that manages an post-handshake stream. pub fn from_partially_read(part: Vec, role: Role, config: Option) -> Self { + let config = config.unwrap_or_default(); WebSocketContext { - frame: FrameCodec::from_partially_read(part), - ..WebSocketContext::new(role, config) + frame: FrameCodec::from_partially_read(part) + .with_max_out_buffer_len(config.max_write_buffer_size), + ..WebSocketContext::new(role, Some(config)) } } /// Change the configuration. pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { - set_func(&mut self.config) + set_func(&mut self.config); + self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size); } /// Read the configuration. @@ -299,12 +310,19 @@ impl WebSocketContext { Stream: Read + Write, { // Do not read from already closed connections. - self.state.check_active()?; + self.state.check_not_terminated()?; loop { - // Since we may get ping or close, we need to reply to the messages even during read. - // Thus we call write_pending() but ignore its blocking. - self.write_pending(stream).no_block()?; + if self.additional_send.is_some() { + // Since we may get ping or close, we need to reply to the messages even during read. + // Thus we call write_pending() but ignore its blocking. + self.write_pending(stream).no_block()?; + } else if self.role == Role::Server && !self.state.can_read() { + self.state = WebSocketState::Terminated; + return Err(Error::ConnectionClosed); + } + + // TODO don't flush writes when reading // If we get here, either write blocks or we have nothing to write. // Thus if read blocks, just let it return WouldBlock. if let Some(message) = self.read_message_frame(stream)? { @@ -314,87 +332,84 @@ impl WebSocketContext { } } - /// Send a message to the provided stream, if possible. + /// Write a message to the provided stream, if possible. + /// + /// A subsequent call should be made to [`Self::write_pending`] to flush writes. /// - /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping - /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned - /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. + /// In the event of stream write failure the message frame will be stored + /// in the write buffer and will try again on the next call to [`Self::write_message`] or [`Self::write_pending`]. /// - /// Note that only the last pong frame is stored to be sent, and only the + /// If the write buffer would exceed the configured [`WebSocketConfig::max_write_buffer_size`] + /// `Err(WriteBufferFull(msg_frame))` is returned. + /// + /// Note that only the latest pong frame is stored to be sent, so only the /// most recent pong frame is sent if multiple pong frames are queued. pub fn write_message(&mut self, stream: &mut Stream, message: Message) -> Result<()> where Stream: Read + Write, { // When terminated, return AlreadyClosed. - self.state.check_active()?; + self.state.check_not_terminated()?; // Do not write after sending a close frame. if !self.state.is_active() { return Err(Error::Protocol(ProtocolError::SendAfterClosing)); } - if let Some(max_send_queue) = self.config.max_send_queue { - if self.send_queue.len() >= max_send_queue { - // Try to make some room for the new message. - // Do not return here if write would block, ignore WouldBlock silently - // since we must queue the message anyway. - self.write_pending(stream).no_block()?; - } - - if self.send_queue.len() >= max_send_queue { - return Err(Error::SendQueueFull(message)); - } - } - let frame = match message { Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { - self.pong = Some(Frame::pong(data)); - return self.write_queue(stream); + self.set_additional(Frame::pong(data)); + // Note: user pongs can be user flushed so no need to flush here + return self.write(stream, None).map(|_| ()); } Message::Close(code) => return self.close(stream, code), Message::Frame(f) => f, }; - self.send_queue.push_back(frame); - self.write_queue(stream) + let should_flush = self.write(stream, Some(frame))?; + if should_flush { + self.write_pending(stream)?; + } + Ok(()) } - /// Flush the pending send queue. + /// Flush pending writes. #[inline] pub fn write_pending(&mut self, stream: &mut Stream) -> Result<()> where Stream: Read + Write, { - self.write_queue(stream)?; + _ = self.write(stream, None)?; Ok(stream.flush()?) } /// Write send queue & pongs. /// /// Does **not** flush. - fn write_queue(&mut self, stream: &mut Stream) -> Result<()> + /// + /// Returns if the write contents indicate we should flush immediately. + fn write(&mut self, stream: &mut Stream, data: Option) -> Result where Stream: Read + Write, { + match data { + Some(data) => self.write_one_frame(stream, data)?, + None => self.frame.write_out_buffer(stream)?, + } + // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // response, unless it already received a Close frame. It SHOULD // respond with Pong frame as soon as is practical. (RFC 6455) - if let Some(pong) = self.pong.take() { - trace!("Sending pong reply"); - self.send_one_frame(stream, pong)?; - } - // If we have any unsent frames, send them. - trace!("Frames still in queue: {}", self.send_queue.len()); - while let Some(data) = self.send_queue.pop_front() { - self.send_one_frame(stream, data)?; - } - - // If we get to this point, the send queue is empty and the underlying socket is still - // willing to take more data. + let should_flush = if let Some(msg) = self.additional_send.take() { + trace!("Sending pong/close"); + self.write_one_frame(stream, msg)?; + true + } else { + false + }; // If we're closing and there is nothing to send anymore, we should close the connection. if self.role == Role::Server && !self.state.can_read() { @@ -407,7 +422,7 @@ impl WebSocketContext { self.state = WebSocketState::Terminated; Err(Error::ConnectionClosed) } else { - Ok(()) + Ok(should_flush) } } @@ -423,11 +438,11 @@ impl WebSocketContext { if let WebSocketState::Active = self.state { self.state = WebSocketState::ClosedByUs; let frame = Frame::close(code); - self.send_queue.push_back(frame); + self.write(stream, Some(frame))?; } else { // Already closed, nothing to do. } - self.write_pending(stream) + Ok(stream.flush()?) } /// Try to decode one message frame. May return None. @@ -496,7 +511,7 @@ impl WebSocketContext { let data = frame.into_data(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.pong = Some(Frame::pong(data.clone())); + self.set_additional(Frame::pong(data.clone())); } Ok(Some(Message::Ping(data))) } @@ -580,7 +595,7 @@ impl WebSocketContext { let reply = Frame::close(close.clone()); debug!("Replying to close with {:?}", reply); - self.send_queue.push_back(reply); + self.set_additional(reply); Some(close) } @@ -597,8 +612,8 @@ impl WebSocketContext { } } - /// Send a single pending frame. - fn send_one_frame(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> + /// Write a single frame into the stream via the write-buffer. + fn write_one_frame(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> where Stream: Read + Write, { @@ -614,6 +629,17 @@ impl WebSocketContext { trace!("Sending frame: {:?}", frame); self.frame.write_frame(stream, frame).check_connection_reset(self.state) } + + /// Replace `additional_send` if it is currently a `Pong` message. + fn set_additional(&mut self, add: Frame) { + let empty_or_pong = self + .additional_send + .as_ref() + .map_or(true, |f| f.header().opcode == OpCode::Control(OpCtl::Pong)); + if empty_or_pong { + self.additional_send.replace(add); + } + } } /// The current connection state. @@ -645,7 +671,7 @@ impl WebSocketState { } /// Check if the state is active, return error if not. - fn check_active(self) -> Result<()> { + fn check_not_terminated(self) -> Result<()> { match self { WebSocketState::Terminated => Err(Error::AlreadyClosed), _ => Ok(()), @@ -697,64 +723,6 @@ mod tests { } } - struct WouldBlockStreamMoc; - - impl io::Write for WouldBlockStreamMoc { - fn write(&mut self, _: &[u8]) -> io::Result { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - fn flush(&mut self) -> io::Result<()> { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - } - - impl io::Read for WouldBlockStreamMoc { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) - } - } - - #[test] - fn queue_logic() { - // Create a socket with the queue size of 1. - let mut socket = WebSocket::from_raw_socket( - WouldBlockStreamMoc, - Role::Client, - Some(WebSocketConfig { max_send_queue: Some(1), ..Default::default() }), - ); - - // Test message that we're going to send. - let message = Message::Binary(vec![0xFF; 1024]); - - // Helper to check the error. - let assert_would_block = |error| { - if let Error::Io(io_error) = error { - assert_eq!(io_error.kind(), io::ErrorKind::WouldBlock); - } else { - panic!("Expected WouldBlock error"); - } - }; - - // The first attempt of writing must not fail, since the queue is empty at start. - // But since the underlying mock object always returns `WouldBlock`, so is the result. - assert_would_block(dbg!(socket.write_message(message.clone()).unwrap_err())); - - // Any subsequent attempts must return an error telling that the queue is full. - for _i in 0..100 { - assert!(matches!( - socket.write_message(message.clone()).unwrap_err(), - Error::SendQueueFull(..) - )); - } - - // The size of the output buffer must not be bigger than the size of that message - // that we managed to write to the output buffer at first. Since we could not make - // any progress (because of the logic of the moc buffer), the size remains unchanged. - if socket.context.frame.output_buffer_len() > message.len() { - panic!("Too many frames in the queue"); - } - } - #[test] fn receive_messages() { let incoming = Cursor::new(vec![