diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 3c99545..5c34499 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -32,12 +32,14 @@ fn main() { for stream in server.incoming() { spawn(move || match stream { - Ok(stream) => if let Err(err) = handle_client(stream) { - match err { - Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), - e => error!("test: {}", e), + Ok(stream) => { + if let Err(err) = handle_client(stream) { + match err { + Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), + e => error!("test: {}", e), + } } - }, + } Err(e) => error!("Error accepting stream: {}", e), }); } diff --git a/src/error.rs b/src/error.rs index ac2de26..bafc099 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,7 +12,7 @@ use std::string; use http; use httparse; -use crate::protocol::Message; +use crate::protocol::EitherMessage; #[cfg(feature = "tls")] pub mod tls { @@ -59,7 +59,7 @@ pub enum Error { /// Protocol violation. Protocol(Cow<'static, str>), /// Message send queue full. - SendQueueFull(Message), + SendQueueFull(EitherMessage), /// UTF coding error Utf8, /// Invalid URL. diff --git a/src/handshake/client.rs b/src/handshake/client.rs index c1721f5..954bff7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -105,16 +105,18 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority() + let authority = uri + .authority() .ok_or_else(|| Error::Url("No host name in the URL".into()))? .as_str(); - let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ + let host = if let Some(idx) = authority.find('@') { + // handle possible name:password@ authority.split_at(idx + 1).1 } else { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())) + return Err(Error::Url("URL contains empty host name".into())); } write!( @@ -261,8 +263,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { use super::super::machine::TryParse; - use crate::client::IntoClientRequest; use super::{generate_key, generate_request, Response}; + use crate::client::IntoClientRequest; #[test] fn random_keys() { @@ -299,7 +301,9 @@ mod tests { #[test] fn request_formatting_with_host() { - let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ @@ -316,7 +320,9 @@ mod tests { #[test] fn request_formatting_with_at() { - let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://user:pass@localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 5992932..0a6ffb6 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,10 +1,13 @@ use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; +use bytes::{Bytes, BytesMut}; use log::*; use std::borrow::Cow; +use std::convert::{TryFrom, TryInto}; use std::default::Default; use std::fmt; use std::io::{Cursor, ErrorKind, Read, Write}; use std::result::Result as StdResult; +use std::str; use std::string::{FromUtf8Error, String}; use super::coding::{CloseCode, Control, Data, OpCode}; @@ -205,11 +208,83 @@ impl FrameHeader { } } +/// A binary payload that might or might not be shared. +#[derive(Debug, Clone)] +pub enum Payload { + Bytes(Vec), + ShBytes(Bytes), +} + +impl Payload { + pub fn len(&self) -> usize { + match self { + Self::Bytes(bytes) => bytes.len(), + Self::ShBytes(bytes) => bytes.len(), + } + } + + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Bytes(bytes) => bytes.as_slice(), + Self::ShBytes(bytes) => bytes.as_ref(), + } + } + + pub fn unwrap_bytes(self) -> Vec { + match self { + Self::Bytes(bytes) => bytes, + _ => panic!("expected variant `Payload::Bytes`"), + } + } +} + +impl TryFrom for String { + type Error = FromUtf8Error; + + fn try_from(payload: Payload) -> std::result::Result { + let vec = match payload { + Payload::Bytes(bytes) => bytes, + Payload::ShBytes(bytes) => bytes.as_ref().to_owned(), + }; + String::from_utf8(vec) + } +} + +impl From> for Payload { + fn from(bytes: Vec) -> Self { + Self::Bytes(bytes) + } +} + +impl From<&[u8]> for Payload { + fn from(bytes: &[u8]) -> Self { + bytes.to_owned().into() + } +} + +impl From for Payload { + fn from(string: String) -> Self { + Self::Bytes(string.into()) + } +} + +impl From<&str> for Payload { + fn from(string: &str) -> Self { + string.to_owned().into() + } +} + +impl From for Payload { + fn from(bytes: Bytes) -> Self { + Self::ShBytes(bytes) + } +} + /// A struct representing a WebSocket frame. #[derive(Debug, Clone)] pub struct Frame { header: FrameHeader, - payload: Vec, + payload: Payload, } impl Frame { @@ -241,14 +316,8 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] - pub fn payload(&self) -> &Vec { - &self.payload - } - - /// Get a mutable reference to the frame's payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut Vec { - &mut self.payload + pub fn payload(&self) -> &[u8] { + self.payload.as_bytes() } /// Test whether the frame is masked. @@ -271,20 +340,27 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - apply_mask(&mut self.payload, mask) + match &mut self.payload { + Payload::Bytes(bytes) => apply_mask(bytes, mask), + Payload::ShBytes(bytes) => { + let mut bytes_mut = BytesMut::from(bytes.as_ref()); + apply_mask(&mut bytes_mut, mask); + *bytes = bytes_mut.freeze(); + } + } } } /// Consume the frame into its payload as binary. #[inline] - pub fn into_data(self) -> Vec { + pub fn into_payload(self) -> Payload { self.payload } /// Consume the frame into its payload as string. #[inline] pub fn into_string(self) -> StdResult { - String::from_utf8(self.payload) + self.payload.try_into() } /// Consume the frame into a closing frame. @@ -294,10 +370,16 @@ impl Frame { 0 => Ok(None), 1 => Err(Error::Protocol("Invalid close sequence".into())), _ => { - let mut data = self.payload; - let code = NetworkEndian::read_u16(&data[0..2]).into(); - data.drain(0..2); - let text = String::from_utf8(data)?; + let data = self.payload; + let code = NetworkEndian::read_u16(&data.as_bytes()[0..2]).into(); + let bytes = match data { + Payload::Bytes(mut bytes) => { + bytes.drain(0..2); + bytes + } + Payload::ShBytes(bytes) => bytes.as_ref()[2..].to_owned(), + }; + let text = String::from_utf8(bytes)?; Ok(Some(CloseFrame { code, reason: text.into(), @@ -308,7 +390,10 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + pub fn message

(payload: P, opcode: OpCode, is_final: bool) -> Frame + where + P: Into, + { debug_assert!( match opcode { OpCode::Data(_) => true, @@ -323,7 +408,7 @@ impl Frame { opcode, ..FrameHeader::default() }, - payload: data, + payload: payload.into(), } } @@ -335,7 +420,7 @@ impl Frame { opcode: OpCode::Control(Control::Pong), ..FrameHeader::default() }, - payload: data, + payload: data.into(), } } @@ -347,7 +432,7 @@ impl Frame { opcode: OpCode::Control(Control::Ping), ..FrameHeader::default() }, - payload: data, + payload: data.into(), } } @@ -365,12 +450,12 @@ impl Frame { Frame { header: FrameHeader::default(), - payload, + payload: payload.into(), } } /// Create a frame from given header and data. - pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { + pub fn from_payload(header: FrameHeader, payload: Payload) -> Self { Frame { header, payload } } @@ -405,6 +490,7 @@ payload: 0x{} self.len(), self.payload.len(), self.payload + .as_bytes() .iter() .map(|byte| format!("{:x}", byte)) .collect::() @@ -476,11 +562,11 @@ mod tests { Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); assert_eq!(length, 7); - let mut payload = Vec::new(); - raw.read_to_end(&mut payload).unwrap(); - let frame = Frame::from_payload(header, payload); + let mut bytes = Vec::new(); + raw.read_to_end(&mut bytes).unwrap(); + let frame = Frame::from_payload(header, bytes.into()); assert_eq!( - frame.into_data(), + frame.into_payload().unwrap_bytes(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); } @@ -495,7 +581,7 @@ mod tests { #[test] fn display() { - let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); + let f = Frame::message("hi there", OpCode::Data(Data::Text), true); let view = format!("{}", f); assert!(view.contains("payload:")); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 6756f0a..c74446a 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -172,7 +172,7 @@ impl FrameCodec { let (header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, payload); + let frame = Frame::from_payload(header, payload.into()); trace!("received frame {}", frame); Ok(Some(frame)) } @@ -228,11 +228,11 @@ mod tests { let mut sock = FrameSocket::new(raw); assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), + sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), + sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(), vec![0x03, 0x02, 0x01] ); assert!(sock.read_frame(None).unwrap().is_none()); @@ -246,7 +246,7 @@ mod tests { let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( - sock.read_frame(None).unwrap().unwrap().into_data(), + sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 5df4ba0..9e9c261 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -6,6 +6,8 @@ use std::str; use super::frame::CloseFrame; use crate::error::{Error, Result}; +use bytes::Bytes; + mod string_collect { use utf8; @@ -168,6 +170,35 @@ pub enum IncompleteMessageType { Binary, } +/// Either a owned or shard `WebSocket` message. +#[derive(Debug, Clone)] +pub enum EitherMessage { + /// A owned `WebSocket` message. + Message(Message), + + /// A shared `WebSocket` message. + SharedMessage(SharedMessage), +} + +impl From for EitherMessage { + fn from(m: Message) -> Self { + EitherMessage::Message(m) + } +} + +impl From for EitherMessage { + fn from(m: SharedMessage) -> Self { + EitherMessage::SharedMessage(m) + } +} + +/// A shared websocket message. +#[derive(Debug, Clone)] +pub enum SharedMessage { + /// A shared binary `WebSocket` message. + Binary(Bytes), +} + /// An enum representing the various forms of a WebSocket message. #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 7627b46..76fcd13 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -5,7 +5,7 @@ pub mod frame; mod message; pub use self::frame::CloseFrame; -pub use self::message::Message; +pub use self::message::{Message, SharedMessage, EitherMessage}; use log::*; use std::collections::VecDeque; @@ -162,6 +162,13 @@ impl WebSocket { self.context.write_message(&mut self.socket, message) } + /// Send a shared message to stream, if possible. + /// + /// This is essentially the same method as `write_message` but for shared messages. + pub fn write_shared_message(&mut self, message: SharedMessage) -> Result<()> { + self.context.write_shared_message(&mut self.socket, message) + } + /// Flush the pending send queue. pub fn write_pending(&mut self) -> Result<()> { self.context.write_pending(&mut self.socket) @@ -272,6 +279,31 @@ impl WebSocketContext { /// Note that only the last pong frame is stored to be sent, and only the /// most recent pong frame is sent if multiple pong frames are queued. pub fn write_message(&mut self, stream: &mut Stream, message: Message) -> Result<()> + where + Stream: Read + Write, + { + self.write_either_message(stream, message.into()) + } + + /// Send a shared message to the provided stream, if possible. + /// + /// This is the same method than `write_message`, but for shared messages instead. + pub fn write_shared_message( + &mut self, + stream: &mut Stream, + message: SharedMessage, + ) -> Result<()> + where + Stream: Read + Write, + { + self.write_either_message(stream, message.into()) + } + + fn write_either_message( + &mut self, + stream: &mut Stream, + message: EitherMessage, + ) -> Result<()> where Stream: Read + Write, { @@ -299,14 +331,19 @@ impl WebSocketContext { } let frame = match message { - Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), - Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), - Message::Ping(data) => Frame::ping(data), - Message::Pong(data) => { - self.pong = Some(Frame::pong(data)); - return self.write_pending(stream); - } - Message::Close(code) => return self.close(stream, code), + EitherMessage::Message(message) => match message { + Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true), + Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), + Message::Ping(data) => Frame::ping(data), + Message::Pong(data) => { + self.pong = Some(Frame::pong(data)); + return self.write_pending(stream); + } + Message::Close(code) => return self.close(stream, code), + }, + EitherMessage::SharedMessage(message) => match message { + SharedMessage::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), + }, }; self.send_queue.push_back(frame); @@ -440,14 +477,14 @@ impl WebSocketContext { format!("Unknown control frame type {}", i).into(), )), OpCtl::Ping => { - let data = frame.into_data(); + let data = frame.into_payload().unwrap_bytes(); // No ping processing after we sent a close frame. if self.state.is_active() { self.pong = Some(Frame::pong(data.clone())); } Ok(Some(Message::Ping(data))) } - OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), + OpCtl::Pong => Ok(Some(Message::Pong(frame.into_payload().unwrap_bytes()))), } } @@ -456,7 +493,10 @@ impl WebSocketContext { match data { OpData::Continue => { if let Some(ref mut msg) = self.incomplete { - msg.extend(frame.into_data(), self.config.max_message_size)?; + msg.extend( + frame.into_payload().as_bytes(), + self.config.max_message_size, + )?; } else { return Err(Error::Protocol( "Continue frame but nothing to continue".into(), @@ -479,7 +519,10 @@ impl WebSocketContext { _ => panic!("Bug: message is not text nor binary"), }; let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data(), self.config.max_message_size)?; + m.extend( + frame.into_payload().as_bytes(), + self.config.max_message_size, + )?; m }; if fin { diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index d95ee81..b86bfe0 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,15 +1,15 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. -use std::net::{TcpStream, TcpListener}; +use std::net::{TcpListener, TcpStream}; use std::process::exit; use std::thread::{sleep, spawn}; use std::time::Duration; -use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream}; use native_tls::TlsStream; -use url::Url; use net2::TcpStreamExt; +use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; +use url::Url; type Sock = WebSocket>>; @@ -26,8 +26,8 @@ where exit(1); }); - let server = TcpListener::bind(("127.0.0.1", port)) - .expect("Can't listen, is port already in use?"); + let server = + TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?"); let client_thread = spawn(move || { let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap()) @@ -46,7 +46,8 @@ where #[test] fn test_server_close() { - do_test(3012, + do_test( + 3012, |mut cli_sock| { cli_sock .write_message(Message::Text("Hello WebSocket".into())) @@ -75,12 +76,14 @@ fn test_server_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); + }, + ); } #[test] fn test_evil_server_close() { - do_test(3013, + do_test( + 3013, |mut cli_sock| { cli_sock .write_message(Message::Text("Hello WebSocket".into())) @@ -106,14 +109,19 @@ fn test_evil_server_close() { let message = srv_sock.read_message().unwrap(); // receive acknowledgement assert!(message.is_close()); // and now just drop the connection without waiting for `ConnectionClosed` - srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); + srv_sock + .get_mut() + .set_linger(Some(Duration::from_secs(0))) + .unwrap(); drop(srv_sock); - }); + }, + ); } #[test] fn test_client_close() { - do_test(3014, + do_test( + 3014, |mut cli_sock| { cli_sock .write_message(Message::Text("Hello WebSocket".into())) @@ -137,7 +145,9 @@ fn test_client_close() { let message = srv_sock.read_message().unwrap(); assert_eq!(message.into_data(), b"Hello WebSocket"); - srv_sock.write_message(Message::Text("From Server".into())).unwrap(); + srv_sock + .write_message(Message::Text("From Server".into())) + .unwrap(); let message = srv_sock.read_message().unwrap(); // receive close from client assert!(message.is_close()); @@ -147,6 +157,6 @@ fn test_client_close() { Error::ConnectionClosed => {} _ => panic!("unexpected error: {:?}", err), } - }); - + }, + ); }