do payload masking directly during reading and writing of the frame

this removes the need to have unique mutable access to the payload data
pull/175/head
Robin Appelman 4 years ago
parent a73434bfc4
commit be3cfdca90
  1. 18
      src/protocol/frame/frame.rs
  2. 36
      src/protocol/frame/mask.rs
  3. 27
      src/protocol/frame/mod.rs
  4. 8
      src/protocol/mod.rs

@ -11,7 +11,7 @@ use std::{
use super::{ use super::{
coding::{CloseCode, Control, Data, OpCode}, coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask}, mask::{generate_mask, write_masked},
}; };
use crate::error::{Error, ProtocolError, Result}; use crate::error::{Error, ProtocolError, Result};
use crate::protocol::data::MessageData; use crate::protocol::data::MessageData;
@ -253,15 +253,6 @@ impl Frame {
self.header.set_random_mask() self.header.set_random_mask()
} }
/// This method unmasks the payload and should only be called on frames that are actually
/// masked. In other words, those frames that have just been received from a client endpoint.
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(self.payload.as_mut(), mask)
}
}
/// Consume the frame into its payload as binary. /// Consume the frame into its payload as binary.
#[inline] #[inline]
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
@ -351,8 +342,11 @@ impl Frame {
/// Write a frame out to a buffer /// Write a frame out to a buffer
pub fn format(mut self, output: &mut impl Write) -> Result<()> { pub fn format(mut self, output: &mut impl Write) -> Result<()> {
self.header.format(self.payload.len() as u64, output)?; self.header.format(self.payload.len() as u64, output)?;
self.apply_mask(); if let Some(mask) = self.header.mask.take() {
output.write_all(self.payload())?; write_masked(self.payload(), output, mask)
} else {
output.write_all(self.payload())?;
}
Ok(()) Ok(())
} }
} }

@ -1,30 +1,31 @@
use std::io::Write;
/// Generate a random frame mask. /// Generate a random frame mask.
#[inline] #[inline]
pub fn generate_mask() -> [u8; 4] { pub fn generate_mask() -> [u8; 4] {
rand::random() rand::random()
} }
/// Mask/unmask a frame. /// Write data to an output, masking the data in the process
#[inline] pub fn write_masked(data: &[u8], output: &mut impl Write, mask: [u8; 4]) {
pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) { write_mask_fast32(data, output, mask)
apply_mask_fast32(buf, mask)
} }
/// A safe unoptimized mask application. /// A safe unoptimized mask application.
#[inline] #[inline]
fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { fn write_mask_fallback(data: &[u8], output: &mut impl Write, mask: [u8; 4]) {
for (i, byte) in buf.iter_mut().enumerate() { for (i, byte) in data.iter().enumerate() {
*byte ^= mask[i & 3]; output.write(&[*byte ^ mask[i & 3]]).unwrap();
} }
} }
/// Faster version of `apply_mask()` which operates on 4-byte blocks. /// Faster version of `apply_mask()` which operates on 4-byte blocks.
#[inline] #[inline]
pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { fn write_mask_fast32(data: &[u8], output: &mut impl Write, mask: [u8; 4]) {
let mask_u32 = u32::from_ne_bytes(mask); let mask_u32 = u32::from_ne_bytes(mask);
let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::<u32>() }; let (mut prefix, words, mut suffix) = unsafe { data.align_to::<u32>() };
apply_mask_fallback(&mut prefix, mask); write_mask_fallback(&mut prefix, output, mask);
let head = prefix.len() & 3; let head = prefix.len() & 3;
let mask_u32 = if head > 0 { let mask_u32 = if head > 0 {
if cfg!(target_endian = "big") { if cfg!(target_endian = "big") {
@ -35,10 +36,11 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
} else { } else {
mask_u32 mask_u32
}; };
for word in words.iter_mut() { for word in words {
*word ^= mask_u32; let bytes = (*word ^ mask_u32).to_ne_bytes();
output.write(&bytes).unwrap();
} }
apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); write_mask_fallback(&mut suffix, output, mask_u32.to_ne_bytes());
} }
#[cfg(test)] #[cfg(test)]
@ -60,11 +62,11 @@ mod tests {
if unmasked.len() < off { if unmasked.len() < off {
continue; continue;
} }
let mut masked = unmasked.to_vec(); let mut masked = Vec::new();
apply_mask_fallback(&mut masked[off..], mask); write_mask_fallback(&unmasked, &mut masked, mask);
let mut masked_fast = unmasked.to_vec(); let mut masked_fast = Vec::new();
apply_mask_fast32(&mut masked_fast[off..], mask); write_mask_fast32(&unmasked, &mut masked_fast, mask);
assert_eq!(masked, masked_fast); assert_eq!(masked, masked_fast);
} }

@ -4,11 +4,16 @@ pub mod coding;
#[allow(clippy::module_inception)] #[allow(clippy::module_inception)]
mod frame; mod frame;
#[cfg(feature = "__expose_benchmark_fn")]
#[allow(missing_docs)]
pub mod mask;
#[cfg(not(feature = "__expose_benchmark_fn"))]
mod mask; mod mask;
pub use self::frame::{CloseFrame, Frame, FrameHeader}; pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::error::{CapacityError, Error, Result}; use crate::error::{CapacityError, Error, Result};
use crate::protocol::frame::mask::write_masked;
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use log::*; use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
@ -142,10 +147,28 @@ impl FrameCodec {
let input_size = cursor.get_ref().len() as u64 - cursor.position(); let input_size = cursor.get_ref().len() as u64 - cursor.position();
if length <= input_size { if length <= input_size {
// No truncation here since `length` is checked above // No truncation here since `length` is checked above
let mut payload = Vec::with_capacity(length as usize);
// take a slice from the cursor
let payload_input = &cursor.get_ref().as_slice()
[(cursor.position() as usize)..(cursor.position() + length) as usize];
let mut payload = Vec::new();
if length > 0 { if length > 0 {
cursor.take(length).read_to_end(&mut payload)?; if let Some(mask) =
self.header.as_ref().and_then(|header| header.0.mask)
{
// A server MUST remove masking for data frames received from a client
// as described in Section 5.3. (RFC 6455)
payload = Vec::with_capacity(length as usize);
write_masked(payload_input, &mut payload, mask);
} else {
payload = payload_input.to_vec();
}
} }
cursor.set_position(cursor.position() + length);
break payload; break payload;
} }
} }

@ -426,7 +426,7 @@ impl WebSocketContext {
where where
Stream: Read + Write, Stream: Read + Write,
{ {
if let Some(mut frame) = self if let Some(frame) = self
.frame .frame
.read_frame(stream, self.config.max_frame_size) .read_frame(stream, self.config.max_frame_size)
.check_connection_reset(self.state)? .check_connection_reset(self.state)?
@ -448,11 +448,7 @@ impl WebSocketContext {
match self.role { match self.role {
Role::Server => { Role::Server => {
if frame.is_masked() { if !frame.is_masked() && !self.config.accept_unmasked_frames {
// A server MUST remove masking for data frames received from a client
// as described in Section 5.3. (RFC 6455)
frame.apply_mask()
} else if !self.config.accept_unmasked_frames {
// The server MUST close the connection upon receiving a // The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455) // frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given // The only exception here is if the user explicitly accepts given

Loading…
Cancel
Save