From 59818a01faf28b90f12729b469377c07873ee9f2 Mon Sep 17 00:00:00 2001 From: Robin Appelman Date: Tue, 9 Feb 2021 20:19:09 +0100 Subject: [PATCH] allow sending messages with static data --- src/protocol/frame/frame.rs | 42 ++++++++++++++++++--------- src/protocol/message.rs | 57 +++++++++++++++++++++++-------------- src/protocol/mod.rs | 12 ++++++-- 3 files changed, 72 insertions(+), 39 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 986bba0..bc24b0c 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -201,7 +201,7 @@ impl FrameHeader { #[derive(Debug, Clone)] pub struct Frame { header: FrameHeader, - payload: Vec, + payload: Cow<'static, [u8]>, } impl Frame { @@ -233,13 +233,13 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] - pub fn payload(&self) -> &Vec { + pub fn payload(&self) -> &[u8] { &self.payload } /// Get a mutable reference to the frame's payload. #[inline] - pub fn payload_mut(&mut self) -> &mut Vec { + pub fn payload_mut(&mut self) -> &mut Cow<'static, [u8]> { &mut self.payload } @@ -263,20 +263,28 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - apply_mask(&mut self.payload, mask) + match &mut self.payload { + Cow::Owned(data) => apply_mask(data, mask), + Cow::Borrowed(data) => { + // can't modify static data, so we have to take ownership first + let mut data = data.to_vec(); + apply_mask(&mut data, mask); + self.payload = Cow::Owned(data); + } + } } } /// Consume the frame into its payload as binary. #[inline] pub fn into_data(self) -> Vec { - self.payload + self.payload.into_owned() } /// Consume the frame into its payload as string. #[inline] pub fn into_string(self) -> StdResult { - String::from_utf8(self.payload) + String::from_utf8(self.into_data()) } /// Consume the frame into a closing frame. @@ -286,7 +294,7 @@ impl Frame { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = self.payload; + let mut data = self.into_data(); let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; @@ -297,10 +305,16 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + pub fn message(data: D, opcode: OpCode, is_final: bool) -> Frame + where + D: Into>, + { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } + Frame { + header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, + payload: data.into(), + } } /// Create a new Pong control frame. @@ -311,7 +325,7 @@ impl Frame { opcode: OpCode::Control(Control::Pong), ..FrameHeader::default() }, - payload: data, + payload: Cow::Owned(data), } } @@ -323,7 +337,7 @@ impl Frame { opcode: OpCode::Control(Control::Ping), ..FrameHeader::default() }, - payload: data, + payload: Cow::Owned(data), } } @@ -339,12 +353,12 @@ impl Frame { Vec::new() }; - Frame { header: FrameHeader::default(), payload } + Frame { header: FrameHeader::default(), payload: Cow::Owned(payload) } } /// Create a frame from given header and data. pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { - Frame { header, payload } + Frame { header, payload: Cow::Owned(payload) } } /// Write a frame out to a buffer @@ -462,7 +476,7 @@ mod tests { #[test] fn display() { - let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); + let f = Frame::message(&b"hi there"[..], OpCode::Data(Data::Text), true); let view = format!("{}", f); assert!(view.contains("payload:")); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 6720c3c..3f5accf 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -79,6 +79,7 @@ mod string_collect { } use self::string_collect::StringCollector; +use std::borrow::Cow; /// A struct representing the incomplete message. #[derive(Debug)] @@ -140,10 +141,10 @@ impl IncompleteMessage { /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { - IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), + IncompleteMessageCollector::Binary(v) => Ok(Message::binary(v)), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; - Ok(Message::Text(text)) + Ok(Message::text(text)) } } } @@ -159,9 +160,9 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Cow<'static, str>), /// A binary WebSocket message - Binary(Vec), + Binary(Cow<'static, [u8]>), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes @@ -180,7 +181,12 @@ impl Message { where S: Into, { - Message::Text(string.into()) + Message::Text(Cow::Owned(string.into())) + } + + /// Create a new static text WebSocket message from a &'static str. + pub fn static_text(string: &'static str) -> Message { + Message::Text(Cow::Borrowed(string)) } /// Create a new binary WebSocket message by converting to Vec. @@ -188,7 +194,12 @@ impl Message { where B: Into>, { - Message::Binary(bin.into()) + Message::Binary(Cow::Owned(bin.into())) + } + + /// Create a new static binary WebSocket message from a &'static [u8]. + pub fn static_binary(bin: &'static [u8]) -> Message { + Message::Binary(Cow::Borrowed(bin)) } /// Indicates whether a message is a text message. @@ -218,12 +229,11 @@ impl Message { /// Get the length of the WebSocket message. pub fn len(&self) -> usize { - match *self { - Message::Text(ref string) => string.len(), - Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - data.len() - } - Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), + match self { + Message::Text(string) => string.len(), + Message::Binary(data) => data.len(), + Message::Ping(data) | Message::Pong(data) => data.len(), + Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), } } @@ -236,8 +246,9 @@ impl Message { /// Consume the WebSocket and return it as binary data. pub fn into_data(self) -> Vec { match self { - Message::Text(string) => string.into_bytes(), - Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, + Message::Text(string) => string.into_owned().into_bytes(), + Message::Binary(data) => data.into_owned(), + Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => Vec::new(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } @@ -246,8 +257,11 @@ impl Message { /// Attempt to consume the WebSocket message and convert it to a String. pub fn into_text(self) -> Result { match self { - Message::Text(string) => Ok(string), - Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { + Message::Text(string) => Ok(string.into_owned()), + Message::Binary(data) => { + Ok(String::from_utf8(data.into_owned()).map_err(|err| err.utf8_error())?) + } + Message::Ping(data) | Message::Pong(data) => { Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) } Message::Close(None) => Ok(String::new()), @@ -258,13 +272,12 @@ impl Message { /// Attempt to get a &str from the WebSocket message, /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str> { - match *self { - Message::Text(ref string) => Ok(string), - Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - Ok(str::from_utf8(data)?) - } + match self { + Message::Text(string) => Ok(string.as_ref()), + Message::Binary(data) => Ok(str::from_utf8(data.as_ref())?), + Message::Ping(data) | Message::Pong(data) => Ok(str::from_utf8(data)?), Message::Close(None) => Ok(""), - Message::Close(Some(ref frame)) => Ok(&frame.reason), + Message::Close(Some(frame)) => Ok(&frame.reason), } } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 215b061..787515a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -24,6 +24,7 @@ use crate::{ error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; +use std::borrow::Cow; /// Indicates a Client or Server role of the websocket #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -348,7 +349,12 @@ impl WebSocketContext { } let frame = match message { - Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), + Message::Text(Cow::Owned(data)) => { + Frame::message(data.into_bytes(), OpCode::Data(OpData::Text), true) + } + Message::Text(Cow::Borrowed(data)) => { + Frame::message(data.as_bytes(), OpCode::Data(OpData::Text), true) + } Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { @@ -700,8 +706,8 @@ mod tests { let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); - assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!(socket.read_message().unwrap(), Message::text("Hello, World!")); + assert_eq!(socket.read_message().unwrap(), Message::binary(vec![0x01, 0x02, 0x03])); } #[test]