commit e63f594a14049fdbf6db74942f2bab793ab269e8 Author: Alexey Galakhov Date: Thu Dec 22 00:42:09 2016 +0100 Initial commit, mostly working client diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9d37c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e231a0c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "ws2" +version = "0.1.0" +authors = ["Alexey Galakhov"] + +[features] +default = [] +tls = ["native-tls"] + +[dependencies] +base64 = "*" +byteorder = "*" +bytes = { git = "https://github.com/carllerche/bytes.git" } +httparse = "*" +env_logger = "*" +log = "*" +rand = "*" +sha1 = "*" +url = "*" +utf-8 = "*" + +[dependencies.native-tls] +optional = true +version = "*" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..00b109a --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2016 Alexey Galakhov +Copyright (c) 2016 Jason Housley + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs new file mode 100644 index 0000000..1725403 --- /dev/null +++ b/examples/autobahn-client.rs @@ -0,0 +1,62 @@ +#[macro_use] extern crate log; +extern crate env_logger; +extern crate ws2; +extern crate url; + +use url::Url; + +use ws2::protocol::Message; +use ws2::client::connect; +use ws2::handshake::Handshake; +use ws2::error::{Error, Result}; + +const AGENT: &'static str = "WS2-RS"; + +fn get_case_count() -> Result { + let mut socket = connect( + Url::parse("ws://localhost:9001/getCaseCount").unwrap() + )?.handshake_wait()?; + let msg = socket.read_message()?; + socket.close(); + Ok(msg.into_text()?.parse::().unwrap()) +} + +fn update_reports() -> Result<()> { + let mut socket = connect( + Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() + )?.handshake_wait()?; + socket.close(); + Ok(()) +} + +fn run_test(case: u32) -> Result<()> { + info!("Running test case {}", case); + let case_url = Url::parse( + &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) + ).unwrap(); + let mut socket = connect(case_url)?.handshake_wait()?; + loop { + let msg = socket.read_message()?; + socket.write_message(msg)?; + } + socket.close(); + Ok(()) +} + +fn main() { + env_logger::init().unwrap(); + + let total = get_case_count().unwrap(); + + for case in 1..(total + 1) { + if let Err(e) = run_test(case) { + match e { + Error::Protocol(_) => { } + err => { warn!("test: {}", err); } + } + } + } + + update_reports().unwrap(); +} + diff --git a/examples/client.rs b/examples/client.rs new file mode 100644 index 0000000..2aa1336 --- /dev/null +++ b/examples/client.rs @@ -0,0 +1,25 @@ +extern crate ws2; +extern crate url; +extern crate env_logger; + +use url::Url; +use ws2::protocol::Message; +use ws2::client::connect; +use ws2::protocol::handshake::Handshake; + +fn main() { + env_logger::init().unwrap(); + + let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + .expect("Can't connect") + .handshake_wait() + .expect("Handshake error"); + + socket.write_message(Message::Text("Hello WebSocket".into())); + loop { + let msg = socket.read_message().expect("Error reading message"); + println!("Received: {}", msg); + } + // socket.close(); + +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..0b18450 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,75 @@ +use std::net::{TcpStream, ToSocketAddrs}; +use url::{Url, SocketAddrs}; + +use protocol::WebSocket; +use handshake::{Handshake as HandshakeTrait, HandshakeResult}; +use handshake::client::{ClientHandshake, Request}; +use error::{Error, Result}; + +/// Connect to the given WebSocket. +/// +/// Note that this function may block the current thread while DNS resolution is performed. +pub fn connect(url: Url) -> Result { + let mode = match url.scheme() { + "ws" => Mode::Plain, + #[cfg(feature="tls")] + "wss" => Mode::Tls, + _ => return Err(Error::Url("URL scheme not supported".into())) + }; + + // Note that this function may block the current thread while resolution is performed. + let addrs = url.to_socket_addrs()?; + Ok(Handshake { + state: HandshakeState::Nothing(url), + alt_addresses: addrs, + }) +} + +enum Mode { + Plain, + Tls, +} + +enum HandshakeState { + Nothing(Url), + WebSocket(ClientHandshake), +} + +pub struct Handshake { + state: HandshakeState, + alt_addresses: SocketAddrs, +} + +impl HandshakeTrait for Handshake { + type Stream = WebSocket; + fn handshake(mut self) -> Result> { + match self.state { + HandshakeState::Nothing(url) => { + if let Some(addr) = self.alt_addresses.next() { + debug!("Trying to contact {} at {}...", url, addr); + let state = { + if let Ok(stream) = TcpStream::connect(addr) { + let hs = ClientHandshake::new(stream, Request { url: url }); + HandshakeState::WebSocket(hs) + } else { + HandshakeState::Nothing(url) + } + }; + Ok(HandshakeResult::Incomplete(Handshake { + state: state, + ..self + })) + } else { + Err(Error::Url(format!("Unable to resolve {}", url).into())) + } + } + HandshakeState::WebSocket(ws) => { + let alt_addresses = self.alt_addresses; + ws.handshake().map(move |r| r.map(move |s| Handshake { + state: HandshakeState::WebSocket(s), + alt_addresses: alt_addresses, + })) + } + } + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..5a51d36 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,84 @@ +//! Error handling. + +use std::borrow::{Borrow, Cow}; + +use std::error::Error as ErrorTrait; +use std::fmt; +use std::io; +use std::result; +use std::str; +use std::string; + +use httparse; + +pub type Result = result::Result; + +/// Possible WebSocket errors +#[derive(Debug)] +pub enum Error { + /// Input-output error + Io(io::Error), + /// Buffer capacity exhausted + Capacity(Cow<'static, str>), + /// Protocol violation + Protocol(Cow<'static, str>), + /// UTF coding error + Utf8(str::Utf8Error), + /// Invlid URL. + Url(Cow<'static, str>), + /// HTTP error. + Http(u16), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Io(ref err) => write!(f, "IO error: {}", err), + Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), + Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), + Error::Utf8(ref err) => write!(f, "UTF-8 encoding error: {}", err), + Error::Url(ref msg) => write!(f, "URL error: {}", msg), + Error::Http(code) => write!(f, "HTTP code: {}", code), + } + } +} + +impl ErrorTrait for Error { + fn description(&self) -> &str { + match *self { + Error::Io(ref err) => err.description(), + Error::Capacity(ref msg) => msg.borrow(), + Error::Protocol(ref msg) => msg.borrow(), + Error::Utf8(ref err) => err.description(), + Error::Url(ref msg) => msg.borrow(), + Error::Http(_) => "", + } + } +} + +impl From for Error { + fn from(err: io::Error) -> Self { + Error::Io(err) + } +} + +impl From for Error { + fn from(err: str::Utf8Error) -> Self { + Error::Utf8(err) + } +} + +impl From for Error { + fn from(err: string::FromUtf8Error) -> Self { + Error::Utf8(err.utf8_error()) + } +} + +impl From for Error { + fn from(err: httparse::Error) -> Self { + match err { + httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), + e => Error::Protocol(Cow::Owned(e.description().into())), + } + } +} diff --git a/src/handshake/client.rs b/src/handshake/client.rs new file mode 100644 index 0000000..038880e --- /dev/null +++ b/src/handshake/client.rs @@ -0,0 +1,259 @@ +use std::io::{Read, Write, Cursor}; + +use base64; +use rand; +use bytes::Buf; +use httparse; +use httparse::Status; +use url::Url; + +use input_buffer::InputBuffer; +use error::{Error, Result}; +use super::{ + Headers, + Httparse, FromHttparse, + Handshake, HandshakeResult, + convert_key, + MAX_HEADERS, +}; +use protocol::{ + WebSocket, Role, +}; + +const MIN_READ: usize = 4096; + +/// Client request. +pub struct Request { + pub url: Url, + // TODO extra headers +} + +impl Request { + /// The GET part of the request. + fn get_path(&self) -> String { + if let Some(query) = self.url.query() { + format!("{path}?{query}", path = self.url.path(), query = query) + } else { + self.url.path().into() + } + } + /// The Host: part of the request. + fn get_host(&self) -> String { + let host = self.url.host_str().expect("Bug: URL without host"); + if let Some(port) = self.url.port() { + format!("{host}:{port}", host = host, port = port) + } else { + host.into() + } + } +} + +/// Client handshake. +pub struct ClientHandshake { + stream: Stream, + state: HandshakeState, + verify_data: VerifyData, +} + +impl ClientHandshake { + /// Initiate a WebSocket handshake over the given stream. + pub fn new(stream: Stream, request: Request) -> Self { + let key = generate_key(); + + let mut req = Vec::new(); + write!(req, "\ + GET {path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Connection: upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: {key}\r\n\ + \r\n", host = request.get_host(), path = request.get_path(), key = key) + .unwrap(); + + let accept_key = convert_key(key.as_ref()).unwrap(); + + ClientHandshake { + stream: stream, + state: HandshakeState::SendingRequest(Cursor::new(req)), + verify_data: VerifyData { + accept_key: accept_key, + }, + } + } +} + +impl Handshake for ClientHandshake { + type Stream = WebSocket; + fn handshake(mut self) -> Result> { + debug!("Performing client handshake..."); + match self.state { + HandshakeState::SendingRequest(mut req) => { + let size = self.stream.write(Buf::bytes(&req))?; + Buf::advance(&mut req, size); + let state = if req.has_remaining() { + HandshakeState::SendingRequest(req) + } else { + HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ)) + }; + Ok(HandshakeResult::Incomplete(ClientHandshake { + state: state, + ..self + })) + } + HandshakeState::ReceivingResponse(mut resp_buf) => { + resp_buf.reserve(MIN_READ, usize::max_value()) + .map_err(|_| Error::Capacity("Header too long".into()))?; + resp_buf.read_from(&mut self.stream)?; + if let Some(resp) = Response::parse(&mut resp_buf)? { + self.verify_data.verify_response(&resp)?; + let ws = WebSocket::from_partially_read(self.stream, + resp_buf.into_vec(), Role::Client); + debug!("Client handshake done."); + Ok(HandshakeResult::Done(ws)) + } else { + Ok(HandshakeResult::Incomplete(ClientHandshake { + state: HandshakeState::ReceivingResponse(resp_buf), + ..self + })) + } + } + } + } +} + +/// Information for handshake verification. +struct VerifyData { + /// Accepted server key. + accept_key: String, +} + +impl VerifyData { + pub fn verify_response(&self, response: &Response) -> Result<()> { + // 1. If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if response.code != 101 { + return Err(Error::Http(response.code)); + } + // 2. If the response lacks an |Upgrade| header field or the |Upgrade| + // header field contains a value that is not an ASCII case- + // insensitive match for the value "websocket", the client MUST + // _Fail the WebSocket Connection_. (RFC 6455) + if !response.headers.header_is_ignore_case("Upgrade", "websocket") { + return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + } + // 3. If the response lacks a |Connection| header field or the + // |Connection| header field doesn't contain a token that is an + // ASCII case-insensitive match for the value "Upgrade", the client + // MUST _Fail the WebSocket Connection_. (RFC 6455) + if !response.headers.header_is_ignore_case("Connection", "Upgrade") { + return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + } + // 4. If the response lacks a |Sec-WebSocket-Accept| header field or + // the |Sec-WebSocket-Accept| contains a value other than the + // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket + // Connection_. (RFC 6455) + if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) { + return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + } + // 5. If the response includes a |Sec-WebSocket-Extensions| header + // field and this header field indicates the use of an extension + // that was not present in the client's handshake (the server has + // indicated an extension not requested by the client), the client + // MUST _Fail the WebSocket Connection_. (RFC 6455) + // TODO + + // 6. If the response includes a |Sec-WebSocket-Protocol| header field + // and this header field indicates the use of a subprotocol that was + // not present in the client's handshake (the server has indicated a + // subprotocol not requested by the client), the client MUST _Fail + // the WebSocket Connection_. (RFC 6455) + // TODO + + Ok(()) + } +} + +/// Internal state of the client handshake. +enum HandshakeState { + SendingRequest(Cursor>), + ReceivingResponse(InputBuffer), +} + +/// Server response. +pub struct Response { + code: u16, + headers: Headers, +} + +impl Response { + /// Parse the response from a stream. + pub fn parse(input: &mut B) -> Result> { + Response::parse_http(input) + } +} + +impl Httparse for Response { + fn httparse(buf: &[u8]) -> Result> { + let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut req = httparse::Response::new(&mut hbuffer); + Ok(match req.parse(buf)? { + Status::Partial => None, + Status::Complete(size) => Some((size, Response::from_httparse(req)?)), + }) + } +} + +impl<'h, 'b: 'h> FromHttparse> for Response { + fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { + if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + } + Ok(Response { + code: raw.code.expect("Bug: no HTTP response code"), + headers: Headers::from_httparse(raw.headers)?, + }) + } +} + +/// Generate a random key for the `Sec-WebSocket-Key` header. +fn generate_key() -> String { + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let r: [u8; 16] = rand::random(); + base64::encode(&r) +} + +#[cfg(test)] +mod tests { + + use super::{Response, generate_key}; + + use std::io::Cursor; + + #[test] + fn random_keys() { + let k1 = generate_key(); + println!("Generated random key 1: {}", k1); + let k2 = generate_key(); + println!("Generated random key 2: {}", k2); + assert_ne!(k1, k2); + assert_eq!(k1.len(), k2.len()); + assert_eq!(k1.len(), 24); + assert_eq!(k2.len(), 24); + assert!(k1.ends_with("==")); + assert!(k2.ends_with("==")); + assert!(k1[..22].find("=").is_none()); + assert!(k2[..22].find("=").is_none()); + } + + #[test] + fn response_parsing() { + const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + let mut inp = Cursor::new(data); + let req = Response::parse(&mut inp).unwrap().unwrap(); + assert_eq!(req.code, 200); + assert_eq!(req.headers.find_first("Content-Type"), Some(&b"text/html"[..])); + } + +} diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs new file mode 100644 index 0000000..ba9eff7 --- /dev/null +++ b/src/handshake/mod.rs @@ -0,0 +1,183 @@ +pub mod client; +pub mod server; +#[cfg(feature="tls")] +pub mod tls; + +use std::ascii::AsciiExt; +use std::str::from_utf8; + +use base64; +use bytes::Buf; +use httparse; +use httparse::Status; +use sha1::Sha1; + +use error::Result; + +// Limit the number of header lines. +const MAX_HEADERS: usize = 124; + +/// A handshake state. +pub trait Handshake: Sized { + /// Resulting stream of this handshake. + type Stream; + /// Perform a single handshake round. + fn handshake(self) -> Result>; + /// Perform handshake to the end in a blocking mode. + fn handshake_wait(self) -> Result { + let mut hs = self; + loop { + hs = match hs.handshake()? { + HandshakeResult::Done(stream) => return Ok(stream), + HandshakeResult::Incomplete(s) => s, + } + } + } +} + +/// A handshake result. +pub enum HandshakeResult { + /// Handshake is done, a WebSocket stream is ready. + Done(H::Stream), + /// Handshake is not done, call handshake() again. + Incomplete(H), +} + +impl HandshakeResult { + pub fn map(self, func: F) -> HandshakeResult + where R: Handshake, + F: FnOnce(H) -> R, + { + match self { + HandshakeResult::Done(s) => HandshakeResult::Done(s), + HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)), + } + } +} + +/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. +fn convert_key(input: &[u8]) -> Result { + // ... field is constructed by concatenating /key/ ... + // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(input); + sha1.update(WS_GUID); + Ok(base64::encode(&sha1.digest().bytes())) +} + +/// HTTP request or response headers. +#[derive(Debug)] +pub struct Headers { + data: Vec<(String, Box<[u8]>)>, +} + +impl Headers { + + /// Get first header with the given name, if any. + pub fn find_first(&self, name: &str) -> Option<&[u8]> { + self.data.iter() + .find(|&&(ref n, _)| n.eq_ignore_ascii_case(name)) + .map(|&(_, ref v)| v.as_ref()) + } + + /// Check if the given header has the given value. + pub fn header_is(&self, name: &str, value: &str) -> bool { + self.find_first(name) + .map(|v| v == value.as_bytes()) + .unwrap_or(false) + } + + /// Check if the given header has the given value (case-insensitive). + pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { + self.find_first(name).ok_or(()) + .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) + .map(|val| val.eq_ignore_ascii_case(value)) + .unwrap_or(false) + } + + /// Try to parse data and return headers, if any. + fn parse(input: &mut B) -> Result> { + Headers::parse_http(input) + } + +} + +/// Trait to read HTTP parseable objects. +trait Httparse: Sized { + fn httparse(buf: &[u8]) -> Result>; + fn parse_http(input: &mut B) -> Result> { + Ok(match Self::httparse(input.bytes())? { + Some((size, obj)) => { + input.advance(size); + Some(obj) + }, + None => None, + }) + } +} + +/// Trait to convert raw objects into HTTP parseables. +trait FromHttparse: Sized { + fn from_httparse(raw: T) -> Result; +} + +impl Httparse for Headers { + fn httparse(buf: &[u8]) -> Result> { + let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; + Ok(match httparse::parse_headers(buf, &mut hbuffer)? { + Status::Partial => None, + Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), + }) + } +} + +impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { + fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { + Ok(Headers { + data: raw.iter() + .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) + .collect(), + }) + } +} + +#[cfg(test)] +mod tests { + + use super::{Headers, convert_key}; + + use std::io::Cursor; + + #[test] + fn key_conversion() { + // example from RFC 6455 + assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + #[test] + fn headers() { + const data: &'static [u8] = + b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"; + let mut inp = Cursor::new(data); + let hdr = Headers::parse(&mut inp).unwrap().unwrap(); + assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); + assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); + assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); + + assert!(hdr.header_is("upgrade", "websocket")); + assert!(!hdr.header_is("upgrade", "Websocket")); + assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); + } + + #[test] + fn headers_incomplete() { + const data: &'static [u8] = + b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"; + let mut inp = Cursor::new(data); + let hdr = Headers::parse(&mut inp).unwrap(); + assert!(hdr.is_none()); + } + +} diff --git a/src/handshake/server.rs b/src/handshake/server.rs new file mode 100644 index 0000000..3b3bc7b --- /dev/null +++ b/src/handshake/server.rs @@ -0,0 +1,90 @@ +use bytes::Buf; +use httparse; +use httparse::Status; + +use error::{Error, Result}; +use super::{Headers, Httparse, FromHttparse, convert_key, MAX_HEADERS}; + +/// Request from the client. +pub struct Request { + path: String, + headers: Headers, +} + +impl Request { + /// Parse the request from a stream. + pub fn parse(input: &mut B) -> Result> { + Request::parse_http(input) + } + /// Reply to the response. + pub fn reply(&self) -> Result> { + let key = self.headers.find_first("Sec-WebSocket-Key") + .ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + let reply = format!("\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n\ + \r\n", convert_key(key)?); + Ok(reply.into()) + } +} + +impl Httparse for Request { + fn httparse(buf: &[u8]) -> Result> { + let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut req = httparse::Request::new(&mut hbuffer); + Ok(match req.parse(buf)? { + Status::Partial => None, + Status::Complete(size) => Some((size, Request::from_httparse(req)?)), + }) + } +} + +impl<'h, 'b: 'h> FromHttparse> for Request { + fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { + if raw.method.expect("Bug: no method in header") != "GET" { + return Err(Error::Protocol("Method is not GET".into())); + } + if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { + return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + } + Ok(Request { + path: raw.path.expect("Bug: no path in header").into(), + headers: Headers::from_httparse(raw.headers)? + }) + } +} + +#[cfg(test)] +mod tests { + + use super::Request; + + use std::io::Cursor; + + #[test] + fn request_parsing() { + const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; + let mut inp = Cursor::new(data); + let req = Request::parse(&mut inp).unwrap().unwrap(); + assert_eq!(req.path, "/script.ws"); + assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); + } + + #[test] + fn request_replying() { + const data: &'static [u8] = b"\ + GET /script.ws HTTP/1.1\r\n\ + Host: foo.com\r\n\ + Connection: upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + \r\n"; + let mut inp = Cursor::new(data); + let req = Request::parse(&mut inp).unwrap().unwrap(); + let reply = req.reply().unwrap(); + } + +} diff --git a/src/handshake/tls.rs b/src/handshake/tls.rs new file mode 100644 index 0000000..c093f2f --- /dev/null +++ b/src/handshake/tls.rs @@ -0,0 +1,14 @@ +use native_tls; + +use stream::Stream; +use super::{Handshake, HandshakeResult}; + +pub struct TlsHandshake { + +} + +impl Handshale for TlsHandshake { + type Stream = Stream; + fn handshake(self) -> Result> { + } +} diff --git a/src/input_buffer.rs b/src/input_buffer.rs new file mode 100644 index 0000000..58458bc --- /dev/null +++ b/src/input_buffer.rs @@ -0,0 +1,85 @@ +use std::io::{Cursor, Read, Result as IoResult}; +use bytes::{Buf, BufMut}; + +/// A FIFO buffer for reading packets from network. +pub struct InputBuffer(Cursor>); + +/// Size limit error. +pub struct SizeLimit; + +impl InputBuffer { + /// Create a new empty one. + pub fn with_capacity(capacity: usize) -> Self { + InputBuffer(Cursor::new(Vec::with_capacity(capacity))) + } + + /// Create a new one from partially read data. + pub fn from_partially_read(part: Vec) -> Self { + InputBuffer(Cursor::new(part)) + } + + /// Reserve the given amount of space. + pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{ + if self.inp_mut().remaining_mut() >= space { + // We have enough space right now. + Ok(()) + } else { + let pos = self.out().position() as usize; + self.inp_mut().drain(0..pos); + self.out_mut().set_position(0); + let avail = self.inp_mut().capacity() - self.inp_mut().len(); + if space <= avail { + Ok(()) + } else if self.inp_mut().capacity() + space > limit { + Err(SizeLimit) + } else { + self.inp_mut().reserve(space - avail); + Ok(()) + } + } + } + + /// Read data from stream into the buffer. + pub fn read_from(&mut self, stream: &mut S) -> IoResult { + let size; + let buf = self.inp_mut(); + unsafe { + size = stream.read(buf.bytes_mut())?; + buf.advance_mut(size); + } + Ok(size) + } + + /// Get the rest of the buffer and destroy the buffer. + pub fn into_vec(mut self) -> Vec { + let pos = self.out().position() as usize; + self.inp_mut().drain(0..pos); + self.0.into_inner() + } + + /// The output end (to the application). + pub fn out(&self) -> &Cursor> { + &self.0 // the cursor itself + } + /// The output end (to the application). + pub fn out_mut(&mut self) -> &mut Cursor> { + &mut self.0 // the cursor itself + } + + /// The input end (to the network). + fn inp_mut(&mut self) -> &mut Vec { + self.0.get_mut() // underlying vector + } +} + +impl Buf for InputBuffer { + fn remaining(&self) -> usize { + Buf::remaining(self.out()) + } + fn bytes(&self) -> &[u8] { + Buf::bytes(self.out()) + } + fn advance(&mut self, size: usize) { + Buf::advance(self.out_mut(), size) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..fbd9281 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,28 @@ +//! Lightweight, flexible WebSockets for Rust. +#![deny( + missing_copy_implementations, + trivial_casts, trivial_numeric_casts, + unstable_features, + unused_must_use, + unused_mut, + unused_imports, + unused_import_braces)] + +#[macro_use] extern crate log; +extern crate base64; +extern crate byteorder; +extern crate bytes; +extern crate httparse; +extern crate rand; +extern crate sha1; +extern crate url; +extern crate utf8; +#[cfg(feature="tls")] extern crate native_tls; + +pub mod error; +pub mod protocol; +pub mod client; +pub mod handshake; + +mod input_buffer; +mod stream; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs new file mode 100644 index 0000000..626492b --- /dev/null +++ b/src/protocol/frame/coding.rs @@ -0,0 +1,274 @@ +use std::fmt; +use std::convert::{Into, From}; + +/// WebSocket message opcode as in RFC 6455. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum OpCode { + Data(Data), + Control(Control), +} + +/// Data opcodes as in RFC 6455 +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Data { + /// 0x0 denotes a continuation frame + Continue, + /// 0x1 denotes a text frame + Text, + /// 0x2 denotes a binary frame + Binary, + /// 0x3-7 are reserved for further non-control frames + Reserved(u8), +} + +/// Control opcodes as in RFC 6455 +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Control { + /// 0x8 denotes a connection close + Close, + /// 0x9 denotes a ping + Ping, + /// 0xa denotes a pong + Pong, + /// 0xb-f are reserved for further control frames + Reserved(u8), +} + +impl fmt::Display for Data { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Data::Continue => write!(f, "CONTINUE"), + Data::Text => write!(f, "TEXT"), + Data::Binary => write!(f, "BINARY"), + Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x), + } + } +} + +impl fmt::Display for Control { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Control::Close => write!(f, "CLOSE"), + Control::Ping => write!(f, "PING"), + Control::Pong => write!(f, "PONG"), + Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x), + } + } +} + +impl fmt::Display for OpCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + OpCode::Data(d) => d.fmt(f), + OpCode::Control(c) => c.fmt(f), + } + } +} + +impl Into for OpCode { + fn into(self) -> u8 { + use self::Data::{Continue, Text, Binary}; + use self::Control::{Close, Ping, Pong}; + use self::OpCode::*; + match self { + Data(Continue) => 0, + Data(Text) => 1, + Data(Binary) => 2, + Data(self::Data::Reserved(i)) => i, + + Control(Close) => 8, + Control(Ping) => 9, + Control(Pong) => 10, + Control(self::Control::Reserved(i)) => i, + } + } +} + +impl From for OpCode { + fn from(byte: u8) -> OpCode { + use self::Data::{Continue, Text, Binary}; + use self::Control::{Close, Ping, Pong}; + use self::OpCode::*; + match byte { + 0 => Data(Continue), + 1 => Data(Text), + 2 => Data(Binary), + i @ 3 ... 7 => Data(self::Data::Reserved(i)), + 8 => Control(Close), + 9 => Control(Ping), + 10 => Control(Pong), + i @ 11 ... 15 => Control(self::Control::Reserved(i)), + _ => panic!("Bug: OpCode out of range"), + } + } +} + +use self::CloseCode::*; +/// Status code used to indicate why an endpoint is closing the WebSocket connection. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum CloseCode { + /// Indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + Normal, + /// Indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + Away, + /// Indicates that an endpoint is terminating the connection due + /// to a protocol error. + Protocol, + /// Indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + Unsupported, + /// Indicates that no status code was included in a closing frame. This + /// close code makes it possible to use a single method, `on_close` to + /// handle even cases where no close code was provided. + Status, + /// Indicates an abnormal closure. If the abnormal closure was due to an + /// error, this close code will not be used. Instead, the `on_error` method + /// of the handler will be called with the error. However, if the connection + /// is simply dropped, without an error, this close code will be sent to the + /// handler. + Abnormal, + /// Indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// data within a text message). + Invalid, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., Unsupported or Size) or if there + /// is a need to hide specific details about the policy. + Policy, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + Size, + /// Indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + Extension, + /// Indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + Error, + /// Indicates that the server is restarting. A client may choose to reconnect, + /// and if it does, it should use a randomized delay of 5-30 seconds between attempts. + Restart, + /// Indicates that the server is overloaded and the client should either connect + /// to a different IP (when multiple targets exist), or reconnect to the same IP + /// when a user has performed an action. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Reserved(u16), + #[doc(hidden)] + Iana(u16), + #[doc(hidden)] + Library(u16), + #[doc(hidden)] + Bad(u16), +} + +impl CloseCode { + /// Check if this CloseCode is allowed. + pub fn is_allowed(&self) -> bool { + match *self { + Bad(_) => false, + Reserved(_) => false, + Status => false, + Abnormal => false, + Tls => false, + _ => true, + } + } +} + +impl Into for CloseCode { + fn into(self) -> u16 { + match self { + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Reserved(code) => code, + Iana(code) => code, + Library(code) => code, + Bad(code) => code, + } + } +} + +impl From for CloseCode { + fn from(code: u16) -> CloseCode { + match code { + 1000 => Normal, + 1001 => Away, + 1002 => Protocol, + 1003 => Unsupported, + 1005 => Status, + 1006 => Abnormal, + 1007 => Invalid, + 1008 => Policy, + 1009 => Size, + 1010 => Extension, + 1011 => Error, + 1012 => Restart, + 1013 => Again, + 1015 => Tls, + 1...999 => Bad(code), + 1000...2999 => Reserved(code), + 3000...3999 => Iana(code), + 4000...4999 => Library(code), + _ => Bad(code) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn opcode_from_u8() { + let byte = 2u8; + assert_eq!(OpCode::from(byte), OpCode::Data(Data::Binary)); + } + + #[test] + fn opcode_into_u8() { + let text = OpCode::Data(Data::Text); + let byte: u8 = text.into(); + assert_eq!(byte, 1u8); + } + + #[test] + fn closecode_from_u16() { + let byte = 1008u16; + assert_eq!(CloseCode::from(byte), CloseCode::Policy); + } + + #[test] + fn closecode_into_u16() { + let text = CloseCode::Away; + let byte: u16 = text.into(); + assert_eq!(byte, 1001u16); + } +} diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs new file mode 100644 index 0000000..3992bc5 --- /dev/null +++ b/src/protocol/frame/frame.rs @@ -0,0 +1,516 @@ +use std::fmt; +use std::mem::transmute; +use std::io::{Cursor, Read, Write}; +use std::default::Default; +use std::iter::FromIterator; +use std::string::{String, FromUtf8Error}; +use std::result::Result as StdResult; +use byteorder::{ByteOrder, NetworkEndian}; +use bytes::BufMut; + +use rand; + +use error::{Error, Result}; +use super::coding::{OpCode, Control, Data, CloseCode}; + +fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { + let iter = buf.iter_mut().zip(mask.iter().cycle()); + for (byte, &key) in iter { + *byte ^= key + } +} + +#[inline] +fn generate_mask() -> [u8; 4] { + unsafe { transmute(rand::random::()) } +} + +/// A struct representing a WebSocket frame. +#[derive(Debug, Clone)] +pub struct Frame { + finished: bool, + rsv1: bool, + rsv2: bool, + rsv3: bool, + opcode: OpCode, + + mask: Option<[u8; 4]>, + + payload: Vec, +} + +impl Frame { + + /// Get the length of the frame. + /// This is the length of the header + the length of the payload. + #[inline] + pub fn len(&self) -> usize { + let mut header_length = 2; + let payload_len = self.payload().len(); + if payload_len > 125 { + if payload_len <= u16::max_value() as usize { + header_length += 2; + } else { + header_length += 8; + } + } + + if self.is_masked() { + header_length += 4; + } + + header_length + payload_len + } + + /// Test whether the frame is a final frame. + #[inline] + pub fn is_final(&self) -> bool { + self.finished + } + + /// Test whether the first reserved bit is set. + #[inline] + pub fn has_rsv1(&self) -> bool { + self.rsv1 + } + + /// Test whether the second reserved bit is set. + #[inline] + pub fn has_rsv2(&self) -> bool { + self.rsv2 + } + + /// Test whether the third reserved bit is set. + #[inline] + pub fn has_rsv3(&self) -> bool { + self.rsv3 + } + + /// Get the OpCode of the frame. + #[inline] + pub fn opcode(&self) -> OpCode { + self.opcode + } + + /// Get a reference to the frame's payload. + #[inline] + pub fn payload(&self) -> &Vec { + &self.payload + } + + // Test whether the frame is masked. + #[doc(hidden)] + #[inline] + pub fn is_masked(&self) -> bool { + self.mask.is_some() + } + + // Get an optional reference to the frame's mask. + #[doc(hidden)] + #[allow(dead_code)] + #[inline] + pub fn mask(&self) -> Option<&[u8; 4]> { + self.mask.as_ref() + } + + /// Make this frame a final frame. + #[allow(dead_code)] + #[inline] + pub fn set_final(&mut self, is_final: bool) -> &mut Frame { + self.finished = is_final; + self + } + + /// Set the first reserved bit. + #[inline] + pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { + self.rsv1 = has_rsv1; + self + } + + /// Set the second reserved bit. + #[inline] + pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { + self.rsv2 = has_rsv2; + self + } + + /// Set the third reserved bit. + #[inline] + pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { + self.rsv3 = has_rsv3; + self + } + + /// Set the OpCode. + #[allow(dead_code)] + #[inline] + pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { + self.opcode = opcode; + self + } + + /// Edit the frame's payload. + #[allow(dead_code)] + #[inline] + pub fn payload_mut(&mut self) -> &mut Vec { + &mut self.payload + } + + // Generate a new mask for this frame. + // + // This method simply generates and stores the mask. It does not change the payload data. + // Instead, the payload data will be masked with the generated mask when the frame is sent + // to the other endpoint. + #[doc(hidden)] + #[inline] + pub fn set_mask(&mut self) -> &mut Frame { + self.mask = Some(generate_mask()); + self + } + + // This method unmasks the payload and should only be called on frames that are actually + // masked. In other words, those frames that have just been received from a client endpoint. + #[doc(hidden)] + #[inline] + pub fn remove_mask(&mut self) { + self.mask.and_then(|mask| { + Some(apply_mask(&mut self.payload, &mask)) + }); + self.mask = None; + } + + /// Consume the frame into its payload as binary. + #[inline] + pub fn into_data(self) -> Vec { + self.payload + } + + /// Consume the frame into its payload as string. + #[inline] + pub fn into_string(self) -> StdResult { + String::from_utf8(self.payload) + } + + /// Consume the frame into a closing frame. + #[inline] + pub fn into_close(self) -> Result> { + match self.payload.len() { + 0 => Ok(None), + 1 => Err(Error::Protocol("Invalid close sequence".into())), + _ => { + let mut data = self.payload; + let code = NetworkEndian::read_u16(&data[0..2]).into(); + data.drain(0..2); + let text = String::from_utf8(data)?; + Ok(Some((code, text))) + } + } + } + + /// Create a new data frame. + #[inline] + pub fn message(data: Vec, code: OpCode, finished: bool) -> Frame { + debug_assert!(match code { + OpCode::Data(_) => true, + _ => false, + }, "Invalid opcode for data frame."); + + Frame { + finished: finished, + opcode: code, + payload: data, + .. Frame::default() + } + } + + /// Create a new Pong control frame. + #[inline] + pub fn pong(data: Vec) -> Frame { + Frame { + opcode: OpCode::Control(Control::Pong), + payload: data, + .. Frame::default() + } + } + + /// Create a new Ping control frame. + #[inline] + pub fn ping(data: Vec) -> Frame { + Frame { + opcode: OpCode::Control(Control::Ping), + payload: data, + .. Frame::default() + } + } + + /// Create a new Close control frame. + #[inline] + pub fn close(msg: Option<(CloseCode, &str)>) -> Frame { + let payload = if let Some((code, reason)) = msg { + let raw: [u8; 2] = unsafe { + let u: u16 = code.into(); + transmute(u.to_be()) + }; + Vec::from_iter( + raw[..].iter() + .chain(reason.as_bytes().iter()) + .map(|&b| b)) + } else { + Vec::new() + }; + + Frame { + payload: payload, + .. Frame::default() + } + } + + /// Parse the input stream into a frame. + pub fn parse(cursor: &mut Cursor>) -> Result> { + let size = cursor.get_ref().len() as u64 - cursor.position(); + let initial = cursor.position(); + trace!("Position in buffer {}", initial); + + let mut head = [0u8; 2]; + if try!(cursor.read(&mut head)) != 2 { + cursor.set_position(initial); + return Ok(None) + } + + trace!("Parsed headers {:?}", head); + + let first = head[0]; + let second = head[1]; + trace!("First: {:b}", first); + trace!("Second: {:b}", second); + + let finished = first & 0x80 != 0; + + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; + + let opcode = OpCode::from(first & 0x0F); + trace!("Opcode: {:?}", opcode); + + let masked = second & 0x80 != 0; + trace!("Masked: {:?}", masked); + + let mut header_length = 2; + + let mut length = (second & 0x7F) as u64; + + if length == 126 { + let mut length_bytes = [0u8; 2]; + if try!(cursor.read(&mut length_bytes)) != 2 { + cursor.set_position(initial); + return Ok(None) + } + + length = unsafe { + let mut wide: u16 = transmute(length_bytes); + wide = u16::from_be(wide); + wide + } as u64; + header_length += 2; + } else if length == 127 { + let mut length_bytes = [0u8; 8]; + if try!(cursor.read(&mut length_bytes)) != 8 { + cursor.set_position(initial); + return Ok(None) + } + + unsafe { length = transmute(length_bytes); } + length = u64::from_be(length); + header_length += 8; + } + trace!("Payload length: {}", length); + + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if try!(cursor.read(&mut mask_bytes)) != 4 { + cursor.set_position(initial); + return Ok(None) + } else { + header_length += 4; + Some(mask_bytes) + } + } else { + None + }; + + if size < length + header_length { + cursor.set_position(initial); + return Ok(None) + } + + let mut data = Vec::with_capacity(length as usize); + if length > 0 { + unsafe { + try!(cursor.read_exact(data.bytes_mut())); + data.advance_mut(length as usize); + } + } + + // Disallow bad opcode + match opcode { + OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { + return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) + } + _ => () + } + + let frame = Frame { + finished: finished, + rsv1: rsv1, + rsv2: rsv2, + rsv3: rsv3, + opcode: opcode, + mask: mask, + payload: data, + }; + + + Ok(Some(frame)) + } + + /// Write a frame out to a buffer + pub fn format(mut self, w: &mut W) -> Result<()> + where W: Write + { + let mut one = 0u8; + let code: u8 = self.opcode.into(); + if self.is_final() { + one |= 0x80; + } + if self.has_rsv1() { + one |= 0x40; + } + if self.has_rsv2() { + one |= 0x20; + } + if self.has_rsv3() { + one |= 0x10; + } + one |= code; + + let mut two = 0u8; + + if self.is_masked() { + two |= 0x80; + } + + if self.payload.len() < 126 { + two |= self.payload.len() as u8; + let headers = [one, two]; + try!(w.write(&headers)); + } else if self.payload.len() <= 65535 { + two |= 126; + let length_bytes: [u8; 2] = unsafe { + let short = self.payload.len() as u16; + transmute(short.to_be()) + }; + let headers = [one, two, length_bytes[0], length_bytes[1]]; + try!(w.write(&headers)); + } else { + two |= 127; + let length_bytes: [u8; 8] = unsafe { + let long = self.payload.len() as u64; + transmute(long.to_be()) + }; + let headers = [ + one, + two, + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + length_bytes[4], + length_bytes[5], + length_bytes[6], + length_bytes[7], + ]; + try!(w.write(&headers)); + } + + if self.is_masked() { + let mask = self.mask.take().unwrap(); + apply_mask(&mut self.payload, &mask); + try!(w.write(&mask)); + } + + try!(w.write(&self.payload)); + Ok(()) + } +} + +impl Default for Frame { + fn default() -> Frame { + Frame { + finished: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: OpCode::Control(Control::Close), + mask: None, + payload: Vec::new(), + } + } +} + +impl fmt::Display for Frame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, + " + +final: {} +reserved: {} {} {} +opcode: {} +length: {} +payload length: {} +payload: 0x{} + ", + self.finished, + self.rsv1, + self.rsv2, + self.rsv3, + self.opcode, + // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), + self.len(), + self.payload.len(), + self.payload.iter().map(|byte| format!("{:x}", byte)).collect::()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use super::super::coding::{OpCode, Data}; + use std::io::Cursor; + + #[test] + fn parse() { + let mut raw: Cursor> = Cursor::new(vec![ + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 + ]); + let frame = Frame::parse(&mut raw).unwrap().unwrap(); + assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); + } + + #[test] + fn format() { + let frame = Frame::ping(vec![0x01, 0x02]); + let mut buf = Vec::with_capacity(frame.len()); + frame.format(&mut buf).unwrap(); + assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); + } + + #[test] + fn display() { + let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); + let view = format!("{}", f); + view.contains("payload:"); + } +} diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs new file mode 100644 index 0000000..af4fbde --- /dev/null +++ b/src/protocol/frame/mod.rs @@ -0,0 +1,135 @@ +pub mod coding; + +mod frame; + +pub use self::frame::Frame; + +use std::io::{Read, Write}; + +use input_buffer; +use error::{Error, Result}; + +const MIN_READ: usize = 4096; + +/// A reader and writer for WebSocket frames. +pub struct FrameSocket { + stream: Stream, + in_buffer: input_buffer::InputBuffer, + out_buffer: Vec, +} + +impl FrameSocket { + /// Create a new frame socket. + pub fn new(stream: Stream) -> Self { + FrameSocket { + stream: stream, + in_buffer: input_buffer::InputBuffer::with_capacity(MIN_READ), + out_buffer: Vec::new(), + } + } + /// Create a new frame socket from partially read data. + pub fn from_partially_read(stream: Stream, part: Vec) -> Self { + FrameSocket { + stream: stream, + in_buffer: input_buffer::InputBuffer::from_partially_read(part), + out_buffer: Vec::new(), + } + } + /// Extract a stream from the socket. + pub fn into_inner(self) -> (Stream, Vec) { + (self.stream, self.in_buffer.into_vec()) + } +} + +impl FrameSocket + where Stream: Read +{ + /// Read a frame from stream. + pub fn read_frame(&mut self) -> Result> { + loop { + if let Some(frame) = Frame::parse(&mut self.in_buffer.out_mut())? { + debug!("received frame {}", frame); + return Ok(Some(frame)); + } + // No full frames in buffer. + self.in_buffer.reserve(MIN_READ, usize::max_value()) + .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?; + let size = self.in_buffer.read_from(&mut self.stream)?; + if size == 0 { + debug!("no frame received"); + return Ok(None) + } + } + } + +} + +impl FrameSocket + where Stream: Write +{ + /// Write a frame to stream. + pub fn write_frame(&mut self, frame: Frame) -> Result<()> { + debug!("writing frame {}", frame); + self.out_buffer.reserve(frame.len()); + frame.format(&mut self.out_buffer)?; + let len = self.stream.write(&self.out_buffer)?; + self.out_buffer.drain(0..len); + Ok(()) + } +} + + +#[cfg(test)] +mod tests { + + use super::{Frame, FrameSocket}; + + use std::io::Cursor; + + #[test] + fn read_frames() { + let raw = Cursor::new(vec![ + 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x82, 0x03, 0x03, 0x02, 0x01, + 0x99, + ]); + let mut sock = FrameSocket::new(raw); + + assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + vec![0x03, 0x02, 0x01]); + assert!(sock.read_frame().unwrap().is_none()); + + let (_, rest) = sock.into_inner(); + assert_eq!(rest, vec![0x99]); + } + + #[test] + fn from_partially_read() { + let raw = Cursor::new(vec![ + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ]); + let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); + assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), + vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + } + + #[test] + fn write_frames() { + let mut sock = FrameSocket::new(Vec::new()); + + let frame = Frame::ping(vec![0x04, 0x05]); + sock.write_frame(frame).unwrap(); + + let frame = Frame::pong(vec![0x01]); + sock.write_frame(frame).unwrap(); + + let (buf, _) = sock.into_inner(); + assert_eq!(buf, vec![ + 0x89, 0x02, 0x04, 0x05, + 0x8a, 0x01, 0x01 + ]); + } + +} diff --git a/src/protocol/message.rs b/src/protocol/message.rs new file mode 100644 index 0000000..1b954d7 --- /dev/null +++ b/src/protocol/message.rs @@ -0,0 +1,260 @@ +use std::convert::{From, Into, AsRef}; +use std::fmt; +use std::result::Result as StdResult; +use std::str; + +use error::Result; + +mod string_collect { + + use utf8; + + use error::{Error, Result}; + + pub struct StringCollector { + data: String, + decoder: utf8::Decoder, + } + + impl StringCollector { + pub fn new() -> Self { + StringCollector { + data: String::new(), + decoder: utf8::Decoder::new(), + } + } + pub fn extend>(&mut self, tail: T) -> Result<()> { + let (sym, text, result) = self.decoder.decode(tail.as_ref()); + self.data.push_str(&sym); + self.data.push_str(text); + match result { + utf8::Result::Ok | utf8::Result::Incomplete => + Ok(()), + utf8::Result::Error { remaining_input_after_error: _ } => + Err(Error::Protocol("Invalid UTF8".into())), // FIXME + } + } + pub fn into_string(self) -> Result { + if self.decoder.has_incomplete_sequence() { + Err(Error::Protocol("Invalid UTF8".into())) // FIXME + } else { + Ok(self.data) + } + } + } + +} + +use self::string_collect::StringCollector; + +/// A struct representing the incomplete message. +pub struct IncompleteMessage { + collector: IncompleteMessageCollector, +} + +enum IncompleteMessageCollector { + Text(StringCollector), + Binary(Vec), +} + +impl IncompleteMessage { + /// Create new. + pub fn new(message_type: IncompleteMessageType) -> Self { + IncompleteMessage { + collector: match message_type { + IncompleteMessageType::Binary => + IncompleteMessageCollector::Binary(Vec::new()), + IncompleteMessageType::Text => + IncompleteMessageCollector::Text(StringCollector::new()), + } + } + } + /// Add more data to an existing message. + pub fn extend>(&mut self, tail: T) -> Result<()> { + match self.collector { + IncompleteMessageCollector::Binary(ref mut v) => { + v.extend(tail.as_ref()); + Ok(()) + } + IncompleteMessageCollector::Text(ref mut t) => { + t.extend(tail) + } + } + } + /// Convert an incomplete message into a complete one. + pub fn complete(self) -> Result { + match self.collector { + IncompleteMessageCollector::Binary(v) => { + Ok(Message::Binary(v)) + } + IncompleteMessageCollector::Text(t) => { + let text = t.into_string()?; + Ok(Message::Text(text)) + } + } + } +} + +/// The type of incomplete message. +pub enum IncompleteMessageType { + Text, + Binary, +} + +/// An enum representing the various forms of a WebSocket message. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Message { + /// A text WebSocket message + Text(String), + /// A binary WebSocket message + Binary(Vec), +} + +impl Message { + + /// Create a new text WebSocket message from a stringable. + pub fn text(string: S) -> Message + where S: Into + { + Message::Text(string.into()) + } + + /// Create a new binary WebSocket message by converting to Vec. + pub fn binary(bin: B) -> Message + where B: Into> + { + Message::Binary(bin.into()) + } + + /// Indicates whether a message is a text message. + pub fn is_text(&self) -> bool { + match *self { + Message::Text(_) => true, + Message::Binary(_) => false, + } + } + + /// Indicates whether a message is a binary message. + pub fn is_binary(&self) -> bool { + match *self { + Message::Text(_) => false, + Message::Binary(_) => true, + } + } + + /// Get the length of the WebSocket message. + pub fn len(&self) -> usize { + match *self { + Message::Text(ref string) => string.len(), + Message::Binary(ref data) => data.len(), + } + } + + /// Returns true if the WebSocket message has no content. + /// For example, if the other side of the connection sent an empty string. + pub fn is_empty(&self) -> bool { + match *self { + Message::Text(ref string) => string.is_empty(), + Message::Binary(ref data) => data.is_empty(), + } + } + + /// Consume the WebSocket and return it as binary data. + pub fn into_data(self) -> Vec { + match self { + Message::Text(string) => string.into_bytes(), + Message::Binary(data) => data, + } + } + + /// Attempt to consume the WebSocket message and convert it to a String. + pub fn into_text(self) -> Result { + match self { + Message::Text(string) => Ok(string), + Message::Binary(data) => Ok(try!( + String::from_utf8(data).map_err(|err| err.utf8_error()))), + } + } + + /// Attempt to get a &str from the WebSocket message, + /// this will try to convert binary data to utf8. + pub fn to_text(&self) -> Result<&str> { + match *self { + Message::Text(ref string) => Ok(string), + Message::Binary(ref data) => Ok(try!(str::from_utf8(data))), + } + } + +} + +impl From for Message { + fn from(string: String) -> Message { + Message::text(string) + } +} + +impl<'s> From<&'s str> for Message { + fn from(string: &'s str) -> Message { + Message::text(string) + } +} + +impl<'b> From<&'b [u8]> for Message { + fn from(data: &'b [u8]) -> Message { + Message::binary(data) + } +} + +impl From> for Message { + fn from(data: Vec) -> Message { + Message::binary(data) + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> { + if let Ok(string) = self.to_text() { + write!(f, "{}", string) + } else { + write!(f, "Binary Data", self.len()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let t = Message::text(format!("test")); + assert_eq!(t.to_string(), "test".to_owned()); + + let bin = Message::binary(vec![0, 1, 3, 4, 241]); + assert_eq!(bin.to_string(), "Binary Data".to_owned()); + } + + #[test] + fn binary_convert() { + let bin = [6u8, 7, 8, 9, 10, 241]; + let msg = Message::from(&bin[..]); + assert!(msg.is_binary()); + assert!(msg.into_text().is_err()); + } + + + #[test] + fn binary_convert_vec() { + let bin = vec![6u8, 7, 8, 9, 10, 241]; + let msg = Message::from(bin); + assert!(msg.is_binary()); + assert!(msg.into_text().is_err()); + } + + #[test] + fn text_convert() { + let s = "kiwotsukete"; + let msg = Message::from(s); + assert!(msg.is_text()); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..1dc9b63 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,374 @@ +//! Generic WebSocket protocol implementation + +mod frame; +mod message; + +pub use self::message::Message; + +use self::message::{IncompleteMessage, IncompleteMessageType}; +use std::collections::VecDeque; +use std::io::{Read, Write}; +use std::mem::replace; + +use error::{Error, Result}; +use self::frame::{Frame, FrameSocket}; +use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode}; + +/// Indicates a Client or Server role of the websocket +#[derive(Debug, Clone, Copy)] +pub enum Role { + /// This socket is a server + Server, + /// This socket is a client + Client, +} + +/// WebSocket input-output stream +pub struct WebSocket { + /// Server or client? + role: Role, + /// The underlying socket. + socket: FrameSocket, + /// The state of processing, either "active" or "closing". + state: WebSocketState, + /// Receive: an incomplete message being processed. + incomplete: Option, + /// Send: a data send queue. + send_queue: VecDeque, + /// Send: an OOB pong message. + pong: Option, +} + +impl WebSocket + where Stream: Read + Write +{ + + /// Convert a raw socket into a WebSocket without performing a handshake. + pub fn from_raw_socket(stream: Stream, role: Role) -> Self { + WebSocket::from_frame_socket(FrameSocket::new(stream), role) + } + + /// Convert a raw socket into a WebSocket without performing a handshake. + pub fn from_partially_read(stream: Stream, part: Vec, role: Role) -> Self { + WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) + } + + /// Read a message from stream, if possible. + pub fn read_message(&mut self) -> Result { + loop { + self.send_pending()?; // FIXME + if let Some(message) = self.read_message_frame()? { + debug!("Received message {}", message); + return Ok(message) + } + } + } + + /// Send a message to stream, if possible. + pub fn write_message(&mut self, message: Message) -> Result<()> { + let frame = { + let opcode = match message { + Message::Text(_) => OpData::Text, + Message::Binary(_) => OpData::Binary, + }; + Frame::message(message.into_data(), OpCode::Data(opcode), true) + }; + self.send_queue.push_back(frame); + self.send_pending() + } + + /// Close the connection. + pub fn close(&mut self) -> Result<()> { + match self.state { + WebSocketState::Active => { + self.state = WebSocketState::ClosedByUs; + // TODO + } + _ => { + // already closed, nothing to do + } + } + Ok(()) + } + + /// Convert a frame socket into a WebSocket. + fn from_frame_socket(socket: FrameSocket, role: Role) -> Self { + WebSocket { + role: role, + socket: socket, + state: WebSocketState::Active, + incomplete: None, + send_queue: VecDeque::new(), + pong: None, + } + } + + /// Try to decode one message frame. May return None. + fn read_message_frame(&mut self) -> Result> { + if let Some(mut frame) = self.socket.read_frame()? { + + // MUST be 0 unless an extension is negotiated that defines meanings + // for non-zero values. If a nonzero value is received and none of + // the negotiated extensions defines the meaning of such a nonzero + // value, the receiving endpoint MUST _Fail the WebSocket + // Connection_. + if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { + return Err(Error::Protocol("Reserved bits are non-zero".into())) + } + + match self.role { + Role::Server => { + if frame.is_masked() { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + frame.remove_mask() + } else { + // The server MUST close the connection upon receiving a + // frame that is not masked. (RFC 6455) + return Err(Error::Protocol("Received an unmasked frame from client".into())) + } + } + Role::Client => { + if frame.is_masked() { + // A client MUST close a connection if it detects a masked frame. (RFC 6455) + return Err(Error::Protocol("Received a masked frame from server".into())) + } + } + } + + match frame.opcode() { + + OpCode::Control(ctl) => { + (match ctl { + // All control frames MUST have a payload length of 125 bytes or less + // and MUST NOT be fragmented. (RFC 6455) + _ if !frame.is_final() => { + Err(Error::Protocol("Fragmented control frame".into())) + } + _ if frame.payload().len() > 125 => { + Err(Error::Protocol("Control frame too big".into())) + } + OpCtl::Close => { + self.do_close(frame.into_close()?) + } + OpCtl::Reserved(i) => { + Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + } + OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => { + // No ping processing while closing. + Ok(()) + } + OpCtl::Ping => { + self.do_ping(frame.into_data()) + } + OpCtl::Pong => { + self.do_pong(frame.into_data()) + } + }).map(|_| None) + } + + OpCode::Data(_) if !self.state.is_active() => { + // No data processing while closing. + Ok(None) + } + + OpCode::Data(data) => { + let fin = frame.is_final(); + match data { + OpData::Continue => { + if let Some(ref mut msg) = self.incomplete { + // TODO if msg too big + msg.extend(frame.into_data())?; + } else { + return Err(Error::Protocol("Continue frame but nothing to continue".into())) + } + if fin { + Ok(Some(replace(&mut self.incomplete, None).unwrap().complete()?)) + } else { + Ok(None) + } + } + c if self.incomplete.is_some() => { + Err(Error::Protocol( + format!("Received {} while waiting for more fragments", c).into() + )) + } + OpData::Text | OpData::Binary => { + let msg = { + let message_type = match data { + OpData::Text => IncompleteMessageType::Text, + OpData::Binary => IncompleteMessageType::Binary, + _ => panic!("Bug: message is not text nor binary"), + }; + let mut m = IncompleteMessage::new(message_type); + m.extend(frame.into_data())?; + m + }; + if fin { + Ok(Some(msg.complete()?)) + } else { + self.incomplete = Some(msg); + Ok(None) + } + } + OpData::Reserved(i) => { + Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + } + } + } + + } // match opcode + + } else { + //Ok(None) // TODO handle EOF? + Err(Error::Protocol("Connection reset without closing handshake".into())) + } + } + + /// Received a close frame. + fn do_close(&mut self, close: Option<(CloseCode, String)>) -> Result<()> { + match self.state { + WebSocketState::Active => { + self.state = WebSocketState::ClosedByPeer; + let reply = if let Some((code, _)) = close { + if code.is_allowed() { + Frame::close(Some((CloseCode::Normal, ""))) + } else { + Frame::close(Some((CloseCode::Protocol, "Protocol violation"))) + } + } else { + Frame::close(None) + }; + self.send_queue.push_back(reply); + } + WebSocketState::ClosedByPeer => { + // It is already closed, just ignore. + } + WebSocketState::ClosedByUs => { + // We received a reply. + match self.role { + Role::Client => { + // Client waits for the server to close the connection. + } + Role::Server => { + // Server closes the connection. + // TODO + } + } + } + } + //unimplemented!() + Ok(()) + } + + /// Received a ping frame. + fn do_ping(&mut self, ping: Vec) -> Result<()> { + // If an endpoint receives a Ping frame and has not yet sent Pong + // frame(s) in response to previous Ping frame(s), the endpoint MAY + // elect to send a Pong frame for only the most recently processed Ping + // frame. (RFC 6455) + // We do exactly that, keeping a "queue" from one and only Pong frame. + self.pong = Some(Frame::pong(ping)); + Ok(()) + } + + /// Received a pong frame. + fn do_pong(&mut self, _: Vec) -> Result<()> { + // A Pong frame MAY be sent unsolicited. This serves as a + // unidirectional heartbeat. A response to an unsolicited Pong frame is + // not expected. (RFC 6455) + // Due to this, we just don't check pongs right now. + // TODO: check if there was a reply to our ping at all... + Ok(()) + } + + /// Flush the pending send queue. + fn send_pending(&mut self) -> Result<()> { + // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in + // response, unless it already received a Close frame. It SHOULD + // respond with Pong frame as soon as is practical. (RFC 6455) + if let Some(pong) = replace(&mut self.pong, None) { + self.send_one_frame(pong)?; + } + // If we have any unsent frames, send them. + while let Some(data) = self.send_queue.pop_front() { + self.send_one_frame(data)?; + } + Ok(()) + } + + /// Send a single pending frame. + fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> { + match self.role { + Role::Server => { + } + Role::Client => { + // 5. If the data is being sent by the client, the frame(s) MUST be + // masked as defined in Section 5.3. (RFC 6455) + frame.set_mask(); + } + } + self.socket.write_frame(frame)?; + Ok(()) + } + +} + +/// The current connection state. +enum WebSocketState { + Active, + ClosedByUs, + ClosedByPeer, +} + +impl WebSocketState { + /// Tell if we're allowed to process normal messages. + fn is_active(&self) -> bool { + match *self { + WebSocketState::Active => true, + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::{WebSocket, Role, Message}; + + use std::io; + use std::io::Cursor; + + struct WriteMoc(Stream); + + impl io::Write for WriteMoc { + fn write(&mut self, buf: &[u8]) -> io::Result { + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + impl io::Read for WriteMoc { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + } + + + #[test] + fn receive_messages() { + let incoming = Cursor::new(vec![ + 0x01, 0x07, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, + 0x80, 0x06, + 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, + 0x82, 0x03, + 0x01, 0x02, 0x03, + ]); + let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); + assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); + assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + } + +} diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..3818bb3 --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,39 @@ +#[cfg(feature="tls")] +use native_tls::TlsStream; + +use std::net::TcpStream; +use std::io::{Read, Write, Result as IoResult}; + +/// Stream, either plain TCP or TLS. +pub enum Stream { + Plain(TcpStream), + #[cfg(feature="tls")] + Tls(TlsStream), +} + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + match *self { + Stream::Plain(ref mut s) => s.read(buf), + #[cfg(feature="tls")] + Stream::Tls(ref mut s) => s.read(buf), + } + } +} + +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> IoResult { + match *self { + Stream::Plain(ref mut s) => s.write(buf), + #[cfg(feature="tls")] + Stream::Tls(ref mut s) => s.write(buf), + } + } + fn flush(&mut self) -> IoResult<()> { + match *self { + Stream::Plain(ref mut s) => s.flush(), + #[cfg(feature="tls")] + Stream::Tls(ref mut s) => s.flush(), + } + } +}