commit
e63f594a14
@ -0,0 +1,2 @@ |
||||
target |
||||
Cargo.lock |
@ -0,0 +1,24 @@ |
||||
[package] |
||||
name = "ws2" |
||||
version = "0.1.0" |
||||
authors = ["Alexey Galakhov"] |
||||
|
||||
[features] |
||||
default = [] |
||||
tls = ["native-tls"] |
||||
|
||||
[dependencies] |
||||
base64 = "*" |
||||
byteorder = "*" |
||||
bytes = { git = "https://github.com/carllerche/bytes.git" } |
||||
httparse = "*" |
||||
env_logger = "*" |
||||
log = "*" |
||||
rand = "*" |
||||
sha1 = "*" |
||||
url = "*" |
||||
utf-8 = "*" |
||||
|
||||
[dependencies.native-tls] |
||||
optional = true |
||||
version = "*" |
@ -0,0 +1,20 @@ |
||||
Copyright (c) 2016 Alexey Galakhov |
||||
Copyright (c) 2016 Jason Housley |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy |
||||
of this software and associated documentation files (the "Software"), to deal |
||||
in the Software without restriction, including without limitation the rights |
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
||||
copies of the Software, and to permit persons to whom the Software is |
||||
furnished to do so, subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in |
||||
all copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
||||
THE SOFTWARE. |
@ -0,0 +1,62 @@ |
||||
#[macro_use] extern crate log; |
||||
extern crate env_logger; |
||||
extern crate ws2; |
||||
extern crate url; |
||||
|
||||
use url::Url; |
||||
|
||||
use ws2::protocol::Message; |
||||
use ws2::client::connect; |
||||
use ws2::handshake::Handshake; |
||||
use ws2::error::{Error, Result}; |
||||
|
||||
const AGENT: &'static str = "WS2-RS"; |
||||
|
||||
fn get_case_count() -> Result<u32> { |
||||
let mut socket = connect( |
||||
Url::parse("ws://localhost:9001/getCaseCount").unwrap() |
||||
)?.handshake_wait()?; |
||||
let msg = socket.read_message()?; |
||||
socket.close(); |
||||
Ok(msg.into_text()?.parse::<u32>().unwrap()) |
||||
} |
||||
|
||||
fn update_reports() -> Result<()> { |
||||
let mut socket = connect( |
||||
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() |
||||
)?.handshake_wait()?; |
||||
socket.close(); |
||||
Ok(()) |
||||
} |
||||
|
||||
fn run_test(case: u32) -> Result<()> { |
||||
info!("Running test case {}", case); |
||||
let case_url = Url::parse( |
||||
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) |
||||
).unwrap(); |
||||
let mut socket = connect(case_url)?.handshake_wait()?; |
||||
loop { |
||||
let msg = socket.read_message()?; |
||||
socket.write_message(msg)?; |
||||
} |
||||
socket.close(); |
||||
Ok(()) |
||||
} |
||||
|
||||
fn main() { |
||||
env_logger::init().unwrap(); |
||||
|
||||
let total = get_case_count().unwrap(); |
||||
|
||||
for case in 1..(total + 1) { |
||||
if let Err(e) = run_test(case) { |
||||
match e { |
||||
Error::Protocol(_) => { } |
||||
err => { warn!("test: {}", err); } |
||||
} |
||||
} |
||||
} |
||||
|
||||
update_reports().unwrap(); |
||||
} |
||||
|
@ -0,0 +1,25 @@ |
||||
extern crate ws2; |
||||
extern crate url; |
||||
extern crate env_logger; |
||||
|
||||
use url::Url; |
||||
use ws2::protocol::Message; |
||||
use ws2::client::connect; |
||||
use ws2::protocol::handshake::Handshake; |
||||
|
||||
fn main() { |
||||
env_logger::init().unwrap(); |
||||
|
||||
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) |
||||
.expect("Can't connect") |
||||
.handshake_wait() |
||||
.expect("Handshake error"); |
||||
|
||||
socket.write_message(Message::Text("Hello WebSocket".into())); |
||||
loop { |
||||
let msg = socket.read_message().expect("Error reading message"); |
||||
println!("Received: {}", msg); |
||||
} |
||||
// socket.close();
|
||||
|
||||
} |
@ -0,0 +1,75 @@ |
||||
use std::net::{TcpStream, ToSocketAddrs}; |
||||
use url::{Url, SocketAddrs}; |
||||
|
||||
use protocol::WebSocket; |
||||
use handshake::{Handshake as HandshakeTrait, HandshakeResult}; |
||||
use handshake::client::{ClientHandshake, Request}; |
||||
use error::{Error, Result}; |
||||
|
||||
/// Connect to the given WebSocket.
|
||||
///
|
||||
/// Note that this function may block the current thread while DNS resolution is performed.
|
||||
pub fn connect(url: Url) -> Result<Handshake> { |
||||
let mode = match url.scheme() { |
||||
"ws" => Mode::Plain, |
||||
#[cfg(feature="tls")] |
||||
"wss" => Mode::Tls, |
||||
_ => return Err(Error::Url("URL scheme not supported".into())) |
||||
}; |
||||
|
||||
// Note that this function may block the current thread while resolution is performed.
|
||||
let addrs = url.to_socket_addrs()?; |
||||
Ok(Handshake { |
||||
state: HandshakeState::Nothing(url), |
||||
alt_addresses: addrs, |
||||
}) |
||||
} |
||||
|
||||
enum Mode { |
||||
Plain, |
||||
Tls, |
||||
} |
||||
|
||||
enum HandshakeState { |
||||
Nothing(Url), |
||||
WebSocket(ClientHandshake<TcpStream>), |
||||
} |
||||
|
||||
pub struct Handshake { |
||||
state: HandshakeState, |
||||
alt_addresses: SocketAddrs, |
||||
} |
||||
|
||||
impl HandshakeTrait for Handshake { |
||||
type Stream = WebSocket<TcpStream>; |
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> { |
||||
match self.state { |
||||
HandshakeState::Nothing(url) => { |
||||
if let Some(addr) = self.alt_addresses.next() { |
||||
debug!("Trying to contact {} at {}...", url, addr); |
||||
let state = { |
||||
if let Ok(stream) = TcpStream::connect(addr) { |
||||
let hs = ClientHandshake::new(stream, Request { url: url }); |
||||
HandshakeState::WebSocket(hs) |
||||
} else { |
||||
HandshakeState::Nothing(url) |
||||
} |
||||
}; |
||||
Ok(HandshakeResult::Incomplete(Handshake { |
||||
state: state, |
||||
..self |
||||
})) |
||||
} else { |
||||
Err(Error::Url(format!("Unable to resolve {}", url).into())) |
||||
} |
||||
} |
||||
HandshakeState::WebSocket(ws) => { |
||||
let alt_addresses = self.alt_addresses; |
||||
ws.handshake().map(move |r| r.map(move |s| Handshake { |
||||
state: HandshakeState::WebSocket(s), |
||||
alt_addresses: alt_addresses, |
||||
})) |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,84 @@ |
||||
//! Error handling.
|
||||
|
||||
use std::borrow::{Borrow, Cow}; |
||||
|
||||
use std::error::Error as ErrorTrait; |
||||
use std::fmt; |
||||
use std::io; |
||||
use std::result; |
||||
use std::str; |
||||
use std::string; |
||||
|
||||
use httparse; |
||||
|
||||
pub type Result<T> = result::Result<T, Error>; |
||||
|
||||
/// Possible WebSocket errors
|
||||
#[derive(Debug)] |
||||
pub enum Error { |
||||
/// Input-output error
|
||||
Io(io::Error), |
||||
/// Buffer capacity exhausted
|
||||
Capacity(Cow<'static, str>), |
||||
/// Protocol violation
|
||||
Protocol(Cow<'static, str>), |
||||
/// UTF coding error
|
||||
Utf8(str::Utf8Error), |
||||
/// Invlid URL.
|
||||
Url(Cow<'static, str>), |
||||
/// HTTP error.
|
||||
Http(u16), |
||||
} |
||||
|
||||
impl fmt::Display for Error { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
||||
match *self { |
||||
Error::Io(ref err) => write!(f, "IO error: {}", err), |
||||
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), |
||||
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), |
||||
Error::Utf8(ref err) => write!(f, "UTF-8 encoding error: {}", err), |
||||
Error::Url(ref msg) => write!(f, "URL error: {}", msg), |
||||
Error::Http(code) => write!(f, "HTTP code: {}", code), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl ErrorTrait for Error { |
||||
fn description(&self) -> &str { |
||||
match *self { |
||||
Error::Io(ref err) => err.description(), |
||||
Error::Capacity(ref msg) => msg.borrow(), |
||||
Error::Protocol(ref msg) => msg.borrow(), |
||||
Error::Utf8(ref err) => err.description(), |
||||
Error::Url(ref msg) => msg.borrow(), |
||||
Error::Http(_) => "", |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl From<io::Error> for Error { |
||||
fn from(err: io::Error) -> Self { |
||||
Error::Io(err) |
||||
} |
||||
} |
||||
|
||||
impl From<str::Utf8Error> for Error { |
||||
fn from(err: str::Utf8Error) -> Self { |
||||
Error::Utf8(err) |
||||
} |
||||
} |
||||
|
||||
impl From<string::FromUtf8Error> for Error { |
||||
fn from(err: string::FromUtf8Error) -> Self { |
||||
Error::Utf8(err.utf8_error()) |
||||
} |
||||
} |
||||
|
||||
impl From<httparse::Error> for Error { |
||||
fn from(err: httparse::Error) -> Self { |
||||
match err { |
||||
httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), |
||||
e => Error::Protocol(Cow::Owned(e.description().into())), |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,259 @@ |
||||
use std::io::{Read, Write, Cursor}; |
||||
|
||||
use base64; |
||||
use rand; |
||||
use bytes::Buf; |
||||
use httparse; |
||||
use httparse::Status; |
||||
use url::Url; |
||||
|
||||
use input_buffer::InputBuffer; |
||||
use error::{Error, Result}; |
||||
use super::{ |
||||
Headers, |
||||
Httparse, FromHttparse, |
||||
Handshake, HandshakeResult, |
||||
convert_key, |
||||
MAX_HEADERS, |
||||
}; |
||||
use protocol::{ |
||||
WebSocket, Role, |
||||
}; |
||||
|
||||
const MIN_READ: usize = 4096; |
||||
|
||||
/// Client request.
|
||||
pub struct Request { |
||||
pub url: Url, |
||||
// TODO extra headers
|
||||
} |
||||
|
||||
impl Request { |
||||
/// The GET part of the request.
|
||||
fn get_path(&self) -> String { |
||||
if let Some(query) = self.url.query() { |
||||
format!("{path}?{query}", path = self.url.path(), query = query) |
||||
} else { |
||||
self.url.path().into() |
||||
} |
||||
} |
||||
/// The Host: part of the request.
|
||||
fn get_host(&self) -> String { |
||||
let host = self.url.host_str().expect("Bug: URL without host"); |
||||
if let Some(port) = self.url.port() { |
||||
format!("{host}:{port}", host = host, port = port) |
||||
} else { |
||||
host.into() |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Client handshake.
|
||||
pub struct ClientHandshake<Stream> { |
||||
stream: Stream, |
||||
state: HandshakeState, |
||||
verify_data: VerifyData, |
||||
} |
||||
|
||||
impl<Stream: Read + Write> ClientHandshake<Stream> { |
||||
/// Initiate a WebSocket handshake over the given stream.
|
||||
pub fn new(stream: Stream, request: Request) -> Self { |
||||
let key = generate_key(); |
||||
|
||||
let mut req = Vec::new(); |
||||
write!(req, "\ |
||||
GET {path} HTTP/1.1\r\n\ |
||||
Host: {host}\r\n\ |
||||
Connection: upgrade\r\n\ |
||||
Upgrade: websocket\r\n\ |
||||
Sec-WebSocket-Version: 13\r\n\ |
||||
Sec-WebSocket-Key: {key}\r\n\ |
||||
\r\n", host = request.get_host(), path = request.get_path(), key = key) |
||||
.unwrap(); |
||||
|
||||
let accept_key = convert_key(key.as_ref()).unwrap(); |
||||
|
||||
ClientHandshake { |
||||
stream: stream, |
||||
state: HandshakeState::SendingRequest(Cursor::new(req)), |
||||
verify_data: VerifyData { |
||||
accept_key: accept_key, |
||||
}, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<Stream: Read + Write> Handshake for ClientHandshake<Stream> { |
||||
type Stream = WebSocket<Stream>; |
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> { |
||||
debug!("Performing client handshake..."); |
||||
match self.state { |
||||
HandshakeState::SendingRequest(mut req) => { |
||||
let size = self.stream.write(Buf::bytes(&req))?; |
||||
Buf::advance(&mut req, size); |
||||
let state = if req.has_remaining() { |
||||
HandshakeState::SendingRequest(req) |
||||
} else { |
||||
HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ)) |
||||
}; |
||||
Ok(HandshakeResult::Incomplete(ClientHandshake { |
||||
state: state, |
||||
..self |
||||
})) |
||||
} |
||||
HandshakeState::ReceivingResponse(mut resp_buf) => { |
||||
resp_buf.reserve(MIN_READ, usize::max_value()) |
||||
.map_err(|_| Error::Capacity("Header too long".into()))?; |
||||
resp_buf.read_from(&mut self.stream)?; |
||||
if let Some(resp) = Response::parse(&mut resp_buf)? { |
||||
self.verify_data.verify_response(&resp)?; |
||||
let ws = WebSocket::from_partially_read(self.stream, |
||||
resp_buf.into_vec(), Role::Client); |
||||
debug!("Client handshake done."); |
||||
Ok(HandshakeResult::Done(ws)) |
||||
} else { |
||||
Ok(HandshakeResult::Incomplete(ClientHandshake { |
||||
state: HandshakeState::ReceivingResponse(resp_buf), |
||||
..self |
||||
})) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Information for handshake verification.
|
||||
struct VerifyData { |
||||
/// Accepted server key.
|
||||
accept_key: String, |
||||
} |
||||
|
||||
impl VerifyData { |
||||
pub fn verify_response(&self, response: &Response) -> Result<()> { |
||||
// 1. If the status code received from the server is not 101, the
|
||||
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
||||
if response.code != 101 { |
||||
return Err(Error::Http(response.code)); |
||||
} |
||||
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
|
||||
// header field contains a value that is not an ASCII case-
|
||||
// insensitive match for the value "websocket", the client MUST
|
||||
// _Fail the WebSocket Connection_. (RFC 6455)
|
||||
if !response.headers.header_is_ignore_case("Upgrade", "websocket") { |
||||
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); |
||||
} |
||||
// 3. If the response lacks a |Connection| header field or the
|
||||
// |Connection| header field doesn't contain a token that is an
|
||||
// ASCII case-insensitive match for the value "Upgrade", the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
if !response.headers.header_is_ignore_case("Connection", "Upgrade") { |
||||
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); |
||||
} |
||||
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
|
||||
// the |Sec-WebSocket-Accept| contains a value other than the
|
||||
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
|
||||
// Connection_. (RFC 6455)
|
||||
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) { |
||||
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); |
||||
} |
||||
// 5. If the response includes a |Sec-WebSocket-Extensions| header
|
||||
// field and this header field indicates the use of an extension
|
||||
// that was not present in the client's handshake (the server has
|
||||
// indicated an extension not requested by the client), the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
|
||||
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
||||
// and this header field indicates the use of a subprotocol that was
|
||||
// not present in the client's handshake (the server has indicated a
|
||||
// subprotocol not requested by the client), the client MUST _Fail
|
||||
// the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
|
||||
Ok(()) |
||||
} |
||||
} |
||||
|
||||
/// Internal state of the client handshake.
|
||||
enum HandshakeState { |
||||
SendingRequest(Cursor<Vec<u8>>), |
||||
ReceivingResponse(InputBuffer), |
||||
} |
||||
|
||||
/// Server response.
|
||||
pub struct Response { |
||||
code: u16, |
||||
headers: Headers, |
||||
} |
||||
|
||||
impl Response { |
||||
/// Parse the response from a stream.
|
||||
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> { |
||||
Response::parse_http(input) |
||||
} |
||||
} |
||||
|
||||
impl Httparse for Response { |
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
||||
let mut req = httparse::Response::new(&mut hbuffer); |
||||
Ok(match req.parse(buf)? { |
||||
Status::Partial => None, |
||||
Status::Complete(size) => Some((size, Response::from_httparse(req)?)), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { |
||||
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { |
||||
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
||||
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); |
||||
} |
||||
Ok(Response { |
||||
code: raw.code.expect("Bug: no HTTP response code"), |
||||
headers: Headers::from_httparse(raw.headers)?, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
/// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||
fn generate_key() -> String { |
||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||
// when decoded, is 16 bytes in length (RFC 6455)
|
||||
let r: [u8; 16] = rand::random(); |
||||
base64::encode(&r) |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
|
||||
use super::{Response, generate_key}; |
||||
|
||||
use std::io::Cursor; |
||||
|
||||
#[test] |
||||
fn random_keys() { |
||||
let k1 = generate_key(); |
||||
println!("Generated random key 1: {}", k1); |
||||
let k2 = generate_key(); |
||||
println!("Generated random key 2: {}", k2); |
||||
assert_ne!(k1, k2); |
||||
assert_eq!(k1.len(), k2.len()); |
||||
assert_eq!(k1.len(), 24); |
||||
assert_eq!(k2.len(), 24); |
||||
assert!(k1.ends_with("==")); |
||||
assert!(k2.ends_with("==")); |
||||
assert!(k1[..22].find("=").is_none()); |
||||
assert!(k2[..22].find("=").is_none()); |
||||
} |
||||
|
||||
#[test] |
||||
fn response_parsing() { |
||||
const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; |
||||
let mut inp = Cursor::new(data); |
||||
let req = Response::parse(&mut inp).unwrap().unwrap(); |
||||
assert_eq!(req.code, 200); |
||||
assert_eq!(req.headers.find_first("Content-Type"), Some(&b"text/html"[..])); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,183 @@ |
||||
pub mod client; |
||||
pub mod server; |
||||
#[cfg(feature="tls")] |
||||
pub mod tls; |
||||
|
||||
use std::ascii::AsciiExt; |
||||
use std::str::from_utf8; |
||||
|
||||
use base64; |
||||
use bytes::Buf; |
||||
use httparse; |
||||
use httparse::Status; |
||||
use sha1::Sha1; |
||||
|
||||
use error::Result; |
||||
|
||||
// Limit the number of header lines.
|
||||
const MAX_HEADERS: usize = 124; |
||||
|
||||
/// A handshake state.
|
||||
pub trait Handshake: Sized { |
||||
/// Resulting stream of this handshake.
|
||||
type Stream; |
||||
/// Perform a single handshake round.
|
||||
fn handshake(self) -> Result<HandshakeResult<Self>>; |
||||
/// Perform handshake to the end in a blocking mode.
|
||||
fn handshake_wait(self) -> Result<Self::Stream> { |
||||
let mut hs = self; |
||||
loop { |
||||
hs = match hs.handshake()? { |
||||
HandshakeResult::Done(stream) => return Ok(stream), |
||||
HandshakeResult::Incomplete(s) => s, |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// A handshake result.
|
||||
pub enum HandshakeResult<H: Handshake> { |
||||
/// Handshake is done, a WebSocket stream is ready.
|
||||
Done(H::Stream), |
||||
/// Handshake is not done, call handshake() again.
|
||||
Incomplete(H), |
||||
} |
||||
|
||||
impl<H: Handshake> HandshakeResult<H> { |
||||
pub fn map<R, F>(self, func: F) -> HandshakeResult<R> |
||||
where R: Handshake<Stream = H::Stream>, |
||||
F: FnOnce(H) -> R, |
||||
{ |
||||
match self { |
||||
HandshakeResult::Done(s) => HandshakeResult::Done(s), |
||||
HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)), |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
|
||||
fn convert_key(input: &[u8]) -> Result<String> { |
||||
// ... field is constructed by concatenating /key/ ...
|
||||
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
|
||||
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
||||
let mut sha1 = Sha1::new(); |
||||
sha1.update(input); |
||||
sha1.update(WS_GUID); |
||||
Ok(base64::encode(&sha1.digest().bytes())) |
||||
} |
||||
|
||||
/// HTTP request or response headers.
|
||||
#[derive(Debug)] |
||||
pub struct Headers { |
||||
data: Vec<(String, Box<[u8]>)>, |
||||
} |
||||
|
||||
impl Headers { |
||||
|
||||
/// Get first header with the given name, if any.
|
||||
pub fn find_first(&self, name: &str) -> Option<&[u8]> { |
||||
self.data.iter() |
||||
.find(|&&(ref n, _)| n.eq_ignore_ascii_case(name)) |
||||
.map(|&(_, ref v)| v.as_ref()) |
||||
} |
||||
|
||||
/// Check if the given header has the given value.
|
||||
pub fn header_is(&self, name: &str, value: &str) -> bool { |
||||
self.find_first(name) |
||||
.map(|v| v == value.as_bytes()) |
||||
.unwrap_or(false) |
||||
} |
||||
|
||||
/// Check if the given header has the given value (case-insensitive).
|
||||
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { |
||||
self.find_first(name).ok_or(()) |
||||
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) |
||||
.map(|val| val.eq_ignore_ascii_case(value)) |
||||
.unwrap_or(false) |
||||
} |
||||
|
||||
/// Try to parse data and return headers, if any.
|
||||
fn parse<B: Buf>(input: &mut B) -> Result<Option<Headers>> { |
||||
Headers::parse_http(input) |
||||
} |
||||
|
||||
} |
||||
|
||||
/// Trait to read HTTP parseable objects.
|
||||
trait Httparse: Sized { |
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>>; |
||||
fn parse_http<B: Buf>(input: &mut B) -> Result<Option<Self>> { |
||||
Ok(match Self::httparse(input.bytes())? { |
||||
Some((size, obj)) => { |
||||
input.advance(size); |
||||
Some(obj) |
||||
}, |
||||
None => None, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
/// Trait to convert raw objects into HTTP parseables.
|
||||
trait FromHttparse<T>: Sized { |
||||
fn from_httparse(raw: T) -> Result<Self>; |
||||
} |
||||
|
||||
impl Httparse for Headers { |
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
||||
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { |
||||
Status::Partial => None, |
||||
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { |
||||
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> { |
||||
Ok(Headers { |
||||
data: raw.iter() |
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) |
||||
.collect(), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
|
||||
use super::{Headers, convert_key}; |
||||
|
||||
use std::io::Cursor; |
||||
|
||||
#[test] |
||||
fn key_conversion() { |
||||
// example from RFC 6455
|
||||
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), |
||||
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); |
||||
} |
||||
|
||||
#[test] |
||||
fn headers() { |
||||
const data: &'static [u8] = |
||||
b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"; |
||||
let mut inp = Cursor::new(data); |
||||
let hdr = Headers::parse(&mut inp).unwrap().unwrap(); |
||||
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); |
||||
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); |
||||
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); |
||||
|
||||
assert!(hdr.header_is("upgrade", "websocket")); |
||||
assert!(!hdr.header_is("upgrade", "Websocket")); |
||||
assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); |
||||
} |
||||
|
||||
#[test] |
||||
fn headers_incomplete() { |
||||
const data: &'static [u8] = |
||||
b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"; |
||||
let mut inp = Cursor::new(data); |
||||
let hdr = Headers::parse(&mut inp).unwrap(); |
||||
assert!(hdr.is_none()); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,90 @@ |
||||
use bytes::Buf; |
||||
use httparse; |
||||
use httparse::Status; |
||||
|
||||
use error::{Error, Result}; |
||||
use super::{Headers, Httparse, FromHttparse, convert_key, MAX_HEADERS}; |
||||
|
||||
/// Request from the client.
|
||||
pub struct Request { |
||||
path: String, |
||||
headers: Headers, |
||||
} |
||||
|
||||
impl Request { |
||||
/// Parse the request from a stream.
|
||||
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> { |
||||
Request::parse_http(input) |
||||
} |
||||
/// Reply to the response.
|
||||
pub fn reply(&self) -> Result<Vec<u8>> { |
||||
let key = self.headers.find_first("Sec-WebSocket-Key") |
||||
.ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?; |
||||
let reply = format!("\ |
||||
HTTP/1.1 101 Switching Protocols\r\n\ |
||||
Connection: Upgrade\r\n\ |
||||
Upgrade: websocket\r\n\ |
||||
Sec-WebSocket-Accept: {}\r\n\ |
||||
\r\n", convert_key(key)?); |
||||
Ok(reply.into()) |
||||
} |
||||
} |
||||
|
||||
impl Httparse for Request { |
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
||||
let mut req = httparse::Request::new(&mut hbuffer); |
||||
Ok(match req.parse(buf)? { |
||||
Status::Partial => None, |
||||
Status::Complete(size) => Some((size, Request::from_httparse(req)?)), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { |
||||
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> { |
||||
if raw.method.expect("Bug: no method in header") != "GET" { |
||||
return Err(Error::Protocol("Method is not GET".into())); |
||||
} |
||||
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
||||
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); |
||||
} |
||||
Ok(Request { |
||||
path: raw.path.expect("Bug: no path in header").into(), |
||||
headers: Headers::from_httparse(raw.headers)? |
||||
}) |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
|
||||
use super::Request; |
||||
|
||||
use std::io::Cursor; |
||||
|
||||
#[test] |
||||
fn request_parsing() { |
||||
const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; |
||||
let mut inp = Cursor::new(data); |
||||
let req = Request::parse(&mut inp).unwrap().unwrap(); |
||||
assert_eq!(req.path, "/script.ws"); |
||||
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); |
||||
} |
||||
|
||||
#[test] |
||||
fn request_replying() { |
||||
const data: &'static [u8] = b"\ |
||||
GET /script.ws HTTP/1.1\r\n\ |
||||
Host: foo.com\r\n\ |
||||
Connection: upgrade\r\n\ |
||||
Upgrade: websocket\r\n\ |
||||
Sec-WebSocket-Version: 13\r\n\ |
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ |
||||
\r\n"; |
||||
let mut inp = Cursor::new(data); |
||||
let req = Request::parse(&mut inp).unwrap().unwrap(); |
||||
let reply = req.reply().unwrap(); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,14 @@ |
||||
use native_tls; |
||||
|
||||
use stream::Stream; |
||||
use super::{Handshake, HandshakeResult}; |
||||
|
||||
pub struct TlsHandshake { |
||||
|
||||
} |
||||
|
||||
impl Handshale for TlsHandshake { |
||||
type Stream = Stream; |
||||
fn handshake(self) -> Result<HandshakeResult<Self>> { |
||||
} |
||||
} |
@ -0,0 +1,85 @@ |
||||
use std::io::{Cursor, Read, Result as IoResult}; |
||||
use bytes::{Buf, BufMut}; |
||||
|
||||
/// A FIFO buffer for reading packets from network.
|
||||
pub struct InputBuffer(Cursor<Vec<u8>>); |
||||
|
||||
/// Size limit error.
|
||||
pub struct SizeLimit; |
||||
|
||||
impl InputBuffer { |
||||
/// Create a new empty one.
|
||||
pub fn with_capacity(capacity: usize) -> Self { |
||||
InputBuffer(Cursor::new(Vec::with_capacity(capacity))) |
||||
} |
||||
|
||||
/// Create a new one from partially read data.
|
||||
pub fn from_partially_read(part: Vec<u8>) -> Self { |
||||
InputBuffer(Cursor::new(part)) |
||||
} |
||||
|
||||
/// Reserve the given amount of space.
|
||||
pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{ |
||||
if self.inp_mut().remaining_mut() >= space { |
||||
// We have enough space right now.
|
||||
Ok(()) |
||||
} else { |
||||
let pos = self.out().position() as usize; |
||||
self.inp_mut().drain(0..pos); |
||||
self.out_mut().set_position(0); |
||||
let avail = self.inp_mut().capacity() - self.inp_mut().len(); |
||||
if space <= avail { |
||||
Ok(()) |
||||
} else if self.inp_mut().capacity() + space > limit { |
||||
Err(SizeLimit) |
||||
} else { |
||||
self.inp_mut().reserve(space - avail); |
||||
Ok(()) |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Read data from stream into the buffer.
|
||||
pub fn read_from<S: Read>(&mut self, stream: &mut S) -> IoResult<usize> { |
||||
let size; |
||||
let buf = self.inp_mut(); |
||||
unsafe { |
||||
size = stream.read(buf.bytes_mut())?; |
||||
buf.advance_mut(size); |
||||
} |
||||
Ok(size) |
||||
} |
||||
|
||||
/// Get the rest of the buffer and destroy the buffer.
|
||||
pub fn into_vec(mut self) -> Vec<u8> { |
||||
let pos = self.out().position() as usize; |
||||
self.inp_mut().drain(0..pos); |
||||
self.0.into_inner() |
||||
} |
||||
|
||||
/// The output end (to the application).
|
||||
pub fn out(&self) -> &Cursor<Vec<u8>> { |
||||
&self.0 // the cursor itself
|
||||
} |
||||
/// The output end (to the application).
|
||||
pub fn out_mut(&mut self) -> &mut Cursor<Vec<u8>> { |
||||
&mut self.0 // the cursor itself
|
||||
} |
||||
|
||||
/// The input end (to the network).
|
||||
fn inp_mut(&mut self) -> &mut Vec<u8> { |
||||
self.0.get_mut() // underlying vector
|
||||
} |
||||
} |
||||
|
||||
impl Buf for InputBuffer { |
||||
fn remaining(&self) -> usize { |
||||
Buf::remaining(self.out()) |
||||
} |
||||
fn bytes(&self) -> &[u8] { |
||||
Buf::bytes(self.out()) |
||||
} |
||||
fn advance(&mut self, size: usize) { |
||||
Buf::advance(self.out_mut(), size) |
||||
} |
||||
} |
@ -0,0 +1,28 @@ |
||||
//! Lightweight, flexible WebSockets for Rust.
|
||||
#![deny(
|
||||
missing_copy_implementations, |
||||
trivial_casts, trivial_numeric_casts, |
||||
unstable_features, |
||||
unused_must_use, |
||||
unused_mut, |
||||
unused_imports, |
||||
unused_import_braces)] |
||||
|
||||
#[macro_use] extern crate log; |
||||
extern crate base64; |
||||
extern crate byteorder; |
||||
extern crate bytes; |
||||
extern crate httparse; |
||||
extern crate rand; |
||||
extern crate sha1; |
||||
extern crate url; |
||||
extern crate utf8; |
||||
#[cfg(feature="tls")] extern crate native_tls; |
||||
|
||||
pub mod error; |
||||
pub mod protocol; |
||||
pub mod client; |
||||
pub mod handshake; |
||||
|
||||
mod input_buffer; |
||||
mod stream; |
@ -0,0 +1,274 @@ |
||||
use std::fmt; |
||||
use std::convert::{Into, From}; |
||||
|
||||
/// WebSocket message opcode as in RFC 6455.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)] |
||||
pub enum OpCode { |
||||
Data(Data), |
||||
Control(Control), |
||||
} |
||||
|
||||
/// Data opcodes as in RFC 6455
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)] |
||||
pub enum Data { |
||||
/// 0x0 denotes a continuation frame
|
||||
Continue, |
||||
/// 0x1 denotes a text frame
|
||||
Text, |
||||
/// 0x2 denotes a binary frame
|
||||
Binary, |
||||
/// 0x3-7 are reserved for further non-control frames
|
||||
Reserved(u8), |
||||
} |
||||
|
||||
/// Control opcodes as in RFC 6455
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)] |
||||
pub enum Control { |
||||
/// 0x8 denotes a connection close
|
||||
Close, |
||||
/// 0x9 denotes a ping
|
||||
Ping, |
||||
/// 0xa denotes a pong
|
||||
Pong, |
||||
/// 0xb-f are reserved for further control frames
|
||||
Reserved(u8), |
||||
} |
||||
|
||||
impl fmt::Display for Data { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
||||
match *self { |
||||
Data::Continue => write!(f, "CONTINUE"), |
||||
Data::Text => write!(f, "TEXT"), |
||||
Data::Binary => write!(f, "BINARY"), |
||||
Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl fmt::Display for Control { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
||||
match *self { |
||||
Control::Close => write!(f, "CLOSE"), |
||||
Control::Ping => write!(f, "PING"), |
||||
Control::Pong => write!(f, "PONG"), |
||||
Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl fmt::Display for OpCode { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
||||
match *self { |
||||
OpCode::Data(d) => d.fmt(f), |
||||
OpCode::Control(c) => c.fmt(f), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl Into<u8> for OpCode { |
||||
fn into(self) -> u8 { |
||||
use self::Data::{Continue, Text, Binary}; |
||||
use self::Control::{Close, Ping, Pong}; |
||||
use self::OpCode::*; |
||||
match self { |
||||
Data(Continue) => 0, |
||||
Data(Text) => 1, |
||||
Data(Binary) => 2, |
||||
Data(self::Data::Reserved(i)) => i, |
||||
|
||||
Control(Close) => 8, |
||||
Control(Ping) => 9, |
||||
Control(Pong) => 10, |
||||
Control(self::Control::Reserved(i)) => i, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl From<u8> for OpCode { |
||||
fn from(byte: u8) -> OpCode { |
||||
use self::Data::{Continue, Text, Binary}; |
||||
use self::Control::{Close, Ping, Pong}; |
||||
use self::OpCode::*; |
||||
match byte { |
||||
0 => Data(Continue), |
||||
1 => Data(Text), |
||||
2 => Data(Binary), |
||||
i @ 3 ... 7 => Data(self::Data::Reserved(i)), |
||||
8 => Control(Close), |
||||
9 => Control(Ping), |
||||
10 => Control(Pong), |
||||
i @ 11 ... 15 => Control(self::Control::Reserved(i)), |
||||
_ => panic!("Bug: OpCode out of range"), |
||||
} |
||||
} |
||||
} |
||||
|
||||
use self::CloseCode::*; |
||||
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
|
||||
#[derive(Debug, Eq, PartialEq, Clone, Copy)] |
||||
pub enum CloseCode { |
||||
/// Indicates a normal closure, meaning that the purpose for
|
||||
/// which the connection was established has been fulfilled.
|
||||
Normal, |
||||
/// Indicates that an endpoint is "going away", such as a server
|
||||
/// going down or a browser having navigated away from a page.
|
||||
Away, |
||||
/// Indicates that an endpoint is terminating the connection due
|
||||
/// to a protocol error.
|
||||
Protocol, |
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a type of data it cannot accept (e.g., an
|
||||
/// endpoint that understands only text data MAY send this if it
|
||||
/// receives a binary message).
|
||||
Unsupported, |
||||
/// Indicates that no status code was included in a closing frame. This
|
||||
/// close code makes it possible to use a single method, `on_close` to
|
||||
/// handle even cases where no close code was provided.
|
||||
Status, |
||||
/// Indicates an abnormal closure. If the abnormal closure was due to an
|
||||
/// error, this close code will not be used. Instead, the `on_error` method
|
||||
/// of the handler will be called with the error. However, if the connection
|
||||
/// is simply dropped, without an error, this close code will be sent to the
|
||||
/// handler.
|
||||
Abnormal, |
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received data within a message that was not
|
||||
/// consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
|
||||
/// data within a text message).
|
||||
Invalid, |
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a message that violates its policy. This
|
||||
/// is a generic status code that can be returned when there is no
|
||||
/// other more suitable status code (e.g., Unsupported or Size) or if there
|
||||
/// is a need to hide specific details about the policy.
|
||||
Policy, |
||||
/// Indicates that an endpoint is terminating the connection
|
||||
/// because it has received a message that is too big for it to
|
||||
/// process.
|
||||
Size, |
||||
/// Indicates that an endpoint (client) is terminating the
|
||||
/// connection because it has expected the server to negotiate one or
|
||||
/// more extension, but the server didn't return them in the response
|
||||
/// message of the WebSocket handshake. The list of extensions that
|
||||
/// are needed should be given as the reason for closing.
|
||||
/// Note that this status code is not used by the server, because it
|
||||
/// can fail the WebSocket handshake instead.
|
||||
Extension, |
||||
/// Indicates that a server is terminating the connection because
|
||||
/// it encountered an unexpected condition that prevented it from
|
||||
/// fulfilling the request.
|
||||
Error, |
||||
/// Indicates that the server is restarting. A client may choose to reconnect,
|
||||
/// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
|
||||
Restart, |
||||
/// Indicates that the server is overloaded and the client should either connect
|
||||
/// to a different IP (when multiple targets exist), or reconnect to the same IP
|
||||
/// when a user has performed an action.
|
||||
Again, |
||||
#[doc(hidden)] |
||||
Tls, |
||||
#[doc(hidden)] |
||||
Reserved(u16), |
||||
#[doc(hidden)] |
||||
Iana(u16), |
||||
#[doc(hidden)] |
||||
Library(u16), |
||||
#[doc(hidden)] |
||||
Bad(u16), |
||||
} |
||||
|
||||
impl CloseCode { |
||||
/// Check if this CloseCode is allowed.
|
||||
pub fn is_allowed(&self) -> bool { |
||||
match *self { |
||||
Bad(_) => false, |
||||
Reserved(_) => false, |
||||
Status => false, |
||||
Abnormal => false, |
||||
Tls => false, |
||||
_ => true, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl Into<u16> for CloseCode { |
||||
fn into(self) -> u16 { |
||||
match self { |
||||
Normal => 1000, |
||||
Away => 1001, |
||||
Protocol => 1002, |
||||
Unsupported => 1003, |
||||
Status => 1005, |
||||
Abnormal => 1006, |
||||
Invalid => 1007, |
||||
Policy => 1008, |
||||
Size => 1009, |
||||
Extension => 1010, |
||||
Error => 1011, |
||||
Restart => 1012, |
||||
Again => 1013, |
||||
Tls => 1015, |
||||
Reserved(code) => code, |
||||
Iana(code) => code, |
||||
Library(code) => code, |
||||
Bad(code) => code, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl From<u16> for CloseCode { |
||||
fn from(code: u16) -> CloseCode { |
||||
match code { |
||||
1000 => Normal, |
||||
1001 => Away, |
||||
1002 => Protocol, |
||||
1003 => Unsupported, |
||||
1005 => Status, |
||||
1006 => Abnormal, |
||||
1007 => Invalid, |
||||
1008 => Policy, |
||||
1009 => Size, |
||||
1010 => Extension, |
||||
1011 => Error, |
||||
1012 => Restart, |
||||
1013 => Again, |
||||
1015 => Tls, |
||||
1...999 => Bad(code), |
||||
1000...2999 => Reserved(code), |
||||
3000...3999 => Iana(code), |
||||
4000...4999 => Library(code), |
||||
_ => Bad(code) |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
use super::*; |
||||
|
||||
#[test] |
||||
fn opcode_from_u8() { |
||||
let byte = 2u8; |
||||
assert_eq!(OpCode::from(byte), OpCode::Data(Data::Binary)); |
||||
} |
||||
|
||||
#[test] |
||||
fn opcode_into_u8() { |
||||
let text = OpCode::Data(Data::Text); |
||||
let byte: u8 = text.into(); |
||||
assert_eq!(byte, 1u8); |
||||
} |
||||
|
||||
#[test] |
||||
fn closecode_from_u16() { |
||||
let byte = 1008u16; |
||||
assert_eq!(CloseCode::from(byte), CloseCode::Policy); |
||||
} |
||||
|
||||
#[test] |
||||
fn closecode_into_u16() { |
||||
let text = CloseCode::Away; |
||||
let byte: u16 = text.into(); |
||||
assert_eq!(byte, 1001u16); |
||||
} |
||||
} |
@ -0,0 +1,516 @@ |
||||
use std::fmt; |
||||
use std::mem::transmute; |
||||
use std::io::{Cursor, Read, Write}; |
||||
use std::default::Default; |
||||
use std::iter::FromIterator; |
||||
use std::string::{String, FromUtf8Error}; |
||||
use std::result::Result as StdResult; |
||||
use byteorder::{ByteOrder, NetworkEndian}; |
||||
use bytes::BufMut; |
||||
|
||||
use rand; |
||||
|
||||
use error::{Error, Result}; |
||||
use super::coding::{OpCode, Control, Data, CloseCode}; |
||||
|
||||
fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { |
||||
let iter = buf.iter_mut().zip(mask.iter().cycle()); |
||||
for (byte, &key) in iter { |
||||
*byte ^= key |
||||
} |
||||
} |
||||
|
||||
#[inline] |
||||
fn generate_mask() -> [u8; 4] { |
||||
unsafe { transmute(rand::random::<u32>()) } |
||||
} |
||||
|
||||
/// A struct representing a WebSocket frame.
|
||||
#[derive(Debug, Clone)] |
||||
pub struct Frame { |
||||
finished: bool, |
||||
rsv1: bool, |
||||
rsv2: bool, |
||||
rsv3: bool, |
||||
opcode: OpCode, |
||||
|
||||
mask: Option<[u8; 4]>, |
||||
|
||||
payload: Vec<u8>, |
||||
} |
||||
|
||||
impl Frame { |
||||
|
||||
/// Get the length of the frame.
|
||||
/// This is the length of the header + the length of the payload.
|
||||
#[inline] |
||||
pub fn len(&self) -> usize { |
||||
let mut header_length = 2; |
||||
let payload_len = self.payload().len(); |
||||
if payload_len > 125 { |
||||
if payload_len <= u16::max_value() as usize { |
||||
header_length += 2; |
||||
} else { |
||||
header_length += 8; |
||||
} |
||||
} |
||||
|
||||
if self.is_masked() { |
||||
header_length += 4; |
||||
} |
||||
|
||||
header_length + payload_len |
||||
} |
||||
|
||||
/// Test whether the frame is a final frame.
|
||||
#[inline] |
||||
pub fn is_final(&self) -> bool { |
||||
self.finished |
||||
} |
||||
|
||||
/// Test whether the first reserved bit is set.
|
||||
#[inline] |
||||
pub fn has_rsv1(&self) -> bool { |
||||
self.rsv1 |
||||
} |
||||
|
||||
/// Test whether the second reserved bit is set.
|
||||
#[inline] |
||||
pub fn has_rsv2(&self) -> bool { |
||||
self.rsv2 |
||||
} |
||||
|
||||
/// Test whether the third reserved bit is set.
|
||||
#[inline] |
||||
pub fn has_rsv3(&self) -> bool { |
||||
self.rsv3 |
||||
} |
||||
|
||||
/// Get the OpCode of the frame.
|
||||
#[inline] |
||||
pub fn opcode(&self) -> OpCode { |
||||
self.opcode |
||||
} |
||||
|
||||
/// Get a reference to the frame's payload.
|
||||
#[inline] |
||||
pub fn payload(&self) -> &Vec<u8> { |
||||
&self.payload |
||||
} |
||||
|
||||
// Test whether the frame is masked.
|
||||
#[doc(hidden)] |
||||
#[inline] |
||||
pub fn is_masked(&self) -> bool { |
||||
self.mask.is_some() |
||||
} |
||||
|
||||
// Get an optional reference to the frame's mask.
|
||||
#[doc(hidden)] |
||||
#[allow(dead_code)] |
||||
#[inline] |
||||
pub fn mask(&self) -> Option<&[u8; 4]> { |
||||
self.mask.as_ref() |
||||
} |
||||
|
||||
/// Make this frame a final frame.
|
||||
#[allow(dead_code)] |
||||
#[inline] |
||||
pub fn set_final(&mut self, is_final: bool) -> &mut Frame { |
||||
self.finished = is_final; |
||||
self |
||||
} |
||||
|
||||
/// Set the first reserved bit.
|
||||
#[inline] |
||||
pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { |
||||
self.rsv1 = has_rsv1; |
||||
self |
||||
} |
||||
|
||||
/// Set the second reserved bit.
|
||||
#[inline] |
||||
pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { |
||||
self.rsv2 = has_rsv2; |
||||
self |
||||
} |
||||
|
||||
/// Set the third reserved bit.
|
||||
#[inline] |
||||
pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { |
||||
self.rsv3 = has_rsv3; |
||||
self |
||||
} |
||||
|
||||
/// Set the OpCode.
|
||||
#[allow(dead_code)] |
||||
#[inline] |
||||
pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { |
||||
self.opcode = opcode; |
||||
self |
||||
} |
||||
|
||||
/// Edit the frame's payload.
|
||||
#[allow(dead_code)] |
||||
#[inline] |
||||
pub fn payload_mut(&mut self) -> &mut Vec<u8> { |
||||
&mut self.payload |
||||
} |
||||
|
||||
// Generate a new mask for this frame.
|
||||
//
|
||||
// This method simply generates and stores the mask. It does not change the payload data.
|
||||
// Instead, the payload data will be masked with the generated mask when the frame is sent
|
||||
// to the other endpoint.
|
||||
#[doc(hidden)] |
||||
#[inline] |
||||
pub fn set_mask(&mut self) -> &mut Frame { |
||||
self.mask = Some(generate_mask()); |
||||
self |
||||
} |
||||
|
||||
// This method unmasks the payload and should only be called on frames that are actually
|
||||
// masked. In other words, those frames that have just been received from a client endpoint.
|
||||
#[doc(hidden)] |
||||
#[inline] |
||||
pub fn remove_mask(&mut self) { |
||||
self.mask.and_then(|mask| { |
||||
Some(apply_mask(&mut self.payload, &mask)) |
||||
}); |
||||
self.mask = None; |
||||
} |
||||
|
||||
/// Consume the frame into its payload as binary.
|
||||
#[inline] |
||||
pub fn into_data(self) -> Vec<u8> { |
||||
self.payload |
||||
} |
||||
|
||||
/// Consume the frame into its payload as string.
|
||||
#[inline] |
||||
pub fn into_string(self) -> StdResult<String, FromUtf8Error> { |
||||
String::from_utf8(self.payload) |
||||
} |
||||
|
||||
/// Consume the frame into a closing frame.
|
||||
#[inline] |
||||
pub fn into_close(self) -> Result<Option<(CloseCode, String)>> { |
||||
match self.payload.len() { |
||||
0 => Ok(None), |
||||
1 => Err(Error::Protocol("Invalid close sequence".into())), |
||||
_ => { |
||||
let mut data = self.payload; |
||||
let code = NetworkEndian::read_u16(&data[0..2]).into(); |
||||
data.drain(0..2); |
||||
let text = String::from_utf8(data)?; |
||||
Ok(Some((code, text))) |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Create a new data frame.
|
||||
#[inline] |
||||
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame { |
||||
debug_assert!(match code { |
||||
OpCode::Data(_) => true, |
||||
_ => false, |
||||
}, "Invalid opcode for data frame."); |
||||
|
||||
Frame { |
||||
finished: finished, |
||||
opcode: code, |
||||
payload: data, |
||||
.. Frame::default() |
||||
} |
||||
} |
||||
|
||||
/// Create a new Pong control frame.
|
||||
#[inline] |
||||
pub fn pong(data: Vec<u8>) -> Frame { |
||||
Frame { |
||||
opcode: OpCode::Control(Control::Pong), |
||||
payload: data, |
||||
.. Frame::default() |
||||
} |
||||
} |
||||
|
||||
/// Create a new Ping control frame.
|
||||
#[inline] |
||||
pub fn ping(data: Vec<u8>) -> Frame { |
||||
Frame { |
||||
opcode: OpCode::Control(Control::Ping), |
||||
payload: data, |
||||
.. Frame::default() |
||||
} |
||||
} |
||||
|
||||
/// Create a new Close control frame.
|
||||
#[inline] |
||||
pub fn close(msg: Option<(CloseCode, &str)>) -> Frame { |
||||
let payload = if let Some((code, reason)) = msg { |
||||
let raw: [u8; 2] = unsafe { |
||||
let u: u16 = code.into(); |
||||
transmute(u.to_be()) |
||||
}; |
||||
Vec::from_iter( |
||||
raw[..].iter() |
||||
.chain(reason.as_bytes().iter()) |
||||
.map(|&b| b)) |
||||
} else { |
||||
Vec::new() |
||||
}; |
||||
|
||||
Frame { |
||||
payload: payload, |
||||
.. Frame::default() |
||||
} |
||||
} |
||||
|
||||
/// Parse the input stream into a frame.
|
||||
pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> { |
||||
let size = cursor.get_ref().len() as u64 - cursor.position(); |
||||
let initial = cursor.position(); |
||||
trace!("Position in buffer {}", initial); |
||||
|
||||
let mut head = [0u8; 2]; |
||||
if try!(cursor.read(&mut head)) != 2 { |
||||
cursor.set_position(initial); |
||||
return Ok(None) |
||||
} |
||||
|
||||
trace!("Parsed headers {:?}", head); |
||||
|
||||
let first = head[0]; |
||||
let second = head[1]; |
||||
trace!("First: {:b}", first); |
||||
trace!("Second: {:b}", second); |
||||
|
||||
let finished = first & 0x80 != 0; |
||||
|
||||
let rsv1 = first & 0x40 != 0; |
||||
let rsv2 = first & 0x20 != 0; |
||||
let rsv3 = first & 0x10 != 0; |
||||
|
||||
let opcode = OpCode::from(first & 0x0F); |
||||
trace!("Opcode: {:?}", opcode); |
||||
|
||||
let masked = second & 0x80 != 0; |
||||
trace!("Masked: {:?}", masked); |
||||
|
||||
let mut header_length = 2; |
||||
|
||||
let mut length = (second & 0x7F) as u64; |
||||
|
||||
if length == 126 { |
||||
let mut length_bytes = [0u8; 2]; |
||||
if try!(cursor.read(&mut length_bytes)) != 2 { |
||||
cursor.set_position(initial); |
||||
return Ok(None) |
||||
} |
||||
|
||||
length = unsafe { |
||||
let mut wide: u16 = transmute(length_bytes); |
||||
wide = u16::from_be(wide); |
||||
wide |
||||
} as u64; |
||||
header_length += 2; |
||||
} else if length == 127 { |
||||
let mut length_bytes = [0u8; 8]; |
||||
if try!(cursor.read(&mut length_bytes)) != 8 { |
||||
cursor.set_position(initial); |
||||
return Ok(None) |
||||
} |
||||
|
||||
unsafe { length = transmute(length_bytes); } |
||||
length = u64::from_be(length); |
||||
header_length += 8; |
||||
} |
||||
trace!("Payload length: {}", length); |
||||
|
||||
let mask = if masked { |
||||
let mut mask_bytes = [0u8; 4]; |
||||
if try!(cursor.read(&mut mask_bytes)) != 4 { |
||||
cursor.set_position(initial); |
||||
return Ok(None) |
||||
} else { |
||||
header_length += 4; |
||||
Some(mask_bytes) |
||||
} |
||||
} else { |
||||
None |
||||
}; |
||||
|
||||
if size < length + header_length { |
||||
cursor.set_position(initial); |
||||
return Ok(None) |
||||
} |
||||
|
||||
let mut data = Vec::with_capacity(length as usize); |
||||
if length > 0 { |
||||
unsafe { |
||||
try!(cursor.read_exact(data.bytes_mut())); |
||||
data.advance_mut(length as usize); |
||||
} |
||||
} |
||||
|
||||
// Disallow bad opcode
|
||||
match opcode { |
||||
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { |
||||
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) |
||||
} |
||||
_ => () |
||||
} |
||||
|
||||
let frame = Frame { |
||||
finished: finished, |
||||
rsv1: rsv1, |
||||
rsv2: rsv2, |
||||
rsv3: rsv3, |
||||
opcode: opcode, |
||||
mask: mask, |
||||
payload: data, |
||||
}; |
||||
|
||||
|
||||
Ok(Some(frame)) |
||||
} |
||||
|
||||
/// Write a frame out to a buffer
|
||||
pub fn format<W>(mut self, w: &mut W) -> Result<()> |
||||
where W: Write |
||||
{ |
||||
let mut one = 0u8; |
||||
let code: u8 = self.opcode.into(); |
||||
if self.is_final() { |
||||
one |= 0x80; |
||||
} |
||||
if self.has_rsv1() { |
||||
one |= 0x40; |
||||
} |
||||
if self.has_rsv2() { |
||||
one |= 0x20; |
||||
} |
||||
if self.has_rsv3() { |
||||
one |= 0x10; |
||||
} |
||||
one |= code; |
||||
|
||||
let mut two = 0u8; |
||||
|
||||
if self.is_masked() { |
||||
two |= 0x80; |
||||
} |
||||
|
||||
if self.payload.len() < 126 { |
||||
two |= self.payload.len() as u8; |
||||
let headers = [one, two]; |
||||
try!(w.write(&headers)); |
||||
} else if self.payload.len() <= 65535 { |
||||
two |= 126; |
||||
let length_bytes: [u8; 2] = unsafe { |
||||
let short = self.payload.len() as u16; |
||||
transmute(short.to_be()) |
||||
}; |
||||
let headers = [one, two, length_bytes[0], length_bytes[1]]; |
||||
try!(w.write(&headers)); |
||||
} else { |
||||
two |= 127; |
||||
let length_bytes: [u8; 8] = unsafe { |
||||
let long = self.payload.len() as u64; |
||||
transmute(long.to_be()) |
||||
}; |
||||
let headers = [ |
||||
one, |
||||
two, |
||||
length_bytes[0], |
||||
length_bytes[1], |
||||
length_bytes[2], |
||||
length_bytes[3], |
||||
length_bytes[4], |
||||
length_bytes[5], |
||||
length_bytes[6], |
||||
length_bytes[7], |
||||
]; |
||||
try!(w.write(&headers)); |
||||
} |
||||
|
||||
if self.is_masked() { |
||||
let mask = self.mask.take().unwrap(); |
||||
apply_mask(&mut self.payload, &mask); |
||||
try!(w.write(&mask)); |
||||
} |
||||
|
||||
try!(w.write(&self.payload)); |
||||
Ok(()) |
||||
} |
||||
} |
||||
|
||||
impl Default for Frame { |
||||
fn default() -> Frame { |
||||
Frame { |
||||
finished: true, |
||||
rsv1: false, |
||||
rsv2: false, |
||||
rsv3: false, |
||||
opcode: OpCode::Control(Control::Close), |
||||
mask: None, |
||||
payload: Vec::new(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl fmt::Display for Frame { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
||||
write!(f, |
||||
" |
||||
<FRAME> |
||||
final: {} |
||||
reserved: {} {} {} |
||||
opcode: {} |
||||
length: {} |
||||
payload length: {} |
||||
payload: 0x{} |
||||
", |
||||
self.finished, |
||||
self.rsv1, |
||||
self.rsv2, |
||||
self.rsv3, |
||||
self.opcode, |
||||
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
|
||||
self.len(), |
||||
self.payload.len(), |
||||
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()) |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
use super::*; |
||||
|
||||
use super::super::coding::{OpCode, Data}; |
||||
use std::io::Cursor; |
||||
|
||||
#[test] |
||||
fn parse() { |
||||
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![ |
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 |
||||
]); |
||||
let frame = Frame::parse(&mut raw).unwrap().unwrap(); |
||||
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); |
||||
} |
||||
|
||||
#[test] |
||||
fn format() { |
||||
let frame = Frame::ping(vec![0x01, 0x02]); |
||||
let mut buf = Vec::with_capacity(frame.len()); |
||||
frame.format(&mut buf).unwrap(); |
||||
assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); |
||||
} |
||||
|
||||
#[test] |
||||
fn display() { |
||||
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); |
||||
let view = format!("{}", f); |
||||
view.contains("payload:"); |
||||
} |
||||
} |
@ -0,0 +1,135 @@ |
||||
pub mod coding; |
||||
|
||||
mod frame; |
||||
|
||||
pub use self::frame::Frame; |
||||
|
||||
use std::io::{Read, Write}; |
||||
|
||||
use input_buffer; |
||||
use error::{Error, Result}; |
||||
|
||||
const MIN_READ: usize = 4096; |
||||
|
||||
/// A reader and writer for WebSocket frames.
|
||||
pub struct FrameSocket<Stream> { |
||||
stream: Stream, |
||||
in_buffer: input_buffer::InputBuffer, |
||||
out_buffer: Vec<u8>, |
||||
} |
||||
|
||||
impl<Stream> FrameSocket<Stream> { |
||||
/// Create a new frame socket.
|
||||
pub fn new(stream: Stream) -> Self { |
||||
FrameSocket { |
||||
stream: stream, |
||||
in_buffer: input_buffer::InputBuffer::with_capacity(MIN_READ), |
||||
out_buffer: Vec::new(), |
||||
} |
||||
} |
||||
/// Create a new frame socket from partially read data.
|
||||
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self { |
||||
FrameSocket { |
||||
stream: stream, |
||||
in_buffer: input_buffer::InputBuffer::from_partially_read(part), |
||||
out_buffer: Vec::new(), |
||||
} |
||||
} |
||||
/// Extract a stream from the socket.
|
||||
pub fn into_inner(self) -> (Stream, Vec<u8>) { |
||||
(self.stream, self.in_buffer.into_vec()) |
||||
} |
||||
} |
||||
|
||||
impl<Stream> FrameSocket<Stream> |
||||
where Stream: Read |
||||
{ |
||||
/// Read a frame from stream.
|
||||
pub fn read_frame(&mut self) -> Result<Option<Frame>> { |
||||
loop { |
||||
if let Some(frame) = Frame::parse(&mut self.in_buffer.out_mut())? { |
||||
debug!("received frame {}", frame); |
||||
return Ok(Some(frame)); |
||||
} |
||||
// No full frames in buffer.
|
||||
self.in_buffer.reserve(MIN_READ, usize::max_value()) |
||||
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?; |
||||
let size = self.in_buffer.read_from(&mut self.stream)?; |
||||
if size == 0 { |
||||
debug!("no frame received"); |
||||
return Ok(None) |
||||
} |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
impl<Stream> FrameSocket<Stream> |
||||
where Stream: Write |
||||
{ |
||||
/// Write a frame to stream.
|
||||
pub fn write_frame(&mut self, frame: Frame) -> Result<()> { |
||||
debug!("writing frame {}", frame); |
||||
self.out_buffer.reserve(frame.len()); |
||||
frame.format(&mut self.out_buffer)?; |
||||
let len = self.stream.write(&self.out_buffer)?; |
||||
self.out_buffer.drain(0..len); |
||||
Ok(()) |
||||
} |
||||
} |
||||
|
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
|
||||
use super::{Frame, FrameSocket}; |
||||
|
||||
use std::io::Cursor; |
||||
|
||||
#[test] |
||||
fn read_frames() { |
||||
let raw = Cursor::new(vec![ |
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, |
||||
0x82, 0x03, 0x03, 0x02, 0x01, |
||||
0x99, |
||||
]); |
||||
let mut sock = FrameSocket::new(raw); |
||||
|
||||
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), |
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); |
||||
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), |
||||
vec![0x03, 0x02, 0x01]); |
||||
assert!(sock.read_frame().unwrap().is_none()); |
||||
|
||||
let (_, rest) = sock.into_inner(); |
||||
assert_eq!(rest, vec![0x99]); |
||||
} |
||||
|
||||
#[test] |
||||
fn from_partially_read() { |
||||
let raw = Cursor::new(vec![ |
||||
0x02, 0x03, 0x04, 0x05, 0x06, 0x07, |
||||
]); |
||||
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); |
||||
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), |
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); |
||||
} |
||||
|
||||
#[test] |
||||
fn write_frames() { |
||||
let mut sock = FrameSocket::new(Vec::new()); |
||||
|
||||
let frame = Frame::ping(vec![0x04, 0x05]); |
||||
sock.write_frame(frame).unwrap(); |
||||
|
||||
let frame = Frame::pong(vec![0x01]); |
||||
sock.write_frame(frame).unwrap(); |
||||
|
||||
let (buf, _) = sock.into_inner(); |
||||
assert_eq!(buf, vec![ |
||||
0x89, 0x02, 0x04, 0x05, |
||||
0x8a, 0x01, 0x01 |
||||
]); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,260 @@ |
||||
use std::convert::{From, Into, AsRef}; |
||||
use std::fmt; |
||||
use std::result::Result as StdResult; |
||||
use std::str; |
||||
|
||||
use error::Result; |
||||
|
||||
mod string_collect { |
||||
|
||||
use utf8; |
||||
|
||||
use error::{Error, Result}; |
||||
|
||||
pub struct StringCollector { |
||||
data: String, |
||||
decoder: utf8::Decoder, |
||||
} |
||||
|
||||
impl StringCollector { |
||||
pub fn new() -> Self { |
||||
StringCollector { |
||||
data: String::new(), |
||||
decoder: utf8::Decoder::new(), |
||||
} |
||||
} |
||||
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> { |
||||
let (sym, text, result) = self.decoder.decode(tail.as_ref()); |
||||
self.data.push_str(&sym); |
||||
self.data.push_str(text); |
||||
match result { |
||||
utf8::Result::Ok | utf8::Result::Incomplete => |
||||
Ok(()), |
||||
utf8::Result::Error { remaining_input_after_error: _ } => |
||||
Err(Error::Protocol("Invalid UTF8".into())), // FIXME
|
||||
} |
||||
} |
||||
pub fn into_string(self) -> Result<String> { |
||||
if self.decoder.has_incomplete_sequence() { |
||||
Err(Error::Protocol("Invalid UTF8".into())) // FIXME
|
||||
} else { |
||||
Ok(self.data) |
||||
} |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
use self::string_collect::StringCollector; |
||||
|
||||
/// A struct representing the incomplete message.
|
||||
pub struct IncompleteMessage { |
||||
collector: IncompleteMessageCollector, |
||||
} |
||||
|
||||
enum IncompleteMessageCollector { |
||||
Text(StringCollector), |
||||
Binary(Vec<u8>), |
||||
} |
||||
|
||||
impl IncompleteMessage { |
||||
/// Create new.
|
||||
pub fn new(message_type: IncompleteMessageType) -> Self { |
||||
IncompleteMessage { |
||||
collector: match message_type { |
||||
IncompleteMessageType::Binary => |
||||
IncompleteMessageCollector::Binary(Vec::new()), |
||||
IncompleteMessageType::Text => |
||||
IncompleteMessageCollector::Text(StringCollector::new()), |
||||
} |
||||
} |
||||
} |
||||
/// Add more data to an existing message.
|
||||
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> { |
||||
match self.collector { |
||||
IncompleteMessageCollector::Binary(ref mut v) => { |
||||
v.extend(tail.as_ref()); |
||||
Ok(()) |
||||
} |
||||
IncompleteMessageCollector::Text(ref mut t) => { |
||||
t.extend(tail) |
||||
} |
||||
} |
||||
} |
||||
/// Convert an incomplete message into a complete one.
|
||||
pub fn complete(self) -> Result<Message> { |
||||
match self.collector { |
||||
IncompleteMessageCollector::Binary(v) => { |
||||
Ok(Message::Binary(v)) |
||||
} |
||||
IncompleteMessageCollector::Text(t) => { |
||||
let text = t.into_string()?; |
||||
Ok(Message::Text(text)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// The type of incomplete message.
|
||||
pub enum IncompleteMessageType { |
||||
Text, |
||||
Binary, |
||||
} |
||||
|
||||
/// An enum representing the various forms of a WebSocket message.
|
||||
#[derive(Debug, Eq, PartialEq, Clone)] |
||||
pub enum Message { |
||||
/// A text WebSocket message
|
||||
Text(String), |
||||
/// A binary WebSocket message
|
||||
Binary(Vec<u8>), |
||||
} |
||||
|
||||
impl Message { |
||||
|
||||
/// Create a new text WebSocket message from a stringable.
|
||||
pub fn text<S>(string: S) -> Message |
||||
where S: Into<String> |
||||
{ |
||||
Message::Text(string.into()) |
||||
} |
||||
|
||||
/// Create a new binary WebSocket message by converting to Vec<u8>.
|
||||
pub fn binary<B>(bin: B) -> Message |
||||
where B: Into<Vec<u8>> |
||||
{ |
||||
Message::Binary(bin.into()) |
||||
} |
||||
|
||||
/// Indicates whether a message is a text message.
|
||||
pub fn is_text(&self) -> bool { |
||||
match *self { |
||||
Message::Text(_) => true, |
||||
Message::Binary(_) => false, |
||||
} |
||||
} |
||||
|
||||
/// Indicates whether a message is a binary message.
|
||||
pub fn is_binary(&self) -> bool { |
||||
match *self { |
||||
Message::Text(_) => false, |
||||
Message::Binary(_) => true, |
||||
} |
||||
} |
||||
|
||||
/// Get the length of the WebSocket message.
|
||||
pub fn len(&self) -> usize { |
||||
match *self { |
||||
Message::Text(ref string) => string.len(), |
||||
Message::Binary(ref data) => data.len(), |
||||
} |
||||
} |
||||
|
||||
/// Returns true if the WebSocket message has no content.
|
||||
/// For example, if the other side of the connection sent an empty string.
|
||||
pub fn is_empty(&self) -> bool { |
||||
match *self { |
||||
Message::Text(ref string) => string.is_empty(), |
||||
Message::Binary(ref data) => data.is_empty(), |
||||
} |
||||
} |
||||
|
||||
/// Consume the WebSocket and return it as binary data.
|
||||
pub fn into_data(self) -> Vec<u8> { |
||||
match self { |
||||
Message::Text(string) => string.into_bytes(), |
||||
Message::Binary(data) => data, |
||||
} |
||||
} |
||||
|
||||
/// Attempt to consume the WebSocket message and convert it to a String.
|
||||
pub fn into_text(self) -> Result<String> { |
||||
match self { |
||||
Message::Text(string) => Ok(string), |
||||
Message::Binary(data) => Ok(try!( |
||||
String::from_utf8(data).map_err(|err| err.utf8_error()))), |
||||
} |
||||
} |
||||
|
||||
/// Attempt to get a &str from the WebSocket message,
|
||||
/// this will try to convert binary data to utf8.
|
||||
pub fn to_text(&self) -> Result<&str> { |
||||
match *self { |
||||
Message::Text(ref string) => Ok(string), |
||||
Message::Binary(ref data) => Ok(try!(str::from_utf8(data))), |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
impl From<String> for Message { |
||||
fn from(string: String) -> Message { |
||||
Message::text(string) |
||||
} |
||||
} |
||||
|
||||
impl<'s> From<&'s str> for Message { |
||||
fn from(string: &'s str) -> Message { |
||||
Message::text(string) |
||||
} |
||||
} |
||||
|
||||
impl<'b> From<&'b [u8]> for Message { |
||||
fn from(data: &'b [u8]) -> Message { |
||||
Message::binary(data) |
||||
} |
||||
} |
||||
|
||||
impl From<Vec<u8>> for Message { |
||||
fn from(data: Vec<u8>) -> Message { |
||||
Message::binary(data) |
||||
} |
||||
} |
||||
|
||||
impl fmt::Display for Message { |
||||
fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> { |
||||
if let Ok(string) = self.to_text() { |
||||
write!(f, "{}", string) |
||||
} else { |
||||
write!(f, "Binary Data<length={}>", self.len()) |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
use super::*; |
||||
|
||||
#[test] |
||||
fn display() { |
||||
let t = Message::text(format!("test")); |
||||
assert_eq!(t.to_string(), "test".to_owned()); |
||||
|
||||
let bin = Message::binary(vec![0, 1, 3, 4, 241]); |
||||
assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned()); |
||||
} |
||||
|
||||
#[test] |
||||
fn binary_convert() { |
||||
let bin = [6u8, 7, 8, 9, 10, 241]; |
||||
let msg = Message::from(&bin[..]); |
||||
assert!(msg.is_binary()); |
||||
assert!(msg.into_text().is_err()); |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn binary_convert_vec() { |
||||
let bin = vec![6u8, 7, 8, 9, 10, 241]; |
||||
let msg = Message::from(bin); |
||||
assert!(msg.is_binary()); |
||||
assert!(msg.into_text().is_err()); |
||||
} |
||||
|
||||
#[test] |
||||
fn text_convert() { |
||||
let s = "kiwotsukete"; |
||||
let msg = Message::from(s); |
||||
assert!(msg.is_text()); |
||||
} |
||||
} |
@ -0,0 +1,374 @@ |
||||
//! Generic WebSocket protocol implementation
|
||||
|
||||
mod frame; |
||||
mod message; |
||||
|
||||
pub use self::message::Message; |
||||
|
||||
use self::message::{IncompleteMessage, IncompleteMessageType}; |
||||
use std::collections::VecDeque; |
||||
use std::io::{Read, Write}; |
||||
use std::mem::replace; |
||||
|
||||
use error::{Error, Result}; |
||||
use self::frame::{Frame, FrameSocket}; |
||||
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode}; |
||||
|
||||
/// Indicates a Client or Server role of the websocket
|
||||
#[derive(Debug, Clone, Copy)] |
||||
pub enum Role { |
||||
/// This socket is a server
|
||||
Server, |
||||
/// This socket is a client
|
||||
Client, |
||||
} |
||||
|
||||
/// WebSocket input-output stream
|
||||
pub struct WebSocket<Stream> { |
||||
/// Server or client?
|
||||
role: Role, |
||||
/// The underlying socket.
|
||||
socket: FrameSocket<Stream>, |
||||
/// The state of processing, either "active" or "closing".
|
||||
state: WebSocketState, |
||||
/// Receive: an incomplete message being processed.
|
||||
incomplete: Option<IncompleteMessage>, |
||||
/// Send: a data send queue.
|
||||
send_queue: VecDeque<Frame>, |
||||
/// Send: an OOB pong message.
|
||||
pong: Option<Frame>, |
||||
} |
||||
|
||||
impl<Stream> WebSocket<Stream> |
||||
where Stream: Read + Write |
||||
{ |
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
pub fn from_raw_socket(stream: Stream, role: Role) -> Self { |
||||
WebSocket::from_frame_socket(FrameSocket::new(stream), role) |
||||
} |
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
pub fn from_partially_read(stream: Stream, part: Vec<u8>, role: Role) -> Self { |
||||
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) |
||||
} |
||||
|
||||
/// Read a message from stream, if possible.
|
||||
pub fn read_message(&mut self) -> Result<Message> { |
||||
loop { |
||||
self.send_pending()?; // FIXME
|
||||
if let Some(message) = self.read_message_frame()? { |
||||
debug!("Received message {}", message); |
||||
return Ok(message) |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Send a message to stream, if possible.
|
||||
pub fn write_message(&mut self, message: Message) -> Result<()> { |
||||
let frame = { |
||||
let opcode = match message { |
||||
Message::Text(_) => OpData::Text, |
||||
Message::Binary(_) => OpData::Binary, |
||||
}; |
||||
Frame::message(message.into_data(), OpCode::Data(opcode), true) |
||||
}; |
||||
self.send_queue.push_back(frame); |
||||
self.send_pending() |
||||
} |
||||
|
||||
/// Close the connection.
|
||||
pub fn close(&mut self) -> Result<()> { |
||||
match self.state { |
||||
WebSocketState::Active => { |
||||
self.state = WebSocketState::ClosedByUs; |
||||
// TODO
|
||||
} |
||||
_ => { |
||||
// already closed, nothing to do
|
||||
} |
||||
} |
||||
Ok(()) |
||||
} |
||||
|
||||
/// Convert a frame socket into a WebSocket.
|
||||
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self { |
||||
WebSocket { |
||||
role: role, |
||||
socket: socket, |
||||
state: WebSocketState::Active, |
||||
incomplete: None, |
||||
send_queue: VecDeque::new(), |
||||
pong: None, |
||||
} |
||||
} |
||||
|
||||
/// Try to decode one message frame. May return None.
|
||||
fn read_message_frame(&mut self) -> Result<Option<Message>> { |
||||
if let Some(mut frame) = self.socket.read_frame()? { |
||||
|
||||
// MUST be 0 unless an extension is negotiated that defines meanings
|
||||
// for non-zero values. If a nonzero value is received and none of
|
||||
// the negotiated extensions defines the meaning of such a nonzero
|
||||
// value, the receiving endpoint MUST _Fail the WebSocket
|
||||
// Connection_.
|
||||
if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { |
||||
return Err(Error::Protocol("Reserved bits are non-zero".into())) |
||||
} |
||||
|
||||
match self.role { |
||||
Role::Server => { |
||||
if frame.is_masked() { |
||||
// A server MUST remove masking for data frames received from a client
|
||||
// as described in Section 5.3. (RFC 6455)
|
||||
frame.remove_mask() |
||||
} else { |
||||
// The server MUST close the connection upon receiving a
|
||||
// frame that is not masked. (RFC 6455)
|
||||
return Err(Error::Protocol("Received an unmasked frame from client".into())) |
||||
} |
||||
} |
||||
Role::Client => { |
||||
if frame.is_masked() { |
||||
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
|
||||
return Err(Error::Protocol("Received a masked frame from server".into())) |
||||
} |
||||
} |
||||
} |
||||
|
||||
match frame.opcode() { |
||||
|
||||
OpCode::Control(ctl) => { |
||||
(match ctl { |
||||
// All control frames MUST have a payload length of 125 bytes or less
|
||||
// and MUST NOT be fragmented. (RFC 6455)
|
||||
_ if !frame.is_final() => { |
||||
Err(Error::Protocol("Fragmented control frame".into())) |
||||
} |
||||
_ if frame.payload().len() > 125 => { |
||||
Err(Error::Protocol("Control frame too big".into())) |
||||
} |
||||
OpCtl::Close => { |
||||
self.do_close(frame.into_close()?) |
||||
} |
||||
OpCtl::Reserved(i) => { |
||||
Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) |
||||
} |
||||
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => { |
||||
// No ping processing while closing.
|
||||
Ok(()) |
||||
} |
||||
OpCtl::Ping => { |
||||
self.do_ping(frame.into_data()) |
||||
} |
||||
OpCtl::Pong => { |
||||
self.do_pong(frame.into_data()) |
||||
} |
||||
}).map(|_| None) |
||||
} |
||||
|
||||
OpCode::Data(_) if !self.state.is_active() => { |
||||
// No data processing while closing.
|
||||
Ok(None) |
||||
} |
||||
|
||||
OpCode::Data(data) => { |
||||
let fin = frame.is_final(); |
||||
match data { |
||||
OpData::Continue => { |
||||
if let Some(ref mut msg) = self.incomplete { |
||||
// TODO if msg too big
|
||||
msg.extend(frame.into_data())?; |
||||
} else { |
||||
return Err(Error::Protocol("Continue frame but nothing to continue".into())) |
||||
} |
||||
if fin { |
||||
Ok(Some(replace(&mut self.incomplete, None).unwrap().complete()?)) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
c if self.incomplete.is_some() => { |
||||
Err(Error::Protocol( |
||||
format!("Received {} while waiting for more fragments", c).into() |
||||
)) |
||||
} |
||||
OpData::Text | OpData::Binary => { |
||||
let msg = { |
||||
let message_type = match data { |
||||
OpData::Text => IncompleteMessageType::Text, |
||||
OpData::Binary => IncompleteMessageType::Binary, |
||||
_ => panic!("Bug: message is not text nor binary"), |
||||
}; |
||||
let mut m = IncompleteMessage::new(message_type); |
||||
m.extend(frame.into_data())?; |
||||
m |
||||
}; |
||||
if fin { |
||||
Ok(Some(msg.complete()?)) |
||||
} else { |
||||
self.incomplete = Some(msg); |
||||
Ok(None) |
||||
} |
||||
} |
||||
OpData::Reserved(i) => { |
||||
Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) |
||||
} |
||||
} |
||||
} |
||||
|
||||
} // match opcode
|
||||
|
||||
} else { |
||||
//Ok(None) // TODO handle EOF?
|
||||
Err(Error::Protocol("Connection reset without closing handshake".into())) |
||||
} |
||||
} |
||||
|
||||
/// Received a close frame.
|
||||
fn do_close(&mut self, close: Option<(CloseCode, String)>) -> Result<()> { |
||||
match self.state { |
||||
WebSocketState::Active => { |
||||
self.state = WebSocketState::ClosedByPeer; |
||||
let reply = if let Some((code, _)) = close { |
||||
if code.is_allowed() { |
||||
Frame::close(Some((CloseCode::Normal, ""))) |
||||
} else { |
||||
Frame::close(Some((CloseCode::Protocol, "Protocol violation"))) |
||||
} |
||||
} else { |
||||
Frame::close(None) |
||||
}; |
||||
self.send_queue.push_back(reply); |
||||
} |
||||
WebSocketState::ClosedByPeer => { |
||||
// It is already closed, just ignore.
|
||||
} |
||||
WebSocketState::ClosedByUs => { |
||||
// We received a reply.
|
||||
match self.role { |
||||
Role::Client => { |
||||
// Client waits for the server to close the connection.
|
||||
} |
||||
Role::Server => { |
||||
// Server closes the connection.
|
||||
// TODO
|
||||
} |
||||
} |
||||
} |
||||
} |
||||
//unimplemented!()
|
||||
Ok(()) |
||||
} |
||||
|
||||
/// Received a ping frame.
|
||||
fn do_ping(&mut self, ping: Vec<u8>) -> Result<()> { |
||||
// If an endpoint receives a Ping frame and has not yet sent Pong
|
||||
// frame(s) in response to previous Ping frame(s), the endpoint MAY
|
||||
// elect to send a Pong frame for only the most recently processed Ping
|
||||
// frame. (RFC 6455)
|
||||
// We do exactly that, keeping a "queue" from one and only Pong frame.
|
||||
self.pong = Some(Frame::pong(ping)); |
||||
Ok(()) |
||||
} |
||||
|
||||
/// Received a pong frame.
|
||||
fn do_pong(&mut self, _: Vec<u8>) -> Result<()> { |
||||
// A Pong frame MAY be sent unsolicited. This serves as a
|
||||
// unidirectional heartbeat. A response to an unsolicited Pong frame is
|
||||
// not expected. (RFC 6455)
|
||||
// Due to this, we just don't check pongs right now.
|
||||
// TODO: check if there was a reply to our ping at all...
|
||||
Ok(()) |
||||
} |
||||
|
||||
/// Flush the pending send queue.
|
||||
fn send_pending(&mut self) -> Result<()> { |
||||
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
|
||||
// response, unless it already received a Close frame. It SHOULD
|
||||
// respond with Pong frame as soon as is practical. (RFC 6455)
|
||||
if let Some(pong) = replace(&mut self.pong, None) { |
||||
self.send_one_frame(pong)?; |
||||
} |
||||
// If we have any unsent frames, send them.
|
||||
while let Some(data) = self.send_queue.pop_front() { |
||||
self.send_one_frame(data)?; |
||||
} |
||||
Ok(()) |
||||
} |
||||
|
||||
/// Send a single pending frame.
|
||||
fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> { |
||||
match self.role { |
||||
Role::Server => { |
||||
} |
||||
Role::Client => { |
||||
// 5. If the data is being sent by the client, the frame(s) MUST be
|
||||
// masked as defined in Section 5.3. (RFC 6455)
|
||||
frame.set_mask(); |
||||
} |
||||
} |
||||
self.socket.write_frame(frame)?; |
||||
Ok(()) |
||||
} |
||||
|
||||
} |
||||
|
||||
/// The current connection state.
|
||||
enum WebSocketState { |
||||
Active, |
||||
ClosedByUs, |
||||
ClosedByPeer, |
||||
} |
||||
|
||||
impl WebSocketState { |
||||
/// Tell if we're allowed to process normal messages.
|
||||
fn is_active(&self) -> bool { |
||||
match *self { |
||||
WebSocketState::Active => true, |
||||
_ => false, |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
use super::{WebSocket, Role, Message}; |
||||
|
||||
use std::io; |
||||
use std::io::Cursor; |
||||
|
||||
struct WriteMoc<Stream>(Stream); |
||||
|
||||
impl<Stream> io::Write for WriteMoc<Stream> { |
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
||||
Ok(buf.len()) |
||||
} |
||||
fn flush(&mut self) -> io::Result<()> { |
||||
Ok(()) |
||||
} |
||||
} |
||||
|
||||
impl<Stream: io::Read> io::Read for WriteMoc<Stream> { |
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
||||
self.0.read(buf) |
||||
} |
||||
} |
||||
|
||||
|
||||
#[test] |
||||
fn receive_messages() { |
||||
let incoming = Cursor::new(vec![ |
||||
0x01, 0x07, |
||||
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, |
||||
0x80, 0x06, |
||||
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, |
||||
0x82, 0x03, |
||||
0x01, 0x02, 0x03, |
||||
]); |
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); |
||||
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); |
||||
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,39 @@ |
||||
#[cfg(feature="tls")] |
||||
use native_tls::TlsStream; |
||||
|
||||
use std::net::TcpStream; |
||||
use std::io::{Read, Write, Result as IoResult}; |
||||
|
||||
/// Stream, either plain TCP or TLS.
|
||||
pub enum Stream { |
||||
Plain(TcpStream), |
||||
#[cfg(feature="tls")] |
||||
Tls(TlsStream<TcpStream>), |
||||
} |
||||
|
||||
impl Read for Stream { |
||||
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { |
||||
match *self { |
||||
Stream::Plain(ref mut s) => s.read(buf), |
||||
#[cfg(feature="tls")] |
||||
Stream::Tls(ref mut s) => s.read(buf), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl Write for Stream { |
||||
fn write(&mut self, buf: &[u8]) -> IoResult<usize> { |
||||
match *self { |
||||
Stream::Plain(ref mut s) => s.write(buf), |
||||
#[cfg(feature="tls")] |
||||
Stream::Tls(ref mut s) => s.write(buf), |
||||
} |
||||
} |
||||
fn flush(&mut self) -> IoResult<()> { |
||||
match *self { |
||||
Stream::Plain(ref mut s) => s.flush(), |
||||
#[cfg(feature="tls")] |
||||
Stream::Tls(ref mut s) => s.flush(), |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue