diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 7c361ca..89f82f3 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -35,17 +35,176 @@ impl<'t> fmt::Display for CloseFrame<'t> { } } -/// A struct representing a WebSocket frame. +/// A struct representing a WebSocket frame header. +#[allow(missing_copy_implementations)] #[derive(Debug, Clone)] -pub struct Frame { - finished: bool, - rsv1: bool, - rsv2: bool, - rsv3: bool, - opcode: OpCode, +pub struct FrameHeader { + /// Indicates that the frame is the last one of a possibly fragmented message. + pub is_final: bool, + /// Reserved for protocol extensions. + pub rsv1: bool, + /// Reserved for protocol extensions. + pub rsv2: bool, + /// Reserved for protocol extensions. + pub rsv3: bool, + /// WebSocket protocol opcode. + pub opcode: OpCode, + /// A frame mask, if any. + pub mask: Option<[u8; 4]>, +} + +impl Default for FrameHeader { + fn default() -> Self { + FrameHeader { + is_final: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: OpCode::Control(Control::Close), + mask: None, + } + } +} + +impl FrameHeader { + /// Parse a header from an input stream. + /// Returns `None` if insufficient data and does not consume anything in this case. + /// Payload size is returned along with the header. + pub fn parse(cursor: &mut Cursor>) -> Result> { + let initial = cursor.position(); + match Self::parse_internal(cursor) { + ret @ Ok(None) => { + cursor.set_position(initial); + ret + } + ret => ret + } + } + + /// Get the size of the header formatted with given payload length. + pub fn len(&self, length: u64) -> usize { + 2 + + LengthFormat::for_length(length).extra_bytes() + + if self.mask.is_some() { 4 } else { 0 } + } + + /// Format a header for given payload size. + pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> { + let code: u8 = self.opcode.into(); + + let one = { + code + | if self.is_final { 0x80 } else { 0 } + | if self.rsv1 { 0x40 } else { 0 } + | if self.rsv2 { 0x20 } else { 0 } + | if self.rsv3 { 0x10 } else { 0 } + }; + + let lenfmt = LengthFormat::for_length(length); + + let two = { + lenfmt.length_byte() + | if self.mask.is_some() { 0x80 } else { 0 } + }; + + output.write(&[one, two])?; + match lenfmt { + LengthFormat::U8(_) => (), + LengthFormat::U16 => output.write_u16::(length as u16)?, + LengthFormat::U64 => output.write_u64::(length)?, + } + + if let Some(ref mask) = self.mask { + output.write_all(mask)? + } + + Ok(()) + } +} + +impl FrameHeader { + /// Internal parse engine. + /// Returns `None` if insufficient data. + /// Payload size is returned along with the header. + fn parse_internal(cursor: &mut impl Read) -> Result> { + let (first, second) = { + let mut head = [0u8; 2]; + if cursor.read(&mut head)? != 2 { + return Ok(None) + } + trace!("Parsed headers {:?}", head); + (head[0], head[1]) + }; + + trace!("First: {:b}", first); + trace!("Second: {:b}", second); + + let is_final = first & 0x80 != 0; + + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; + + let opcode = OpCode::from(first & 0x0F); + trace!("Opcode: {:?}", opcode); + + let masked = second & 0x80 != 0; + trace!("Masked: {:?}", masked); + + let length = { + let length_byte = second & 0x7F; + let length_length = LengthFormat::for_byte(length_byte).extra_bytes(); + if length_length > 0 { + match cursor.read_uint::(length_length) { + Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(err) => { + return Err(err.into()); + } + Ok(read) => read + } + } else { + length_byte as u64 + } + }; + + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if cursor.read(&mut mask_bytes)? != 4 { + return Ok(None) + } else { + Some(mask_bytes) + } + } else { + None + }; + + // Disallow bad opcode + match opcode { + OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { + return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) + } + _ => () + } + + let hdr = FrameHeader { + is_final, + rsv1, + rsv2, + rsv3, + opcode, + mask, + }; - mask: Option<[u8; 4]>, + Ok(Some((hdr, length))) + } +} +/// A struct representing a WebSocket frame. +#[derive(Debug, Clone)] +pub struct Frame { + header: FrameHeader, payload: Vec, } @@ -55,51 +214,38 @@ impl Frame { /// This is the length of the header + the length of the payload. #[inline] pub fn len(&self) -> usize { - let mut header_length = 2; - let payload_len = self.payload().len(); - if payload_len > 125 { - if payload_len <= u16::max_value() as usize { - header_length += 2; - } else { - header_length += 8; - } - } - - if self.is_masked() { - header_length += 4; - } - - header_length + payload_len + let length = self.payload.len(); + self.header.len(length as u64) + length } /// Test whether the frame is a final frame. #[inline] pub fn is_final(&self) -> bool { - self.finished + self.header.is_final } /// Test whether the first reserved bit is set. #[inline] pub fn has_rsv1(&self) -> bool { - self.rsv1 + self.header.rsv1 } /// Test whether the second reserved bit is set. #[inline] pub fn has_rsv2(&self) -> bool { - self.rsv2 + self.header.rsv2 } /// Test whether the third reserved bit is set. #[inline] pub fn has_rsv3(&self) -> bool { - self.rsv3 + self.header.rsv3 } /// Get the OpCode of the frame. #[inline] pub fn opcode(&self) -> OpCode { - self.opcode + self.header.opcode } /// Get a reference to the frame's payload. @@ -112,7 +258,7 @@ impl Frame { #[doc(hidden)] #[inline] pub fn is_masked(&self) -> bool { - self.mask.is_some() + self.header.mask.is_some() } // Get an optional reference to the frame's mask. @@ -120,35 +266,35 @@ impl Frame { #[allow(dead_code)] #[inline] pub fn mask(&self) -> Option<&[u8; 4]> { - self.mask.as_ref() + self.header.mask.as_ref() } /// Make this frame a final frame. #[allow(dead_code)] #[inline] pub fn set_final(&mut self, is_final: bool) -> &mut Frame { - self.finished = is_final; + self.header.is_final = is_final; self } /// Set the first reserved bit. #[inline] pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { - self.rsv1 = has_rsv1; + self.header.rsv1 = has_rsv1; self } /// Set the second reserved bit. #[inline] pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { - self.rsv2 = has_rsv2; + self.header.rsv2 = has_rsv2; self } /// Set the third reserved bit. #[inline] pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { - self.rsv3 = has_rsv3; + self.header.rsv3 = has_rsv3; self } @@ -156,7 +302,7 @@ impl Frame { #[allow(dead_code)] #[inline] pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { - self.opcode = opcode; + self.header.opcode = opcode; self } @@ -175,7 +321,7 @@ impl Frame { #[doc(hidden)] #[inline] pub fn set_mask(&mut self) -> &mut Frame { - self.mask = Some(generate_mask()); + self.header.mask = Some(generate_mask()); self } @@ -184,10 +330,10 @@ impl Frame { #[doc(hidden)] #[inline] pub fn remove_mask(&mut self) { - self.mask.and_then(|mask| { + self.header.mask.and_then(|mask| { Some(apply_mask(&mut self.payload, &mask)) }); - self.mask = None; + self.header.mask = None; } /// Consume the frame into its payload as binary. @@ -220,17 +366,19 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: Vec, code: OpCode, finished: bool) -> Frame { - debug_assert!(match code { + pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + debug_assert!(match opcode { OpCode::Data(_) => true, _ => false, }, "Invalid opcode for data frame."); Frame { - finished: finished, - opcode: code, + header: FrameHeader { + is_final, + opcode, + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -238,9 +386,11 @@ impl Frame { #[inline] pub fn pong(data: Vec) -> Frame { Frame { - opcode: OpCode::Control(Control::Pong), + header: FrameHeader { + opcode: OpCode::Control(Control::Pong), + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -248,9 +398,11 @@ impl Frame { #[inline] pub fn ping(data: Vec) -> Frame { Frame { - opcode: OpCode::Control(Control::Ping), + header: FrameHeader { + opcode: OpCode::Control(Control::Ping), + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -267,199 +419,28 @@ impl Frame { }; Frame { + header: FrameHeader::default(), payload: payload, - .. Frame::default() } } - /// Parse the input stream into a frame. - pub fn parse(cursor: &mut Cursor>) -> Result> { - let size = cursor.get_ref().len() as u64 - cursor.position(); - let initial = cursor.position(); - trace!("Position in buffer {}", initial); - - let mut head = [0u8; 2]; - if try!(cursor.read(&mut head)) != 2 { - cursor.set_position(initial); - return Ok(None) - } - - trace!("Parsed headers {:?}", head); - - let first = head[0]; - let second = head[1]; - trace!("First: {:b}", first); - trace!("Second: {:b}", second); - - let finished = first & 0x80 != 0; - - let rsv1 = first & 0x40 != 0; - let rsv2 = first & 0x20 != 0; - let rsv3 = first & 0x10 != 0; - - let opcode = OpCode::from(first & 0x0F); - trace!("Opcode: {:?}", opcode); - - let masked = second & 0x80 != 0; - trace!("Masked: {:?}", masked); - - let mut header_length = 2; - - let mut length = (second & 0x7F) as u64; - - if let Some(length_nbytes) = match length { - 126 => Some(2), - 127 => Some(8), - _ => None, - } { - match cursor.read_uint::(length_nbytes) { - Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { - cursor.set_position(initial); - return Ok(None); - } - Err(err) => { - return Err(Error::from(err)); - } - Ok(read) => { - length = read; - } - }; - header_length += length_nbytes as u64; - } - trace!("Payload length: {}", length); - - let mask = if masked { - let mut mask_bytes = [0u8; 4]; - if try!(cursor.read(&mut mask_bytes)) != 4 { - cursor.set_position(initial); - return Ok(None) - } else { - header_length += 4; - Some(mask_bytes) - } - } else { - None - }; - - // Disallow bad opcode - match opcode { - OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) - } - _ => () - } - - // Make sure `length` is not too big (fits into `usize`). - if length > usize::max_value() as u64 { - return Err(Error::Capacity(format!("Message length too big: {}", length).into())); - } - - if size < header_length || size - header_length < length { - cursor.set_position(initial); - return Ok(None) - } - - // Size is checked above, so it won't be truncated here. - let mut data = Vec::with_capacity(length as usize); - if length > 0 { - try!(cursor.take(length).read_to_end(&mut data)); + /// Create a frame from given header and data. + pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { + Frame { + header, + payload, } - debug_assert_eq!(data.len() as u64, length); - - let frame = Frame { - finished: finished, - rsv1: rsv1, - rsv2: rsv2, - rsv3: rsv3, - opcode: opcode, - mask: mask, - payload: data, - }; - - - Ok(Some(frame)) } /// Write a frame out to a buffer - pub fn format(mut self, w: &mut W) -> Result<()> - where W: Write - { - let mut one = 0u8; - let code: u8 = self.opcode.into(); - if self.is_final() { - one |= 0x80; - } - if self.has_rsv1() { - one |= 0x40; - } - if self.has_rsv2() { - one |= 0x20; - } - if self.has_rsv3() { - one |= 0x10; - } - one |= code; - - let mut two = 0u8; - - if self.is_masked() { - two |= 0x80; - } - - if self.payload.len() < 126 { - two |= self.payload.len() as u8; - let headers = [one, two]; - try!(w.write(&headers)); - } else if self.payload.len() <= 65535 { - two |= 126; - let mut length_bytes = [0u8; 2]; - NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16); - let headers = [one, two, length_bytes[0], length_bytes[1]]; - try!(w.write(&headers)); - } else { - two |= 127; - let mut length_bytes = [0u8; 8]; - NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64); - let headers = [ - one, - two, - length_bytes[0], - length_bytes[1], - length_bytes[2], - length_bytes[3], - length_bytes[4], - length_bytes[5], - length_bytes[6], - length_bytes[7], - ]; - try!(w.write(&headers)); - } - - if self.is_masked() { - let mask = self.mask.take().unwrap(); - apply_mask(&mut self.payload, &mask); - try!(w.write(&mask)); - } - - try!(w.write(&self.payload)); + pub fn format(mut self, output: &mut impl Write) -> Result<()> { + self.header.format(self.payload.len() as u64, output)?; + self.remove_mask(); + output.write_all(self.payload())?; Ok(()) } } -impl Default for Frame { - fn default() -> Frame { - Frame { - finished: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: OpCode::Control(Control::Close), - mask: None, - payload: Vec::new(), - } - } -} - impl fmt::Display for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, @@ -472,11 +453,11 @@ length: {} payload length: {} payload: 0x{} ", - self.finished, - self.rsv1, - self.rsv2, - self.rsv3, - self.opcode, + self.header.is_final, + self.header.rsv1, + self.header.rsv2, + self.header.rsv3, + self.header.opcode, // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), @@ -484,6 +465,57 @@ payload: 0x{} } } +/// Handling of the length format. +enum LengthFormat { + U8(u8), + U16, + U64, +} + +impl LengthFormat { + /// Get length format for a given data size. + #[inline] + fn for_length(length: u64) -> Self { + if length < 126 { + LengthFormat::U8(length as u8) + } else if length < 65536 { + LengthFormat::U16 + } else { + LengthFormat::U64 + } + } + + /// Get the size of length encoding. + #[inline] + fn extra_bytes(&self) -> usize { + match *self { + LengthFormat::U8(_) => 0, + LengthFormat::U16 => 2, + LengthFormat::U64 => 8, + } + } + + /// Encode the givem length. + #[inline] + fn length_byte(&self) -> u8 { + match *self { + LengthFormat::U8(b) => b, + LengthFormat::U16 => 126, + LengthFormat::U64 => 127, + } + } + + /// Get length format for a given length byte. + #[inline] + fn for_byte(byte: u8) -> Self { + match byte & 0x7F { + 126 => LengthFormat::U16, + 127 => LengthFormat::U64, + b => LengthFormat::U8(b) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -496,7 +528,11 @@ mod tests { let mut raw: Cursor> = Cursor::new(vec![ 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); - let frame = Frame::parse(&mut raw).unwrap().unwrap(); + let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); + assert_eq!(length, 7); + let mut payload = Vec::new(); + raw.read_to_end(&mut payload).unwrap(); + let frame = Frame::from_payload(header, payload); assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); } @@ -515,12 +551,4 @@ mod tests { assert!(view.contains("payload:")); } - #[test] - fn parse_overflow() { - let mut raw: Cursor> = Cursor::new(vec![ - 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, - ]); - let _ = Frame::parse(&mut raw); // should not crash - } } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index f90c288..56711d0 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,7 +5,7 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::Frame; +pub use self::frame::{Frame, FrameHeader}; pub use self::frame::CloseFrame; use std::io::{Read, Write}; @@ -16,9 +16,14 @@ use error::{Error, Result}; /// A reader and writer for WebSocket frames. #[derive(Debug)] pub struct FrameSocket { + /// The underlying network stream. stream: Stream, + /// Buffer to read data from the stream. in_buffer: InputBuffer, + /// Buffer to send packets to the network. out_buffer: Vec, + /// Header and remaining size of the incoming packet being processed. + header: Option<(FrameHeader, u64)>, } impl FrameSocket { @@ -28,24 +33,30 @@ impl FrameSocket { stream: stream, in_buffer: InputBuffer::with_capacity(MIN_READ), out_buffer: Vec::new(), + header: None, } } + /// Create a new frame socket from partially read data. pub fn from_partially_read(stream: Stream, part: Vec) -> Self { FrameSocket { stream: stream, in_buffer: InputBuffer::from_partially_read(part), out_buffer: Vec::new(), + header: None, } } + /// Extract a stream from the socket. pub fn into_inner(self) -> (Stream, Vec) { (self.stream, self.in_buffer.into_vec()) } + /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &Stream { &self.stream } + /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut Stream { &mut self.stream @@ -57,12 +68,37 @@ impl FrameSocket { /// Read a frame from stream. pub fn read_frame(&mut self) -> Result> { - loop { - if let Some(frame) = Frame::parse(&mut self.in_buffer.as_cursor_mut())? { - trace!("received frame {}", frame); - return Ok(Some(frame)); + let payload = loop { + { + let cursor = self.in_buffer.as_cursor_mut(); + + if self.header.is_none() { + self.header = FrameHeader::parse(cursor)?; + } + + if let Some((_, ref length)) = self.header { + let length = *length; + + // Make sure `length` is not too big (fits into `usize`). + if length > usize::max_value() as u64 { + return Err(Error::Capacity( + format!("Message length too big: {}", length).into() + )) + } + + let input_size = cursor.get_ref().len() as u64 - cursor.position(); + if length <= input_size { + // No truncation here since `length` is checked above + let mut payload = Vec::with_capacity(length as usize); + if length > 0 { + cursor.take(length).read_to_end(&mut payload)?; + } + break payload + } + } } - // No full frames in buffer. + + // Not enough data in buffer. let size = self.in_buffer.prepare_reserve(MIN_READ) .with_limit(usize::max_value()) .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? @@ -71,7 +107,13 @@ impl FrameSocket trace!("no frame received"); return Ok(None) } - } + }; + + let (header, length) = self.header.take().expect("Bug: no frame header"); + debug_assert_eq!(payload.len() as u64, length); + let frame = Frame::from_payload(header, payload); + trace!("received frame {}", frame); + Ok(Some(frame)) } } @@ -155,4 +197,13 @@ mod tests { ]); } + #[test] + fn parse_overflow() { + let raw = Cursor::new(vec![ + 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + ]); + let mut sock = FrameSocket::new(raw); + let _ = sock.read_frame(); // should not crash + } }