refactor: parse header separately from payload

Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
pull/43/head
Alexey Galakhov 7 years ago
parent 20242d19f7
commit 75aa0d54f3
  1. 512
      src/protocol/frame/frame.rs
  2. 65
      src/protocol/frame/mod.rs

@ -35,17 +35,176 @@ impl<'t> fmt::Display for CloseFrame<'t> {
}
}
/// A struct representing a WebSocket frame.
/// A struct representing a WebSocket frame header.
#[allow(missing_copy_implementations)]
#[derive(Debug, Clone)]
pub struct Frame {
finished: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: OpCode,
pub struct FrameHeader {
/// Indicates that the frame is the last one of a possibly fragmented message.
pub is_final: bool,
/// Reserved for protocol extensions.
pub rsv1: bool,
/// Reserved for protocol extensions.
pub rsv2: bool,
/// Reserved for protocol extensions.
pub rsv3: bool,
/// WebSocket protocol opcode.
pub opcode: OpCode,
/// A frame mask, if any.
pub mask: Option<[u8; 4]>,
}
impl Default for FrameHeader {
fn default() -> Self {
FrameHeader {
is_final: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: OpCode::Control(Control::Close),
mask: None,
}
}
}
impl FrameHeader {
/// Parse a header from an input stream.
/// Returns `None` if insufficient data and does not consume anything in this case.
/// Payload size is returned along with the header.
pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
let initial = cursor.position();
match Self::parse_internal(cursor) {
ret @ Ok(None) => {
cursor.set_position(initial);
ret
}
ret => ret
}
}
/// Get the size of the header formatted with given payload length.
pub fn len(&self, length: u64) -> usize {
2
+ LengthFormat::for_length(length).extra_bytes()
+ if self.mask.is_some() { 4 } else { 0 }
}
/// Format a header for given payload size.
pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
let code: u8 = self.opcode.into();
let one = {
code
| if self.is_final { 0x80 } else { 0 }
| if self.rsv1 { 0x40 } else { 0 }
| if self.rsv2 { 0x20 } else { 0 }
| if self.rsv3 { 0x10 } else { 0 }
};
let lenfmt = LengthFormat::for_length(length);
let two = {
lenfmt.length_byte()
| if self.mask.is_some() { 0x80 } else { 0 }
};
output.write(&[one, two])?;
match lenfmt {
LengthFormat::U8(_) => (),
LengthFormat::U16 => output.write_u16::<NetworkEndian>(length as u16)?,
LengthFormat::U64 => output.write_u64::<NetworkEndian>(length)?,
}
if let Some(ref mask) = self.mask {
output.write_all(mask)?
}
Ok(())
}
}
impl FrameHeader {
/// Internal parse engine.
/// Returns `None` if insufficient data.
/// Payload size is returned along with the header.
fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
let (first, second) = {
let mut head = [0u8; 2];
if cursor.read(&mut head)? != 2 {
return Ok(None)
}
trace!("Parsed headers {:?}", head);
(head[0], head[1])
};
trace!("First: {:b}", first);
trace!("Second: {:b}", second);
let is_final = first & 0x80 != 0;
let rsv1 = first & 0x40 != 0;
let rsv2 = first & 0x20 != 0;
let rsv3 = first & 0x10 != 0;
let opcode = OpCode::from(first & 0x0F);
trace!("Opcode: {:?}", opcode);
let masked = second & 0x80 != 0;
trace!("Masked: {:?}", masked);
let length = {
let length_byte = second & 0x7F;
let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
if length_length > 0 {
match cursor.read_uint::<NetworkEndian>(length_length) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
return Ok(None);
}
Err(err) => {
return Err(err.into());
}
Ok(read) => read
}
} else {
length_byte as u64
}
};
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if cursor.read(&mut mask_bytes)? != 4 {
return Ok(None)
} else {
Some(mask_bytes)
}
} else {
None
};
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
}
_ => ()
}
let hdr = FrameHeader {
is_final,
rsv1,
rsv2,
rsv3,
opcode,
mask,
};
mask: Option<[u8; 4]>,
Ok(Some((hdr, length)))
}
}
/// A struct representing a WebSocket frame.
#[derive(Debug, Clone)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
}
@ -55,51 +214,38 @@ impl Frame {
/// This is the length of the header + the length of the payload.
#[inline]
pub fn len(&self) -> usize {
let mut header_length = 2;
let payload_len = self.payload().len();
if payload_len > 125 {
if payload_len <= u16::max_value() as usize {
header_length += 2;
} else {
header_length += 8;
}
}
if self.is_masked() {
header_length += 4;
}
header_length + payload_len
let length = self.payload.len();
self.header.len(length as u64) + length
}
/// Test whether the frame is a final frame.
#[inline]
pub fn is_final(&self) -> bool {
self.finished
self.header.is_final
}
/// Test whether the first reserved bit is set.
#[inline]
pub fn has_rsv1(&self) -> bool {
self.rsv1
self.header.rsv1
}
/// Test whether the second reserved bit is set.
#[inline]
pub fn has_rsv2(&self) -> bool {
self.rsv2
self.header.rsv2
}
/// Test whether the third reserved bit is set.
#[inline]
pub fn has_rsv3(&self) -> bool {
self.rsv3
self.header.rsv3
}
/// Get the OpCode of the frame.
#[inline]
pub fn opcode(&self) -> OpCode {
self.opcode
self.header.opcode
}
/// Get a reference to the frame's payload.
@ -112,7 +258,7 @@ impl Frame {
#[doc(hidden)]
#[inline]
pub fn is_masked(&self) -> bool {
self.mask.is_some()
self.header.mask.is_some()
}
// Get an optional reference to the frame's mask.
@ -120,35 +266,35 @@ impl Frame {
#[allow(dead_code)]
#[inline]
pub fn mask(&self) -> Option<&[u8; 4]> {
self.mask.as_ref()
self.header.mask.as_ref()
}
/// Make this frame a final frame.
#[allow(dead_code)]
#[inline]
pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
self.finished = is_final;
self.header.is_final = is_final;
self
}
/// Set the first reserved bit.
#[inline]
pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
self.rsv1 = has_rsv1;
self.header.rsv1 = has_rsv1;
self
}
/// Set the second reserved bit.
#[inline]
pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
self.rsv2 = has_rsv2;
self.header.rsv2 = has_rsv2;
self
}
/// Set the third reserved bit.
#[inline]
pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
self.rsv3 = has_rsv3;
self.header.rsv3 = has_rsv3;
self
}
@ -156,7 +302,7 @@ impl Frame {
#[allow(dead_code)]
#[inline]
pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
self.opcode = opcode;
self.header.opcode = opcode;
self
}
@ -175,7 +321,7 @@ impl Frame {
#[doc(hidden)]
#[inline]
pub fn set_mask(&mut self) -> &mut Frame {
self.mask = Some(generate_mask());
self.header.mask = Some(generate_mask());
self
}
@ -184,10 +330,10 @@ impl Frame {
#[doc(hidden)]
#[inline]
pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| {
self.header.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask))
});
self.mask = None;
self.header.mask = None;
}
/// Consume the frame into its payload as binary.
@ -220,17 +366,19 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
debug_assert!(match code {
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(match opcode {
OpCode::Data(_) => true,
_ => false,
}, "Invalid opcode for data frame.");
Frame {
finished: finished,
opcode: code,
header: FrameHeader {
is_final,
opcode,
.. FrameHeader::default()
},
payload: data,
.. Frame::default()
}
}
@ -238,9 +386,11 @@ impl Frame {
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {
Frame {
opcode: OpCode::Control(Control::Pong),
header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
.. FrameHeader::default()
},
payload: data,
.. Frame::default()
}
}
@ -248,9 +398,11 @@ impl Frame {
#[inline]
pub fn ping(data: Vec<u8>) -> Frame {
Frame {
opcode: OpCode::Control(Control::Ping),
header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
.. FrameHeader::default()
},
payload: data,
.. Frame::default()
}
}
@ -267,199 +419,28 @@ impl Frame {
};
Frame {
header: FrameHeader::default(),
payload: payload,
.. Frame::default()
}
}
/// Parse the input stream into a frame.
pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> {
let size = cursor.get_ref().len() as u64 - cursor.position();
let initial = cursor.position();
trace!("Position in buffer {}", initial);
let mut head = [0u8; 2];
if try!(cursor.read(&mut head)) != 2 {
cursor.set_position(initial);
return Ok(None)
}
trace!("Parsed headers {:?}", head);
let first = head[0];
let second = head[1];
trace!("First: {:b}", first);
trace!("Second: {:b}", second);
let finished = first & 0x80 != 0;
let rsv1 = first & 0x40 != 0;
let rsv2 = first & 0x20 != 0;
let rsv3 = first & 0x10 != 0;
let opcode = OpCode::from(first & 0x0F);
trace!("Opcode: {:?}", opcode);
let masked = second & 0x80 != 0;
trace!("Masked: {:?}", masked);
let mut header_length = 2;
let mut length = (second & 0x7F) as u64;
if let Some(length_nbytes) = match length {
126 => Some(2),
127 => Some(8),
_ => None,
} {
match cursor.read_uint::<NetworkEndian>(length_nbytes) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
cursor.set_position(initial);
return Ok(None);
}
Err(err) => {
return Err(Error::from(err));
}
Ok(read) => {
length = read;
}
};
header_length += length_nbytes as u64;
}
trace!("Payload length: {}", length);
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if try!(cursor.read(&mut mask_bytes)) != 4 {
cursor.set_position(initial);
return Ok(None)
} else {
header_length += 4;
Some(mask_bytes)
}
} else {
None
};
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
}
_ => ()
}
// Make sure `length` is not too big (fits into `usize`).
if length > usize::max_value() as u64 {
return Err(Error::Capacity(format!("Message length too big: {}", length).into()));
}
if size < header_length || size - header_length < length {
cursor.set_position(initial);
return Ok(None)
}
// Size is checked above, so it won't be truncated here.
let mut data = Vec::with_capacity(length as usize);
if length > 0 {
try!(cursor.take(length).read_to_end(&mut data));
/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame {
header,
payload,
}
debug_assert_eq!(data.len() as u64, length);
let frame = Frame {
finished: finished,
rsv1: rsv1,
rsv2: rsv2,
rsv3: rsv3,
opcode: opcode,
mask: mask,
payload: data,
};
Ok(Some(frame))
}
/// Write a frame out to a buffer
pub fn format<W>(mut self, w: &mut W) -> Result<()>
where W: Write
{
let mut one = 0u8;
let code: u8 = self.opcode.into();
if self.is_final() {
one |= 0x80;
}
if self.has_rsv1() {
one |= 0x40;
}
if self.has_rsv2() {
one |= 0x20;
}
if self.has_rsv3() {
one |= 0x10;
}
one |= code;
let mut two = 0u8;
if self.is_masked() {
two |= 0x80;
}
if self.payload.len() < 126 {
two |= self.payload.len() as u8;
let headers = [one, two];
try!(w.write(&headers));
} else if self.payload.len() <= 65535 {
two |= 126;
let mut length_bytes = [0u8; 2];
NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16);
let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers));
} else {
two |= 127;
let mut length_bytes = [0u8; 8];
NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64);
let headers = [
one,
two,
length_bytes[0],
length_bytes[1],
length_bytes[2],
length_bytes[3],
length_bytes[4],
length_bytes[5],
length_bytes[6],
length_bytes[7],
];
try!(w.write(&headers));
}
if self.is_masked() {
let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask);
try!(w.write(&mask));
}
try!(w.write(&self.payload));
pub fn format(mut self, output: &mut impl Write) -> Result<()> {
self.header.format(self.payload.len() as u64, output)?;
self.remove_mask();
output.write_all(self.payload())?;
Ok(())
}
}
impl Default for Frame {
fn default() -> Frame {
Frame {
finished: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: OpCode::Control(Control::Close),
mask: None,
payload: Vec::new(),
}
}
}
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f,
@ -472,11 +453,11 @@ length: {}
payload length: {}
payload: 0x{}
",
self.finished,
self.rsv1,
self.rsv2,
self.rsv3,
self.opcode,
self.header.is_final,
self.header.rsv1,
self.header.rsv2,
self.header.rsv3,
self.header.opcode,
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
@ -484,6 +465,57 @@ payload: 0x{}
}
}
/// Handling of the length format.
enum LengthFormat {
U8(u8),
U16,
U64,
}
impl LengthFormat {
/// Get length format for a given data size.
#[inline]
fn for_length(length: u64) -> Self {
if length < 126 {
LengthFormat::U8(length as u8)
} else if length < 65536 {
LengthFormat::U16
} else {
LengthFormat::U64
}
}
/// Get the size of length encoding.
#[inline]
fn extra_bytes(&self) -> usize {
match *self {
LengthFormat::U8(_) => 0,
LengthFormat::U16 => 2,
LengthFormat::U64 => 8,
}
}
/// Encode the givem length.
#[inline]
fn length_byte(&self) -> u8 {
match *self {
LengthFormat::U8(b) => b,
LengthFormat::U16 => 126,
LengthFormat::U64 => 127,
}
}
/// Get length format for a given length byte.
#[inline]
fn for_byte(byte: u8) -> Self {
match byte & 0x7F {
126 => LengthFormat::U16,
127 => LengthFormat::U64,
b => LengthFormat::U8(b)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -496,7 +528,11 @@ mod tests {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
]);
let frame = Frame::parse(&mut raw).unwrap().unwrap();
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
}
@ -515,12 +551,4 @@ mod tests {
assert!(view.contains("payload:"));
}
#[test]
fn parse_overflow() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let _ = Frame::parse(&mut raw); // should not crash
}
}

@ -5,7 +5,7 @@ pub mod coding;
mod frame;
mod mask;
pub use self::frame::Frame;
pub use self::frame::{Frame, FrameHeader};
pub use self::frame::CloseFrame;
use std::io::{Read, Write};
@ -16,9 +16,14 @@ use error::{Error, Result};
/// A reader and writer for WebSocket frames.
#[derive(Debug)]
pub struct FrameSocket<Stream> {
/// The underlying network stream.
stream: Stream,
/// Buffer to read data from the stream.
in_buffer: InputBuffer,
/// Buffer to send packets to the network.
out_buffer: Vec<u8>,
/// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>,
}
impl<Stream> FrameSocket<Stream> {
@ -28,24 +33,30 @@ impl<Stream> FrameSocket<Stream> {
stream: stream,
in_buffer: InputBuffer::with_capacity(MIN_READ),
out_buffer: Vec::new(),
header: None,
}
}
/// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket {
stream: stream,
in_buffer: InputBuffer::from_partially_read(part),
out_buffer: Vec::new(),
header: None,
}
}
/// Extract a stream from the socket.
pub fn into_inner(self) -> (Stream, Vec<u8>) {
(self.stream, self.in_buffer.into_vec())
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.stream
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.stream
@ -57,12 +68,37 @@ impl<Stream> FrameSocket<Stream>
{
/// Read a frame from stream.
pub fn read_frame(&mut self) -> Result<Option<Frame>> {
loop {
if let Some(frame) = Frame::parse(&mut self.in_buffer.as_cursor_mut())? {
trace!("received frame {}", frame);
return Ok(Some(frame));
let payload = loop {
{
let cursor = self.in_buffer.as_cursor_mut();
if self.header.is_none() {
self.header = FrameHeader::parse(cursor)?;
}
if let Some((_, ref length)) = self.header {
let length = *length;
// Make sure `length` is not too big (fits into `usize`).
if length > usize::max_value() as u64 {
return Err(Error::Capacity(
format!("Message length too big: {}", length).into()
))
}
let input_size = cursor.get_ref().len() as u64 - cursor.position();
if length <= input_size {
// No truncation here since `length` is checked above
let mut payload = Vec::with_capacity(length as usize);
if length > 0 {
cursor.take(length).read_to_end(&mut payload)?;
}
break payload
}
}
}
// No full frames in buffer.
// Not enough data in buffer.
let size = self.in_buffer.prepare_reserve(MIN_READ)
.with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
@ -71,7 +107,13 @@ impl<Stream> FrameSocket<Stream>
trace!("no frame received");
return Ok(None)
}
}
};
let (header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, payload);
trace!("received frame {}", frame);
Ok(Some(frame))
}
}
@ -155,4 +197,13 @@ mod tests {
]);
}
#[test]
fn parse_overflow() {
let raw = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let mut sock = FrameSocket::new(raw);
let _ = sock.read_frame(); // should not crash
}
}

Loading…
Cancel
Save