pull/144/head
SirCipher 5 years ago
parent 3cf0b83949
commit 2744d1be4f
  1. 128
      src/protocol/mod.rs

@ -18,8 +18,8 @@ use self::message::IncompleteMessage;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension; use crate::extensions::WebSocketExtension;
use crate::util::NonBlockingResult;
use crate::protocol::frame::coding::Data; use crate::protocol::frame::coding::Data;
use crate::util::NonBlockingResult;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
@ -35,8 +35,8 @@ pub enum Role {
/// The configuration for WebSocket connection. /// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig<E = UncompressedExt> pub struct WebSocketConfig<E = UncompressedExt>
where where
E: WebSocketExtension, E: WebSocketExtension,
{ {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None` /// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited /// means here that the size of the queue is unlimited. The default value is the unlimited
@ -52,8 +52,8 @@ pub struct WebSocketConfig<E = UncompressedExt>
} }
impl<E> Default for WebSocketConfig<E> impl<E> Default for WebSocketConfig<E>
where where
E: WebSocketExtension, E: WebSocketExtension,
{ {
fn default() -> Self { fn default() -> Self {
WebSocketConfig { WebSocketConfig {
@ -65,8 +65,8 @@ impl<E> Default for WebSocketConfig<E>
} }
impl<E> WebSocketConfig<E> impl<E> WebSocketConfig<E>
where where
E: WebSocketExtension, E: WebSocketExtension,
{ {
/// Creates a `WebSocketConfig` instance using the default configuration and the provided /// Creates a `WebSocketConfig` instance using the default configuration and the provided
/// encoder for new connections. /// encoder for new connections.
@ -85,8 +85,8 @@ impl<E> WebSocketConfig<E>
/// It may be created by calling `connect`, `accept` or `client` functions. /// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocket<Stream, Ext> pub struct WebSocket<Stream, Ext>
where where
Ext: WebSocketExtension, Ext: WebSocketExtension,
{ {
/// The underlying socket. /// The underlying socket.
socket: Stream, socket: Stream,
@ -95,8 +95,8 @@ pub struct WebSocket<Stream, Ext>
} }
impl<Stream, Ext> WebSocket<Stream, Ext> impl<Stream, Ext> WebSocket<Stream, Ext>
where where
Ext: WebSocketExtension, Ext: WebSocketExtension,
{ {
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
/// ///
@ -168,9 +168,9 @@ impl<Stream, Ext> WebSocket<Stream, Ext>
} }
impl<Stream, Ext> WebSocket<Stream, Ext> impl<Stream, Ext> WebSocket<Stream, Ext>
where where
Stream: Read + Write, Stream: Read + Write,
Ext: WebSocketExtension, Ext: WebSocketExtension,
{ {
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
@ -255,8 +255,8 @@ impl<Stream, Ext> WebSocket<Stream, Ext>
/// A context for managing WebSocket stream. /// A context for managing WebSocket stream.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocketContext<Ext = UncompressedExt> pub struct WebSocketContext<Ext = UncompressedExt>
where where
Ext: WebSocketExtension, Ext: WebSocketExtension,
{ {
/// Server or client? /// Server or client?
role: Role, role: Role,
@ -275,8 +275,8 @@ pub struct WebSocketContext<Ext = UncompressedExt>
} }
impl<Ext> WebSocketContext<Ext> impl<Ext> WebSocketContext<Ext>
where where
Ext: WebSocketExtension, Ext: WebSocketExtension,
{ {
/// Create a WebSocket context that manages a post-handshake stream. /// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig<Ext>>) -> Self { pub fn new(role: Role, config: Option<WebSocketConfig<Ext>>) -> Self {
@ -335,8 +335,8 @@ impl<Ext> WebSocketContext<Ext>
/// This function sends pong and close responses automatically. /// This function sends pong and close responses automatically.
/// However, it never blocks on write. /// However, it never blocks on write.
pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message> pub fn read_message<Stream>(&mut self, stream: &mut Stream) -> Result<Message>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// Do not read from already closed connections. // Do not read from already closed connections.
self.state.check_active()?; self.state.check_active()?;
@ -363,8 +363,8 @@ impl<Ext> WebSocketContext<Ext>
/// Note that only the last pong frame is stored to be sent, and only the /// Note that only the last pong frame is stored to be sent, and only the
/// most recent pong frame is sent if multiple pong frames are queued. /// most recent pong frame is sent if multiple pong frames are queued.
pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()> pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// When terminated, return AlreadyClosed. // When terminated, return AlreadyClosed.
self.state.check_active()?; self.state.check_active()?;
@ -406,8 +406,8 @@ impl<Ext> WebSocketContext<Ext>
/// Flush the pending send queue. /// Flush the pending send queue.
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()> pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
// First, make sure we have no pending frame sending. // First, make sure we have no pending frame sending.
self.frame.write_pending(stream)?; self.frame.write_pending(stream)?;
@ -449,8 +449,8 @@ impl<Ext> WebSocketContext<Ext>
/// There is no need to call it again. Calling this function is /// There is no need to call it again. Calling this function is
/// the same as calling `write(Message::Close(..))`. /// the same as calling `write(Message::Close(..))`.
pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()> pub fn close<Stream>(&mut self, stream: &mut Stream, code: Option<CloseFrame>) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
if let WebSocketState::Active = self.state { if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs; self.state = WebSocketState::ClosedByUs;
@ -464,8 +464,8 @@ impl<Ext> WebSocketContext<Ext>
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
fn read_message_frame<Stream>(&mut self, stream: &mut Stream) -> Result<Option<Message>> fn read_message_frame<Stream>(&mut self, stream: &mut Stream) -> Result<Option<Message>>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
if let Some(mut frame) = self if let Some(mut frame) = self
.frame .frame
@ -590,8 +590,8 @@ impl<Ext> WebSocketContext<Ext>
/// Send a single pending frame. /// Send a single pending frame.
fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()> fn send_one_frame<Stream>(&mut self, stream: &mut Stream, mut frame: Frame) -> Result<()>
where where
Stream: Read + Write, Stream: Read + Write,
{ {
match self.role { match self.role {
Role::Server => {} Role::Server => {}
@ -611,16 +611,31 @@ impl<Ext> WebSocketContext<Ext>
let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value); let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value);
if frame.payload().len() > max_frame_size { if frame.payload().len() > max_frame_size {
let mut chunks = frame.payload().chunks(self.config.max_frame_size.unwrap_or_else(usize::max_value)).peekable(); let mut chunks = frame
let data_frame = Frame::message(Vec::from(chunks.next().unwrap()), frame.header().opcode, false); .payload()
self.frame.write_frame(stream, data_frame).check_connection_reset(self.state)?; .chunks(self.config.max_frame_size.unwrap_or_else(usize::max_value))
.peekable();
let data_frame = Frame::message(
Vec::from(chunks.next().unwrap()),
frame.header().opcode,
false,
);
self.frame
.write_frame(stream, data_frame)
.check_connection_reset(self.state)?;
while let Some(chunk) = chunks.next() { while let Some(chunk) = chunks.next() {
let frame = Frame::message(Vec::from(chunk), OpCode::Data(Data::Continue), chunks.peek().is_none()); let frame = Frame::message(
Vec::from(chunk),
OpCode::Data(Data::Continue),
chunks.peek().is_none(),
);
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state)?; self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)?;
} }
Ok(()) Ok(())
@ -701,10 +716,10 @@ mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::uncompressed::UncompressedExt;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;
use crate::protocol::frame::Frame;
use crate::protocol::frame::coding::{OpCode, Data};
struct WriteMoc<Stream>(Stream); struct WriteMoc<Stream>(Stream);
@ -788,16 +803,23 @@ mod tests {
encoder: UncompressedExt::new(Some(max_message_size)), encoder: UncompressedExt::new(Some(max_message_size)),
}; };
let mut socket = WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit)); let mut socket =
WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit));
socket.write_message(Message::text(input_str)).unwrap(); socket.write_message(Message::text(input_str)).unwrap();
socket.socket.set_position(0); socket.socket.set_position(0);
let WebSocket { mut socket, mut context } = socket; let WebSocket {
mut socket,
mut context,
} = socket;
let vec = input_str.chars().collect::<Vec<_>>(); let vec = input_str.chars().collect::<Vec<_>>();
let mut iter = vec.chunks(max_message_size).map(|c| c.iter().collect::<String>()) let mut iter = vec
.into_iter().peekable(); .chunks(max_message_size)
.map(|c| c.iter().collect::<String>())
.into_iter()
.peekable();
let frame_eq = |expected: Frame, actual: Frame| { let frame_eq = |expected: Frame, actual: Frame| {
assert_eq!(expected.payload(), actual.payload()); assert_eq!(expected.payload(), actual.payload());
@ -806,11 +828,29 @@ mod tests {
}; };
let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false); let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false);
frame_eq(expected, context.frame.read_frame(&mut socket, Some(max_message_size)).unwrap().unwrap()); frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
while let Some(chars) = iter.next() { while let Some(chars) = iter.next() {
let expected = Frame::message(chars.into(), OpCode::Data(Data::Continue), iter.peek().is_none()); let expected = Frame::message(
frame_eq(expected, context.frame.read_frame(&mut socket, Some(max_message_size)).unwrap().unwrap()); chars.into(),
OpCode::Data(Data::Continue),
iter.peek().is_none(),
);
frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
} }
} }
} }

Loading…
Cancel
Save