From 2787031c2f15e8465bae25c46fdbcdeb609d0bb7 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 2 Jul 2021 19:22:47 +0200 Subject: [PATCH] Replace `InputBuffer` with a faster alternative We're also deprecating the usage of `input_buffer` crate, see: https://github.com/snapview/input_buffer/issues/6#issuecomment-870548303 --- Cargo.toml | 2 +- src/buffer.rs | 111 ++++++++++++++++++++++++++++++++++++++ src/client.rs | 3 +- src/error.rs | 5 -- src/handshake/machine.rs | 18 ++----- src/handshake/mod.rs | 5 +- src/lib.rs | 4 ++ src/protocol/frame/mod.rs | 28 ++++------ 8 files changed, 135 insertions(+), 41 deletions(-) create mode 100644 src/buffer.rs diff --git a/Cargo.toml b/Cargo.toml index 2ef26fc..90ef5aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ byteorder = "1.3.2" bytes = "1.0" http = "0.2" httparse = "1.3.4" -input_buffer = "0.4.0" log = "0.4.8" rand = "0.8.0" sha-1 = "0.9" @@ -53,5 +52,6 @@ optional = true version = "0.5.0" [dev-dependencies] +input_buffer = "0.5.0" env_logger = "0.8.1" net2 = "0.2.33" diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..74aef57 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,111 @@ +//! A buffer for reading data from the network. +//! +//! The `ReadBuffer` is a buffer of bytes similar to a first-in, first-out queue. +//! It is filled by reading from a stream supporting `Read` and is then +//! accessible as a cursor for reading bytes. + +use std::io::{Cursor, Read, Result as IoResult}; + +use bytes::Buf; + +/// A FIFO buffer for reading packets from the network. +#[derive(Debug)] +pub struct ReadBuffer { + storage: Cursor>, + chunk: [u8; CHUNK_SIZE], +} + +impl ReadBuffer { + /// Create a new empty input buffer. + pub fn new() -> Self { + Self::with_capacity(CHUNK_SIZE) + } + + /// Create a new empty input buffer with a given `capacity`. + pub fn with_capacity(capacity: usize) -> Self { + Self::from_partially_read(Vec::with_capacity(capacity)) + } + + /// Create a input buffer filled with previously read data. + pub fn from_partially_read(part: Vec) -> Self { + Self { storage: Cursor::new(part), chunk: [0; CHUNK_SIZE] } + } + + /// Get a cursor to the data storage. + pub fn as_cursor(&self) -> &Cursor> { + &self.storage + } + + /// Get a cursor to the mutable data storage. + pub fn as_cursor_mut(&mut self) -> &mut Cursor> { + &mut self.storage + } + + /// Consume the `ReadBuffer` and get the internal storage. + pub fn into_vec(mut self) -> Vec { + // Current implementation of `tungstenite-rs` expects that the `into_vec()` drains + // the data from the container that has already been read by the cursor. + let pos = self.storage.position() as usize; + self.storage.get_mut().drain(0..pos).count(); + self.storage.set_position(0); + + // Now we can safely return the internal container. + self.storage.into_inner() + } + + /// Read next portion of data from the given input stream. + pub fn read_from(&mut self, stream: &mut S) -> IoResult { + let size = stream.read(&mut self.chunk)?; + self.storage.get_mut().extend_from_slice(&self.chunk[..size]); + Ok(size) + } +} + +impl Buf for ReadBuffer { + fn remaining(&self) -> usize { + Buf::remaining(self.as_cursor()) + } + + fn chunk(&self) -> &[u8] { + Buf::chunk(self.as_cursor()) + } + + fn advance(&mut self, cnt: usize) { + Buf::advance(self.as_cursor_mut(), cnt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple_reading() { + let mut input = Cursor::new(b"Hello World!".to_vec()); + let mut buffer = ReadBuffer::<4096>::new(); + let size = buffer.read_from(&mut input).unwrap(); + assert_eq!(size, 12); + assert_eq!(buffer.chunk(), b"Hello World!"); + } + + #[test] + fn reading_in_chunks() { + let mut inp = Cursor::new(b"Hello World!".to_vec()); + let mut buf = ReadBuffer::<4>::new(); + + let size = buf.read_from(&mut inp).unwrap(); + assert_eq!(size, 4); + assert_eq!(buf.chunk(), b"Hell"); + + buf.advance(2); + assert_eq!(buf.chunk(), b"ll"); + + let size = buf.read_from(&mut inp).unwrap(); + assert_eq!(size, 4); + assert_eq!(buf.chunk(), b"llo Wo"); + + let size = buf.read_from(&mut inp).unwrap(); + assert_eq!(size, 4); + assert_eq!(buf.chunk(), b"llo World!"); + } +} diff --git a/src/client.rs b/src/client.rs index 55ba080..d4d9492 100644 --- a/src/client.rs +++ b/src/client.rs @@ -72,7 +72,8 @@ mod encryption { Mode::Tls => { let config = { let mut config = ClientConfig::new(); - config.root_store = rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?; + config.root_store = + rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?; Arc::new(config) }; diff --git a/src/error.rs b/src/error.rs index 406e64f..6ab7420 100644 --- a/src/error.rs +++ b/src/error.rs @@ -127,8 +127,6 @@ pub enum CapacityError { #[error("Too many headers")] TooManyHeaders, /// Received header is too long. - #[error("Header too long")] - HeaderTooLong, /// Message is bigger than the maximum allowed size. #[error("Message too long: {size} > {max_size}")] MessageTooLong { @@ -137,9 +135,6 @@ pub enum CapacityError { /// The maximum allowed message size. max_size: usize, }, - /// TCP buffer is full. - #[error("Incoming TCP buffer is full")] - TcpBufferFull, } /// Indicates the specific type/cause of a protocol error. diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index ced0153..83dae1f 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,10 +3,10 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{CapacityError, Error, ProtocolError, Result}, + error::{Error, ProtocolError, Result}, util::NonBlockingResult, + ReadBuffer, }; -use input_buffer::{InputBuffer, MIN_READ}; /// A generic handshake state machine. #[derive(Debug)] @@ -18,10 +18,7 @@ pub struct HandshakeMachine { impl HandshakeMachine { /// Start reading data from the peer. pub fn start_read(stream: Stream) -> Self { - HandshakeMachine { - stream, - state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), - } + HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) } } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { @@ -43,12 +40,7 @@ impl HandshakeMachine { trace!("Doing handshake round."); match self.state { HandshakeState::Reading(mut buf) => { - let read = buf - .prepare_reserve(MIN_READ) - .with_limit(usize::max_value()) // TODO limit size - .map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))? - .read_from(&mut self.stream) - .no_block()?; + let read = buf.read_from(&mut self.stream).no_block()?; match read { Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { @@ -124,7 +116,7 @@ pub trait TryParse: Sized { #[derive(Debug)] enum HandshakeState { /// Reading data from the peer. - Reading(InputBuffer), + Reading(ReadBuffer), /// Sending data to the peer. Writing(Cursor>), } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index c2c63de..e063d4a 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -131,9 +131,6 @@ mod tests { #[test] fn key_conversion() { // example from RFC 6455 - assert_eq!( - derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), - "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" - ); + assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } } diff --git a/src/lib.rs b/src/lib.rs index 82f7822..4ac56d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub use http; +mod buffer; pub mod client; pub mod error; pub mod handshake; @@ -22,6 +23,9 @@ pub mod server; pub mod stream; pub mod util; +const READ_BUFFER_CHUNK_SIZE: usize = 4096; +type ReadBuffer = buffer::ReadBuffer; + pub use crate::{ client::{client, connect}, error::{Error, Result}, diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 1e41853..3c45dd9 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -6,12 +6,15 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::{CloseFrame, Frame, FrameHeader}; +use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; -use crate::error::{CapacityError, Error, Result}; -use input_buffer::{InputBuffer, MIN_READ}; use log::*; -use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; + +pub use self::frame::{CloseFrame, Frame, FrameHeader}; +use crate::{ + error::{CapacityError, Error, Result}, + ReadBuffer, +}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -82,7 +85,7 @@ where #[derive(Debug)] pub(super) struct FrameCodec { /// Buffer to read data from the stream. - in_buffer: InputBuffer, + in_buffer: ReadBuffer, /// Buffer to send packets to the network. out_buffer: Vec, /// Header and remaining size of the incoming packet being processed. @@ -92,17 +95,13 @@ pub(super) struct FrameCodec { impl FrameCodec { /// Create a new frame codec. pub(super) fn new() -> Self { - Self { - in_buffer: InputBuffer::with_capacity(MIN_READ), - out_buffer: Vec::new(), - header: None, - } + Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None } } /// Create a new frame codec from partially read data. pub(super) fn from_partially_read(part: Vec) -> Self { Self { - in_buffer: InputBuffer::from_partially_read(part), + in_buffer: ReadBuffer::from_partially_read(part), out_buffer: Vec::new(), header: None, } @@ -152,12 +151,7 @@ impl FrameCodec { } // Not enough data in buffer. - let size = self - .in_buffer - .prepare_reserve(MIN_READ) - .with_limit(usize::max_value()) - .map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))? - .read_from(stream)?; + let size = self.in_buffer.read_from(stream)?; if size == 0 { trace!("no frame received"); return Ok(None);