Derive Debugs

Requires updated input_buffer with Debug derive fix: https://github.com/snapview/input_buffer/pull/1
Resolved build errors with patch in Cargo.toml with the fix above.
Deny missing debug and resolve resulting errors to satisfy:
https://rust-lang-nursery.github.io/api-guidelines/debuggability.html#all-public-types-implement-debug-c-debug
Formatted with rustfmt-nightly v0.3.6-nightly
pull/24/head
Sean Schwartz 7 years ago
parent 2d8395031b
commit 1154d5a5e4
  1. 3
      Cargo.toml
  2. 34
      examples/autobahn-client.rs
  3. 25
      examples/autobahn-server.rs
  4. 13
      examples/client.rs
  5. 5
      examples/server.rs
  6. 47
      src/client.rs
  7. 10
      src/error.rs
  8. 93
      src/handshake/client.rs
  9. 24
      src/handshake/headers.rs
  10. 32
      src/handshake/machine.rs
  11. 29
      src/handshake/mod.rs
  12. 68
      src/handshake/server.rs
  13. 22
      src/lib.rs
  14. 103
      src/protocol/frame/coding.rs
  15. 79
      src/protocol/frame/frame.rs
  16. 13
      src/protocol/frame/mask.rs
  17. 48
      src/protocol/frame/mod.rs
  18. 62
      src/protocol/message.rs
  19. 114
      src/protocol/mod.rs
  20. 13
      src/server.rs
  21. 9
      src/stream.rs
  22. 5
      src/util.rs

@ -27,6 +27,9 @@ sha1 = "0.4.0"
url = "1.5.1"
utf-8 = "0.7.1"
[patch.crates-io]
input_buffer = { git = "https://github.com/unv-annihilator/input_buffer", rev = "f940362f34afd61a34d126d211a9ad2bf2ec903a" }
[dependencies.native-tls]
optional = true
version = "0.1.5"

@ -1,18 +1,17 @@
#[macro_use] extern crate log;
extern crate env_logger;
#[macro_use]
extern crate log;
extern crate tungstenite;
extern crate url;
use url::Url;
use tungstenite::{connect, Error, Result, Message};
use tungstenite::{connect, Error, Message, Result};
const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?;
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read_message()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
@ -20,7 +19,10 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
)).unwrap(),
)?;
socket.close(None)?;
Ok(())
@ -28,18 +30,17 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let case_url = Url::parse(&format!(
"ws://localhost:9001/runCase?case={}&agent={}",
case, AGENT
)).unwrap();
let (mut socket, _) = connect(case_url)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |
msg @ Message::Binary(_) => {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?;
}
Message::Ping(_) |
Message::Pong(_) => {}
Message::Ping(_) | Message::Pong(_) => {}
}
}
}
@ -52,12 +53,13 @@ fn main() {
for case in 1..(total + 1) {
if let Err(e) = run_test(case) {
match e {
Error::Protocol(_) => { }
err => { warn!("test: {}", err); }
Error::Protocol(_) => {}
err => {
warn!("test: {}", err);
}
}
}
}
update_reports().unwrap();
}

@ -1,11 +1,12 @@
#[macro_use] extern crate log;
extern crate env_logger;
#[macro_use]
extern crate log;
extern crate tungstenite;
use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message};
use tungstenite::{accept, Error, HandshakeError, Message, Result};
use tungstenite::handshake::HandshakeRole;
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
@ -19,12 +20,10 @@ fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |
msg @ Message::Binary(_) => {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?;
}
Message::Ping(_) |
Message::Pong(_) => {}
Message::Ping(_) | Message::Pong(_) => {}
}
}
}
@ -35,14 +34,12 @@ fn main() {
let server = TcpListener::bind("127.0.0.1:9001").unwrap();
for stream in server.incoming() {
spawn(move || {
match stream {
Ok(stream) => match handle_client(stream) {
Ok(_) => (),
Err(e) => warn!("Error in client: {}", e),
},
Err(e) => warn!("Error accepting stream: {}", e),
}
spawn(move || match stream {
Ok(stream) => match handle_client(stream) {
Ok(_) => (),
Err(e) => warn!("Error in client: {}", e),
},
Err(e) => warn!("Error accepting stream: {}", e),
});
}
}

@ -1,15 +1,15 @@
extern crate env_logger;
extern crate tungstenite;
extern crate url;
extern crate env_logger;
use url::Url;
use tungstenite::{Message, connect};
use tungstenite::{connect, Message};
fn main() {
env_logger::init().unwrap();
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect");
let (mut socket, response) =
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");
println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
@ -18,11 +18,12 @@ fn main() {
println!("* {}", header);
}
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
socket
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
loop {
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);
}
// socket.close(None);
}

@ -21,7 +21,10 @@ fn main() {
// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")),
(
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"),
),
];
Ok(Some(extra_headers))
};

@ -1,6 +1,6 @@
//! Methods to connect to an WebSocket as a client.
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use std::io::{Read, Write};
@ -8,10 +8,10 @@ use url::Url;
use handshake::client::Response;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
mod encryption {
use std::net::TcpStream;
use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError};
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector};
pub use native_tls::TlsStream;
pub use stream::Stream as StreamSwitcher;
@ -26,10 +26,13 @@ mod encryption {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder()?.build()?;
connector.connect(domain, stream)
connector
.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
TlsHandshakeError::Interrupted(_) => {
panic!("Bug: TLS handshake not blocked")
}
})
.map(StreamSwitcher::Tls)
}
@ -37,7 +40,7 @@ mod encryption {
}
}
#[cfg(not(feature="tls"))]
#[cfg(not(feature = "tls"))]
mod encryption {
use std::net::TcpStream;
@ -61,10 +64,9 @@ use self::encryption::wrap_stream;
use protocol::WebSocket;
use handshake::HandshakeError;
use handshake::client::{ClientHandshake, Request};
use stream::{NoDelay, Mode};
use stream::{Mode, NoDelay};
use error::{Error, Result};
/// Connect to the given WebSocket in blocking mode.
///
/// The URL may be either ws:// or wss://.
@ -77,30 +79,31 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Response)>
{
pub fn connect<'t, Req: Into<Request<'t>>>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?;
let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client(request, stream)
.map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
client(request, stream).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr>
where
A: Iterator<Item = SocketAddr>,
{
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let domain = url.host_str()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream)
return Ok(stream);
}
}
}
@ -115,7 +118,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into()))
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
@ -126,8 +129,8 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
request: Req,
stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,

@ -13,7 +13,7 @@ use httparse;
use protocol::frame::CloseFrame;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
pub mod tls {
//! TLS error wrapper module, feature-gated.
pub use native_tls::Error;
@ -29,7 +29,7 @@ pub enum Error {
ConnectionClosed(Option<CloseFrame<'static>>),
/// Input-output error
Io(io::Error),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
/// TLS error
Tls(tls::Error),
/// Buffer capacity exhausted
@ -55,7 +55,7 @@ impl fmt::Display for Error {
}
}
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
@ -71,7 +71,7 @@ impl ErrorTrait for Error {
match *self {
Error::ConnectionClosed(_) => "A close handshake is performed",
Error::Io(ref err) => err.description(),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
Error::Tls(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(),
@ -100,7 +100,7 @@ impl From<string::FromUtf8Error> for Error {
}
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)

@ -11,12 +11,13 @@ use rand;
use url::Url;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use protocol::{Role, WebSocket};
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
/// Client request.
#[derive(Debug)]
pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to.
pub url: Url,
@ -67,6 +68,7 @@ impl From<Url> for Request<'static> {
}
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S> {
verify_data: VerifyData,
_marker: PhantomData<S>,
@ -79,14 +81,19 @@ impl<S: Read + Write> ClientHandshake<S> {
let machine = {
let mut req = Vec::new();
write!(req, "\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), path = request.get_path(), key = key).unwrap();
write!(
req,
"\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(),
path = request.get_path(),
key = key
).unwrap();
if let Some(eh) = request.extra_headers {
for (k, v) in eh {
write!(req, "{}: {}\r\n", k, v).unwrap();
@ -105,7 +112,10 @@ impl<S: Read + Write> ClientHandshake<S> {
};
trace!("Client handshake initiated.");
MidHandshake { role: client, machine: machine }
MidHandshake {
role: client,
machine: machine,
}
}
}
@ -113,24 +123,32 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail, } => {
StageResult::DoneReading {
stream,
result,
tail,
} => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result))
ProcessingResult::Done((
WebSocket::from_partially_read(stream, tail, Role::Client),
result,
))
}
})
}
}
/// Information for handshake verification.
#[derive(Debug)]
struct VerifyData {
/// Accepted server key.
accept_key: String,
@ -147,22 +165,37 @@ impl VerifyData {
// header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Upgrade", "websocket") {
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
if !response
.headers
.header_is_ignore_case("Upgrade", "websocket")
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Connection", "Upgrade") {
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
if !response
.headers
.header_is_ignore_case("Connection", "Upgrade")
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
if !response
.headers
.header_is("Sec-WebSocket-Accept", &self.accept_key)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
@ -183,6 +216,7 @@ impl VerifyData {
}
/// Server response.
#[derive(Debug)]
pub struct Response {
/// HTTP response code of the response.
pub code: u16,
@ -204,7 +238,9 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
@ -223,7 +259,7 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{Response, generate_key};
use super::{generate_key, Response};
use super::super::machine::TryParse;
#[test]
@ -247,6 +283,9 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
assert_eq!(
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
}
}

@ -19,7 +19,6 @@ pub struct Headers {
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
@ -29,7 +28,7 @@ impl Headers {
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name: name,
iter: self.data.iter()
iter: self.data.iter(),
}
}
@ -42,7 +41,8 @@ impl Headers {
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
@ -52,10 +52,10 @@ impl Headers {
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers.
#[derive(Debug)]
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
@ -66,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
return Some(value);
}
}
None
}
}
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
@ -94,8 +93,8 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
@ -108,8 +107,7 @@ mod tests {
#[test]
fn headers() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
@ -125,8 +123,7 @@ mod tests {
#[test]
fn headers_iter() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
@ -141,8 +138,7 @@ mod tests {
#[test]
fn headers_incomplete() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap();

@ -6,6 +6,7 @@ use error::{Error, Result};
use util::NonBlockingResult;
/// A generic handshake state machine.
#[derive(Debug)]
pub struct HandshakeMachine<Stream> {
stream: Stream,
state: HandshakeState,
@ -47,11 +48,9 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.map_err(|_| Error::Capacity("Header too long".into()))?
.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => {
Err(Error::Protocol("Handshake not finished".into()))
}
Some(_) => {
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok(
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
@ -63,14 +62,12 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
state: HandshakeState::Reading(buf),
..self
})
})
}
None => {
Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
}))
}
},
),
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})),
}
}
HandshakeState::Writing(mut buf) => {
@ -98,6 +95,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
}
/// The result of the round.
#[derive(Debug)]
pub enum RoundResult<Obj, Stream> {
/// Round not done, I/O would block.
WouldBlock(HandshakeMachine<Stream>),
@ -108,9 +106,14 @@ pub enum RoundResult<Obj, Stream> {
}
/// The result of the stage.
#[derive(Debug)]
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
DoneReading {
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
/// Writing round finished.
DoneWriting(Stream),
}
@ -122,6 +125,7 @@ pub trait TryParse: Sized {
}
/// The handshake state.
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
Reading(InputBuffer),

@ -17,6 +17,7 @@ use error::Error;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake.
#[derive(Debug)]
pub struct MidHandshake<Role: HandshakeRole> {
role: Role,
machine: HandshakeMachine<Role::InternalStream>,
@ -29,15 +30,16 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop {
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
return Err(HandshakeError::Interrupted(MidHandshake {
machine: m,
..self
}))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result),
}
}
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result),
},
}
}
}
@ -93,12 +95,15 @@ pub trait HandshakeRole {
#[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
}
/// Stage processing result.
#[doc(hidden)]
#[derive(Debug)]
pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>),
Done(FinalResult),
@ -122,8 +127,10 @@ mod tests {
#[test]
fn key_conversion() {
// example from RFC 6455
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
assert_eq!(
convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
}
}

@ -8,12 +8,13 @@ use httparse;
use httparse::Status;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use protocol::{Role, WebSocket};
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
/// Request from the client.
#[derive(Debug)]
pub struct Request {
/// Path part of the URL.
pub path: String,
@ -24,14 +25,15 @@ pub struct Request {
impl Request {
/// Reply to the response.
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
let key = self.headers
.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
);
if let Some(eh) = extra_headers {
@ -61,11 +63,13 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
return Err(Error::Protocol("Method is not GET".into()));
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?
headers: Headers::from_httparse(raw.headers)?,
})
}
}
@ -83,14 +87,17 @@ pub trait Callback: Sized {
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>>;
}
impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>> {
impl<F> Callback for F
where
F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>>,
{
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>> {
self(request)
}
}
/// Stub for callback that does nothing.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub struct NoCallback;
impl Callback for NoCallback {
@ -101,6 +108,7 @@ impl Callback for NoCallback {
/// Server handshake role.
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C> {
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
@ -119,7 +127,10 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), _marker: PhantomData },
role: ServerHandshake {
callback: Some(callback),
_marker: PhantomData,
},
}
}
}
@ -129,13 +140,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type InternalStream = S;
type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
StageResult::DoneReading {
stream,
result,
tail,
} => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
return Err(Error::Protocol("Junk after client request".into()));
}
let extra_headers = {
if let Some(callback) = self.callback.take() {
@ -182,13 +198,19 @@ mod tests {
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let extra_headers = Some(vec![
(
String::from("MyCustomHeader"),
String::from("MyCustomValue"),
),
(String::from("MyVersion"), String::from("LOL")),
]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(
req.headers.find_first("MyCustomHeader"),
Some(b"MyCustomValue".as_ref())
);
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
}

@ -1,25 +1,21 @@
//! Lightweight, flexible WebSockets for Rust.
#![deny(
missing_docs,
missing_copy_implementations,
trivial_casts, trivial_numeric_casts,
unstable_features,
unused_must_use,
unused_mut,
unused_imports,
unused_import_braces)]
#![deny(missing_docs, missing_copy_implementations, missing_debug_implementations, trivial_casts,
trivial_numeric_casts, unstable_features, unused_must_use, unused_mut, unused_imports,
unused_import_braces)]
#[macro_use] extern crate log;
extern crate base64;
extern crate byteorder;
extern crate bytes;
extern crate httparse;
extern crate input_buffer;
#[macro_use]
extern crate log;
#[cfg(feature = "tls")]
extern crate native_tls;
extern crate rand;
extern crate sha1;
extern crate url;
extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls;
pub mod error;
pub mod protocol;
@ -29,10 +25,10 @@ pub mod handshake;
pub mod stream;
pub mod util;
pub use client::{connect, client};
pub use client::{client, connect};
pub use server::{accept, accept_hdr};
pub use error::{Error, Result};
pub use protocol::{WebSocket, Message};
pub use protocol::{Message, WebSocket};
pub use handshake::HandshakeError;
pub use handshake::client::ClientHandshake;
pub use handshake::server::ServerHandshake;

@ -1,7 +1,7 @@
//! Various codes defined in RFC 6455.
use std::fmt;
use std::convert::{Into, From};
use std::convert::{From, Into};
/// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -42,8 +42,8 @@ impl fmt::Display for Data {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Data::Continue => write!(f, "CONTINUE"),
Data::Text => write!(f, "TEXT"),
Data::Binary => write!(f, "BINARY"),
Data::Text => write!(f, "TEXT"),
Data::Binary => write!(f, "BINARY"),
Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x),
}
}
@ -53,8 +53,8 @@ impl fmt::Display for Control {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Control::Close => write!(f, "CLOSE"),
Control::Ping => write!(f, "PING"),
Control::Pong => write!(f, "PONG"),
Control::Ping => write!(f, "PING"),
Control::Pong => write!(f, "PONG"),
Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x),
}
}
@ -71,18 +71,18 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode {
fn into(self) -> u8 {
use self::Data::{Continue, Text, Binary};
use self::Data::{Binary, Continue, Text};
use self::Control::{Close, Ping, Pong};
use self::OpCode::*;
match self {
Data(Continue) => 0,
Data(Text) => 1,
Data(Binary) => 2,
Data(Text) => 1,
Data(Binary) => 2,
Data(self::Data::Reserved(i)) => i,
Control(Close) => 8,
Control(Ping) => 9,
Control(Pong) => 10,
Control(Ping) => 9,
Control(Pong) => 10,
Control(self::Control::Reserved(i)) => i,
}
}
@ -90,19 +90,19 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
use self::Data::{Continue, Text, Binary};
use self::Data::{Binary, Continue, Text};
use self::Control::{Close, Ping, Pong};
use self::OpCode::*;
match byte {
0 => Data(Continue),
1 => Data(Text),
2 => Data(Binary),
i @ 3 ... 7 => Data(self::Data::Reserved(i)),
8 => Control(Close),
9 => Control(Ping),
10 => Control(Pong),
i @ 11 ... 15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"),
0 => Data(Continue),
1 => Data(Text),
2 => Data(Binary),
i @ 3...7 => Data(self::Data::Reserved(i)),
8 => Control(Close),
9 => Control(Ping),
10 => Control(Pong),
i @ 11...15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"),
}
}
}
@ -169,27 +169,22 @@ pub enum CloseCode {
/// to a different IP (when multiple targets exist), or reconnect to the same IP
/// when a user has performed an action.
Again,
#[doc(hidden)]
Tls,
#[doc(hidden)]
Reserved(u16),
#[doc(hidden)]
Iana(u16),
#[doc(hidden)]
Library(u16),
#[doc(hidden)]
Bad(u16),
#[doc(hidden)] Tls,
#[doc(hidden)] Reserved(u16),
#[doc(hidden)] Iana(u16),
#[doc(hidden)] Library(u16),
#[doc(hidden)] Bad(u16),
}
impl CloseCode {
/// Check if this CloseCode is allowed.
pub fn is_allowed(&self) -> bool {
match *self {
Bad(_) => false,
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
}
@ -205,24 +200,24 @@ impl fmt::Display for CloseCode {
impl<'t> Into<u16> for &'t CloseCode {
fn into(self) -> u16 {
match *self {
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
}
}
}
@ -250,11 +245,11 @@ impl From<u16> for CloseCode {
1012 => Restart,
1013 => Again,
1015 => Tls,
1...999 => Bad(code),
1...999 => Bad(code),
1000...2999 => Reserved(code),
3000...3999 => Iana(code),
4000...4999 => Library(code),
_ => Bad(code)
_ => Bad(code),
}
}
}

@ -1,16 +1,16 @@
use std::fmt;
use std::borrow::Cow;
use std::mem::transmute;
use std::io::{Cursor, Read, Write, ErrorKind};
use std::io::{Cursor, ErrorKind, Read, Write};
use std::default::Default;
use std::string::{String, FromUtf8Error};
use std::string::{FromUtf8Error, String};
use std::result::Result as StdResult;
use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian};
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt};
use bytes::BufMut;
use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode};
use super::mask::{generate_mask, apply_mask};
use super::coding::{CloseCode, Control, Data, OpCode};
use super::mask::{apply_mask, generate_mask};
/// A struct representing the close command.
#[derive(Debug, Clone)]
@ -52,7 +52,6 @@ pub struct Frame {
}
impl Frame {
/// Get the length of the frame.
/// This is the length of the header + the length of the payload.
#[inline]
@ -186,9 +185,8 @@ impl Frame {
#[doc(hidden)]
#[inline]
pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask))
});
self.mask
.and_then(|mask| Some(apply_mask(&mut self.payload, &mask)));
self.mask = None;
}
@ -204,7 +202,7 @@ impl Frame {
String::from_utf8(self.payload)
}
/// Consume the frame into a closing frame.
/// Consume the frame into a closing frame.
#[inline]
pub fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() {
@ -215,7 +213,10 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code: code, reason: text.into() }))
Ok(Some(CloseFrame {
code: code,
reason: text.into(),
}))
}
}
}
@ -223,16 +224,19 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
debug_assert!(match code {
OpCode::Data(_) => true,
_ => false,
}, "Invalid opcode for data frame.");
debug_assert!(
match code {
OpCode::Data(_) => true,
_ => false,
},
"Invalid opcode for data frame."
);
Frame {
finished: finished,
opcode: code,
payload: data,
.. Frame::default()
..Frame::default()
}
}
@ -242,7 +246,7 @@ impl Frame {
Frame {
opcode: OpCode::Control(Control::Pong),
payload: data,
.. Frame::default()
..Frame::default()
}
}
@ -252,7 +256,7 @@ impl Frame {
Frame {
opcode: OpCode::Control(Control::Ping),
payload: data,
.. Frame::default()
..Frame::default()
}
}
@ -271,7 +275,7 @@ impl Frame {
Frame {
payload: payload,
.. Frame::default()
..Frame::default()
}
}
@ -284,7 +288,7 @@ impl Frame {
let mut head = [0u8; 2];
if try!(cursor.read(&mut head)) != 2 {
cursor.set_position(initial);
return Ok(None)
return Ok(None);
}
trace!("Parsed headers {:?}", head);
@ -335,7 +339,7 @@ impl Frame {
let mut mask_bytes = [0u8; 4];
if try!(cursor.read(&mut mask_bytes)) != 4 {
cursor.set_position(initial);
return Ok(None)
return Ok(None);
} else {
header_length += 4;
Some(mask_bytes)
@ -346,7 +350,7 @@ impl Frame {
if size < length + header_length {
cursor.set_position(initial);
return Ok(None)
return Ok(None);
}
let mut data = Vec::with_capacity(length as usize);
@ -360,9 +364,11 @@ impl Frame {
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
return Err(Error::Protocol(
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
}
_ => ()
_ => (),
}
let frame = Frame {
@ -375,13 +381,13 @@ impl Frame {
payload: data,
};
Ok(Some(frame))
}
/// Write a frame out to a buffer
pub fn format<W>(mut self, w: &mut W) -> Result<()>
where W: Write
where
W: Write,
{
let mut one = 0u8;
let code: u8 = self.opcode.into();
@ -461,7 +467,8 @@ impl Default for Frame {
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f,
write!(
f,
"
<FRAME>
final: {}
@ -479,7 +486,11 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>())
self.payload
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
)
}
}
@ -487,16 +498,18 @@ payload: 0x{}
mod tests {
use super::*;
use super::super::coding::{OpCode, Data};
use super::super::coding::{Data, OpCode};
use std::io::Cursor;
#[test]
fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
]);
let mut raw: Cursor<Vec<u8>> =
Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let frame = Frame::parse(&mut raw).unwrap().unwrap();
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
assert_eq!(
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
#[test]

@ -71,7 +71,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) {
// Possible last block.
if len > 0 {
unsafe { xor_mem(ptr, mask_u32, len); }
unsafe {
xor_mem(ptr, mask_u32, len);
}
}
}
@ -94,12 +96,10 @@ mod tests {
#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let mask = [0x6d, 0xb6, 0xb2, 0x80];
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82,
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03,
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0x12, 0x03,
];
// Check masking with proper alignment.
@ -126,4 +126,3 @@ mod tests {
}
}

@ -14,6 +14,7 @@ use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
/// A reader and writer for WebSocket frames.
#[derive(Debug)]
pub struct FrameSocket<Stream> {
stream: Stream,
in_buffer: InputBuffer,
@ -52,7 +53,8 @@ impl<Stream> FrameSocket<Stream> {
}
impl<Stream> FrameSocket<Stream>
where Stream: Read
where
Stream: Read,
{
/// Read a frame from stream.
pub fn read_frame(&mut self) -> Result<Option<Frame>> {
@ -62,21 +64,22 @@ impl<Stream> FrameSocket<Stream>
return Ok(Some(frame));
}
// No full frames in buffer.
let size = self.in_buffer.prepare_reserve(MIN_READ)
let size = self.in_buffer
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
.read_from(&mut self.stream)?;
if size == 0 {
trace!("no frame received");
return Ok(None)
return Ok(None);
}
}
}
}
impl<Stream> FrameSocket<Stream>
where Stream: Write
where
Stream: Write,
{
/// Write a frame to stream.
///
@ -86,7 +89,9 @@ impl<Stream> FrameSocket<Stream>
pub fn write_frame(&mut self, frame: Frame) -> Result<()> {
trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending()
}
/// Complete pending write, if any.
@ -100,7 +105,6 @@ impl<Stream> FrameSocket<Stream>
}
}
#[cfg(test)]
mod tests {
@ -111,16 +115,19 @@ mod tests {
#[test]
fn read_frames() {
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x82, 0x03, 0x03, 0x02, 0x01,
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
0x99,
]);
let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]);
assert_eq!(
sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame().unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame().unwrap().is_none());
let (_, rest) = sock.into_inner();
@ -129,12 +136,12 @@ mod tests {
#[test]
fn from_partially_read() {
let raw = Cursor::new(vec![
0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(
sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
#[test]
@ -148,10 +155,7 @@ mod tests {
sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![
0x89, 0x02, 0x04, 0x05,
0x8a, 0x01, 0x01
]);
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
}
}

@ -1,4 +1,4 @@
use std::convert::{From, Into, AsRef};
use std::convert::{AsRef, From, Into};
use std::fmt;
use std::result::Result as StdResult;
use std::str;
@ -12,6 +12,7 @@ mod string_collect {
use error::{Error, Result};
#[derive(Debug)]
pub struct StringCollector {
data: String,
incomplete: Option<utf8::Incomplete>,
@ -34,7 +35,7 @@ mod string_collect {
if let Ok(text) = result {
self.data.push_str(text);
} else {
return Err(Error::Utf8)
return Err(Error::Utf8);
}
true
} else {
@ -52,7 +53,10 @@ mod string_collect {
self.data.push_str(text);
Ok(())
}
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
Err(DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => {
self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix);
Ok(())
@ -81,10 +85,12 @@ mod string_collect {
use self::string_collect::StringCollector;
/// A struct representing the incomplete message.
#[derive(Debug)]
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
}
#[derive(Debug)]
enum IncompleteMessageCollector {
Text(StringCollector),
Binary(Vec<u8>),
@ -95,11 +101,11 @@ impl IncompleteMessage {
pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary =>
IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text =>
IncompleteMessageCollector::Text(StringCollector::new()),
}
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
}
}
/// Add more data to an existing message.
@ -109,17 +115,13 @@ impl IncompleteMessage {
v.extend(tail.as_ref());
Ok(())
}
IncompleteMessageCollector::Text(ref mut t) => {
t.extend(tail)
}
IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
}
}
/// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> {
match self.collector {
IncompleteMessageCollector::Binary(v) => {
Ok(Message::Binary(v))
}
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?;
Ok(Message::Text(text))
@ -152,17 +154,18 @@ pub enum Message {
}
impl Message {
/// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message
where S: Into<String>
where
S: Into<String>,
{
Message::Text(string.into())
}
/// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message
where B: Into<Vec<u8>>
where
B: Into<Vec<u8>>,
{
Message::Binary(bin.into())
}
@ -203,9 +206,9 @@ impl Message {
pub fn len(&self) -> usize {
match *self {
Message::Text(ref string) => string.len(),
Message::Binary(ref data) |
Message::Ping(ref data) |
Message::Pong(ref data) => data.len(),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
data.len()
}
}
}
@ -219,9 +222,7 @@ impl Message {
pub fn into_data(self) -> Vec<u8> {
match self {
Message::Text(string) => string.into_bytes(),
Message::Binary(data) |
Message::Ping(data) |
Message::Pong(data) => data,
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
}
}
@ -229,10 +230,9 @@ impl Message {
pub fn into_text(self) -> Result<String> {
match self {
Message::Text(string) => Ok(string),
Message::Binary(data) |
Message::Ping(data) |
Message::Pong(data) => Ok(try!(
String::from_utf8(data).map_err(|err| err.utf8_error()))),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => Ok(try!(
String::from_utf8(data).map_err(|err| err.utf8_error())
)),
}
}
@ -241,12 +241,11 @@ impl Message {
pub fn to_text(&self) -> Result<&str> {
match *self {
Message::Text(ref string) => Ok(string),
Message::Binary(ref data) |
Message::Ping(ref data) |
Message::Pong(ref data) => Ok(try!(str::from_utf8(data))),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
Ok(try!(str::from_utf8(data)))
}
}
}
}
impl From<String> for Message {
@ -304,7 +303,6 @@ mod tests {
assert!(msg.into_text().is_err());
}
#[test]
fn binary_convert_vec() {
let bin = vec![6u8, 7, 8, 9, 10, 241];

@ -8,13 +8,13 @@ pub use self::message::Message;
pub use self::frame::CloseFrame;
use std::collections::VecDeque;
use std::io::{Read, Write, ErrorKind as IoErrorKind};
use std::io::{ErrorKind as IoErrorKind, Read, Write};
use std::mem::replace;
use error::{Error, Result};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::frame::{Frame, FrameSocket};
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode};
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use util::NonBlockingResult;
/// Indicates a Client or Server role of the websocket
@ -30,6 +30,7 @@ pub enum Role {
///
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)]
pub struct WebSocket<Stream> {
/// Server or client?
role: Role,
@ -93,7 +94,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
let res = self.read_message_frame();
if let Some(message) = self.translate_close(res)? {
trace!("Received message {}", message);
return Ok(message)
return Ok(message);
}
}
}
@ -108,16 +109,12 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// most recent pong frame is sent if multiple pong frames are queued up.
pub fn write_message(&mut self, message: Message) -> Result<()> {
let frame = match message {
Message::Text(data) => {
Frame::message(data.into(), OpCode::Data(OpData::Text), true)
}
Message::Binary(data) => {
Frame::message(data, OpCode::Data(OpData::Binary), true)
}
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
return self.write_pending()
return self.write_pending();
}
};
self.send_queue.push_back(frame);
@ -182,14 +179,13 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Try to decode one message frame. May return None.
fn read_message_frame(&mut self) -> Result<Option<Message>> {
if let Some(mut frame) = self.socket.read_frame()? {
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() {
return Err(Error::Protocol("Reserved bits are non-zero".into()))
return Err(Error::Protocol("Reserved bits are non-zero".into()));
}
match self.role {
@ -201,19 +197,22 @@ impl<Stream: Read + Write> WebSocket<Stream> {
} else {
// The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455)
return Err(Error::Protocol("Received an unmasked frame from client".into()))
return Err(Error::Protocol(
"Received an unmasked frame from client".into(),
));
}
}
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into()))
return Err(Error::Protocol(
"Received a masked frame from server".into(),
));
}
}
}
match frame.opcode() {
OpCode::Control(ctl) => {
match ctl {
// All control frames MUST have a payload length of 125 bytes or less
@ -224,12 +223,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => {
self.do_close(frame.into_close()?).map(|_| None)
}
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Close => self.do_close(frame.into_close()?).map(|_| None),
OpCtl::Reserved(i) => Err(Error::Protocol(
format!("Unknown control frame type {}", i).into(),
)),
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
// No ping processing while closing.
Ok(None)
@ -239,9 +236,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
self.pong = Some(Frame::pong(data.clone()));
Ok(Some(Message::Ping(data)))
}
OpCtl::Pong => {
Ok(Some(Message::Pong(frame.into_data())))
}
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
}
}
@ -258,19 +253,21 @@ impl<Stream: Read + Write> WebSocket<Stream> {
// TODO if msg too big
msg.extend(frame.into_data())?;
} else {
return Err(Error::Protocol("Continue frame but nothing to continue".into()))
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(replace(&mut self.incomplete, None).unwrap().complete()?))
Ok(Some(replace(&mut self.incomplete, None)
.unwrap()
.complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into()
))
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
@ -289,22 +286,20 @@ impl<Stream: Read + Write> WebSocket<Stream> {
Ok(None)
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
}
OpData::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
}
}
} // match opcode
} else {
match replace(&mut self.state, WebSocketState::Terminated) {
WebSocketState::CloseAcknowledged(close) | WebSocketState::ClosedByPeer(close) => {
Err(Error::ConnectionClosed(close))
}
_ => {
Err(Error::Protocol("Connection reset without closing handshake".into()))
}
_ => Err(Error::Protocol(
"Connection reset without closing handshake".into(),
)),
}
}
}
@ -325,7 +320,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
} else {
Frame::close(Some(CloseFrame {
code: CloseCode::Protocol,
reason: "Protocol violation".into()
reason: "Protocol violation".into(),
}))
}
} else {
@ -361,8 +356,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Send a single pending frame.
fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> {
match self.role {
Role::Server => {
}
Role::Server => {}
Role::Client => {
// 5. If the data is being sent by the client, the frame(s) MUST be
// masked as defined in Section 5.3. (RFC 6455)
@ -379,10 +373,12 @@ impl<Stream: Read + Write> WebSocket<Stream> {
Err(Error::Io(err)) => Err({
if err.kind() == IoErrorKind::ConnectionReset {
match self.state {
WebSocketState::ClosedByPeer(ref mut frame) =>
Error::ConnectionClosed(replace(frame, None)),
WebSocketState::CloseAcknowledged(ref mut frame) =>
Error::ConnectionClosed(replace(frame, None)),
WebSocketState::ClosedByPeer(ref mut frame) => {
Error::ConnectionClosed(replace(frame, None))
}
WebSocketState::CloseAcknowledged(ref mut frame) => {
Error::ConnectionClosed(replace(frame, None))
}
_ => Error::Io(err),
}
} else {
@ -392,10 +388,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
x => x,
}
}
}
/// The current connection state.
#[derive(Debug)]
enum WebSocketState {
/// The connection is active.
Active,
@ -421,7 +417,7 @@ impl WebSocketState {
#[cfg(test)]
mod tests {
use super::{WebSocket, Role, Message};
use super::{Message, Role, WebSocket};
use std::io;
use std::io::Cursor;
@ -443,24 +439,24 @@ mod tests {
}
}
#[test]
fn receive_messages() {
let incoming = Cursor::new(vec![
0x89, 0x02, 0x01, 0x02,
0x8a, 0x01, 0x03,
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x82, 0x03,
0x01, 0x02, 0x03,
0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
assert_eq!(
socket.read_message().unwrap(),
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
}
}

@ -15,9 +15,9 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(stream: S)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_hdr(stream, NoCallback)
}
@ -26,8 +26,9 @@ pub fn accept<S: Read + Write>(stream: S)
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
{
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback).handshake()
}

@ -4,15 +4,15 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use std::io::{Read, Write, Result as IoResult};
use std::io::{Read, Result as IoResult, Write};
use std::net::TcpStream;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use native_tls::TlsStream;
/// Stream mode, either plain TCP or TLS.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub enum Mode {
/// Plain mode (`ws://` URL).
Plain,
@ -32,7 +32,7 @@ impl NoDelay for TcpStream {
}
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().set_nodelay(nodelay)
@ -40,6 +40,7 @@ impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
}
/// Stream, either plain TCP or TLS.
#[derive(Debug)]
pub enum Stream<S, T> {
/// Unencrypted socket stream.
Plain(S),

@ -40,7 +40,8 @@ pub trait NonBlockingResult {
}
impl<T, E> NonBlockingResult for StdResult<T, E>
where E : NonBlockingError
where
E: NonBlockingError,
{
type Result = StdResult<Option<T>, E>;
fn no_block(self) -> Self::Result {
@ -49,7 +50,7 @@ impl<T, E> NonBlockingResult for StdResult<T, E>
Err(e) => match e.into_non_blocking() {
Some(e) => Err(e),
None => Ok(None),
}
},
}
}
}

Loading…
Cancel
Save