diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 15494f9..014af64 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -52,7 +52,7 @@ impl<'t> Request<'t> { /// Adds a custom header to the request. pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { - let mut headers = self.extra_headers.take().unwrap_or(vec![]); + let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new); headers.push((name, value)); self.extra_headers = Some(headers); } @@ -113,7 +113,7 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { role: client, machine: machine } + MidHandshake { role: client, machine } } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 73419b6..23f0d77 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -28,7 +28,7 @@ impl Headers { /// Iterate over all headers with the given name. pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { HeadersIter { - name: name, + name, iter: self.data.iter() } } diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index a945be1..39f37a5 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -16,14 +16,14 @@ impl HandshakeMachine { /// Start reading data from the peer. pub fn start_read(stream: Stream) -> Self { HandshakeMachine { - stream: stream, + stream, state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), } } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { HandshakeMachine { - stream: stream, + stream, state: HandshakeState::Writing(Cursor::new(data.into())), } } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index fcb3f78..f28445d 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -110,7 +110,7 @@ pub enum ProcessingResult { fn convert_key(input: &[u8]) -> Result { // ... field is constructed by concatenating /key/ ... // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; let mut sha1 = Sha1::default(); sha1.input(input); sha1.input(WS_GUID); diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 7c361ca..c5bf9c2 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -35,159 +35,242 @@ impl<'t> fmt::Display for CloseFrame<'t> { } } -/// A struct representing a WebSocket frame. +/// A struct representing a WebSocket frame header. +#[allow(missing_copy_implementations)] #[derive(Debug, Clone)] -pub struct Frame { - finished: bool, - rsv1: bool, - rsv2: bool, - rsv3: bool, - opcode: OpCode, - +pub struct FrameHeader { + /// Indicates that the frame is the last one of a possibly fragmented message. + pub is_final: bool, + /// Reserved for protocol extensions. + pub rsv1: bool, + /// Reserved for protocol extensions. + pub rsv2: bool, + /// Reserved for protocol extensions. + pub rsv3: bool, + /// WebSocket protocol opcode. + pub opcode: OpCode, + /// A frame mask, if any. mask: Option<[u8; 4]>, - - payload: Vec, } -impl Frame { +impl Default for FrameHeader { + fn default() -> Self { + FrameHeader { + is_final: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: OpCode::Control(Control::Close), + mask: None, + } + } +} - /// Get the length of the frame. - /// This is the length of the header + the length of the payload. - #[inline] - pub fn len(&self) -> usize { - let mut header_length = 2; - let payload_len = self.payload().len(); - if payload_len > 125 { - if payload_len <= u16::max_value() as usize { - header_length += 2; - } else { - header_length += 8; +impl FrameHeader { + /// Parse a header from an input stream. + /// Returns `None` if insufficient data and does not consume anything in this case. + /// Payload size is returned along with the header. + pub fn parse(cursor: &mut Cursor>) -> Result> { + let initial = cursor.position(); + match Self::parse_internal(cursor) { + ret @ Ok(None) => { + cursor.set_position(initial); + ret } + ret => ret } + } + + /// Get the size of the header formatted with given payload length. + pub fn len(&self, length: u64) -> usize { + 2 + + LengthFormat::for_length(length).extra_bytes() + + if self.mask.is_some() { 4 } else { 0 } + } - if self.is_masked() { - header_length += 4; + /// Format a header for given payload size. + pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> { + let code: u8 = self.opcode.into(); + + let one = { + code + | if self.is_final { 0x80 } else { 0 } + | if self.rsv1 { 0x40 } else { 0 } + | if self.rsv2 { 0x20 } else { 0 } + | if self.rsv3 { 0x10 } else { 0 } + }; + + let lenfmt = LengthFormat::for_length(length); + + let two = { + lenfmt.length_byte() + | if self.mask.is_some() { 0x80 } else { 0 } + }; + + output.write_all(&[one, two])?; + match lenfmt { + LengthFormat::U8(_) => (), + LengthFormat::U16 => output.write_u16::(length as u16)?, + LengthFormat::U64 => output.write_u64::(length)?, } - header_length + payload_len - } + if let Some(ref mask) = self.mask { + output.write_all(mask)? + } - /// Test whether the frame is a final frame. - #[inline] - pub fn is_final(&self) -> bool { - self.finished + Ok(()) } - /// Test whether the first reserved bit is set. - #[inline] - pub fn has_rsv1(&self) -> bool { - self.rsv1 + /// Generate a random frame mask and store this in the header. + /// + /// Of course this does not change frame contents. It just generates a mask. + pub(crate) fn set_random_mask(&mut self) { + self.mask = Some(generate_mask()) } +} - /// Test whether the second reserved bit is set. - #[inline] - pub fn has_rsv2(&self) -> bool { - self.rsv2 - } +impl FrameHeader { + /// Internal parse engine. + /// Returns `None` if insufficient data. + /// Payload size is returned along with the header. + fn parse_internal(cursor: &mut impl Read) -> Result> { + let (first, second) = { + let mut head = [0u8; 2]; + if cursor.read(&mut head)? != 2 { + return Ok(None) + } + trace!("Parsed headers {:?}", head); + (head[0], head[1]) + }; - /// Test whether the third reserved bit is set. - #[inline] - pub fn has_rsv3(&self) -> bool { - self.rsv3 - } + trace!("First: {:b}", first); + trace!("Second: {:b}", second); - /// Get the OpCode of the frame. - #[inline] - pub fn opcode(&self) -> OpCode { - self.opcode - } + let is_final = first & 0x80 != 0; - /// Get a reference to the frame's payload. - #[inline] - pub fn payload(&self) -> &Vec { - &self.payload - } + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; - // Test whether the frame is masked. - #[doc(hidden)] - #[inline] - pub fn is_masked(&self) -> bool { - self.mask.is_some() - } + let opcode = OpCode::from(first & 0x0F); + trace!("Opcode: {:?}", opcode); - // Get an optional reference to the frame's mask. - #[doc(hidden)] - #[allow(dead_code)] - #[inline] - pub fn mask(&self) -> Option<&[u8; 4]> { - self.mask.as_ref() - } + let masked = second & 0x80 != 0; + trace!("Masked: {:?}", masked); - /// Make this frame a final frame. - #[allow(dead_code)] - #[inline] - pub fn set_final(&mut self, is_final: bool) -> &mut Frame { - self.finished = is_final; - self + let length = { + let length_byte = second & 0x7F; + let length_length = LengthFormat::for_byte(length_byte).extra_bytes(); + if length_length > 0 { + match cursor.read_uint::(length_length) { + Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(err) => { + return Err(err.into()); + } + Ok(read) => read + } + } else { + length_byte as u64 + } + }; + + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if cursor.read(&mut mask_bytes)? != 4 { + return Ok(None) + } else { + Some(mask_bytes) + } + } else { + None + }; + + // Disallow bad opcode + match opcode { + OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { + return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) + } + _ => () + } + + let hdr = FrameHeader { + is_final, + rsv1, + rsv2, + rsv3, + opcode, + mask, + }; + + Ok(Some((hdr, length))) } +} + +/// A struct representing a WebSocket frame. +#[derive(Debug, Clone)] +pub struct Frame { + header: FrameHeader, + payload: Vec, +} - /// Set the first reserved bit. +impl Frame { + + /// Get the length of the frame. + /// This is the length of the header + the length of the payload. #[inline] - pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { - self.rsv1 = has_rsv1; - self + pub fn len(&self) -> usize { + let length = self.payload.len(); + self.header.len(length as u64) + length } - /// Set the second reserved bit. + /// Get a reference to the frame's header. #[inline] - pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { - self.rsv2 = has_rsv2; - self + pub fn header(&self) -> &FrameHeader { + &self.header } - /// Set the third reserved bit. + /// Get a mutable reference to the frame's header. #[inline] - pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { - self.rsv3 = has_rsv3; - self + pub fn header_mut(&mut self) -> &mut FrameHeader { + &mut self.header } - /// Set the OpCode. - #[allow(dead_code)] + /// Get a reference to the frame's payload. #[inline] - pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { - self.opcode = opcode; - self + pub fn payload(&self) -> &Vec { + &self.payload } - /// Edit the frame's payload. - #[allow(dead_code)] + /// Get a mutable reference to the frame's payload. #[inline] pub fn payload_mut(&mut self) -> &mut Vec { &mut self.payload } - // Generate a new mask for this frame. - // - // This method simply generates and stores the mask. It does not change the payload data. - // Instead, the payload data will be masked with the generated mask when the frame is sent - // to the other endpoint. - #[doc(hidden)] + /// Test whether the frame is masked. #[inline] - pub fn set_mask(&mut self) -> &mut Frame { - self.mask = Some(generate_mask()); - self + pub(crate) fn is_masked(&self) -> bool { + self.header.mask.is_some() } - // This method unmasks the payload and should only be called on frames that are actually - // masked. In other words, those frames that have just been received from a client endpoint. - #[doc(hidden)] + /// Generate a random mask for the frame. + /// + /// This just generates a mask, payload is not changed. The actual masking is performed + /// either on `format()` or on `apply_mask()` call. #[inline] - pub fn remove_mask(&mut self) { - self.mask.and_then(|mask| { - Some(apply_mask(&mut self.payload, &mask)) - }); - self.mask = None; + pub(crate) fn set_random_mask(&mut self) { + self.header.set_random_mask() + } + + /// This method unmasks the payload and should only be called on frames that are actually + /// masked. In other words, those frames that have just been received from a client endpoint. + #[inline] + pub(crate) fn apply_mask(&mut self) { + if let Some(mask) = self.header.mask.take() { + apply_mask(&mut self.payload, mask) + } } /// Consume the frame into its payload as binary. @@ -204,7 +287,7 @@ impl Frame { /// Consume the frame into a closing frame. #[inline] - pub fn into_close(self) -> Result>> { + pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), 1 => Err(Error::Protocol("Invalid close sequence".into())), @@ -213,24 +296,26 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some(CloseFrame { code: code, reason: text.into() })) + Ok(Some(CloseFrame { code, reason: text.into() })) } } } /// Create a new data frame. #[inline] - pub fn message(data: Vec, code: OpCode, finished: bool) -> Frame { - debug_assert!(match code { + pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + debug_assert!(match opcode { OpCode::Data(_) => true, _ => false, }, "Invalid opcode for data frame."); Frame { - finished: finished, - opcode: code, + header: FrameHeader { + is_final, + opcode, + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -238,9 +323,11 @@ impl Frame { #[inline] pub fn pong(data: Vec) -> Frame { Frame { - opcode: OpCode::Control(Control::Pong), + header: FrameHeader { + opcode: OpCode::Control(Control::Pong), + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -248,9 +335,11 @@ impl Frame { #[inline] pub fn ping(data: Vec) -> Frame { Frame { - opcode: OpCode::Control(Control::Ping), + header: FrameHeader { + opcode: OpCode::Control(Control::Ping), + .. FrameHeader::default() + }, payload: data, - .. Frame::default() } } @@ -267,199 +356,28 @@ impl Frame { }; Frame { - payload: payload, - .. Frame::default() + header: FrameHeader::default(), + payload, } } - /// Parse the input stream into a frame. - pub fn parse(cursor: &mut Cursor>) -> Result> { - let size = cursor.get_ref().len() as u64 - cursor.position(); - let initial = cursor.position(); - trace!("Position in buffer {}", initial); - - let mut head = [0u8; 2]; - if try!(cursor.read(&mut head)) != 2 { - cursor.set_position(initial); - return Ok(None) - } - - trace!("Parsed headers {:?}", head); - - let first = head[0]; - let second = head[1]; - trace!("First: {:b}", first); - trace!("Second: {:b}", second); - - let finished = first & 0x80 != 0; - - let rsv1 = first & 0x40 != 0; - let rsv2 = first & 0x20 != 0; - let rsv3 = first & 0x10 != 0; - - let opcode = OpCode::from(first & 0x0F); - trace!("Opcode: {:?}", opcode); - - let masked = second & 0x80 != 0; - trace!("Masked: {:?}", masked); - - let mut header_length = 2; - - let mut length = (second & 0x7F) as u64; - - if let Some(length_nbytes) = match length { - 126 => Some(2), - 127 => Some(8), - _ => None, - } { - match cursor.read_uint::(length_nbytes) { - Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { - cursor.set_position(initial); - return Ok(None); - } - Err(err) => { - return Err(Error::from(err)); - } - Ok(read) => { - length = read; - } - }; - header_length += length_nbytes as u64; - } - trace!("Payload length: {}", length); - - let mask = if masked { - let mut mask_bytes = [0u8; 4]; - if try!(cursor.read(&mut mask_bytes)) != 4 { - cursor.set_position(initial); - return Ok(None) - } else { - header_length += 4; - Some(mask_bytes) - } - } else { - None - }; - - // Disallow bad opcode - match opcode { - OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) - } - _ => () - } - - // Make sure `length` is not too big (fits into `usize`). - if length > usize::max_value() as u64 { - return Err(Error::Capacity(format!("Message length too big: {}", length).into())); - } - - if size < header_length || size - header_length < length { - cursor.set_position(initial); - return Ok(None) - } - - // Size is checked above, so it won't be truncated here. - let mut data = Vec::with_capacity(length as usize); - if length > 0 { - try!(cursor.take(length).read_to_end(&mut data)); + /// Create a frame from given header and data. + pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { + Frame { + header, + payload, } - debug_assert_eq!(data.len() as u64, length); - - let frame = Frame { - finished: finished, - rsv1: rsv1, - rsv2: rsv2, - rsv3: rsv3, - opcode: opcode, - mask: mask, - payload: data, - }; - - - Ok(Some(frame)) } /// Write a frame out to a buffer - pub fn format(mut self, w: &mut W) -> Result<()> - where W: Write - { - let mut one = 0u8; - let code: u8 = self.opcode.into(); - if self.is_final() { - one |= 0x80; - } - if self.has_rsv1() { - one |= 0x40; - } - if self.has_rsv2() { - one |= 0x20; - } - if self.has_rsv3() { - one |= 0x10; - } - one |= code; - - let mut two = 0u8; - - if self.is_masked() { - two |= 0x80; - } - - if self.payload.len() < 126 { - two |= self.payload.len() as u8; - let headers = [one, two]; - try!(w.write(&headers)); - } else if self.payload.len() <= 65535 { - two |= 126; - let mut length_bytes = [0u8; 2]; - NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16); - let headers = [one, two, length_bytes[0], length_bytes[1]]; - try!(w.write(&headers)); - } else { - two |= 127; - let mut length_bytes = [0u8; 8]; - NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64); - let headers = [ - one, - two, - length_bytes[0], - length_bytes[1], - length_bytes[2], - length_bytes[3], - length_bytes[4], - length_bytes[5], - length_bytes[6], - length_bytes[7], - ]; - try!(w.write(&headers)); - } - - if self.is_masked() { - let mask = self.mask.take().unwrap(); - apply_mask(&mut self.payload, &mask); - try!(w.write(&mask)); - } - - try!(w.write(&self.payload)); + pub fn format(mut self, output: &mut impl Write) -> Result<()> { + self.header.format(self.payload.len() as u64, output)?; + self.apply_mask(); + output.write_all(self.payload())?; Ok(()) } } -impl Default for Frame { - fn default() -> Frame { - Frame { - finished: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: OpCode::Control(Control::Close), - mask: None, - payload: Vec::new(), - } - } -} - impl fmt::Display for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, @@ -472,11 +390,11 @@ length: {} payload length: {} payload: 0x{} ", - self.finished, - self.rsv1, - self.rsv2, - self.rsv3, - self.opcode, + self.header.is_final, + self.header.rsv1, + self.header.rsv2, + self.header.rsv3, + self.header.opcode, // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), @@ -484,6 +402,57 @@ payload: 0x{} } } +/// Handling of the length format. +enum LengthFormat { + U8(u8), + U16, + U64, +} + +impl LengthFormat { + /// Get length format for a given data size. + #[inline] + fn for_length(length: u64) -> Self { + if length < 126 { + LengthFormat::U8(length as u8) + } else if length < 65536 { + LengthFormat::U16 + } else { + LengthFormat::U64 + } + } + + /// Get the size of length encoding. + #[inline] + fn extra_bytes(&self) -> usize { + match *self { + LengthFormat::U8(_) => 0, + LengthFormat::U16 => 2, + LengthFormat::U64 => 8, + } + } + + /// Encode the givem length. + #[inline] + fn length_byte(&self) -> u8 { + match *self { + LengthFormat::U8(b) => b, + LengthFormat::U16 => 126, + LengthFormat::U64 => 127, + } + } + + /// Get length format for a given length byte. + #[inline] + fn for_byte(byte: u8) -> Self { + match byte & 0x7F { + 126 => LengthFormat::U16, + 127 => LengthFormat::U64, + b => LengthFormat::U8(b) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -496,7 +465,11 @@ mod tests { let mut raw: Cursor> = Cursor::new(vec![ 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); - let frame = Frame::parse(&mut raw).unwrap().unwrap(); + let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); + assert_eq!(length, 7); + let mut payload = Vec::new(); + raw.read_to_end(&mut payload).unwrap(); + let frame = Frame::from_payload(header, payload); assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); } @@ -515,12 +488,4 @@ mod tests { assert!(view.contains("payload:")); } - #[test] - fn parse_overflow() { - let mut raw: Cursor> = Cursor::new(vec![ - 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, - ]); - let _ = Frame::parse(&mut raw); // should not crash - } } diff --git a/src/protocol/frame/mask.rs b/src/protocol/frame/mask.rs index 055be08..8b39d76 100644 --- a/src/protocol/frame/mask.rs +++ b/src/protocol/frame/mask.rs @@ -11,14 +11,14 @@ pub fn generate_mask() -> [u8; 4] { /// Mask/unmask a frame. #[inline] -pub fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { +pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) { apply_mask_fast32(buf, mask) } /// A safe unoptimized mask application. #[inline] #[allow(dead_code)] -fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { +fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { for (i, byte) in buf.iter_mut().enumerate() { *byte ^= mask[i & 3]; } @@ -27,7 +27,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { /// Faster version of `apply_mask()` which operates on 4-byte blocks. #[inline] #[allow(dead_code)] -fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { +fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { let mask_u32: u32 = unsafe { read_unaligned(mask.as_ptr() as *const u32) }; @@ -101,10 +101,10 @@ mod tests { // Check masking with proper alignment. { let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked, &mask); + apply_mask_fallback(&mut masked, mask); let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast, &mask); + apply_mask_fast32(&mut masked_fast, mask); assert_eq!(masked, masked_fast); } @@ -112,10 +112,10 @@ mod tests { // Check masking without alignment. { let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked[1..], &mask); + apply_mask_fallback(&mut masked[1..], mask); let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast[1..], &mask); + apply_mask_fast32(&mut masked_fast[1..], mask); assert_eq!(masked, masked_fast); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index f90c288..9d17338 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,7 +5,7 @@ pub mod coding; mod frame; mod mask; -pub use self::frame::Frame; +pub use self::frame::{Frame, FrameHeader}; pub use self::frame::CloseFrame; use std::io::{Read, Write}; @@ -16,36 +16,47 @@ use error::{Error, Result}; /// A reader and writer for WebSocket frames. #[derive(Debug)] pub struct FrameSocket { + /// The underlying network stream. stream: Stream, + /// Buffer to read data from the stream. in_buffer: InputBuffer, + /// Buffer to send packets to the network. out_buffer: Vec, + /// Header and remaining size of the incoming packet being processed. + header: Option<(FrameHeader, u64)>, } impl FrameSocket { /// Create a new frame socket. pub fn new(stream: Stream) -> Self { FrameSocket { - stream: stream, + stream, in_buffer: InputBuffer::with_capacity(MIN_READ), out_buffer: Vec::new(), + header: None, } } + /// Create a new frame socket from partially read data. pub fn from_partially_read(stream: Stream, part: Vec) -> Self { FrameSocket { - stream: stream, + stream, in_buffer: InputBuffer::from_partially_read(part), out_buffer: Vec::new(), + header: None, } } + /// Extract a stream from the socket. pub fn into_inner(self) -> (Stream, Vec) { (self.stream, self.in_buffer.into_vec()) } + /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &Stream { &self.stream } + /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut Stream { &mut self.stream @@ -56,13 +67,41 @@ impl FrameSocket where Stream: Read { /// Read a frame from stream. - pub fn read_frame(&mut self) -> Result> { - loop { - if let Some(frame) = Frame::parse(&mut self.in_buffer.as_cursor_mut())? { - trace!("received frame {}", frame); - return Ok(Some(frame)); + pub fn read_frame(&mut self, max_size: Option) -> Result> { + let max_size = max_size.unwrap_or_else(usize::max_value); + + let payload = loop { + { + let cursor = self.in_buffer.as_cursor_mut(); + + if self.header.is_none() { + self.header = FrameHeader::parse(cursor)?; + } + + if let Some((_, ref length)) = self.header { + let length = *length; + + // Enforce frame size limit early and make sure `length` + // is not too big (fits into `usize`). + if length > max_size as u64 { + return Err(Error::Capacity( + format!("Message length too big: {} > {}", length, max_size).into() + )) + } + + let input_size = cursor.get_ref().len() as u64 - cursor.position(); + if length <= input_size { + // No truncation here since `length` is checked above + let mut payload = Vec::with_capacity(length as usize); + if length > 0 { + cursor.take(length).read_to_end(&mut payload)?; + } + break payload + } + } } - // No full frames in buffer. + + // Not enough data in buffer. let size = self.in_buffer.prepare_reserve(MIN_READ) .with_limit(usize::max_value()) .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? @@ -71,7 +110,13 @@ impl FrameSocket trace!("no frame received"); return Ok(None) } - } + }; + + let (header, length) = self.header.take().expect("Bug: no frame header"); + debug_assert_eq!(payload.len() as u64, length); + let frame = Frame::from_payload(header, payload); + trace!("received frame {}", frame); + Ok(Some(frame)) } } @@ -118,11 +163,11 @@ mod tests { ]); let mut sock = FrameSocket::new(raw); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); - assert!(sock.read_frame().unwrap().is_none()); + assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); assert_eq!(rest, vec![0x99]); @@ -134,7 +179,7 @@ mod tests { 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, ]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } @@ -155,4 +200,24 @@ mod tests { ]); } + #[test] + fn parse_overflow() { + let raw = Cursor::new(vec![ + 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + ]); + let mut sock = FrameSocket::new(raw); + let _ = sock.read_frame(None); // should not crash + } + + #[test] + fn size_limit_hit() { + let raw = Cursor::new(vec![ + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ]); + let mut sock = FrameSocket::new(raw); + assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(), + "Space limit exceeded: Message length too big: 7 > 5" + ); + } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index b2ee1a5..8ee3eac 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -3,7 +3,7 @@ use std::fmt; use std::result::Result as StdResult; use std::str; -use error::Result; +use error::{Result, Error}; mod string_collect { @@ -26,6 +26,11 @@ mod string_collect { } } + pub fn len(&self) -> usize { + self.data.len() + .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0)) + } + pub fn extend>(&mut self, tail: T) -> Result<()> { let mut input: &[u8] = tail.as_ref(); @@ -105,8 +110,29 @@ impl IncompleteMessage { } } } + + /// Get the current filled size of the buffer. + pub fn len(&self) -> usize { + match self.collector { + IncompleteMessageCollector::Text(ref t) => t.len(), + IncompleteMessageCollector::Binary(ref b) => b.len(), + } + } + /// Add more data to an existing message. - pub fn extend>(&mut self, tail: T) -> Result<()> { + pub fn extend>(&mut self, tail: T, size_limit: Option) -> Result<()> { + // Always have a max size. This ensures an error in case of concatenating two buffers + // of more than `usize::max_value()` bytes in total. + let max_size = size_limit.unwrap_or_else(usize::max_value); + let my_size = self.len(); + let portion_size = tail.as_ref().len(); + // Be careful about integer overflows here. + if my_size > max_size || portion_size > max_size - my_size { + return Err(Error::Capacity( + format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into() + )) + } + match self.collector { IncompleteMessageCollector::Binary(ref mut v) => { v.extend(tail.as_ref()); @@ -117,6 +143,7 @@ impl IncompleteMessage { } } } + /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 869477f..8b3de04 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -33,12 +33,23 @@ pub struct WebSocketConfig { /// means here that the size of the queue is unlimited. The default value is the unlimited /// queue. pub max_send_queue: Option, + /// The maximum size of a message. `None` means no size limit. The default value is 64 megabytes + /// which should be reasonably big for all normal use-cases but small enough to prevent + /// memory eating by a malicious user. + pub max_message_size: Option, + /// The maximum size of a single message frame. `None` means no size limit. The limit is for + /// frame payload NOT including the frame header. The default value is 16 megabytes which should + /// be reasonably big for all normal use-cases but small enough to prevent memory eating + /// by a malicious user. + pub max_frame_size: Option, } impl Default for WebSocketConfig { fn default() -> Self { WebSocketConfig { max_send_queue: None, + max_message_size: Some(64 << 20), + max_frame_size: Some(16 << 20), } } } @@ -98,6 +109,13 @@ impl WebSocket { self.socket.get_mut() } + /// Change the configuration. + pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { + set_func(&mut self.config) + } +} + +impl WebSocket { /// Convert a frame socket into a WebSocket. fn from_frame_socket( socket: FrameSocket, @@ -105,13 +123,13 @@ impl WebSocket { config: Option ) -> Self { WebSocket { - role: role, - socket: socket, + role, + socket, state: WebSocketState::Active, incomplete: None, send_queue: VecDeque::new(), pong: None, - config: config.unwrap_or_else(|| WebSocketConfig::default()), + config: config.unwrap_or_else(WebSocketConfig::default), } } } @@ -145,10 +163,14 @@ impl WebSocket { /// Note that only the last pong frame is stored to be sent, and only the /// most recent pong frame is sent if multiple pong frames are queued. pub fn write_message(&mut self, message: Message) -> Result<()> { - // Try to make some room for the new message - self.write_pending().no_block()?; - if let Some(max_send_queue) = self.config.max_send_queue { + if self.send_queue.len() >= max_send_queue { + // Try to make some room for the new message. + // Do not return here if write would block, ignore WouldBlock silently + // since we must queue the message anyway. + self.write_pending().no_block()?; + } + if self.send_queue.len() >= max_send_queue { return Err(Error::SendQueueFull(message)); } @@ -167,8 +189,9 @@ impl WebSocket { return self.write_pending() } }; + self.send_queue.push_back(frame); - Ok(()) + self.write_pending() } /// Close the connection. @@ -229,15 +252,18 @@ impl WebSocket { impl WebSocket { /// Try to decode one message frame. May return None. fn read_message_frame(&mut self) -> Result> { - if let Some(mut frame) = self.socket.read_frame()? { + if let Some(mut frame) = self.socket.read_frame(self.config.max_frame_size)? { // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of // the negotiated extensions defines the meaning of such a nonzero // value, the receiving endpoint MUST _Fail the WebSocket // Connection_. - if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { - return Err(Error::Protocol("Reserved bits are non-zero".into())) + { + let hdr = frame.header(); + if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { + return Err(Error::Protocol("Reserved bits are non-zero".into())) + } } match self.role { @@ -245,7 +271,7 @@ impl WebSocket { if frame.is_masked() { // A server MUST remove masking for data frames received from a client // as described in Section 5.3. (RFC 6455) - frame.remove_mask() + frame.apply_mask() } else { // The server MUST close the connection upon receiving a // frame that is not masked. (RFC 6455) @@ -260,13 +286,13 @@ impl WebSocket { } } - match frame.opcode() { + match frame.header().opcode { OpCode::Control(ctl) => { match ctl { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) - _ if !frame.is_final() => { + _ if !frame.header().is_final => { Err(Error::Protocol("Fragmented control frame".into())) } _ if frame.payload().len() > 125 => { @@ -299,12 +325,11 @@ impl WebSocket { } OpCode::Data(data) => { - let fin = frame.is_final(); + let fin = frame.header().is_final; match data { OpData::Continue => { if let Some(ref mut msg) = self.incomplete { - // TODO if msg too big - msg.extend(frame.into_data())?; + msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol("Continue frame but nothing to continue".into())) } @@ -327,7 +352,7 @@ impl WebSocket { _ => panic!("Bug: message is not text nor binary"), }; let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data())?; + m.extend(frame.into_data(), self.config.max_message_size)?; m }; if fin { @@ -414,7 +439,7 @@ impl WebSocket { Role::Client => { // 5. If the data is being sent by the client, the frame(s) MUST be // masked as defined in Section 5.3. (RFC 6455) - frame.set_mask(); + frame.set_random_mask(); } } let res = self.socket.write_frame(frame); @@ -470,7 +495,7 @@ impl WebSocketState { #[cfg(test)] mod tests { - use super::{WebSocket, Role, Message}; + use super::{WebSocket, Role, Message, WebSocketConfig}; use std::io; use std::io::Cursor; @@ -512,4 +537,38 @@ mod tests { assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); } + + #[test] + fn size_limiting_text_fragmented() { + let incoming = Cursor::new(vec![ + 0x01, 0x07, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, + 0x80, 0x06, + 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, + ]); + let limit = WebSocketConfig { + max_message_size: Some(10), + .. WebSocketConfig::default() + }; + let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); + assert_eq!(socket.read_message().unwrap_err().to_string(), + "Space limit exceeded: Message too big: 7 + 6 > 10" + ); + } + + #[test] + fn size_limiting_binary() { + let incoming = Cursor::new(vec![ + 0x82, 0x03, + 0x01, 0x02, 0x03, + ]); + let limit = WebSocketConfig { + max_message_size: Some(2), + .. WebSocketConfig::default() + }; + let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); + assert_eq!(socket.read_message().unwrap_err().to_string(), + "Space limit exceeded: Message too big: 0 + 3 > 2" + ); + } }