allow sending messages with static data

pull/175/head
Robin Appelman 4 years ago
parent c101024c28
commit 59818a01fa
  1. 42
      src/protocol/frame/frame.rs
  2. 57
      src/protocol/message.rs
  3. 12
      src/protocol/mod.rs

@ -201,7 +201,7 @@ impl FrameHeader {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Frame { pub struct Frame {
header: FrameHeader, header: FrameHeader,
payload: Vec<u8>, payload: Cow<'static, [u8]>,
} }
impl Frame { impl Frame {
@ -233,13 +233,13 @@ impl Frame {
/// Get a reference to the frame's payload. /// Get a reference to the frame's payload.
#[inline] #[inline]
pub fn payload(&self) -> &Vec<u8> { pub fn payload(&self) -> &[u8] {
&self.payload &self.payload
} }
/// Get a mutable reference to the frame's payload. /// Get a mutable reference to the frame's payload.
#[inline] #[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> { pub fn payload_mut(&mut self) -> &mut Cow<'static, [u8]> {
&mut self.payload &mut self.payload
} }
@ -263,20 +263,28 @@ impl Frame {
#[inline] #[inline]
pub(crate) fn apply_mask(&mut self) { pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() { if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask) match &mut self.payload {
Cow::Owned(data) => apply_mask(data, mask),
Cow::Borrowed(data) => {
// can't modify static data, so we have to take ownership first
let mut data = data.to_vec();
apply_mask(&mut data, mask);
self.payload = Cow::Owned(data);
}
}
} }
} }
/// Consume the frame into its payload as binary. /// Consume the frame into its payload as binary.
#[inline] #[inline]
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
self.payload self.payload.into_owned()
} }
/// Consume the frame into its payload as string. /// Consume the frame into its payload as string.
#[inline] #[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> { pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload) String::from_utf8(self.into_data())
} }
/// Consume the frame into a closing frame. /// Consume the frame into a closing frame.
@ -286,7 +294,7 @@ impl Frame {
0 => Ok(None), 0 => Ok(None),
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => { _ => {
let mut data = self.payload; let mut data = self.into_data();
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
@ -297,10 +305,16 @@ impl Frame {
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { pub fn message<D>(data: D, opcode: OpCode, is_final: bool) -> Frame
where
D: Into<Cow<'static, [u8]>>,
{
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } Frame {
header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
payload: data.into(),
}
} }
/// Create a new Pong control frame. /// Create a new Pong control frame.
@ -311,7 +325,7 @@ impl Frame {
opcode: OpCode::Control(Control::Pong), opcode: OpCode::Control(Control::Pong),
..FrameHeader::default() ..FrameHeader::default()
}, },
payload: data, payload: Cow::Owned(data),
} }
} }
@ -323,7 +337,7 @@ impl Frame {
opcode: OpCode::Control(Control::Ping), opcode: OpCode::Control(Control::Ping),
..FrameHeader::default() ..FrameHeader::default()
}, },
payload: data, payload: Cow::Owned(data),
} }
} }
@ -339,12 +353,12 @@ impl Frame {
Vec::new() Vec::new()
}; };
Frame { header: FrameHeader::default(), payload } Frame { header: FrameHeader::default(), payload: Cow::Owned(payload) }
} }
/// Create a frame from given header and data. /// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self { pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame { header, payload } Frame { header, payload: Cow::Owned(payload) }
} }
/// Write a frame out to a buffer /// Write a frame out to a buffer
@ -462,7 +476,7 @@ mod tests {
#[test] #[test]
fn display() { fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); let f = Frame::message(&b"hi there"[..], OpCode::Data(Data::Text), true);
let view = format!("{}", f); let view = format!("{}", f);
assert!(view.contains("payload:")); assert!(view.contains("payload:"));
} }

@ -79,6 +79,7 @@ mod string_collect {
} }
use self::string_collect::StringCollector; use self::string_collect::StringCollector;
use std::borrow::Cow;
/// A struct representing the incomplete message. /// A struct representing the incomplete message.
#[derive(Debug)] #[derive(Debug)]
@ -140,10 +141,10 @@ impl IncompleteMessage {
/// Convert an incomplete message into a complete one. /// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> { pub fn complete(self) -> Result<Message> {
match self.collector { match self.collector {
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), IncompleteMessageCollector::Binary(v) => Ok(Message::binary(v)),
IncompleteMessageCollector::Text(t) => { IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?; let text = t.into_string()?;
Ok(Message::Text(text)) Ok(Message::text(text))
} }
} }
} }
@ -159,9 +160,9 @@ pub enum IncompleteMessageType {
#[derive(Debug, Eq, PartialEq, Clone)] #[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message { pub enum Message {
/// A text WebSocket message /// A text WebSocket message
Text(String), Text(Cow<'static, str>),
/// A binary WebSocket message /// A binary WebSocket message
Binary(Vec<u8>), Binary(Cow<'static, [u8]>),
/// A ping message with the specified payload /// A ping message with the specified payload
/// ///
/// The payload here must have a length less than 125 bytes /// The payload here must have a length less than 125 bytes
@ -180,7 +181,12 @@ impl Message {
where where
S: Into<String>, S: Into<String>,
{ {
Message::Text(string.into()) Message::Text(Cow::Owned(string.into()))
}
/// Create a new static text WebSocket message from a &'static str.
pub fn static_text(string: &'static str) -> Message {
Message::Text(Cow::Borrowed(string))
} }
/// Create a new binary WebSocket message by converting to Vec<u8>. /// Create a new binary WebSocket message by converting to Vec<u8>.
@ -188,7 +194,12 @@ impl Message {
where where
B: Into<Vec<u8>>, B: Into<Vec<u8>>,
{ {
Message::Binary(bin.into()) Message::Binary(Cow::Owned(bin.into()))
}
/// Create a new static binary WebSocket message from a &'static [u8].
pub fn static_binary(bin: &'static [u8]) -> Message {
Message::Binary(Cow::Borrowed(bin))
} }
/// Indicates whether a message is a text message. /// Indicates whether a message is a text message.
@ -218,12 +229,11 @@ impl Message {
/// Get the length of the WebSocket message. /// Get the length of the WebSocket message.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match *self { match self {
Message::Text(ref string) => string.len(), Message::Text(string) => string.len(),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { Message::Binary(data) => data.len(),
data.len() Message::Ping(data) | Message::Pong(data) => data.len(),
} Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
} }
} }
@ -236,8 +246,9 @@ impl Message {
/// Consume the WebSocket and return it as binary data. /// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
match self { match self {
Message::Text(string) => string.into_bytes(), Message::Text(string) => string.into_owned().into_bytes(),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, Message::Binary(data) => data.into_owned(),
Message::Ping(data) | Message::Pong(data) => data,
Message::Close(None) => Vec::new(), Message::Close(None) => Vec::new(),
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
} }
@ -246,8 +257,11 @@ impl Message {
/// Attempt to consume the WebSocket message and convert it to a String. /// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String> { pub fn into_text(self) -> Result<String> {
match self { match self {
Message::Text(string) => Ok(string), Message::Text(string) => Ok(string.into_owned()),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { Message::Binary(data) => {
Ok(String::from_utf8(data.into_owned()).map_err(|err| err.utf8_error())?)
}
Message::Ping(data) | Message::Pong(data) => {
Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?)
} }
Message::Close(None) => Ok(String::new()), Message::Close(None) => Ok(String::new()),
@ -258,13 +272,12 @@ impl Message {
/// Attempt to get a &str from the WebSocket message, /// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8. /// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str> { pub fn to_text(&self) -> Result<&str> {
match *self { match self {
Message::Text(ref string) => Ok(string), Message::Text(string) => Ok(string.as_ref()),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { Message::Binary(data) => Ok(str::from_utf8(data.as_ref())?),
Ok(str::from_utf8(data)?) Message::Ping(data) | Message::Pong(data) => Ok(str::from_utf8(data)?),
}
Message::Close(None) => Ok(""), Message::Close(None) => Ok(""),
Message::Close(Some(ref frame)) => Ok(&frame.reason), Message::Close(Some(frame)) => Ok(&frame.reason),
} }
} }
} }

@ -24,6 +24,7 @@ use crate::{
error::{Error, ProtocolError, Result}, error::{Error, ProtocolError, Result},
util::NonBlockingResult, util::NonBlockingResult,
}; };
use std::borrow::Cow;
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -348,7 +349,12 @@ impl WebSocketContext {
} }
let frame = match message { let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), Message::Text(Cow::Owned(data)) => {
Frame::message(data.into_bytes(), OpCode::Data(OpData::Text), true)
}
Message::Text(Cow::Borrowed(data)) => {
Frame::message(data.as_bytes(), OpCode::Data(OpData::Text), true)
}
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
@ -700,8 +706,8 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read_message().unwrap(), Message::text("Hello, World!"));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); assert_eq!(socket.read_message().unwrap(), Message::binary(vec![0x01, 0x02, 0x03]));
} }
#[test] #[test]

Loading…
Cancel
Save