From be3cfdca90675778644c10093af01596b82419c3 Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Sun, 14 Feb 2021 20:51:12 +0100 Subject: [PATCH] do payload masking directly during reading and writing of the frame this removes the need to have unique mutable access to the payload data --- src/protocol/frame/frame.rs | 18 ++++++------------ src/protocol/frame/mask.rs | 36 +++++++++++++++++++----------------- src/protocol/frame/mod.rs | 27 +++++++++++++++++++++++++-- src/protocol/mod.rs | 8 ++------ 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 2d4b244..34fa95b 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -11,7 +11,7 @@ use std::{ use super::{ coding::{CloseCode, Control, Data, OpCode}, - mask::{apply_mask, generate_mask}, + mask::{generate_mask, write_masked}, }; use crate::error::{Error, ProtocolError, Result}; use crate::protocol::data::MessageData; @@ -253,15 +253,6 @@ impl Frame { self.header.set_random_mask() } - /// This method unmasks the payload and should only be called on frames that are actually - /// masked. In other words, those frames that have just been received from a client endpoint. - #[inline] - pub(crate) fn apply_mask(&mut self) { - if let Some(mask) = self.header.mask.take() { - apply_mask(self.payload.as_mut(), mask) - } - } - /// Consume the frame into its payload as binary. #[inline] pub fn into_data(self) -> Vec { @@ -351,8 +342,11 @@ impl Frame { /// Write a frame out to a buffer pub fn format(mut self, output: &mut impl Write) -> Result<()> { self.header.format(self.payload.len() as u64, output)?; - self.apply_mask(); - output.write_all(self.payload())?; + if let Some(mask) = self.header.mask.take() { + write_masked(self.payload(), output, mask) + } else { + output.write_all(self.payload())?; + } Ok(()) } } diff --git a/src/protocol/frame/mask.rs b/src/protocol/frame/mask.rs index 28f0eaf..1352483 100644 --- a/src/protocol/frame/mask.rs +++ b/src/protocol/frame/mask.rs @@ -1,30 +1,31 @@ +use std::io::Write; + /// Generate a random frame mask. #[inline] pub fn generate_mask() -> [u8; 4] { rand::random() } -/// Mask/unmask a frame. -#[inline] -pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) { - apply_mask_fast32(buf, mask) +/// Write data to an output, masking the data in the process +pub fn write_masked(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { + write_mask_fast32(data, output, mask) } /// A safe unoptimized mask application. #[inline] -fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { - for (i, byte) in buf.iter_mut().enumerate() { - *byte ^= mask[i & 3]; +fn write_mask_fallback(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { + for (i, byte) in data.iter().enumerate() { + output.write(&[*byte ^ mask[i & 3]]).unwrap(); } } /// Faster version of `apply_mask()` which operates on 4-byte blocks. #[inline] -pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { +fn write_mask_fast32(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { let mask_u32 = u32::from_ne_bytes(mask); - let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::() }; - apply_mask_fallback(&mut prefix, mask); + let (mut prefix, words, mut suffix) = unsafe { data.align_to::() }; + write_mask_fallback(&mut prefix, output, mask); let head = prefix.len() & 3; let mask_u32 = if head > 0 { if cfg!(target_endian = "big") { @@ -35,10 +36,11 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { } else { mask_u32 }; - for word in words.iter_mut() { - *word ^= mask_u32; + for word in words { + let bytes = (*word ^ mask_u32).to_ne_bytes(); + output.write(&bytes).unwrap(); } - apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); + write_mask_fallback(&mut suffix, output, mask_u32.to_ne_bytes()); } #[cfg(test)] @@ -60,11 +62,11 @@ mod tests { if unmasked.len() < off { continue; } - let mut masked = unmasked.to_vec(); - apply_mask_fallback(&mut masked[off..], mask); + let mut masked = Vec::new(); + write_mask_fallback(&unmasked, &mut masked, mask); - let mut masked_fast = unmasked.to_vec(); - apply_mask_fast32(&mut masked_fast[off..], mask); + let mut masked_fast = Vec::new(); + write_mask_fast32(&unmasked, &mut masked_fast, mask); assert_eq!(masked, masked_fast); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 1e41853..a1e49ac 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -4,11 +4,16 @@ pub mod coding; #[allow(clippy::module_inception)] mod frame; +#[cfg(feature = "__expose_benchmark_fn")] +#[allow(missing_docs)] +pub mod mask; +#[cfg(not(feature = "__expose_benchmark_fn"))] mod mask; pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::error::{CapacityError, Error, Result}; +use crate::protocol::frame::mask::write_masked; use input_buffer::{InputBuffer, MIN_READ}; use log::*; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; @@ -142,10 +147,28 @@ impl FrameCodec { let input_size = cursor.get_ref().len() as u64 - cursor.position(); if length <= input_size { // No truncation here since `length` is checked above - let mut payload = Vec::with_capacity(length as usize); + + // take a slice from the cursor + let payload_input = &cursor.get_ref().as_slice() + [(cursor.position() as usize)..(cursor.position() + length) as usize]; + + let mut payload = Vec::new(); if length > 0 { - cursor.take(length).read_to_end(&mut payload)?; + if let Some(mask) = + self.header.as_ref().and_then(|header| header.0.mask) + { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + + payload = Vec::with_capacity(length as usize); + write_masked(payload_input, &mut payload, mask); + } else { + payload = payload_input.to_vec(); + } } + + cursor.set_position(cursor.position() + length); + break payload; } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 122183e..49f65cc 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -426,7 +426,7 @@ impl WebSocketContext { where Stream: Read + Write, { - if let Some(mut frame) = self + if let Some(frame) = self .frame .read_frame(stream, self.config.max_frame_size) .check_connection_reset(self.state)? @@ -448,11 +448,7 @@ impl WebSocketContext { match self.role { Role::Server => { - if frame.is_masked() { - // A server MUST remove masking for data frames received from a client - // as described in Section 5.3. (RFC 6455) - frame.apply_mask() - } else if !self.config.accept_unmasked_frames { + if !frame.is_masked() && !self.config.accept_unmasked_frames { // The server MUST close the connection upon receiving a // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given