diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index bb695e5..9af32c7 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -18,6 +18,7 @@ use self::message::IncompleteMessage; use crate::error::{Error, Result}; use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; +use crate::protocol::frame::coding::Data; use crate::util::NonBlockingResult; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; @@ -50,26 +51,26 @@ where pub encoder: E, } -impl Default for WebSocketConfig +impl Default for WebSocketConfig where - E: WebSocketExtension, + Ext: WebSocketExtension, { fn default() -> Self { WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: E::new(Some(MAX_MESSAGE_SIZE)), + encoder: Ext::new(Some(MAX_MESSAGE_SIZE)), } } } -impl WebSocketConfig +impl WebSocketConfig where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// Creates a `WebSocketConfig` instance using the default configuration and the provided /// encoder for new connections. - pub fn default_with_encoder(encoder: E) -> WebSocketConfig { + pub fn default_with_encoder(encoder: Ext) -> WebSocketConfig { WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), @@ -608,10 +609,39 @@ where }; } - trace!("Sending frame: {:?}", frame); - self.frame - .write_frame(stream, frame) - .check_connection_reset(self.state) + let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value); + if frame.payload().len() > max_frame_size { + let mut chunks = frame.payload().chunks(max_frame_size).peekable(); + let data_frame = Frame::message( + Vec::from(chunks.next().unwrap()), + frame.header().opcode, + false, + ); + self.frame + .write_frame(stream, data_frame) + .check_connection_reset(self.state)?; + + while let Some(chunk) = chunks.next() { + let frame = Frame::message( + Vec::from(chunk), + OpCode::Data(Data::Continue), + chunks.peek().is_none(), + ); + + trace!("Sending frame: {:?}", frame); + + self.frame + .write_frame(stream, frame) + .check_connection_reset(self.state)?; + } + + Ok(()) + } else { + trace!("Sending frame: {:?}", frame); + self.frame + .write_frame(stream, frame) + .check_connection_reset(self.state) + } } } @@ -683,6 +713,8 @@ mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; use crate::extensions::uncompressed::UncompressedExt; + use crate::protocol::frame::coding::{Data, OpCode}; + use crate::protocol::frame::Frame; use std::io; use std::io::Cursor; @@ -756,4 +788,66 @@ mod tests { "Space limit exceeded: Message too big: 0 + 3 > 2" ); } + + #[test] + fn fragmented_tx() { + let max_message_size = 2; + let input_str = "hello unit test"; + + let limit = WebSocketConfig { + max_send_queue: None, + max_frame_size: Some(2), + encoder: UncompressedExt::new(Some(max_message_size)), + }; + + let mut socket = + WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit)); + + socket.write_message(Message::text(input_str)).unwrap(); + socket.socket.set_position(0); + + let WebSocket { + mut socket, + mut context, + } = socket; + + let vec = input_str.chars().collect::>(); + let mut iter = vec + .chunks(max_message_size) + .map(|c| c.iter().collect::()) + .into_iter() + .peekable(); + + let frame_eq = |expected: Frame, actual: Frame| { + assert_eq!(expected.payload(), actual.payload()); + assert_eq!(expected.header().opcode, actual.header().opcode); + assert_eq!(expected.header().rsv1, actual.header().rsv1); + }; + + let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false); + frame_eq( + expected, + context + .frame + .read_frame(&mut socket, Some(max_message_size)) + .unwrap() + .unwrap(), + ); + + while let Some(chars) = iter.next() { + let expected = Frame::message( + chars.into(), + OpCode::Data(Data::Continue), + iter.peek().is_none(), + ); + frame_eq( + expected, + context + .frame + .read_frame(&mut socket, Some(max_message_size)) + .unwrap() + .unwrap(), + ); + } + } }