diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 3997803..a32fb88 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -110,6 +110,12 @@ impl WebSocket { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { self.context.set_config(set_func) } + + /// Consume the websocket, returning the content of its message queue, + /// from most to least recently queued. + pub fn drain(self) -> Vec { + self.context.drain() + } } impl WebSocket { @@ -370,6 +376,22 @@ impl WebSocketContext { } self.write_pending(stream) } + + /// Consume the websocket context, returning its queued messages, + /// from most to least recently queued. + pub fn drain(mut self) -> Vec { + let mut messages = Vec::with_capacity(self.send_queue.len()); + + let send_queue = replace(&mut self.send_queue, VecDeque::new()); + for frame in send_queue.into_iter().rev() { + // This should not error since we created those frames. + if let Some(message) = self._read_message_frame(frame).unwrap() { + messages.push(message); + } + } + + messages + } } impl WebSocketContext { @@ -378,130 +400,134 @@ impl WebSocketContext { where Stream: Read + Write, { - if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { - if !self.state.can_read() { - return Err(Error::Protocol( - "Remote sent frame after having sent a Close Frame".into(), - )); - } - // MUST be 0 unless an extension is negotiated that defines meanings - // for non-zero values. If a nonzero value is received and none of - // the negotiated extensions defines the meaning of such a nonzero - // value, the receiving endpoint MUST _Fail the WebSocket - // Connection_. - { - let hdr = frame.header(); - if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol("Reserved bits are non-zero".into())); + if let Some(frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { + self._read_message_frame(frame) + } else { + // Connection closed by peer + match replace(&mut self.state, WebSocketState::Terminated) { + WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { + Err(Error::ConnectionClosed) } + _ => Err(Error::Protocol( + "Connection reset without closing handshake".into(), + )), } + } + } - match self.role { - Role::Server => { - if frame.is_masked() { - // A server MUST remove masking for data frames received from a client - // as described in Section 5.3. (RFC 6455) - frame.apply_mask() - } else { - // The server MUST close the connection upon receiving a - // frame that is not masked. (RFC 6455) - return Err(Error::Protocol( - "Received an unmasked frame from client".into(), - )); - } + fn _read_message_frame(&mut self, mut frame: Frame) -> Result> { + if !self.state.can_read() { + return Err(Error::Protocol( + "Remote sent frame after having sent a Close Frame".into(), + )); + } + // MUST be 0 unless an extension is negotiated that defines meanings + // for non-zero values. If a nonzero value is received and none of + // the negotiated extensions defines the meaning of such a nonzero + // value, the receiving endpoint MUST _Fail the WebSocket + // Connection_. + { + let hdr = frame.header(); + if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { + return Err(Error::Protocol("Reserved bits are non-zero".into())); + } + } + + match self.role { + Role::Server => { + if frame.is_masked() { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + frame.apply_mask() + } else { + // The server MUST close the connection upon receiving a + // frame that is not masked. (RFC 6455) + return Err(Error::Protocol( + "Received an unmasked frame from client".into(), + )); } - Role::Client => { - if frame.is_masked() { - // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol( - "Received a masked frame from server".into(), - )); - } + } + Role::Client => { + if frame.is_masked() { + // A client MUST close a connection if it detects a masked frame. (RFC 6455) + return Err(Error::Protocol( + "Received a masked frame from server".into(), + )); } } + } - match frame.header().opcode { - OpCode::Control(ctl) => { - match ctl { - // All control frames MUST have a payload length of 125 bytes or less - // and MUST NOT be fragmented. (RFC 6455) - _ if !frame.header().is_final => { - Err(Error::Protocol("Fragmented control frame".into())) - } - _ if frame.payload().len() > 125 => { - Err(Error::Protocol("Control frame too big".into())) - } - OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), - OpCtl::Reserved(i) => Err(Error::Protocol( - format!("Unknown control frame type {}", i).into(), - )), - OpCtl::Ping => { - let data = frame.into_data(); - // No ping processing after we sent a close frame. - if self.state.is_active() { - self.pong = Some(Frame::pong(data.clone())); - } - Ok(Some(Message::Ping(data))) + match frame.header().opcode { + OpCode::Control(ctl) => { + match ctl { + // All control frames MUST have a payload length of 125 bytes or less + // and MUST NOT be fragmented. (RFC 6455) + _ if !frame.header().is_final => { + Err(Error::Protocol("Fragmented control frame".into())) + } + _ if frame.payload().len() > 125 => { + Err(Error::Protocol("Control frame too big".into())) + } + OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), + OpCtl::Reserved(i) => Err(Error::Protocol( + format!("Unknown control frame type {}", i).into(), + )), + OpCtl::Ping => { + let data = frame.into_data(); + // No ping processing after we sent a close frame. + if self.state.is_active() { + self.pong = Some(Frame::pong(data.clone())); } - OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), + Ok(Some(Message::Ping(data))) } + OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), } + } - OpCode::Data(data) => { - let fin = frame.header().is_final; - match data { - OpData::Continue => { - if let Some(ref mut msg) = self.incomplete { - msg.extend(frame.into_data(), self.config.max_message_size)?; - } else { - return Err(Error::Protocol( - "Continue frame but nothing to continue".into(), - )); - } - if fin { - Ok(Some(self.incomplete.take().unwrap().complete()?)) - } else { - Ok(None) - } + OpCode::Data(data) => { + let fin = frame.header().is_final; + match data { + OpData::Continue => { + if let Some(ref mut msg) = self.incomplete { + msg.extend(frame.into_data(), self.config.max_message_size)?; + } else { + return Err(Error::Protocol( + "Continue frame but nothing to continue".into(), + )); } - c if self.incomplete.is_some() => Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into(), - )), - OpData::Text | OpData::Binary => { - let msg = { - let message_type = match data { - OpData::Text => IncompleteMessageType::Text, - OpData::Binary => IncompleteMessageType::Binary, - _ => panic!("Bug: message is not text nor binary"), - }; - let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data(), self.config.max_message_size)?; - m + if fin { + Ok(Some(self.incomplete.take().unwrap().complete()?)) + } else { + Ok(None) + } + } + c if self.incomplete.is_some() => Err(Error::Protocol( + format!("Received {} while waiting for more fragments", c).into(), + )), + OpData::Text | OpData::Binary => { + let msg = { + let message_type = match data { + OpData::Text => IncompleteMessageType::Text, + OpData::Binary => IncompleteMessageType::Binary, + _ => panic!("Bug: message is not text nor binary"), }; - if fin { - Ok(Some(msg.complete()?)) - } else { - self.incomplete = Some(msg); - Ok(None) - } + let mut m = IncompleteMessage::new(message_type); + m.extend(frame.into_data(), self.config.max_message_size)?; + m + }; + if fin { + Ok(Some(msg.complete()?)) + } else { + self.incomplete = Some(msg); + Ok(None) } - OpData::Reserved(i) => Err(Error::Protocol( - format!("Unknown data frame type {}", i).into(), - )), } + OpData::Reserved(i) => Err(Error::Protocol( + format!("Unknown data frame type {}", i).into(), + )), } - } // match opcode - } else { - // Connection closed by peer - match replace(&mut self.state, WebSocketState::Terminated) { - WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { - Err(Error::ConnectionClosed) - } - _ => Err(Error::Protocol( - "Connection reset without closing handshake".into(), - )), } - } + } // match opcode } /// Received a close frame. Tells if we need to return a close frame to the user.