diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 56711d0..470a795 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -67,7 +67,9 @@ impl FrameSocket where Stream: Read { /// Read a frame from stream. - pub fn read_frame(&mut self) -> Result> { + pub fn read_frame(&mut self, max_size: Option) -> Result> { + let max_size = max_size.unwrap_or_else(usize::max_value); + let payload = loop { { let cursor = self.in_buffer.as_cursor_mut(); @@ -79,10 +81,11 @@ impl FrameSocket if let Some((_, ref length)) = self.header { let length = *length; - // Make sure `length` is not too big (fits into `usize`). - if length > usize::max_value() as u64 { + // Enforce frame size limit early and make sure `length` + // is not too big (fits into `usize`). + if length > max_size as u64 { return Err(Error::Capacity( - format!("Message length too big: {}", length).into() + format!("Message length too big: {} > {}", length, max_size).into() )) } @@ -160,11 +163,11 @@ mod tests { ]); let mut sock = FrameSocket::new(raw); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); - assert!(sock.read_frame().unwrap().is_none()); + assert!(sock.read_frame(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); assert_eq!(rest, vec![0x99]); @@ -176,7 +179,7 @@ mod tests { 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, ]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); - assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } @@ -204,6 +207,17 @@ mod tests { 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, ]); let mut sock = FrameSocket::new(raw); - let _ = sock.read_frame(); // should not crash + let _ = sock.read_frame(None); // should not crash + } + + #[test] + fn size_limit_hit() { + let raw = Cursor::new(vec![ + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ]); + let mut sock = FrameSocket::new(raw); + assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(), + "Space limit exceeded: Message length too big: 7 > 5" + ); } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8124e73..e35b9b7 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -37,6 +37,11 @@ pub struct WebSocketConfig { /// which should be reasonably big for all normal use-cases but small enough to prevent /// memory eating by a malicious user. pub max_message_size: Option, + /// The maximum size of a single message frame. `None` means no size limit. The limit is for + /// frame payload NOT including the frame header. The default value is 16 megabytes which should + /// be reasonably big for all normal use-cases but small enough to prevent memory eating + /// by a malicious user. + pub max_frame_size: Option, } impl Default for WebSocketConfig { @@ -44,6 +49,7 @@ impl Default for WebSocketConfig { WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), + max_frame_size: Some(16 << 20), } } } @@ -239,7 +245,7 @@ impl WebSocket { impl WebSocket { /// Try to decode one message frame. May return None. fn read_message_frame(&mut self) -> Result> { - if let Some(mut frame) = self.socket.read_frame()? { + if let Some(mut frame) = self.socket.read_frame(self.config.max_frame_size)? { // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of