switch to custom message data enums to handle shared and static data

pull/175/head
Robin Appelman 4 years ago
parent 59818a01fa
commit a662560c66
  1. 149
      src/protocol/data.rs
  2. 41
      src/protocol/frame/frame.rs
  3. 46
      src/protocol/message.rs
  4. 9
      src/protocol/mod.rs

@ -0,0 +1,149 @@
use bytes::Bytes;
/// Binary message data
#[derive(Debug, Clone)]
pub struct MessageData(MessageDataImpl);
/// opaque inner type to allow modifying the implementation in the future
#[derive(Debug, Clone)]
enum MessageDataImpl {
Shared(Bytes),
Unique(Vec<u8>),
}
impl MessageData {
pub fn len(&self) -> usize {
self.as_ref().len()
}
fn make_unique(&mut self) {
if let MessageDataImpl::Shared(data) = &self.0 {
self.0 = MessageDataImpl::Unique(Vec::from(data.as_ref()));
}
}
}
impl PartialEq for MessageData {
fn eq(&self, other: &MessageData) -> bool {
self.as_ref().eq(other.as_ref())
}
}
impl Eq for MessageData {}
impl From<MessageData> for Vec<u8> {
fn from(data: MessageData) -> Vec<u8> {
match data.0 {
MessageDataImpl::Shared(data) => {
let mut bytes = Vec::with_capacity(data.len());
bytes.copy_from_slice(data.as_ref());
bytes
}
MessageDataImpl::Unique(data) => data,
}
}
}
impl From<MessageData> for Bytes {
fn from(data: MessageData) -> Bytes {
match data.0 {
MessageDataImpl::Shared(data) => data,
MessageDataImpl::Unique(data) => data.into(),
}
}
}
impl AsRef<[u8]> for MessageData {
fn as_ref(&self) -> &[u8] {
match &self.0 {
MessageDataImpl::Shared(data) => data.as_ref(),
MessageDataImpl::Unique(data) => data.as_ref(),
}
}
}
impl AsMut<[u8]> for MessageData {
fn as_mut(&mut self) -> &mut [u8] {
self.make_unique();
match &mut self.0 {
MessageDataImpl::Unique(data) => data.as_mut_slice(),
MessageDataImpl::Shared(_) => unreachable!("Data has just been made unique"),
}
}
}
/// String message data
#[derive(Debug, Clone)]
pub struct MessageStringData(MessageStringDataImpl);
/// opaque inner type to allow modifying the implementation in the future
#[derive(Debug, Clone)]
enum MessageStringDataImpl {
Static(&'static str),
Unique(String),
}
impl PartialEq for MessageStringData {
fn eq(&self, other: &MessageStringData) -> bool {
self.as_ref().eq(other.as_ref())
}
}
impl Eq for MessageStringData {}
impl From<MessageStringData> for String {
fn from(data: MessageStringData) -> String {
match data.0 {
MessageStringDataImpl::Static(data) => data.into(),
MessageStringDataImpl::Unique(data) => data,
}
}
}
impl From<MessageStringData> for MessageData {
fn from(data: MessageStringData) -> MessageData {
match data.0 {
MessageStringDataImpl::Static(data) => MessageData::from(data.as_bytes()),
MessageStringDataImpl::Unique(data) => MessageData::from(data.into_bytes()),
}
}
}
impl AsRef<str> for MessageStringData {
fn as_ref(&self) -> &str {
match &self.0 {
MessageStringDataImpl::Static(data) => *data,
MessageStringDataImpl::Unique(data) => data.as_ref(),
}
}
}
impl From<String> for MessageStringData {
fn from(string: String) -> MessageStringData {
MessageStringData(MessageStringDataImpl::Unique(string))
}
}
impl From<&'static str> for MessageStringData {
fn from(string: &'static str) -> MessageStringData {
MessageStringData(MessageStringDataImpl::Static(string))
}
}
impl From<Vec<u8>> for MessageData {
fn from(data: Vec<u8>) -> MessageData {
MessageData(MessageDataImpl::Unique(data))
}
}
impl From<&'static [u8]> for MessageData {
fn from(data: &'static [u8]) -> MessageData {
MessageData(MessageDataImpl::Shared(Bytes::from_static(data)))
}
}
impl From<Bytes> for MessageData {
fn from(data: Bytes) -> MessageData {
MessageData(MessageDataImpl::Shared(data))
}
}

@ -14,6 +14,7 @@ use super::{
mask::{apply_mask, generate_mask}, mask::{apply_mask, generate_mask},
}; };
use crate::error::{Error, ProtocolError, Result}; use crate::error::{Error, ProtocolError, Result};
use crate::protocol::data::MessageData;
/// A struct representing the close command. /// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
@ -186,7 +187,7 @@ impl FrameHeader {
// Disallow bad opcode // Disallow bad opcode
match opcode { match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)));
} }
_ => (), _ => (),
} }
@ -201,7 +202,7 @@ impl FrameHeader {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Frame { pub struct Frame {
header: FrameHeader, header: FrameHeader,
payload: Cow<'static, [u8]>, payload: MessageData,
} }
impl Frame { impl Frame {
@ -234,13 +235,7 @@ impl Frame {
/// Get a reference to the frame's payload. /// Get a reference to the frame's payload.
#[inline] #[inline]
pub fn payload(&self) -> &[u8] { pub fn payload(&self) -> &[u8] {
&self.payload self.payload.as_ref()
}
/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut Cow<'static, [u8]> {
&mut self.payload
} }
/// Test whether the frame is masked. /// Test whether the frame is masked.
@ -263,28 +258,20 @@ impl Frame {
#[inline] #[inline]
pub(crate) fn apply_mask(&mut self) { pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() { if let Some(mask) = self.header.mask.take() {
match &mut self.payload { apply_mask(self.payload.as_mut(), mask)
Cow::Owned(data) => apply_mask(data, mask),
Cow::Borrowed(data) => {
// can't modify static data, so we have to take ownership first
let mut data = data.to_vec();
apply_mask(&mut data, mask);
self.payload = Cow::Owned(data);
}
}
} }
} }
/// Consume the frame into its payload as binary. /// Consume the frame into its payload as binary.
#[inline] #[inline]
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
self.payload.into_owned() self.payload.into()
} }
/// Consume the frame into its payload as string. /// Consume the frame into its payload as string.
#[inline] #[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> { pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.into_data()) String::from_utf8(self.payload.into())
} }
/// Consume the frame into a closing frame. /// Consume the frame into a closing frame.
@ -297,7 +284,7 @@ impl Frame {
let mut data = self.into_data(); let mut data = self.into_data();
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data.into())?;
Ok(Some(CloseFrame { code, reason: text.into() })) Ok(Some(CloseFrame { code, reason: text.into() }))
} }
} }
@ -307,7 +294,7 @@ impl Frame {
#[inline] #[inline]
pub fn message<D>(data: D, opcode: OpCode, is_final: bool) -> Frame pub fn message<D>(data: D, opcode: OpCode, is_final: bool) -> Frame
where where
D: Into<Cow<'static, [u8]>>, D: Into<MessageData>,
{ {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
@ -325,7 +312,7 @@ impl Frame {
opcode: OpCode::Control(Control::Pong), opcode: OpCode::Control(Control::Pong),
..FrameHeader::default() ..FrameHeader::default()
}, },
payload: Cow::Owned(data), payload: data.into(),
} }
} }
@ -337,7 +324,7 @@ impl Frame {
opcode: OpCode::Control(Control::Ping), opcode: OpCode::Control(Control::Ping),
..FrameHeader::default() ..FrameHeader::default()
}, },
payload: Cow::Owned(data), payload: data.into(),
} }
} }
@ -353,12 +340,12 @@ impl Frame {
Vec::new() Vec::new()
}; };
Frame { header: FrameHeader::default(), payload: Cow::Owned(payload) } Frame { header: FrameHeader::default(), payload: payload.into() }
} }
/// Create a frame from given header and data. /// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self { pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame { header, payload: Cow::Owned(payload) } Frame { header, payload: payload.into() }
} }
/// Write a frame out to a buffer /// Write a frame out to a buffer
@ -391,7 +378,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>() self.payload.as_ref().iter().map(|byte| format!("{:x}", byte)).collect::<String>()
) )
} }
} }

@ -79,7 +79,7 @@ mod string_collect {
} }
use self::string_collect::StringCollector; use self::string_collect::StringCollector;
use std::borrow::Cow; use crate::protocol::data::{MessageData, MessageStringData};
/// A struct representing the incomplete message. /// A struct representing the incomplete message.
#[derive(Debug)] #[derive(Debug)]
@ -160,9 +160,9 @@ pub enum IncompleteMessageType {
#[derive(Debug, Eq, PartialEq, Clone)] #[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message { pub enum Message {
/// A text WebSocket message /// A text WebSocket message
Text(Cow<'static, str>), Text(MessageStringData),
/// A binary WebSocket message /// A binary WebSocket message
Binary(Cow<'static, [u8]>), Binary(MessageData),
/// A ping message with the specified payload /// A ping message with the specified payload
/// ///
/// The payload here must have a length less than 125 bytes /// The payload here must have a length less than 125 bytes
@ -179,27 +179,17 @@ impl Message {
/// Create a new text WebSocket message from a stringable. /// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message pub fn text<S>(string: S) -> Message
where where
S: Into<String>, S: Into<MessageStringData>,
{ {
Message::Text(Cow::Owned(string.into())) Message::Text(string.into())
}
/// Create a new static text WebSocket message from a &'static str.
pub fn static_text(string: &'static str) -> Message {
Message::Text(Cow::Borrowed(string))
} }
/// Create a new binary WebSocket message by converting to Vec<u8>. /// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message pub fn binary<B>(bin: B) -> Message
where where
B: Into<Vec<u8>>, B: Into<MessageData>,
{ {
Message::Binary(Cow::Owned(bin.into())) Message::Binary(bin.into())
}
/// Create a new static binary WebSocket message from a &'static [u8].
pub fn static_binary(bin: &'static [u8]) -> Message {
Message::Binary(Cow::Borrowed(bin))
} }
/// Indicates whether a message is a text message. /// Indicates whether a message is a text message.
@ -230,8 +220,8 @@ impl Message {
/// Get the length of the WebSocket message. /// Get the length of the WebSocket message.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self { match self {
Message::Text(string) => string.len(), Message::Text(string) => string.as_ref().len(),
Message::Binary(data) => data.len(), Message::Binary(data) => data.as_ref().len(),
Message::Ping(data) | Message::Pong(data) => data.len(), Message::Ping(data) | Message::Pong(data) => data.len(),
Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
} }
@ -246,8 +236,8 @@ impl Message {
/// Consume the WebSocket and return it as binary data. /// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
match self { match self {
Message::Text(string) => string.into_owned().into_bytes(), Message::Text(string) => String::from(string).into(),
Message::Binary(data) => data.into_owned(), Message::Binary(data) => data.into(),
Message::Ping(data) | Message::Pong(data) => data, Message::Ping(data) | Message::Pong(data) => data,
Message::Close(None) => Vec::new(), Message::Close(None) => Vec::new(),
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
@ -257,9 +247,9 @@ impl Message {
/// Attempt to consume the WebSocket message and convert it to a String. /// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String> { pub fn into_text(self) -> Result<String> {
match self { match self {
Message::Text(string) => Ok(string.into_owned()), Message::Text(string) => Ok(string.into()),
Message::Binary(data) => { Message::Binary(data) => {
Ok(String::from_utf8(data.into_owned()).map_err(|err| err.utf8_error())?) Ok(String::from_utf8(data.into()).map_err(|err| err.utf8_error())?)
} }
Message::Ping(data) | Message::Pong(data) => { Message::Ping(data) | Message::Pong(data) => {
Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?)
@ -290,13 +280,13 @@ impl From<String> for Message {
impl<'s> From<&'s str> for Message { impl<'s> From<&'s str> for Message {
fn from(string: &'s str) -> Message { fn from(string: &'s str) -> Message {
Message::text(string) Message::text(string.to_string())
} }
} }
impl<'b> From<&'b [u8]> for Message { impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Message { fn from(data: &'b [u8]) -> Message {
Message::binary(data) Message::binary(data.to_vec())
} }
} }
@ -306,9 +296,9 @@ impl From<Vec<u8>> for Message {
} }
} }
impl Into<Vec<u8>> for Message { impl From<Message> for Vec<u8> {
fn into(self) -> Vec<u8> { fn from(message: Message) -> Vec<u8> {
self.into_data() message.into_data()
} }
} }

@ -2,6 +2,7 @@
pub mod frame; pub mod frame;
mod data;
mod message; mod message;
pub use self::{frame::CloseFrame, message::Message}; pub use self::{frame::CloseFrame, message::Message};
@ -24,7 +25,6 @@ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
use std::borrow::Cow;
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -349,12 +349,7 @@ impl WebSocketContext {
} }
let frame = match message { let frame = match message {
Message::Text(Cow::Owned(data)) => { Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true),
Frame::message(data.into_bytes(), OpCode::Data(OpData::Text), true)
}
Message::Text(Cow::Borrowed(data)) => {
Frame::message(data.as_bytes(), OpCode::Data(OpData::Text), true)
}
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {

Loading…
Cancel
Save