From 59818a01faf28b90f12729b469377c07873ee9f2 Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Tue, 9 Feb 2021 20:19:09 +0100 Subject: [PATCH 1/4] allow sending messages with static data --- src/protocol/frame/frame.rs | 42 ++++++++++++++++++--------- src/protocol/message.rs | 57 +++++++++++++++++++++++-------------- src/protocol/mod.rs | 12 ++++++-- 3 files changed, 72 insertions(+), 39 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 986bba0..bc24b0c 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -201,7 +201,7 @@ impl FrameHeader { #[derive(Debug, Clone)] pub struct Frame { header: FrameHeader, - payload: Vec, + payload: Cow<'static, [u8]>, } impl Frame { @@ -233,13 +233,13 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] - pub fn payload(&self) -> &Vec { + pub fn payload(&self) -> &[u8] { &self.payload } /// Get a mutable reference to the frame's payload. #[inline] - pub fn payload_mut(&mut self) -> &mut Vec { + pub fn payload_mut(&mut self) -> &mut Cow<'static, [u8]> { &mut self.payload } @@ -263,20 +263,28 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - apply_mask(&mut self.payload, mask) + match &mut self.payload { + Cow::Owned(data) => apply_mask(data, mask), + Cow::Borrowed(data) => { + // can't modify static data, so we have to take ownership first + let mut data = data.to_vec(); + apply_mask(&mut data, mask); + self.payload = Cow::Owned(data); + } + } } } /// Consume the frame into its payload as binary. #[inline] pub fn into_data(self) -> Vec { - self.payload + self.payload.into_owned() } /// Consume the frame into its payload as string. #[inline] pub fn into_string(self) -> StdResult { - String::from_utf8(self.payload) + String::from_utf8(self.into_data()) } /// Consume the frame into a closing frame. @@ -286,7 +294,7 @@ impl Frame { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = self.payload; + let mut data = self.into_data(); let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; @@ -297,10 +305,16 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + pub fn message(data: D, opcode: OpCode, is_final: bool) -> Frame + where + D: Into>, + { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } + Frame { + header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, + payload: data.into(), + } } /// Create a new Pong control frame. @@ -311,7 +325,7 @@ impl Frame { opcode: OpCode::Control(Control::Pong), ..FrameHeader::default() }, - payload: data, + payload: Cow::Owned(data), } } @@ -323,7 +337,7 @@ impl Frame { opcode: OpCode::Control(Control::Ping), ..FrameHeader::default() }, - payload: data, + payload: Cow::Owned(data), } } @@ -339,12 +353,12 @@ impl Frame { Vec::new() }; - Frame { header: FrameHeader::default(), payload } + Frame { header: FrameHeader::default(), payload: Cow::Owned(payload) } } /// Create a frame from given header and data. pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { - Frame { header, payload } + Frame { header, payload: Cow::Owned(payload) } } /// Write a frame out to a buffer @@ -462,7 +476,7 @@ mod tests { #[test] fn display() { - let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); + let f = Frame::message(&b"hi there"[..], OpCode::Data(Data::Text), true); let view = format!("{}", f); assert!(view.contains("payload:")); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 6720c3c..3f5accf 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -79,6 +79,7 @@ mod string_collect { } use self::string_collect::StringCollector; +use std::borrow::Cow; /// A struct representing the incomplete message. #[derive(Debug)] @@ -140,10 +141,10 @@ impl IncompleteMessage { /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { - IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), + IncompleteMessageCollector::Binary(v) => Ok(Message::binary(v)), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; - Ok(Message::Text(text)) + Ok(Message::text(text)) } } } @@ -159,9 +160,9 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Cow<'static, str>), /// A binary WebSocket message - Binary(Vec), + Binary(Cow<'static, [u8]>), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes @@ -180,7 +181,12 @@ impl Message { where S: Into, { - Message::Text(string.into()) + Message::Text(Cow::Owned(string.into())) + } + + /// Create a new static text WebSocket message from a &'static str. + pub fn static_text(string: &'static str) -> Message { + Message::Text(Cow::Borrowed(string)) } /// Create a new binary WebSocket message by converting to Vec. @@ -188,7 +194,12 @@ impl Message { where B: Into>, { - Message::Binary(bin.into()) + Message::Binary(Cow::Owned(bin.into())) + } + + /// Create a new static binary WebSocket message from a &'static [u8]. + pub fn static_binary(bin: &'static [u8]) -> Message { + Message::Binary(Cow::Borrowed(bin)) } /// Indicates whether a message is a text message. @@ -218,12 +229,11 @@ impl Message { /// Get the length of the WebSocket message. pub fn len(&self) -> usize { - match *self { - Message::Text(ref string) => string.len(), - Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - data.len() - } - Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), + match self { + Message::Text(string) => string.len(), + Message::Binary(data) => data.len(), + Message::Ping(data) | Message::Pong(data) => data.len(), + Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), } } @@ -236,8 +246,9 @@ impl Message { /// Consume the WebSocket and return it as binary data. pub fn into_data(self) -> Vec { match self { - Message::Text(string) => string.into_bytes(), - Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, + Message::Text(string) => string.into_owned().into_bytes(), + Message::Binary(data) => data.into_owned(), + Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => Vec::new(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } @@ -246,8 +257,11 @@ impl Message { /// Attempt to consume the WebSocket message and convert it to a String. pub fn into_text(self) -> Result { match self { - Message::Text(string) => Ok(string), - Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { + Message::Text(string) => Ok(string.into_owned()), + Message::Binary(data) => { + Ok(String::from_utf8(data.into_owned()).map_err(|err| err.utf8_error())?) + } + Message::Ping(data) | Message::Pong(data) => { Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) } Message::Close(None) => Ok(String::new()), @@ -258,13 +272,12 @@ impl Message { /// Attempt to get a &str from the WebSocket message, /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str> { - match *self { - Message::Text(ref string) => Ok(string), - Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - Ok(str::from_utf8(data)?) - } + match self { + Message::Text(string) => Ok(string.as_ref()), + Message::Binary(data) => Ok(str::from_utf8(data.as_ref())?), + Message::Ping(data) | Message::Pong(data) => Ok(str::from_utf8(data)?), Message::Close(None) => Ok(""), - Message::Close(Some(ref frame)) => Ok(&frame.reason), + Message::Close(Some(frame)) => Ok(&frame.reason), } } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 215b061..787515a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -24,6 +24,7 @@ use crate::{ error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; +use std::borrow::Cow; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -348,7 +349,12 @@ impl WebSocketContext { } let frame = match message { - Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), + Message::Text(Cow::Owned(data)) => { + Frame::message(data.into_bytes(), OpCode::Data(OpData::Text), true) + } + Message::Text(Cow::Borrowed(data)) => { + Frame::message(data.as_bytes(), OpCode::Data(OpData::Text), true) + } Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { @@ -700,8 +706,8 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!(socket.read_message().unwrap(), Message::text("Hello, World!")); + assert_eq!(socket.read_message().unwrap(), Message::binary(vec![0x01, 0x02, 0x03])); } #[test] From a662560c6686742f4463cf3cfb5f2492f11fc3db Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Sat, 13 Feb 2021 18:46:35 +0100 Subject: [PATCH 2/4] switch to custom message data enums to handle shared and static data --- src/protocol/data.rs | 149 ++++++++++++++++++++++++++++++++++++ src/protocol/frame/frame.rs | 41 ++++------ src/protocol/message.rs | 46 +++++------ src/protocol/mod.rs | 9 +-- 4 files changed, 183 insertions(+), 62 deletions(-) create mode 100644 src/protocol/data.rs diff --git a/src/protocol/data.rs b/src/protocol/data.rs new file mode 100644 index 0000000..199a949 --- /dev/null +++ b/src/protocol/data.rs @@ -0,0 +1,149 @@ +use bytes::Bytes; + +/// Binary message data +#[derive(Debug, Clone)] +pub struct MessageData(MessageDataImpl); + +/// opaque inner type to allow modifying the implementation in the future +#[derive(Debug, Clone)] +enum MessageDataImpl { + Shared(Bytes), + Unique(Vec), +} + +impl MessageData { + pub fn len(&self) -> usize { + self.as_ref().len() + } + + fn make_unique(&mut self) { + if let MessageDataImpl::Shared(data) = &self.0 { + self.0 = MessageDataImpl::Unique(Vec::from(data.as_ref())); + } + } +} + +impl PartialEq for MessageData { + fn eq(&self, other: &MessageData) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl Eq for MessageData {} + +impl From for Vec { + fn from(data: MessageData) -> Vec { + match data.0 { + MessageDataImpl::Shared(data) => { + let mut bytes = Vec::with_capacity(data.len()); + bytes.copy_from_slice(data.as_ref()); + bytes + } + MessageDataImpl::Unique(data) => data, + } + } +} + +impl From for Bytes { + fn from(data: MessageData) -> Bytes { + match data.0 { + MessageDataImpl::Shared(data) => data, + MessageDataImpl::Unique(data) => data.into(), + } + } +} + +impl AsRef<[u8]> for MessageData { + fn as_ref(&self) -> &[u8] { + match &self.0 { + MessageDataImpl::Shared(data) => data.as_ref(), + MessageDataImpl::Unique(data) => data.as_ref(), + } + } +} + +impl AsMut<[u8]> for MessageData { + fn as_mut(&mut self) -> &mut [u8] { + self.make_unique(); + match &mut self.0 { + MessageDataImpl::Unique(data) => data.as_mut_slice(), + MessageDataImpl::Shared(_) => unreachable!("Data has just been made unique"), + } + } +} + +/// String message data +#[derive(Debug, Clone)] +pub struct MessageStringData(MessageStringDataImpl); + +/// opaque inner type to allow modifying the implementation in the future +#[derive(Debug, Clone)] +enum MessageStringDataImpl { + Static(&'static str), + Unique(String), +} + +impl PartialEq for MessageStringData { + fn eq(&self, other: &MessageStringData) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl Eq for MessageStringData {} + +impl From for String { + fn from(data: MessageStringData) -> String { + match data.0 { + MessageStringDataImpl::Static(data) => data.into(), + MessageStringDataImpl::Unique(data) => data, + } + } +} + +impl From for MessageData { + fn from(data: MessageStringData) -> MessageData { + match data.0 { + MessageStringDataImpl::Static(data) => MessageData::from(data.as_bytes()), + MessageStringDataImpl::Unique(data) => MessageData::from(data.into_bytes()), + } + } +} + +impl AsRef for MessageStringData { + fn as_ref(&self) -> &str { + match &self.0 { + MessageStringDataImpl::Static(data) => *data, + MessageStringDataImpl::Unique(data) => data.as_ref(), + } + } +} + +impl From for MessageStringData { + fn from(string: String) -> MessageStringData { + MessageStringData(MessageStringDataImpl::Unique(string)) + } +} + +impl From<&'static str> for MessageStringData { + fn from(string: &'static str) -> MessageStringData { + MessageStringData(MessageStringDataImpl::Static(string)) + } +} + +impl From> for MessageData { + fn from(data: Vec) -> MessageData { + MessageData(MessageDataImpl::Unique(data)) + } +} + +impl From<&'static [u8]> for MessageData { + fn from(data: &'static [u8]) -> MessageData { + MessageData(MessageDataImpl::Shared(Bytes::from_static(data))) + } +} + +impl From for MessageData { + fn from(data: Bytes) -> MessageData { + MessageData(MessageDataImpl::Shared(data)) + } +} diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index bc24b0c..2d4b244 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -14,6 +14,7 @@ use super::{ mask::{apply_mask, generate_mask}, }; use crate::error::{Error, ProtocolError, Result}; +use crate::protocol::data::MessageData; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,7 +187,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) + return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))); } _ => (), } @@ -201,7 +202,7 @@ impl FrameHeader { #[derive(Debug, Clone)] pub struct Frame { header: FrameHeader, - payload: Cow<'static, [u8]>, + payload: MessageData, } impl Frame { @@ -234,13 +235,7 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] pub fn payload(&self) -> &[u8] { - &self.payload - } - - /// Get a mutable reference to the frame's payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut Cow<'static, [u8]> { - &mut self.payload + self.payload.as_ref() } /// Test whether the frame is masked. @@ -263,28 +258,20 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - match &mut self.payload { - Cow::Owned(data) => apply_mask(data, mask), - Cow::Borrowed(data) => { - // can't modify static data, so we have to take ownership first - let mut data = data.to_vec(); - apply_mask(&mut data, mask); - self.payload = Cow::Owned(data); - } - } + apply_mask(self.payload.as_mut(), mask) } } /// Consume the frame into its payload as binary. #[inline] pub fn into_data(self) -> Vec { - self.payload.into_owned() + self.payload.into() } /// Consume the frame into its payload as string. #[inline] pub fn into_string(self) -> StdResult { - String::from_utf8(self.into_data()) + String::from_utf8(self.payload.into()) } /// Consume the frame into a closing frame. @@ -297,7 +284,7 @@ impl Frame { let mut data = self.into_data(); let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); - let text = String::from_utf8(data)?; + let text = String::from_utf8(data.into())?; Ok(Some(CloseFrame { code, reason: text.into() })) } } @@ -307,7 +294,7 @@ impl Frame { #[inline] pub fn message(data: D, opcode: OpCode, is_final: bool) -> Frame where - D: Into>, + D: Into, { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); @@ -325,7 +312,7 @@ impl Frame { opcode: OpCode::Control(Control::Pong), ..FrameHeader::default() }, - payload: Cow::Owned(data), + payload: data.into(), } } @@ -337,7 +324,7 @@ impl Frame { opcode: OpCode::Control(Control::Ping), ..FrameHeader::default() }, - payload: Cow::Owned(data), + payload: data.into(), } } @@ -353,12 +340,12 @@ impl Frame { Vec::new() }; - Frame { header: FrameHeader::default(), payload: Cow::Owned(payload) } + Frame { header: FrameHeader::default(), payload: payload.into() } } /// Create a frame from given header and data. pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { - Frame { header, payload: Cow::Owned(payload) } + Frame { header, payload: payload.into() } } /// Write a frame out to a buffer @@ -391,7 +378,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.iter().map(|byte| format!("{:x}", byte)).collect::() + self.payload.as_ref().iter().map(|byte| format!("{:x}", byte)).collect::() ) } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 3f5accf..909c2d8 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -79,7 +79,7 @@ mod string_collect { } use self::string_collect::StringCollector; -use std::borrow::Cow; +use crate::protocol::data::{MessageData, MessageStringData}; /// A struct representing the incomplete message. #[derive(Debug)] @@ -160,9 +160,9 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(Cow<'static, str>), + Text(MessageStringData), /// A binary WebSocket message - Binary(Cow<'static, [u8]>), + Binary(MessageData), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes @@ -179,27 +179,17 @@ impl Message { /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message where - S: Into, + S: Into, { - Message::Text(Cow::Owned(string.into())) - } - - /// Create a new static text WebSocket message from a &'static str. - pub fn static_text(string: &'static str) -> Message { - Message::Text(Cow::Borrowed(string)) + Message::Text(string.into()) } /// Create a new binary WebSocket message by converting to Vec. pub fn binary(bin: B) -> Message where - B: Into>, + B: Into, { - Message::Binary(Cow::Owned(bin.into())) - } - - /// Create a new static binary WebSocket message from a &'static [u8]. - pub fn static_binary(bin: &'static [u8]) -> Message { - Message::Binary(Cow::Borrowed(bin)) + Message::Binary(bin.into()) } /// Indicates whether a message is a text message. @@ -230,8 +220,8 @@ impl Message { /// Get the length of the WebSocket message. pub fn len(&self) -> usize { match self { - Message::Text(string) => string.len(), - Message::Binary(data) => data.len(), + Message::Text(string) => string.as_ref().len(), + Message::Binary(data) => data.as_ref().len(), Message::Ping(data) | Message::Pong(data) => data.len(), Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), } @@ -246,8 +236,8 @@ impl Message { /// Consume the WebSocket and return it as binary data. pub fn into_data(self) -> Vec { match self { - Message::Text(string) => string.into_owned().into_bytes(), - Message::Binary(data) => data.into_owned(), + Message::Text(string) => String::from(string).into(), + Message::Binary(data) => data.into(), Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => Vec::new(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), @@ -257,9 +247,9 @@ impl Message { /// Attempt to consume the WebSocket message and convert it to a String. pub fn into_text(self) -> Result { match self { - Message::Text(string) => Ok(string.into_owned()), + Message::Text(string) => Ok(string.into()), Message::Binary(data) => { - Ok(String::from_utf8(data.into_owned()).map_err(|err| err.utf8_error())?) + Ok(String::from_utf8(data.into()).map_err(|err| err.utf8_error())?) } Message::Ping(data) | Message::Pong(data) => { Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) @@ -290,13 +280,13 @@ impl From for Message { impl<'s> From<&'s str> for Message { fn from(string: &'s str) -> Message { - Message::text(string) + Message::text(string.to_string()) } } impl<'b> From<&'b [u8]> for Message { fn from(data: &'b [u8]) -> Message { - Message::binary(data) + Message::binary(data.to_vec()) } } @@ -306,9 +296,9 @@ impl From> for Message { } } -impl Into> for Message { - fn into(self) -> Vec { - self.into_data() +impl From for Vec { + fn from(message: Message) -> Vec { + message.into_data() } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 787515a..122183e 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -2,6 +2,7 @@ pub mod frame; +mod data; mod message; pub use self::{frame::CloseFrame, message::Message}; @@ -24,7 +25,6 @@ use crate::{ error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; -use std::borrow::Cow; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -349,12 +349,7 @@ impl WebSocketContext { } let frame = match message { - Message::Text(Cow::Owned(data)) => { - Frame::message(data.into_bytes(), OpCode::Data(OpData::Text), true) - } - Message::Text(Cow::Borrowed(data)) => { - Frame::message(data.as_bytes(), OpCode::Data(OpData::Text), true) - } + Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { From a73434bfc4e1156de6bd50e3068a0a6c5772d760 Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Tue, 16 Feb 2021 19:33:19 +0100 Subject: [PATCH 3/4] move code around a bit, use a Cow for string data --- src/protocol/data.rs | 78 +++++++++++++++----------------------------- 1 file changed, 26 insertions(+), 52 deletions(-) diff --git a/src/protocol/data.rs b/src/protocol/data.rs index 199a949..3c53552 100644 --- a/src/protocol/data.rs +++ b/src/protocol/data.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use std::borrow::Cow; /// Binary message data #[derive(Debug, Clone)] @@ -72,78 +73,51 @@ impl AsMut<[u8]> for MessageData { } } -/// String message data -#[derive(Debug, Clone)] -pub struct MessageStringData(MessageStringDataImpl); +impl From> for MessageData { + fn from(data: Vec) -> MessageData { + MessageData(MessageDataImpl::Unique(data)) + } +} -/// opaque inner type to allow modifying the implementation in the future -#[derive(Debug, Clone)] -enum MessageStringDataImpl { - Static(&'static str), - Unique(String), +impl From<&'static [u8]> for MessageData { + fn from(data: &'static [u8]) -> MessageData { + MessageData(MessageDataImpl::Shared(Bytes::from_static(data))) + } } -impl PartialEq for MessageStringData { - fn eq(&self, other: &MessageStringData) -> bool { - self.as_ref().eq(other.as_ref()) +impl From for MessageData { + fn from(data: Bytes) -> MessageData { + MessageData(MessageDataImpl::Shared(data)) } } -impl Eq for MessageStringData {} +/// String message data +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MessageStringData(Cow<'static, str>); + +impl>> From for MessageStringData { + fn from(str: T) -> Self { + MessageStringData(str.into()) + } +} impl From for String { fn from(data: MessageStringData) -> String { - match data.0 { - MessageStringDataImpl::Static(data) => data.into(), - MessageStringDataImpl::Unique(data) => data, - } + data.0.into() } } impl From for MessageData { fn from(data: MessageStringData) -> MessageData { match data.0 { - MessageStringDataImpl::Static(data) => MessageData::from(data.as_bytes()), - MessageStringDataImpl::Unique(data) => MessageData::from(data.into_bytes()), + Cow::Borrowed(data) => MessageData::from(data.as_bytes()), + Cow::Owned(data) => MessageData::from(data.into_bytes()), } } } impl AsRef for MessageStringData { fn as_ref(&self) -> &str { - match &self.0 { - MessageStringDataImpl::Static(data) => *data, - MessageStringDataImpl::Unique(data) => data.as_ref(), - } - } -} - -impl From for MessageStringData { - fn from(string: String) -> MessageStringData { - MessageStringData(MessageStringDataImpl::Unique(string)) - } -} - -impl From<&'static str> for MessageStringData { - fn from(string: &'static str) -> MessageStringData { - MessageStringData(MessageStringDataImpl::Static(string)) - } -} - -impl From> for MessageData { - fn from(data: Vec) -> MessageData { - MessageData(MessageDataImpl::Unique(data)) - } -} - -impl From<&'static [u8]> for MessageData { - fn from(data: &'static [u8]) -> MessageData { - MessageData(MessageDataImpl::Shared(Bytes::from_static(data))) - } -} - -impl From for MessageData { - fn from(data: Bytes) -> MessageData { - MessageData(MessageDataImpl::Shared(data)) + self.0.as_ref() } } From be3cfdca90675778644c10093af01596b82419c3 Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Sun, 14 Feb 2021 20:51:12 +0100 Subject: [PATCH 4/4] do payload masking directly during reading and writing of the frame this removes the need to have unique mutable access to the payload data --- src/protocol/frame/frame.rs | 18 ++++++------------ src/protocol/frame/mask.rs | 36 +++++++++++++++++++----------------- src/protocol/frame/mod.rs | 27 +++++++++++++++++++++++++-- src/protocol/mod.rs | 8 ++------ 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 2d4b244..34fa95b 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -11,7 +11,7 @@ use std::{ use super::{ coding::{CloseCode, Control, Data, OpCode}, - mask::{apply_mask, generate_mask}, + mask::{generate_mask, write_masked}, }; use crate::error::{Error, ProtocolError, Result}; use crate::protocol::data::MessageData; @@ -253,15 +253,6 @@ impl Frame { self.header.set_random_mask() } - /// This method unmasks the payload and should only be called on frames that are actually - /// masked. In other words, those frames that have just been received from a client endpoint. - #[inline] - pub(crate) fn apply_mask(&mut self) { - if let Some(mask) = self.header.mask.take() { - apply_mask(self.payload.as_mut(), mask) - } - } - /// Consume the frame into its payload as binary. #[inline] pub fn into_data(self) -> Vec { @@ -351,8 +342,11 @@ impl Frame { /// Write a frame out to a buffer pub fn format(mut self, output: &mut impl Write) -> Result<()> { self.header.format(self.payload.len() as u64, output)?; - self.apply_mask(); - output.write_all(self.payload())?; + if let Some(mask) = self.header.mask.take() { + write_masked(self.payload(), output, mask) + } else { + output.write_all(self.payload())?; + } Ok(()) } } diff --git a/src/protocol/frame/mask.rs b/src/protocol/frame/mask.rs index 28f0eaf..1352483 100644 --- a/src/protocol/frame/mask.rs +++ b/src/protocol/frame/mask.rs @@ -1,30 +1,31 @@ +use std::io::Write; + /// Generate a random frame mask. #[inline] pub fn generate_mask() -> [u8; 4] { rand::random() } -/// Mask/unmask a frame. -#[inline] -pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) { - apply_mask_fast32(buf, mask) +/// Write data to an output, masking the data in the process +pub fn write_masked(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { + write_mask_fast32(data, output, mask) } /// A safe unoptimized mask application. #[inline] -fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { - for (i, byte) in buf.iter_mut().enumerate() { - *byte ^= mask[i & 3]; +fn write_mask_fallback(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { + for (i, byte) in data.iter().enumerate() { + output.write(&[*byte ^ mask[i & 3]]).unwrap(); } } /// Faster version of `apply_mask()` which operates on 4-byte blocks. #[inline] -pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { +fn write_mask_fast32(data: &[u8], output: &mut impl Write, mask: [u8; 4]) { let mask_u32 = u32::from_ne_bytes(mask); - let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::() }; - apply_mask_fallback(&mut prefix, mask); + let (mut prefix, words, mut suffix) = unsafe { data.align_to::() }; + write_mask_fallback(&mut prefix, output, mask); let head = prefix.len() & 3; let mask_u32 = if head > 0 { if cfg!(target_endian = "big") { @@ -35,10 +36,11 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { } else { mask_u32 }; - for word in words.iter_mut() { - *word ^= mask_u32; + for word in words { + let bytes = (*word ^ mask_u32).to_ne_bytes(); + output.write(&bytes).unwrap(); } - apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); + write_mask_fallback(&mut suffix, output, mask_u32.to_ne_bytes()); } #[cfg(test)] @@ -60,11 +62,11 @@ mod tests { if unmasked.len() < off { continue; } - let mut masked = unmasked.to_vec(); - apply_mask_fallback(&mut masked[off..], mask); + let mut masked = Vec::new(); + write_mask_fallback(&unmasked, &mut masked, mask); - let mut masked_fast = unmasked.to_vec(); - apply_mask_fast32(&mut masked_fast[off..], mask); + let mut masked_fast = Vec::new(); + write_mask_fast32(&unmasked, &mut masked_fast, mask); assert_eq!(masked, masked_fast); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 1e41853..a1e49ac 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -4,11 +4,16 @@ pub mod coding; #[allow(clippy::module_inception)] mod frame; +#[cfg(feature = "__expose_benchmark_fn")] +#[allow(missing_docs)] +pub mod mask; +#[cfg(not(feature = "__expose_benchmark_fn"))] mod mask; pub use self::frame::{CloseFrame, Frame, FrameHeader}; use crate::error::{CapacityError, Error, Result}; +use crate::protocol::frame::mask::write_masked; use input_buffer::{InputBuffer, MIN_READ}; use log::*; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; @@ -142,10 +147,28 @@ impl FrameCodec { let input_size = cursor.get_ref().len() as u64 - cursor.position(); if length <= input_size { // No truncation here since `length` is checked above - let mut payload = Vec::with_capacity(length as usize); + + // take a slice from the cursor + let payload_input = &cursor.get_ref().as_slice() + [(cursor.position() as usize)..(cursor.position() + length) as usize]; + + let mut payload = Vec::new(); if length > 0 { - cursor.take(length).read_to_end(&mut payload)?; + if let Some(mask) = + self.header.as_ref().and_then(|header| header.0.mask) + { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + + payload = Vec::with_capacity(length as usize); + write_masked(payload_input, &mut payload, mask); + } else { + payload = payload_input.to_vec(); + } } + + cursor.set_position(cursor.position() + length); + break payload; } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 122183e..49f65cc 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -426,7 +426,7 @@ impl WebSocketContext { where Stream: Read + Write, { - if let Some(mut frame) = self + if let Some(frame) = self .frame .read_frame(stream, self.config.max_frame_size) .check_connection_reset(self.state)? @@ -448,11 +448,7 @@ impl WebSocketContext { match self.role { Role::Server => { - if frame.is_masked() { - // A server MUST remove masking for data frames received from a client - // as described in Section 5.3. (RFC 6455) - frame.apply_mask() - } else if !self.config.accept_unmasked_frames { + if !frame.is_masked() && !self.config.accept_unmasked_frames { // The server MUST close the connection upon receiving a // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given