Initial commit, mostly working client

pull/7/head
Alexey Galakhov 8 years ago
commit e63f594a14
  1. 2
      .gitignore
  2. 24
      Cargo.toml
  3. 20
      LICENSE
  4. 62
      examples/autobahn-client.rs
  5. 25
      examples/client.rs
  6. 75
      src/client.rs
  7. 84
      src/error.rs
  8. 259
      src/handshake/client.rs
  9. 183
      src/handshake/mod.rs
  10. 90
      src/handshake/server.rs
  11. 14
      src/handshake/tls.rs
  12. 85
      src/input_buffer.rs
  13. 28
      src/lib.rs
  14. 274
      src/protocol/frame/coding.rs
  15. 516
      src/protocol/frame/frame.rs
  16. 135
      src/protocol/frame/mod.rs
  17. 260
      src/protocol/message.rs
  18. 374
      src/protocol/mod.rs
  19. 39
      src/stream.rs

2
.gitignore vendored

@ -0,0 +1,2 @@
target
Cargo.lock

@ -0,0 +1,24 @@
[package]
name = "ws2"
version = "0.1.0"
authors = ["Alexey Galakhov"]
[features]
default = []
tls = ["native-tls"]
[dependencies]
base64 = "*"
byteorder = "*"
bytes = { git = "https://github.com/carllerche/bytes.git" }
httparse = "*"
env_logger = "*"
log = "*"
rand = "*"
sha1 = "*"
url = "*"
utf-8 = "*"
[dependencies.native-tls]
optional = true
version = "*"

@ -0,0 +1,20 @@
Copyright (c) 2016 Alexey Galakhov
Copyright (c) 2016 Jason Housley
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

@ -0,0 +1,62 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate ws2;
extern crate url;
use url::Url;
use ws2::protocol::Message;
use ws2::client::connect;
use ws2::handshake::Handshake;
use ws2::error::{Error, Result};
const AGENT: &'static str = "WS2-RS";
fn get_case_count() -> Result<u32> {
let mut socket = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?.handshake_wait()?;
let msg = socket.read_message()?;
socket.close();
Ok(msg.into_text()?.parse::<u32>().unwrap())
}
fn update_reports() -> Result<()> {
let mut socket = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?.handshake_wait()?;
socket.close();
Ok(())
}
fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let mut socket = connect(case_url)?.handshake_wait()?;
loop {
let msg = socket.read_message()?;
socket.write_message(msg)?;
}
socket.close();
Ok(())
}
fn main() {
env_logger::init().unwrap();
let total = get_case_count().unwrap();
for case in 1..(total + 1) {
if let Err(e) = run_test(case) {
match e {
Error::Protocol(_) => { }
err => { warn!("test: {}", err); }
}
}
}
update_reports().unwrap();
}

@ -0,0 +1,25 @@
extern crate ws2;
extern crate url;
extern crate env_logger;
use url::Url;
use ws2::protocol::Message;
use ws2::client::connect;
use ws2::protocol::handshake::Handshake;
fn main() {
env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect")
.handshake_wait()
.expect("Handshake error");
socket.write_message(Message::Text("Hello WebSocket".into()));
loop {
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);
}
// socket.close();
}

@ -0,0 +1,75 @@
use std::net::{TcpStream, ToSocketAddrs};
use url::{Url, SocketAddrs};
use protocol::WebSocket;
use handshake::{Handshake as HandshakeTrait, HandshakeResult};
use handshake::client::{ClientHandshake, Request};
use error::{Error, Result};
/// Connect to the given WebSocket.
///
/// Note that this function may block the current thread while DNS resolution is performed.
pub fn connect(url: Url) -> Result<Handshake> {
let mode = match url.scheme() {
"ws" => Mode::Plain,
#[cfg(feature="tls")]
"wss" => Mode::Tls,
_ => return Err(Error::Url("URL scheme not supported".into()))
};
// Note that this function may block the current thread while resolution is performed.
let addrs = url.to_socket_addrs()?;
Ok(Handshake {
state: HandshakeState::Nothing(url),
alt_addresses: addrs,
})
}
enum Mode {
Plain,
Tls,
}
enum HandshakeState {
Nothing(Url),
WebSocket(ClientHandshake<TcpStream>),
}
pub struct Handshake {
state: HandshakeState,
alt_addresses: SocketAddrs,
}
impl HandshakeTrait for Handshake {
type Stream = WebSocket<TcpStream>;
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
match self.state {
HandshakeState::Nothing(url) => {
if let Some(addr) = self.alt_addresses.next() {
debug!("Trying to contact {} at {}...", url, addr);
let state = {
if let Ok(stream) = TcpStream::connect(addr) {
let hs = ClientHandshake::new(stream, Request { url: url });
HandshakeState::WebSocket(hs)
} else {
HandshakeState::Nothing(url)
}
};
Ok(HandshakeResult::Incomplete(Handshake {
state: state,
..self
}))
} else {
Err(Error::Url(format!("Unable to resolve {}", url).into()))
}
}
HandshakeState::WebSocket(ws) => {
let alt_addresses = self.alt_addresses;
ws.handshake().map(move |r| r.map(move |s| Handshake {
state: HandshakeState::WebSocket(s),
alt_addresses: alt_addresses,
}))
}
}
}
}

@ -0,0 +1,84 @@
//! Error handling.
use std::borrow::{Borrow, Cow};
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io;
use std::result;
use std::str;
use std::string;
use httparse;
pub type Result<T> = result::Result<T, Error>;
/// Possible WebSocket errors
#[derive(Debug)]
pub enum Error {
/// Input-output error
Io(io::Error),
/// Buffer capacity exhausted
Capacity(Cow<'static, str>),
/// Protocol violation
Protocol(Cow<'static, str>),
/// UTF coding error
Utf8(str::Utf8Error),
/// Invlid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Io(ref err) => write!(f, "IO error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::Utf8(ref err) => write!(f, "UTF-8 encoding error: {}", err),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code),
}
}
}
impl ErrorTrait for Error {
fn description(&self) -> &str {
match *self {
Error::Io(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(),
Error::Utf8(ref err) => err.description(),
Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "",
}
}
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::Io(err)
}
}
impl From<str::Utf8Error> for Error {
fn from(err: str::Utf8Error) -> Self {
Error::Utf8(err)
}
}
impl From<string::FromUtf8Error> for Error {
fn from(err: string::FromUtf8Error) -> Self {
Error::Utf8(err.utf8_error())
}
}
impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self {
match err {
httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()),
e => Error::Protocol(Cow::Owned(e.description().into())),
}
}
}

@ -0,0 +1,259 @@
use std::io::{Read, Write, Cursor};
use base64;
use rand;
use bytes::Buf;
use httparse;
use httparse::Status;
use url::Url;
use input_buffer::InputBuffer;
use error::{Error, Result};
use super::{
Headers,
Httparse, FromHttparse,
Handshake, HandshakeResult,
convert_key,
MAX_HEADERS,
};
use protocol::{
WebSocket, Role,
};
const MIN_READ: usize = 4096;
/// Client request.
pub struct Request {
pub url: Url,
// TODO extra headers
}
impl Request {
/// The GET part of the request.
fn get_path(&self) -> String {
if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query)
} else {
self.url.path().into()
}
}
/// The Host: part of the request.
fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() {
format!("{host}:{port}", host = host, port = port)
} else {
host.into()
}
}
}
/// Client handshake.
pub struct ClientHandshake<Stream> {
stream: Stream,
state: HandshakeState,
verify_data: VerifyData,
}
impl<Stream: Read + Write> ClientHandshake<Stream> {
/// Initiate a WebSocket handshake over the given stream.
pub fn new(stream: Stream, request: Request) -> Self {
let key = generate_key();
let mut req = Vec::new();
write!(req, "\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
\r\n", host = request.get_host(), path = request.get_path(), key = key)
.unwrap();
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
stream: stream,
state: HandshakeState::SendingRequest(Cursor::new(req)),
verify_data: VerifyData {
accept_key: accept_key,
},
}
}
}
impl<Stream: Read + Write> Handshake for ClientHandshake<Stream> {
type Stream = WebSocket<Stream>;
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
debug!("Performing client handshake...");
match self.state {
HandshakeState::SendingRequest(mut req) => {
let size = self.stream.write(Buf::bytes(&req))?;
Buf::advance(&mut req, size);
let state = if req.has_remaining() {
HandshakeState::SendingRequest(req)
} else {
HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ))
};
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: state,
..self
}))
}
HandshakeState::ReceivingResponse(mut resp_buf) => {
resp_buf.reserve(MIN_READ, usize::max_value())
.map_err(|_| Error::Capacity("Header too long".into()))?;
resp_buf.read_from(&mut self.stream)?;
if let Some(resp) = Response::parse(&mut resp_buf)? {
self.verify_data.verify_response(&resp)?;
let ws = WebSocket::from_partially_read(self.stream,
resp_buf.into_vec(), Role::Client);
debug!("Client handshake done.");
Ok(HandshakeResult::Done(ws))
} else {
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: HandshakeState::ReceivingResponse(resp_buf),
..self
}))
}
}
}
}
}
/// Information for handshake verification.
struct VerifyData {
/// Accepted server key.
accept_key: String,
}
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.code != 101 {
return Err(Error::Http(response.code));
}
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
// header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Upgrade", "websocket") {
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Connection", "Upgrade") {
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
// not present in the client's handshake (the server has indicated a
// subprotocol not requested by the client), the client MUST _Fail
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(())
}
}
/// Internal state of the client handshake.
enum HandshakeState {
SendingRequest(Cursor<Vec<u8>>),
ReceivingResponse(InputBuffer),
}
/// Server response.
pub struct Response {
code: u16,
headers: Headers,
}
impl Response {
/// Parse the response from a stream.
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Response::parse_http(input)
}
}
impl Httparse for Response {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Response::new(&mut hbuffer);
Ok(match req.parse(buf)? {
Status::Partial => None,
Status::Complete(size) => Some((size, Response::from_httparse(req)?)),
})
}
}
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
headers: Headers::from_httparse(raw.headers)?,
})
}
}
/// Generate a random key for the `Sec-WebSocket-Key` header.
fn generate_key() -> String {
// a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455)
let r: [u8; 16] = rand::random();
base64::encode(&r)
}
#[cfg(test)]
mod tests {
use super::{Response, generate_key};
use std::io::Cursor;
#[test]
fn random_keys() {
let k1 = generate_key();
println!("Generated random key 1: {}", k1);
let k2 = generate_key();
println!("Generated random key 2: {}", k2);
assert_ne!(k1, k2);
assert_eq!(k1.len(), k2.len());
assert_eq!(k1.len(), 24);
assert_eq!(k2.len(), 24);
assert!(k1.ends_with("=="));
assert!(k2.ends_with("=="));
assert!(k1[..22].find("=").is_none());
assert!(k2[..22].find("=").is_none());
}
#[test]
fn response_parsing() {
const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let mut inp = Cursor::new(data);
let req = Response::parse(&mut inp).unwrap().unwrap();
assert_eq!(req.code, 200);
assert_eq!(req.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
}
}

@ -0,0 +1,183 @@
pub mod client;
pub mod server;
#[cfg(feature="tls")]
pub mod tls;
use std::ascii::AsciiExt;
use std::str::from_utf8;
use base64;
use bytes::Buf;
use httparse;
use httparse::Status;
use sha1::Sha1;
use error::Result;
// Limit the number of header lines.
const MAX_HEADERS: usize = 124;
/// A handshake state.
pub trait Handshake: Sized {
/// Resulting stream of this handshake.
type Stream;
/// Perform a single handshake round.
fn handshake(self) -> Result<HandshakeResult<Self>>;
/// Perform handshake to the end in a blocking mode.
fn handshake_wait(self) -> Result<Self::Stream> {
let mut hs = self;
loop {
hs = match hs.handshake()? {
HandshakeResult::Done(stream) => return Ok(stream),
HandshakeResult::Incomplete(s) => s,
}
}
}
}
/// A handshake result.
pub enum HandshakeResult<H: Handshake> {
/// Handshake is done, a WebSocket stream is ready.
Done(H::Stream),
/// Handshake is not done, call handshake() again.
Incomplete(H),
}
impl<H: Handshake> HandshakeResult<H> {
pub fn map<R, F>(self, func: F) -> HandshakeResult<R>
where R: Handshake<Stream = H::Stream>,
F: FnOnce(H) -> R,
{
match self {
HandshakeResult::Done(s) => HandshakeResult::Done(s),
HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)),
}
}
}
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
fn convert_key(input: &[u8]) -> Result<String> {
// ... field is constructed by concatenating /key/ ...
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = Sha1::new();
sha1.update(input);
sha1.update(WS_GUID);
Ok(base64::encode(&sha1.digest().bytes()))
}
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.data.iter()
.find(|&&(ref n, _)| n.eq_ignore_ascii_case(name))
.map(|&(_, ref v)| v.as_ref())
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
/// Try to parse data and return headers, if any.
fn parse<B: Buf>(input: &mut B) -> Result<Option<Headers>> {
Headers::parse_http(input)
}
}
/// Trait to read HTTP parseable objects.
trait Httparse: Sized {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>>;
fn parse_http<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Ok(match Self::httparse(input.bytes())? {
Some((size, obj)) => {
input.advance(size);
Some(obj)
},
None => None,
})
}
}
/// Trait to convert raw objects into HTTP parseables.
trait FromHttparse<T>: Sized {
fn from_httparse(raw: T) -> Result<Self>;
}
impl Httparse for Headers {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
#[cfg(test)]
mod tests {
use super::{Headers, convert_key};
use std::io::Cursor;
#[test]
fn key_conversion() {
// example from RFC 6455
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn headers() {
const data: &'static [u8] =
b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
}
#[test]
fn headers_incomplete() {
const data: &'static [u8] =
b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap();
assert!(hdr.is_none());
}
}

@ -0,0 +1,90 @@
use bytes::Buf;
use httparse;
use httparse::Status;
use error::{Error, Result};
use super::{Headers, Httparse, FromHttparse, convert_key, MAX_HEADERS};
/// Request from the client.
pub struct Request {
path: String,
headers: Headers,
}
impl Request {
/// Parse the request from a stream.
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Request::parse_http(input)
}
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n", convert_key(key)?);
Ok(reply.into())
}
}
impl Httparse for Request {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut hbuffer);
Ok(match req.parse(buf)? {
Status::Partial => None,
Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
})
}
}
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into()));
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?
})
}
}
#[cfg(test)]
mod tests {
use super::Request;
use std::io::Cursor;
#[test]
fn request_parsing() {
const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let mut inp = Cursor::new(data);
let req = Request::parse(&mut inp).unwrap().unwrap();
assert_eq!(req.path, "/script.ws");
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..]));
}
#[test]
fn request_replying() {
const data: &'static [u8] = b"\
GET /script.ws HTTP/1.1\r\n\
Host: foo.com\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n";
let mut inp = Cursor::new(data);
let req = Request::parse(&mut inp).unwrap().unwrap();
let reply = req.reply().unwrap();
}
}

@ -0,0 +1,14 @@
use native_tls;
use stream::Stream;
use super::{Handshake, HandshakeResult};
pub struct TlsHandshake {
}
impl Handshale for TlsHandshake {
type Stream = Stream;
fn handshake(self) -> Result<HandshakeResult<Self>> {
}
}

@ -0,0 +1,85 @@
use std::io::{Cursor, Read, Result as IoResult};
use bytes::{Buf, BufMut};
/// A FIFO buffer for reading packets from network.
pub struct InputBuffer(Cursor<Vec<u8>>);
/// Size limit error.
pub struct SizeLimit;
impl InputBuffer {
/// Create a new empty one.
pub fn with_capacity(capacity: usize) -> Self {
InputBuffer(Cursor::new(Vec::with_capacity(capacity)))
}
/// Create a new one from partially read data.
pub fn from_partially_read(part: Vec<u8>) -> Self {
InputBuffer(Cursor::new(part))
}
/// Reserve the given amount of space.
pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{
if self.inp_mut().remaining_mut() >= space {
// We have enough space right now.
Ok(())
} else {
let pos = self.out().position() as usize;
self.inp_mut().drain(0..pos);
self.out_mut().set_position(0);
let avail = self.inp_mut().capacity() - self.inp_mut().len();
if space <= avail {
Ok(())
} else if self.inp_mut().capacity() + space > limit {
Err(SizeLimit)
} else {
self.inp_mut().reserve(space - avail);
Ok(())
}
}
}
/// Read data from stream into the buffer.
pub fn read_from<S: Read>(&mut self, stream: &mut S) -> IoResult<usize> {
let size;
let buf = self.inp_mut();
unsafe {
size = stream.read(buf.bytes_mut())?;
buf.advance_mut(size);
}
Ok(size)
}
/// Get the rest of the buffer and destroy the buffer.
pub fn into_vec(mut self) -> Vec<u8> {
let pos = self.out().position() as usize;
self.inp_mut().drain(0..pos);
self.0.into_inner()
}
/// The output end (to the application).
pub fn out(&self) -> &Cursor<Vec<u8>> {
&self.0 // the cursor itself
}
/// The output end (to the application).
pub fn out_mut(&mut self) -> &mut Cursor<Vec<u8>> {
&mut self.0 // the cursor itself
}
/// The input end (to the network).
fn inp_mut(&mut self) -> &mut Vec<u8> {
self.0.get_mut() // underlying vector
}
}
impl Buf for InputBuffer {
fn remaining(&self) -> usize {
Buf::remaining(self.out())
}
fn bytes(&self) -> &[u8] {
Buf::bytes(self.out())
}
fn advance(&mut self, size: usize) {
Buf::advance(self.out_mut(), size)
}
}

@ -0,0 +1,28 @@
//! Lightweight, flexible WebSockets for Rust.
#![deny(
missing_copy_implementations,
trivial_casts, trivial_numeric_casts,
unstable_features,
unused_must_use,
unused_mut,
unused_imports,
unused_import_braces)]
#[macro_use] extern crate log;
extern crate base64;
extern crate byteorder;
extern crate bytes;
extern crate httparse;
extern crate rand;
extern crate sha1;
extern crate url;
extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls;
pub mod error;
pub mod protocol;
pub mod client;
pub mod handshake;
mod input_buffer;
mod stream;

@ -0,0 +1,274 @@
use std::fmt;
use std::convert::{Into, From};
/// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum OpCode {
Data(Data),
Control(Control),
}
/// Data opcodes as in RFC 6455
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Data {
/// 0x0 denotes a continuation frame
Continue,
/// 0x1 denotes a text frame
Text,
/// 0x2 denotes a binary frame
Binary,
/// 0x3-7 are reserved for further non-control frames
Reserved(u8),
}
/// Control opcodes as in RFC 6455
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Control {
/// 0x8 denotes a connection close
Close,
/// 0x9 denotes a ping
Ping,
/// 0xa denotes a pong
Pong,
/// 0xb-f are reserved for further control frames
Reserved(u8),
}
impl fmt::Display for Data {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Data::Continue => write!(f, "CONTINUE"),
Data::Text => write!(f, "TEXT"),
Data::Binary => write!(f, "BINARY"),
Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x),
}
}
}
impl fmt::Display for Control {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Control::Close => write!(f, "CLOSE"),
Control::Ping => write!(f, "PING"),
Control::Pong => write!(f, "PONG"),
Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x),
}
}
}
impl fmt::Display for OpCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
OpCode::Data(d) => d.fmt(f),
OpCode::Control(c) => c.fmt(f),
}
}
}
impl Into<u8> for OpCode {
fn into(self) -> u8 {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong};
use self::OpCode::*;
match self {
Data(Continue) => 0,
Data(Text) => 1,
Data(Binary) => 2,
Data(self::Data::Reserved(i)) => i,
Control(Close) => 8,
Control(Ping) => 9,
Control(Pong) => 10,
Control(self::Control::Reserved(i)) => i,
}
}
}
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong};
use self::OpCode::*;
match byte {
0 => Data(Continue),
1 => Data(Text),
2 => Data(Binary),
i @ 3 ... 7 => Data(self::Data::Reserved(i)),
8 => Control(Close),
9 => Control(Ping),
10 => Control(Pong),
i @ 11 ... 15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"),
}
}
}
use self::CloseCode::*;
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum CloseCode {
/// Indicates a normal closure, meaning that the purpose for
/// which the connection was established has been fulfilled.
Normal,
/// Indicates that an endpoint is "going away", such as a server
/// going down or a browser having navigated away from a page.
Away,
/// Indicates that an endpoint is terminating the connection due
/// to a protocol error.
Protocol,
/// Indicates that an endpoint is terminating the connection
/// because it has received a type of data it cannot accept (e.g., an
/// endpoint that understands only text data MAY send this if it
/// receives a binary message).
Unsupported,
/// Indicates that no status code was included in a closing frame. This
/// close code makes it possible to use a single method, `on_close` to
/// handle even cases where no close code was provided.
Status,
/// Indicates an abnormal closure. If the abnormal closure was due to an
/// error, this close code will not be used. Instead, the `on_error` method
/// of the handler will be called with the error. However, if the connection
/// is simply dropped, without an error, this close code will be sent to the
/// handler.
Abnormal,
/// Indicates that an endpoint is terminating the connection
/// because it has received data within a message that was not
/// consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
/// data within a text message).
Invalid,
/// Indicates that an endpoint is terminating the connection
/// because it has received a message that violates its policy. This
/// is a generic status code that can be returned when there is no
/// other more suitable status code (e.g., Unsupported or Size) or if there
/// is a need to hide specific details about the policy.
Policy,
/// Indicates that an endpoint is terminating the connection
/// because it has received a message that is too big for it to
/// process.
Size,
/// Indicates that an endpoint (client) is terminating the
/// connection because it has expected the server to negotiate one or
/// more extension, but the server didn't return them in the response
/// message of the WebSocket handshake. The list of extensions that
/// are needed should be given as the reason for closing.
/// Note that this status code is not used by the server, because it
/// can fail the WebSocket handshake instead.
Extension,
/// Indicates that a server is terminating the connection because
/// it encountered an unexpected condition that prevented it from
/// fulfilling the request.
Error,
/// Indicates that the server is restarting. A client may choose to reconnect,
/// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
Restart,
/// Indicates that the server is overloaded and the client should either connect
/// to a different IP (when multiple targets exist), or reconnect to the same IP
/// when a user has performed an action.
Again,
#[doc(hidden)]
Tls,
#[doc(hidden)]
Reserved(u16),
#[doc(hidden)]
Iana(u16),
#[doc(hidden)]
Library(u16),
#[doc(hidden)]
Bad(u16),
}
impl CloseCode {
/// Check if this CloseCode is allowed.
pub fn is_allowed(&self) -> bool {
match *self {
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
}
}
impl Into<u16> for CloseCode {
fn into(self) -> u16 {
match self {
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
}
}
}
impl From<u16> for CloseCode {
fn from(code: u16) -> CloseCode {
match code {
1000 => Normal,
1001 => Away,
1002 => Protocol,
1003 => Unsupported,
1005 => Status,
1006 => Abnormal,
1007 => Invalid,
1008 => Policy,
1009 => Size,
1010 => Extension,
1011 => Error,
1012 => Restart,
1013 => Again,
1015 => Tls,
1...999 => Bad(code),
1000...2999 => Reserved(code),
3000...3999 => Iana(code),
4000...4999 => Library(code),
_ => Bad(code)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn opcode_from_u8() {
let byte = 2u8;
assert_eq!(OpCode::from(byte), OpCode::Data(Data::Binary));
}
#[test]
fn opcode_into_u8() {
let text = OpCode::Data(Data::Text);
let byte: u8 = text.into();
assert_eq!(byte, 1u8);
}
#[test]
fn closecode_from_u16() {
let byte = 1008u16;
assert_eq!(CloseCode::from(byte), CloseCode::Policy);
}
#[test]
fn closecode_into_u16() {
let text = CloseCode::Away;
let byte: u16 = text.into();
assert_eq!(byte, 1001u16);
}
}

@ -0,0 +1,516 @@
use std::fmt;
use std::mem::transmute;
use std::io::{Cursor, Read, Write};
use std::default::Default;
use std::iter::FromIterator;
use std::string::{String, FromUtf8Error};
use std::result::Result as StdResult;
use byteorder::{ByteOrder, NetworkEndian};
use bytes::BufMut;
use rand;
use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode};
fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
let iter = buf.iter_mut().zip(mask.iter().cycle());
for (byte, &key) in iter {
*byte ^= key
}
}
#[inline]
fn generate_mask() -> [u8; 4] {
unsafe { transmute(rand::random::<u32>()) }
}
/// A struct representing a WebSocket frame.
#[derive(Debug, Clone)]
pub struct Frame {
finished: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: OpCode,
mask: Option<[u8; 4]>,
payload: Vec<u8>,
}
impl Frame {
/// Get the length of the frame.
/// This is the length of the header + the length of the payload.
#[inline]
pub fn len(&self) -> usize {
let mut header_length = 2;
let payload_len = self.payload().len();
if payload_len > 125 {
if payload_len <= u16::max_value() as usize {
header_length += 2;
} else {
header_length += 8;
}
}
if self.is_masked() {
header_length += 4;
}
header_length + payload_len
}
/// Test whether the frame is a final frame.
#[inline]
pub fn is_final(&self) -> bool {
self.finished
}
/// Test whether the first reserved bit is set.
#[inline]
pub fn has_rsv1(&self) -> bool {
self.rsv1
}
/// Test whether the second reserved bit is set.
#[inline]
pub fn has_rsv2(&self) -> bool {
self.rsv2
}
/// Test whether the third reserved bit is set.
#[inline]
pub fn has_rsv3(&self) -> bool {
self.rsv3
}
/// Get the OpCode of the frame.
#[inline]
pub fn opcode(&self) -> OpCode {
self.opcode
}
/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &Vec<u8> {
&self.payload
}
// Test whether the frame is masked.
#[doc(hidden)]
#[inline]
pub fn is_masked(&self) -> bool {
self.mask.is_some()
}
// Get an optional reference to the frame's mask.
#[doc(hidden)]
#[allow(dead_code)]
#[inline]
pub fn mask(&self) -> Option<&[u8; 4]> {
self.mask.as_ref()
}
/// Make this frame a final frame.
#[allow(dead_code)]
#[inline]
pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
self.finished = is_final;
self
}
/// Set the first reserved bit.
#[inline]
pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
self.rsv1 = has_rsv1;
self
}
/// Set the second reserved bit.
#[inline]
pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
self.rsv2 = has_rsv2;
self
}
/// Set the third reserved bit.
#[inline]
pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
self.rsv3 = has_rsv3;
self
}
/// Set the OpCode.
#[allow(dead_code)]
#[inline]
pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
self.opcode = opcode;
self
}
/// Edit the frame's payload.
#[allow(dead_code)]
#[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload
}
// Generate a new mask for this frame.
//
// This method simply generates and stores the mask. It does not change the payload data.
// Instead, the payload data will be masked with the generated mask when the frame is sent
// to the other endpoint.
#[doc(hidden)]
#[inline]
pub fn set_mask(&mut self) -> &mut Frame {
self.mask = Some(generate_mask());
self
}
// This method unmasks the payload and should only be called on frames that are actually
// masked. In other words, those frames that have just been received from a client endpoint.
#[doc(hidden)]
#[inline]
pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask))
});
self.mask = None;
}
/// Consume the frame into its payload as binary.
#[inline]
pub fn into_data(self) -> Vec<u8> {
self.payload
}
/// Consume the frame into its payload as string.
#[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload)
}
/// Consume the frame into a closing frame.
#[inline]
pub fn into_close(self) -> Result<Option<(CloseCode, String)>> {
match self.payload.len() {
0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some((code, text)))
}
}
}
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
debug_assert!(match code {
OpCode::Data(_) => true,
_ => false,
}, "Invalid opcode for data frame.");
Frame {
finished: finished,
opcode: code,
payload: data,
.. Frame::default()
}
}
/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {
Frame {
opcode: OpCode::Control(Control::Pong),
payload: data,
.. Frame::default()
}
}
/// Create a new Ping control frame.
#[inline]
pub fn ping(data: Vec<u8>) -> Frame {
Frame {
opcode: OpCode::Control(Control::Ping),
payload: data,
.. Frame::default()
}
}
/// Create a new Close control frame.
#[inline]
pub fn close(msg: Option<(CloseCode, &str)>) -> Frame {
let payload = if let Some((code, reason)) = msg {
let raw: [u8; 2] = unsafe {
let u: u16 = code.into();
transmute(u.to_be())
};
Vec::from_iter(
raw[..].iter()
.chain(reason.as_bytes().iter())
.map(|&b| b))
} else {
Vec::new()
};
Frame {
payload: payload,
.. Frame::default()
}
}
/// Parse the input stream into a frame.
pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> {
let size = cursor.get_ref().len() as u64 - cursor.position();
let initial = cursor.position();
trace!("Position in buffer {}", initial);
let mut head = [0u8; 2];
if try!(cursor.read(&mut head)) != 2 {
cursor.set_position(initial);
return Ok(None)
}
trace!("Parsed headers {:?}", head);
let first = head[0];
let second = head[1];
trace!("First: {:b}", first);
trace!("Second: {:b}", second);
let finished = first & 0x80 != 0;
let rsv1 = first & 0x40 != 0;
let rsv2 = first & 0x20 != 0;
let rsv3 = first & 0x10 != 0;
let opcode = OpCode::from(first & 0x0F);
trace!("Opcode: {:?}", opcode);
let masked = second & 0x80 != 0;
trace!("Masked: {:?}", masked);
let mut header_length = 2;
let mut length = (second & 0x7F) as u64;
if length == 126 {
let mut length_bytes = [0u8; 2];
if try!(cursor.read(&mut length_bytes)) != 2 {
cursor.set_position(initial);
return Ok(None)
}
length = unsafe {
let mut wide: u16 = transmute(length_bytes);
wide = u16::from_be(wide);
wide
} as u64;
header_length += 2;
} else if length == 127 {
let mut length_bytes = [0u8; 8];
if try!(cursor.read(&mut length_bytes)) != 8 {
cursor.set_position(initial);
return Ok(None)
}
unsafe { length = transmute(length_bytes); }
length = u64::from_be(length);
header_length += 8;
}
trace!("Payload length: {}", length);
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if try!(cursor.read(&mut mask_bytes)) != 4 {
cursor.set_position(initial);
return Ok(None)
} else {
header_length += 4;
Some(mask_bytes)
}
} else {
None
};
if size < length + header_length {
cursor.set_position(initial);
return Ok(None)
}
let mut data = Vec::with_capacity(length as usize);
if length > 0 {
unsafe {
try!(cursor.read_exact(data.bytes_mut()));
data.advance_mut(length as usize);
}
}
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
}
_ => ()
}
let frame = Frame {
finished: finished,
rsv1: rsv1,
rsv2: rsv2,
rsv3: rsv3,
opcode: opcode,
mask: mask,
payload: data,
};
Ok(Some(frame))
}
/// Write a frame out to a buffer
pub fn format<W>(mut self, w: &mut W) -> Result<()>
where W: Write
{
let mut one = 0u8;
let code: u8 = self.opcode.into();
if self.is_final() {
one |= 0x80;
}
if self.has_rsv1() {
one |= 0x40;
}
if self.has_rsv2() {
one |= 0x20;
}
if self.has_rsv3() {
one |= 0x10;
}
one |= code;
let mut two = 0u8;
if self.is_masked() {
two |= 0x80;
}
if self.payload.len() < 126 {
two |= self.payload.len() as u8;
let headers = [one, two];
try!(w.write(&headers));
} else if self.payload.len() <= 65535 {
two |= 126;
let length_bytes: [u8; 2] = unsafe {
let short = self.payload.len() as u16;
transmute(short.to_be())
};
let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers));
} else {
two |= 127;
let length_bytes: [u8; 8] = unsafe {
let long = self.payload.len() as u64;
transmute(long.to_be())
};
let headers = [
one,
two,
length_bytes[0],
length_bytes[1],
length_bytes[2],
length_bytes[3],
length_bytes[4],
length_bytes[5],
length_bytes[6],
length_bytes[7],
];
try!(w.write(&headers));
}
if self.is_masked() {
let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask);
try!(w.write(&mask));
}
try!(w.write(&self.payload));
Ok(())
}
}
impl Default for Frame {
fn default() -> Frame {
Frame {
finished: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: OpCode::Control(Control::Close),
mask: None,
payload: Vec::new(),
}
}
}
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f,
"
<FRAME>
final: {}
reserved: {} {} {}
opcode: {}
length: {}
payload length: {}
payload: 0x{}
",
self.finished,
self.rsv1,
self.rsv2,
self.rsv3,
self.opcode,
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>())
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::coding::{OpCode, Data};
use std::io::Cursor;
#[test]
fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
]);
let frame = Frame::parse(&mut raw).unwrap().unwrap();
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
}
#[test]
fn format() {
let frame = Frame::ping(vec![0x01, 0x02]);
let mut buf = Vec::with_capacity(frame.len());
frame.format(&mut buf).unwrap();
assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
}
#[test]
fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
let view = format!("{}", f);
view.contains("payload:");
}
}

@ -0,0 +1,135 @@
pub mod coding;
mod frame;
pub use self::frame::Frame;
use std::io::{Read, Write};
use input_buffer;
use error::{Error, Result};
const MIN_READ: usize = 4096;
/// A reader and writer for WebSocket frames.
pub struct FrameSocket<Stream> {
stream: Stream,
in_buffer: input_buffer::InputBuffer,
out_buffer: Vec<u8>,
}
impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket.
pub fn new(stream: Stream) -> Self {
FrameSocket {
stream: stream,
in_buffer: input_buffer::InputBuffer::with_capacity(MIN_READ),
out_buffer: Vec::new(),
}
}
/// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket {
stream: stream,
in_buffer: input_buffer::InputBuffer::from_partially_read(part),
out_buffer: Vec::new(),
}
}
/// Extract a stream from the socket.
pub fn into_inner(self) -> (Stream, Vec<u8>) {
(self.stream, self.in_buffer.into_vec())
}
}
impl<Stream> FrameSocket<Stream>
where Stream: Read
{
/// Read a frame from stream.
pub fn read_frame(&mut self) -> Result<Option<Frame>> {
loop {
if let Some(frame) = Frame::parse(&mut self.in_buffer.out_mut())? {
debug!("received frame {}", frame);
return Ok(Some(frame));
}
// No full frames in buffer.
self.in_buffer.reserve(MIN_READ, usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?;
let size = self.in_buffer.read_from(&mut self.stream)?;
if size == 0 {
debug!("no frame received");
return Ok(None)
}
}
}
}
impl<Stream> FrameSocket<Stream>
where Stream: Write
{
/// Write a frame to stream.
pub fn write_frame(&mut self, frame: Frame) -> Result<()> {
debug!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer)?;
let len = self.stream.write(&self.out_buffer)?;
self.out_buffer.drain(0..len);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{Frame, FrameSocket};
use std::io::Cursor;
#[test]
fn read_frames() {
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x82, 0x03, 0x03, 0x02, 0x01,
0x99,
]);
let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame().unwrap().is_none());
let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]);
}
#[test]
fn from_partially_read() {
let raw = Cursor::new(vec![
0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
}
#[test]
fn write_frames() {
let mut sock = FrameSocket::new(Vec::new());
let frame = Frame::ping(vec![0x04, 0x05]);
sock.write_frame(frame).unwrap();
let frame = Frame::pong(vec![0x01]);
sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![
0x89, 0x02, 0x04, 0x05,
0x8a, 0x01, 0x01
]);
}
}

@ -0,0 +1,260 @@
use std::convert::{From, Into, AsRef};
use std::fmt;
use std::result::Result as StdResult;
use std::str;
use error::Result;
mod string_collect {
use utf8;
use error::{Error, Result};
pub struct StringCollector {
data: String,
decoder: utf8::Decoder,
}
impl StringCollector {
pub fn new() -> Self {
StringCollector {
data: String::new(),
decoder: utf8::Decoder::new(),
}
}
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
let (sym, text, result) = self.decoder.decode(tail.as_ref());
self.data.push_str(&sym);
self.data.push_str(text);
match result {
utf8::Result::Ok | utf8::Result::Incomplete =>
Ok(()),
utf8::Result::Error { remaining_input_after_error: _ } =>
Err(Error::Protocol("Invalid UTF8".into())), // FIXME
}
}
pub fn into_string(self) -> Result<String> {
if self.decoder.has_incomplete_sequence() {
Err(Error::Protocol("Invalid UTF8".into())) // FIXME
} else {
Ok(self.data)
}
}
}
}
use self::string_collect::StringCollector;
/// A struct representing the incomplete message.
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
}
enum IncompleteMessageCollector {
Text(StringCollector),
Binary(Vec<u8>),
}
impl IncompleteMessage {
/// Create new.
pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary =>
IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text =>
IncompleteMessageCollector::Text(StringCollector::new()),
}
}
}
/// Add more data to an existing message.
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
match self.collector {
IncompleteMessageCollector::Binary(ref mut v) => {
v.extend(tail.as_ref());
Ok(())
}
IncompleteMessageCollector::Text(ref mut t) => {
t.extend(tail)
}
}
}
/// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> {
match self.collector {
IncompleteMessageCollector::Binary(v) => {
Ok(Message::Binary(v))
}
IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?;
Ok(Message::Text(text))
}
}
}
}
/// The type of incomplete message.
pub enum IncompleteMessageType {
Text,
Binary,
}
/// An enum representing the various forms of a WebSocket message.
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
/// A binary WebSocket message
Binary(Vec<u8>),
}
impl Message {
/// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message
where S: Into<String>
{
Message::Text(string.into())
}
/// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message
where B: Into<Vec<u8>>
{
Message::Binary(bin.into())
}
/// Indicates whether a message is a text message.
pub fn is_text(&self) -> bool {
match *self {
Message::Text(_) => true,
Message::Binary(_) => false,
}
}
/// Indicates whether a message is a binary message.
pub fn is_binary(&self) -> bool {
match *self {
Message::Text(_) => false,
Message::Binary(_) => true,
}
}
/// Get the length of the WebSocket message.
pub fn len(&self) -> usize {
match *self {
Message::Text(ref string) => string.len(),
Message::Binary(ref data) => data.len(),
}
}
/// Returns true if the WebSocket message has no content.
/// For example, if the other side of the connection sent an empty string.
pub fn is_empty(&self) -> bool {
match *self {
Message::Text(ref string) => string.is_empty(),
Message::Binary(ref data) => data.is_empty(),
}
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
match self {
Message::Text(string) => string.into_bytes(),
Message::Binary(data) => data,
}
}
/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String> {
match self {
Message::Text(string) => Ok(string),
Message::Binary(data) => Ok(try!(
String::from_utf8(data).map_err(|err| err.utf8_error()))),
}
}
/// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str> {
match *self {
Message::Text(ref string) => Ok(string),
Message::Binary(ref data) => Ok(try!(str::from_utf8(data))),
}
}
}
impl From<String> for Message {
fn from(string: String) -> Message {
Message::text(string)
}
}
impl<'s> From<&'s str> for Message {
fn from(string: &'s str) -> Message {
Message::text(string)
}
}
impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Message {
Message::binary(data)
}
}
impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Message {
Message::binary(data)
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
if let Ok(string) = self.to_text() {
write!(f, "{}", string)
} else {
write!(f, "Binary Data<length={}>", self.len())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let t = Message::text(format!("test"));
assert_eq!(t.to_string(), "test".to_owned());
let bin = Message::binary(vec![0, 1, 3, 4, 241]);
assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
}
#[test]
fn binary_convert() {
let bin = [6u8, 7, 8, 9, 10, 241];
let msg = Message::from(&bin[..]);
assert!(msg.is_binary());
assert!(msg.into_text().is_err());
}
#[test]
fn binary_convert_vec() {
let bin = vec![6u8, 7, 8, 9, 10, 241];
let msg = Message::from(bin);
assert!(msg.is_binary());
assert!(msg.into_text().is_err());
}
#[test]
fn text_convert() {
let s = "kiwotsukete";
let msg = Message::from(s);
assert!(msg.is_text());
}
}

@ -0,0 +1,374 @@
//! Generic WebSocket protocol implementation
mod frame;
mod message;
pub use self::message::Message;
use self::message::{IncompleteMessage, IncompleteMessageType};
use std::collections::VecDeque;
use std::io::{Read, Write};
use std::mem::replace;
use error::{Error, Result};
use self::frame::{Frame, FrameSocket};
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode};
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy)]
pub enum Role {
/// This socket is a server
Server,
/// This socket is a client
Client,
}
/// WebSocket input-output stream
pub struct WebSocket<Stream> {
/// Server or client?
role: Role,
/// The underlying socket.
socket: FrameSocket<Stream>,
/// The state of processing, either "active" or "closing".
state: WebSocketState,
/// Receive: an incomplete message being processed.
incomplete: Option<IncompleteMessage>,
/// Send: a data send queue.
send_queue: VecDeque<Frame>,
/// Send: an OOB pong message.
pong: Option<Frame>,
}
impl<Stream> WebSocket<Stream>
where Stream: Read + Write
{
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket(stream: Stream, role: Role) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role)
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_partially_read(stream: Stream, part: Vec<u8>, role: Role) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role)
}
/// Read a message from stream, if possible.
pub fn read_message(&mut self) -> Result<Message> {
loop {
self.send_pending()?; // FIXME
if let Some(message) = self.read_message_frame()? {
debug!("Received message {}", message);
return Ok(message)
}
}
}
/// Send a message to stream, if possible.
pub fn write_message(&mut self, message: Message) -> Result<()> {
let frame = {
let opcode = match message {
Message::Text(_) => OpData::Text,
Message::Binary(_) => OpData::Binary,
};
Frame::message(message.into_data(), OpCode::Data(opcode), true)
};
self.send_queue.push_back(frame);
self.send_pending()
}
/// Close the connection.
pub fn close(&mut self) -> Result<()> {
match self.state {
WebSocketState::Active => {
self.state = WebSocketState::ClosedByUs;
// TODO
}
_ => {
// already closed, nothing to do
}
}
Ok(())
}
/// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self {
WebSocket {
role: role,
socket: socket,
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
}
}
/// Try to decode one message frame. May return None.
fn read_message_frame(&mut self) -> Result<Option<Message>> {
if let Some(mut frame) = self.socket.read_frame()? {
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() {
return Err(Error::Protocol("Reserved bits are non-zero".into()))
}
match self.role {
Role::Server => {
if frame.is_masked() {
// A server MUST remove masking for data frames received from a client
// as described in Section 5.3. (RFC 6455)
frame.remove_mask()
} else {
// The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455)
return Err(Error::Protocol("Received an unmasked frame from client".into()))
}
}
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into()))
}
}
}
match frame.opcode() {
OpCode::Control(ctl) => {
(match ctl {
// All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455)
_ if !frame.is_final() => {
Err(Error::Protocol("Fragmented control frame".into()))
}
_ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => {
self.do_close(frame.into_close()?)
}
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
// No ping processing while closing.
Ok(())
}
OpCtl::Ping => {
self.do_ping(frame.into_data())
}
OpCtl::Pong => {
self.do_pong(frame.into_data())
}
}).map(|_| None)
}
OpCode::Data(_) if !self.state.is_active() => {
// No data processing while closing.
Ok(None)
}
OpCode::Data(data) => {
let fin = frame.is_final();
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
// TODO if msg too big
msg.extend(frame.into_data())?;
} else {
return Err(Error::Protocol("Continue frame but nothing to continue".into()))
}
if fin {
Ok(Some(replace(&mut self.incomplete, None).unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into()
))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data())?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
}
}
}
} // match opcode
} else {
//Ok(None) // TODO handle EOF?
Err(Error::Protocol("Connection reset without closing handshake".into()))
}
}
/// Received a close frame.
fn do_close(&mut self, close: Option<(CloseCode, String)>) -> Result<()> {
match self.state {
WebSocketState::Active => {
self.state = WebSocketState::ClosedByPeer;
let reply = if let Some((code, _)) = close {
if code.is_allowed() {
Frame::close(Some((CloseCode::Normal, "")))
} else {
Frame::close(Some((CloseCode::Protocol, "Protocol violation")))
}
} else {
Frame::close(None)
};
self.send_queue.push_back(reply);
}
WebSocketState::ClosedByPeer => {
// It is already closed, just ignore.
}
WebSocketState::ClosedByUs => {
// We received a reply.
match self.role {
Role::Client => {
// Client waits for the server to close the connection.
}
Role::Server => {
// Server closes the connection.
// TODO
}
}
}
}
//unimplemented!()
Ok(())
}
/// Received a ping frame.
fn do_ping(&mut self, ping: Vec<u8>) -> Result<()> {
// If an endpoint receives a Ping frame and has not yet sent Pong
// frame(s) in response to previous Ping frame(s), the endpoint MAY
// elect to send a Pong frame for only the most recently processed Ping
// frame. (RFC 6455)
// We do exactly that, keeping a "queue" from one and only Pong frame.
self.pong = Some(Frame::pong(ping));
Ok(())
}
/// Received a pong frame.
fn do_pong(&mut self, _: Vec<u8>) -> Result<()> {
// A Pong frame MAY be sent unsolicited. This serves as a
// unidirectional heartbeat. A response to an unsolicited Pong frame is
// not expected. (RFC 6455)
// Due to this, we just don't check pongs right now.
// TODO: check if there was a reply to our ping at all...
Ok(())
}
/// Flush the pending send queue.
fn send_pending(&mut self) -> Result<()> {
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
// response, unless it already received a Close frame. It SHOULD
// respond with Pong frame as soon as is practical. (RFC 6455)
if let Some(pong) = replace(&mut self.pong, None) {
self.send_one_frame(pong)?;
}
// If we have any unsent frames, send them.
while let Some(data) = self.send_queue.pop_front() {
self.send_one_frame(data)?;
}
Ok(())
}
/// Send a single pending frame.
fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> {
match self.role {
Role::Server => {
}
Role::Client => {
// 5. If the data is being sent by the client, the frame(s) MUST be
// masked as defined in Section 5.3. (RFC 6455)
frame.set_mask();
}
}
self.socket.write_frame(frame)?;
Ok(())
}
}
/// The current connection state.
enum WebSocketState {
Active,
ClosedByUs,
ClosedByPeer,
}
impl WebSocketState {
/// Tell if we're allowed to process normal messages.
fn is_active(&self) -> bool {
match *self {
WebSocketState::Active => true,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::{WebSocket, Role, Message};
use std::io;
use std::io::Cursor;
struct WriteMoc<Stream>(Stream);
impl<Stream> io::Write for WriteMoc<Stream> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<Stream: io::Read> io::Read for WriteMoc<Stream> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
#[test]
fn receive_messages() {
let incoming = Cursor::new(vec![
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client);
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
}
}

@ -0,0 +1,39 @@
#[cfg(feature="tls")]
use native_tls::TlsStream;
use std::net::TcpStream;
use std::io::{Read, Write, Result as IoResult};
/// Stream, either plain TCP or TLS.
pub enum Stream {
Plain(TcpStream),
#[cfg(feature="tls")]
Tls(TlsStream<TcpStream>),
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.read(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.read(buf),
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.write(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.flush(),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.flush(),
}
}
}
Loading…
Cancel
Save