Add support for shared websocket messages

This adds the `SharedMessage` enum as well as the `write_shared_message`
for both `Websocket` and the `WebsocketContext`.
pull/104/head
jean-airoldie 5 years ago
parent 010266e001
commit 4bef574d63
  1. 12
      examples/autobahn-server.rs
  2. 4
      src/error.rs
  3. 18
      src/handshake/client.rs
  4. 140
      src/protocol/frame/frame.rs
  5. 8
      src/protocol/frame/mod.rs
  6. 31
      src/protocol/message.rs
  7. 69
      src/protocol/mod.rs
  8. 38
      tests/connection_reset.rs

@ -32,12 +32,14 @@ fn main() {
for stream in server.incoming() {
spawn(move || match stream {
Ok(stream) => if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
Ok(stream) => {
if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
}
}
},
}
Err(e) => error!("Error accepting stream: {}", e),
});
}

@ -12,7 +12,7 @@ use std::string;
use http;
use httparse;
use crate::protocol::Message;
use crate::protocol::EitherMessage;
#[cfg(feature = "tls")]
pub mod tls {
@ -59,7 +59,7 @@ pub enum Error {
/// Protocol violation.
Protocol(Cow<'static, str>),
/// Message send queue full.
SendQueueFull(Message),
SendQueueFull(EitherMessage),
/// UTF coding error
Utf8,
/// Invalid URL.

@ -105,16 +105,18 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority()
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}
write!(
@ -261,8 +263,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
#[test]
fn random_keys() {
@ -299,7 +301,9 @@ mod tests {
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
@ -316,7 +320,9 @@ mod tests {
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://user:pass@localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\

@ -1,10 +1,13 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Bytes, BytesMut};
use log::*;
use std::borrow::Cow;
use std::convert::{TryFrom, TryInto};
use std::default::Default;
use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult;
use std::str;
use std::string::{FromUtf8Error, String};
use super::coding::{CloseCode, Control, Data, OpCode};
@ -205,11 +208,83 @@ impl FrameHeader {
}
}
/// A binary payload that might or might not be shared.
#[derive(Debug, Clone)]
pub enum Payload {
Bytes(Vec<u8>),
ShBytes(Bytes),
}
impl Payload {
pub fn len(&self) -> usize {
match self {
Self::Bytes(bytes) => bytes.len(),
Self::ShBytes(bytes) => bytes.len(),
}
}
pub fn as_bytes(&self) -> &[u8] {
match self {
Self::Bytes(bytes) => bytes.as_slice(),
Self::ShBytes(bytes) => bytes.as_ref(),
}
}
pub fn unwrap_bytes(self) -> Vec<u8> {
match self {
Self::Bytes(bytes) => bytes,
_ => panic!("expected variant `Payload::Bytes`"),
}
}
}
impl TryFrom<Payload> for String {
type Error = FromUtf8Error;
fn try_from(payload: Payload) -> std::result::Result<Self, Self::Error> {
let vec = match payload {
Payload::Bytes(bytes) => bytes,
Payload::ShBytes(bytes) => bytes.as_ref().to_owned(),
};
String::from_utf8(vec)
}
}
impl From<Vec<u8>> for Payload {
fn from(bytes: Vec<u8>) -> Self {
Self::Bytes(bytes)
}
}
impl From<&[u8]> for Payload {
fn from(bytes: &[u8]) -> Self {
bytes.to_owned().into()
}
}
impl From<String> for Payload {
fn from(string: String) -> Self {
Self::Bytes(string.into())
}
}
impl From<&str> for Payload {
fn from(string: &str) -> Self {
string.to_owned().into()
}
}
impl From<Bytes> for Payload {
fn from(bytes: Bytes) -> Self {
Self::ShBytes(bytes)
}
}
/// A struct representing a WebSocket frame.
#[derive(Debug, Clone)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
payload: Payload,
}
impl Frame {
@ -241,14 +316,8 @@ impl Frame {
/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &Vec<u8> {
&self.payload
}
/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload
pub fn payload(&self) -> &[u8] {
self.payload.as_bytes()
}
/// Test whether the frame is masked.
@ -271,20 +340,27 @@ impl Frame {
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask)
match &mut self.payload {
Payload::Bytes(bytes) => apply_mask(bytes, mask),
Payload::ShBytes(bytes) => {
let mut bytes_mut = BytesMut::from(bytes.as_ref());
apply_mask(&mut bytes_mut, mask);
*bytes = bytes_mut.freeze();
}
}
}
}
/// Consume the frame into its payload as binary.
#[inline]
pub fn into_data(self) -> Vec<u8> {
pub fn into_payload(self) -> Payload {
self.payload
}
/// Consume the frame into its payload as string.
#[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload)
self.payload.try_into()
}
/// Consume the frame into a closing frame.
@ -294,10 +370,16 @@ impl Frame {
0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
let data = self.payload;
let code = NetworkEndian::read_u16(&data.as_bytes()[0..2]).into();
let bytes = match data {
Payload::Bytes(mut bytes) => {
bytes.drain(0..2);
bytes
}
Payload::ShBytes(bytes) => bytes.as_ref()[2..].to_owned(),
};
let text = String::from_utf8(bytes)?;
Ok(Some(CloseFrame {
code,
reason: text.into(),
@ -308,7 +390,10 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
pub fn message<P>(payload: P, opcode: OpCode, is_final: bool) -> Frame
where
P: Into<Payload>,
{
debug_assert!(
match opcode {
OpCode::Data(_) => true,
@ -323,7 +408,7 @@ impl Frame {
opcode,
..FrameHeader::default()
},
payload: data,
payload: payload.into(),
}
}
@ -335,7 +420,7 @@ impl Frame {
opcode: OpCode::Control(Control::Pong),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}
@ -347,7 +432,7 @@ impl Frame {
opcode: OpCode::Control(Control::Ping),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}
@ -365,12 +450,12 @@ impl Frame {
Frame {
header: FrameHeader::default(),
payload,
payload: payload.into(),
}
}
/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
pub fn from_payload(header: FrameHeader, payload: Payload) -> Self {
Frame { header, payload }
}
@ -405,6 +490,7 @@ payload: 0x{}
self.len(),
self.payload.len(),
self.payload
.as_bytes()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
@ -476,11 +562,11 @@ mod tests {
Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
let mut bytes = Vec::new();
raw.read_to_end(&mut bytes).unwrap();
let frame = Frame::from_payload(header, bytes.into());
assert_eq!(
frame.into_data(),
frame.into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
@ -495,7 +581,7 @@ mod tests {
#[test]
fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
let f = Frame::message("hi there", OpCode::Data(Data::Text), true);
let view = format!("{}", f);
assert!(view.contains("payload:"));
}

@ -172,7 +172,7 @@ impl FrameCodec {
let (header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, payload);
let frame = Frame::from_payload(header, payload.into());
trace!("received frame {}", frame);
Ok(Some(frame))
}
@ -228,11 +228,11 @@ mod tests {
let mut sock = FrameSocket::new(raw);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none());
@ -246,7 +246,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}

@ -6,6 +6,8 @@ use std::str;
use super::frame::CloseFrame;
use crate::error::{Error, Result};
use bytes::Bytes;
mod string_collect {
use utf8;
@ -168,6 +170,35 @@ pub enum IncompleteMessageType {
Binary,
}
/// Either a owned or shard `WebSocket` message.
#[derive(Debug, Clone)]
pub enum EitherMessage {
/// A owned `WebSocket` message.
Message(Message),
/// A shared `WebSocket` message.
SharedMessage(SharedMessage),
}
impl From<Message> for EitherMessage {
fn from(m: Message) -> Self {
EitherMessage::Message(m)
}
}
impl From<SharedMessage> for EitherMessage {
fn from(m: SharedMessage) -> Self {
EitherMessage::SharedMessage(m)
}
}
/// A shared websocket message.
#[derive(Debug, Clone)]
pub enum SharedMessage {
/// A shared binary `WebSocket` message.
Binary(Bytes),
}
/// An enum representing the various forms of a WebSocket message.
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {

@ -5,7 +5,7 @@ pub mod frame;
mod message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
pub use self::message::{Message, SharedMessage, EitherMessage};
use log::*;
use std::collections::VecDeque;
@ -162,6 +162,13 @@ impl<Stream: Read + Write> WebSocket<Stream> {
self.context.write_message(&mut self.socket, message)
}
/// Send a shared message to stream, if possible.
///
/// This is essentially the same method as `write_message` but for shared messages.
pub fn write_shared_message(&mut self, message: SharedMessage) -> Result<()> {
self.context.write_shared_message(&mut self.socket, message)
}
/// Flush the pending send queue.
pub fn write_pending(&mut self) -> Result<()> {
self.context.write_pending(&mut self.socket)
@ -272,6 +279,31 @@ impl WebSocketContext {
/// Note that only the last pong frame is stored to be sent, and only the
/// most recent pong frame is sent if multiple pong frames are queued.
pub fn write_message<Stream>(&mut self, stream: &mut Stream, message: Message) -> Result<()>
where
Stream: Read + Write,
{
self.write_either_message(stream, message.into())
}
/// Send a shared message to the provided stream, if possible.
///
/// This is the same method than `write_message`, but for shared messages instead.
pub fn write_shared_message<Stream>(
&mut self,
stream: &mut Stream,
message: SharedMessage,
) -> Result<()>
where
Stream: Read + Write,
{
self.write_either_message(stream, message.into())
}
fn write_either_message<Stream>(
&mut self,
stream: &mut Stream,
message: EitherMessage,
) -> Result<()>
where
Stream: Read + Write,
{
@ -299,14 +331,19 @@ impl WebSocketContext {
}
let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
return self.write_pending(stream);
}
Message::Close(code) => return self.close(stream, code),
EitherMessage::Message(message) => match message {
Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
return self.write_pending(stream);
}
Message::Close(code) => return self.close(stream, code),
},
EitherMessage::SharedMessage(message) => match message {
SharedMessage::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
},
};
self.send_queue.push_back(frame);
@ -440,14 +477,14 @@ impl WebSocketContext {
format!("Unknown control frame type {}", i).into(),
)),
OpCtl::Ping => {
let data = frame.into_data();
let data = frame.into_payload().unwrap_bytes();
// No ping processing after we sent a close frame.
if self.state.is_active() {
self.pong = Some(Frame::pong(data.clone()));
}
Ok(Some(Message::Ping(data)))
}
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_payload().unwrap_bytes()))),
}
}
@ -456,7 +493,10 @@ impl WebSocketContext {
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
msg.extend(
frame.into_payload().as_bytes(),
self.config.max_message_size,
)?;
} else {
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
@ -479,7 +519,10 @@ impl WebSocketContext {
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m.extend(
frame.into_payload().as_bytes(),
self.config.max_message_size,
)?;
m
};
if fin {

@ -1,15 +1,15 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpStream, TcpListener};
use std::net::{TcpListener, TcpStream};
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream};
use native_tls::TlsStream;
use url::Url;
use net2::TcpStreamExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
@ -26,8 +26,8 @@ where
exit(1);
});
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is port already in use?");
let server =
TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?");
let client_thread = spawn(move || {
let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap())
@ -46,7 +46,8 @@ where
#[test]
fn test_server_close() {
do_test(3012,
do_test(
3012,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -75,12 +76,14 @@ fn test_server_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}
#[test]
fn test_evil_server_close() {
do_test(3013,
do_test(
3013,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -106,14 +109,19 @@ fn test_evil_server_close() {
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
srv_sock
.get_mut()
.set_linger(Some(Duration::from_secs(0)))
.unwrap();
drop(srv_sock);
});
},
);
}
#[test]
fn test_client_close() {
do_test(3014,
do_test(
3014,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
@ -137,7 +145,9 @@ fn test_client_close() {
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock.write_message(Message::Text("From Server".into())).unwrap();
srv_sock
.write_message(Message::Text("From Server".into()))
.unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close());
@ -147,6 +157,6 @@ fn test_client_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}

Loading…
Cancel
Save