Merged in handshake-refactor (pull request #1)

Handshake refactor
pull/7/head
Alexey Galakhov 8 years ago committed by Daniel Abramov
commit 087636c3ad
  1. 14
      examples/autobahn-client.rs
  2. 13
      examples/autobahn-server.rs
  3. 5
      examples/client.rs
  4. 132
      src/client.rs
  5. 17
      src/error.rs
  6. 118
      src/handshake/client.rs
  7. 145
      src/handshake/headers.rs
  8. 119
      src/handshake/machine.rs
  9. 247
      src/handshake/mod.rs
  10. 110
      src/handshake/server.rs
  11. 14
      src/handshake/tls.rs
  12. 31
      src/protocol/mod.rs
  13. 13
      src/server.rs
  14. 25
      src/stream.rs

@ -5,9 +5,7 @@ extern crate url;
use url::Url; use url::Url;
use tungstenite::protocol::Message;
use tungstenite::client::connect; use tungstenite::client::connect;
use tungstenite::handshake::Handshake;
use tungstenite::error::{Error, Result}; use tungstenite::error::{Error, Result};
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
@ -15,17 +13,17 @@ const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let mut socket = connect( let mut socket = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap() Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?.handshake_wait()?; )?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.close(); socket.close()?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
} }
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let mut socket = connect( let mut socket = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?.handshake_wait()?; )?;
socket.close(); socket.close()?;
Ok(()) Ok(())
} }
@ -34,13 +32,11 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse( let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap(); ).unwrap();
let mut socket = connect(case_url)?.handshake_wait()?; let mut socket = connect(case_url)?;
loop { loop {
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.write_message(msg)?; socket.write_message(msg)?;
} }
socket.close();
Ok(())
} }
fn main() { fn main() {

@ -6,11 +6,18 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn; use std::thread::spawn;
use tungstenite::server::accept; use tungstenite::server::accept;
use tungstenite::error::Result; use tungstenite::handshake::HandshakeError;
use tungstenite::handshake::Handshake; use tungstenite::error::{Error, Result};
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f,
}
}
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).handshake_wait()?; let mut socket = accept(stream).map_err(must_not_block)?;
loop { loop {
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.write_message(msg)?; socket.write_message(msg)?;

@ -5,15 +5,12 @@ extern crate env_logger;
use url::Url; use url::Url;
use tungstenite::protocol::Message; use tungstenite::protocol::Message;
use tungstenite::client::connect; use tungstenite::client::connect;
use tungstenite::handshake::Handshake;
fn main() { fn main() {
env_logger::init().unwrap(); env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect") .expect("Can't connect");
.handshake_wait()
.expect("Handshake error");
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop { loop {

@ -1,75 +1,97 @@
use std::net::{TcpStream, ToSocketAddrs}; use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use url::{Url, SocketAddrs}; use std::result::Result as StdResult;
use std::io::{Read, Write};
use url::Url;
#[cfg(feature="tls")]
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError};
use protocol::WebSocket; use protocol::WebSocket;
use handshake::{Handshake as HandshakeTrait, HandshakeResult}; use handshake::HandshakeError;
use handshake::client::{ClientHandshake, Request}; use handshake::client::{ClientHandshake, Request};
use stream::Mode;
use error::{Error, Result}; use error::{Error, Result};
/// Connect to the given WebSocket. #[cfg(feature="tls")]
/// use stream::Stream as StreamSwitcher;
/// Note that this function may block the current thread while DNS resolution is performed.
pub fn connect(url: Url) -> Result<Handshake> {
let mode = match url.scheme() {
"ws" => Mode::Plain,
#[cfg(feature="tls")]
"wss" => Mode::Tls,
_ => return Err(Error::Url("URL scheme not supported".into()))
};
// Note that this function may block the current thread while resolution is performed. #[cfg(feature="tls")]
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
#[cfg(not(feature="tls"))]
pub type AutoStream = TcpStream;
/// Connect to the given WebSocket in blocking mode.
///
/// The URL may be either ws:// or wss://.
/// To support wss:// URLs, feature "tls" must be turned on.
pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> {
let mode = url_mode(&url)?;
let addrs = url.to_socket_addrs()?; let addrs = url.to_socket_addrs()?;
Ok(Handshake { let stream = connect_to_some(addrs, &url, mode)?;
state: HandshakeState::Nothing(url), client(url.clone(), stream)
alt_addresses: addrs, .map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) })
} }
enum Mode { #[cfg(feature="tls")]
Plain, fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
Tls, match mode {
} Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
enum HandshakeState { let connector = TlsConnector::builder()?.build()?;
Nothing(Url), connector.connect(domain, stream)
WebSocket(ClientHandshake<TcpStream>), .map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
}
}
} }
pub struct Handshake { #[cfg(not(feature="tls"))]
state: HandshakeState, fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
alt_addresses: SocketAddrs, match mode {
Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
} }
impl HandshakeTrait for Handshake { fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
type Stream = WebSocket<TcpStream>; where A: Iterator<Item=SocketAddr>
fn handshake(mut self) -> Result<HandshakeResult<Self>> { {
match self.state { let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?;
HandshakeState::Nothing(url) => { for addr in addrs {
if let Some(addr) = self.alt_addresses.next() {
debug!("Trying to contact {} at {}...", url, addr); debug!("Trying to contact {} at {}...", url, addr);
let state = { if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
let hs = ClientHandshake::new(stream, Request { url: url }); return Ok(stream)
HandshakeState::WebSocket(hs)
} else {
HandshakeState::Nothing(url)
}
};
Ok(HandshakeResult::Incomplete(Handshake {
state: state,
..self
}))
} else {
Err(Error::Url(format!("Unable to resolve {}", url).into()))
}
} }
HandshakeState::WebSocket(ws) => {
let alt_addresses = self.alt_addresses;
ws.handshake().map(move |r| r.map(move |s| Handshake {
state: HandshakeState::WebSocket(s),
alt_addresses: alt_addresses,
}))
} }
} }
Err(Error::Url(format!("Unable to connect to {}", url).into()))
}
/// Get the mode of the given URL.
///
/// This function may be used in non-blocking implementations.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into()))
} }
} }
/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support.
pub fn client<Stream: Read + Write>(url: Url, stream: Stream)
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>>
{
let request = Request { url: url };
ClientHandshake::start(stream, request).handshake()
}

@ -11,6 +11,9 @@ use std::string;
use httparse; use httparse;
#[cfg(feature="tls")]
use native_tls;
pub type Result<T> = result::Result<T, Error>; pub type Result<T> = result::Result<T, Error>;
/// Possible WebSocket errors /// Possible WebSocket errors
@ -20,6 +23,9 @@ pub enum Error {
ConnectionClosed, ConnectionClosed,
/// Input-output error /// Input-output error
Io(io::Error), Io(io::Error),
#[cfg(feature="tls")]
/// TLS error
Tls(native_tls::Error),
/// Buffer capacity exhausted /// Buffer capacity exhausted
Capacity(Cow<'static, str>), Capacity(Cow<'static, str>),
/// Protocol violation /// Protocol violation
@ -37,6 +43,8 @@ impl fmt::Display for Error {
match *self { match *self {
Error::ConnectionClosed => write!(f, "Connection closed"), Error::ConnectionClosed => write!(f, "Connection closed"),
Error::Io(ref err) => write!(f, "IO error: {}", err), Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature="tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
@ -51,6 +59,8 @@ impl ErrorTrait for Error {
match *self { match *self {
Error::ConnectionClosed => "", Error::ConnectionClosed => "",
Error::Io(ref err) => err.description(), Error::Io(ref err) => err.description(),
#[cfg(feature="tls")]
Error::Tls(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(), Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(), Error::Protocol(ref msg) => msg.borrow(),
Error::Utf8 => "", Error::Utf8 => "",
@ -78,6 +88,13 @@ impl From<string::FromUtf8Error> for Error {
} }
} }
#[cfg(feature="tls")]
impl From<native_tls::Error> for Error {
fn from(err: native_tls::Error) -> Self {
Error::Tls(err)
}
}
impl From<httparse::Error> for Error { impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self { fn from(err: httparse::Error) -> Self {
match err { match err {

@ -1,25 +1,16 @@
use std::io::{Read, Write, Cursor};
use base64; use base64;
use rand; use rand;
use bytes::Buf;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
use std::io::Write;
use url::Url; use url::Url;
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result}; use error::{Error, Result};
use protocol::{ use protocol::{WebSocket, Role};
WebSocket, Role,
}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{ use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
Headers, use super::machine::{HandshakeMachine, StageResult, TryParse};
Httparse, FromHttparse,
Handshake, HandshakeResult,
convert_key,
MAX_HEADERS,
};
use util::NonBlockingResult;
/// Client request. /// Client request.
pub struct Request { pub struct Request {
@ -47,18 +38,17 @@ impl Request {
} }
} }
/// Client handshake. /// Client handshake role.
pub struct ClientHandshake<Stream> { pub struct ClientHandshake {
stream: Stream,
state: HandshakeState,
verify_data: VerifyData, verify_data: VerifyData,
} }
impl<Stream: Read + Write> ClientHandshake<Stream> { impl ClientHandshake {
/// Initiate a WebSocket handshake over the given stream. /// Initiate a client handshake.
pub fn new(stream: Stream, request: Request) -> Self { pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> {
let key = generate_key(); let key = generate_key();
let machine = {
let mut req = Vec::new(); let mut req = Vec::new();
write!(req, "\ write!(req, "\
GET {path} HTTP/1.1\r\n\ GET {path} HTTP/1.1\r\n\
@ -69,55 +59,38 @@ impl<Stream: Read + Write> ClientHandshake<Stream> {
Sec-WebSocket-Key: {key}\r\n\ Sec-WebSocket-Key: {key}\r\n\
\r\n", host = request.get_host(), path = request.get_path(), key = key) \r\n", host = request.get_host(), path = request.get_path(), key = key)
.unwrap(); .unwrap();
HandshakeMachine::start_write(stream, req)
};
let client = {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake {
stream: stream,
state: HandshakeState::SendingRequest(Cursor::new(req)),
verify_data: VerifyData { verify_data: VerifyData {
accept_key: accept_key, accept_key: accept_key,
}, },
} }
};
debug!("Client handshake initiated.");
MidHandshake { role: client, machine: machine }
} }
} }
impl<Stream: Read + Write> Handshake for ClientHandshake<Stream> { impl HandshakeRole for ClientHandshake {
type Stream = WebSocket<Stream>; type IncomingData = Response;
fn handshake(mut self) -> Result<HandshakeResult<Self>> { fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
debug!("Performing client handshake..."); -> Result<ProcessingResult<Stream>>
match self.state { {
HandshakeState::SendingRequest(mut req) => { Ok(match finish {
let size = self.stream.write(Buf::bytes(&req)).no_block()?.unwrap_or(0); StageResult::DoneWriting(stream) => {
Buf::advance(&mut req, size); ProcessingResult::Continue(HandshakeMachine::start_read(stream))
let state = if req.has_remaining() { }
HandshakeState::SendingRequest(req) StageResult::DoneReading { stream, result, tail, } => {
} else { self.verify_data.verify_response(&result)?;
HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ))
};
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: state,
..self
}))
}
HandshakeState::ReceivingResponse(mut resp_buf) => {
resp_buf.reserve(MIN_READ, usize::max_value())
.map_err(|_| Error::Capacity("Header too long".into()))?;
resp_buf.read_from(&mut self.stream).no_block()?;
if let Some(resp) = Response::parse(&mut resp_buf)? {
self.verify_data.verify_response(&resp)?;
let ws = WebSocket::from_partially_read(self.stream,
resp_buf.into_vec(), Role::Client);
debug!("Client handshake done."); debug!("Client handshake done.");
Ok(HandshakeResult::Done(ws)) ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client))
} else {
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: HandshakeState::ReceivingResponse(resp_buf),
..self
}))
}
}
} }
})
} }
} }
@ -173,27 +146,14 @@ impl VerifyData {
} }
} }
/// Internal state of the client handshake.
enum HandshakeState {
SendingRequest(Cursor<Vec<u8>>),
ReceivingResponse(InputBuffer),
}
/// Server response. /// Server response.
pub struct Response { pub struct Response {
code: u16, code: u16,
headers: Headers, headers: Headers,
} }
impl Response { impl TryParse for Response {
/// Parse the response from a stream. fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Response::parse_http(input)
}
}
impl Httparse for Response {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Response::new(&mut hbuffer); let mut req = httparse::Response::new(&mut hbuffer);
Ok(match req.parse(buf)? { Ok(match req.parse(buf)? {
@ -227,8 +187,7 @@ fn generate_key() -> String {
mod tests { mod tests {
use super::{Response, generate_key}; use super::{Response, generate_key};
use super::super::machine::TryParse;
use std::io::Cursor;
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -249,10 +208,9 @@ mod tests {
#[test] #[test]
fn response_parsing() { fn response_parsing() {
const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let mut inp = Cursor::new(data); let (_, resp) = Response::try_parse(data).unwrap().unwrap();
let req = Response::parse(&mut inp).unwrap().unwrap(); assert_eq!(resp.code, 200);
assert_eq!(req.code, 200); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
assert_eq!(req.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
} }
} }

@ -0,0 +1,145 @@
use std::ascii::AsciiExt;
use std::str::from_utf8;
use std::slice;
use httparse;
use httparse::Status;
use error::Result;
use super::machine::TryParse;
// Limit the number of header lines.
pub const MAX_HEADERS: usize = 124;
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
}
/// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name: name,
iter: self.data.iter()
}
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
}
/// The iterator over headers.
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
}
}
None
}
}
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
fn from_httparse(raw: T) -> Result<Self>;
}
impl TryParse for Headers {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
#[cfg(test)]
mod tests {
use super::Headers;
use super::super::machine::TryParse;
#[test]
fn headers() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
}
#[test]
fn headers_iter() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
assert_eq!(iter.next(), None);
}
#[test]
fn headers_incomplete() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let hdr = Headers::try_parse(data).unwrap();
assert!(hdr.is_none());
}
}

@ -0,0 +1,119 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf;
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use util::NonBlockingResult;
/// A generic handshake state machine.
pub struct HandshakeMachine<Stream> {
stream: Stream,
state: HandshakeState,
}
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
HandshakeMachine {
stream: stream,
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
}
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine {
stream: stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.stream
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.stream
}
}
impl<Stream: Read + Write> HandshakeMachine<Stream> {
/// Perform a single handshake round.
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
Ok(match self.state {
HandshakeState::Reading(mut buf) => {
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))?;
if let Some(_) = buf.read_from(&mut self.stream).no_block()? {
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}
} else {
RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}
}
HandshakeState::Writing(mut buf) => {
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? {
buf.advance(size);
if buf.has_remaining() {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
})
} else {
RoundResult::StageFinished(StageResult::DoneWriting(self.stream))
}
} else {
RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
})
}
}
})
}
}
/// The result of the round.
pub enum RoundResult<Obj, Stream> {
/// Round not done, I/O would block.
WouldBlock(HandshakeMachine<Stream>),
/// Round done, state unchanged.
Incomplete(HandshakeMachine<Stream>),
/// Stage complete.
StageFinished(StageResult<Obj, Stream>),
}
/// The result of the stage.
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
/// Writing round finished.
DoneWriting(Stream),
}
/// The parseable object.
pub trait TryParse: Sized {
/// Return Ok(None) if incomplete, Err on syntax error.
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
}
/// The handshake state.
enum HandshakeState {
/// Reading data from the peer.
Reading(InputBuffer),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}

@ -1,179 +1,102 @@
pub mod headers;
pub mod client; pub mod client;
pub mod server; pub mod server;
#[cfg(feature="tls")]
pub mod tls;
use std::ascii::AsciiExt; mod machine;
use std::str::from_utf8;
use std::slice; use std::io::{Read, Write};
use base64; use base64;
use bytes::Buf;
use httparse;
use httparse::Status;
use sha1::Sha1; use sha1::Sha1;
use error::Result; use error::Error;
use protocol::WebSocket;
// Limit the number of header lines.
const MAX_HEADERS: usize = 124;
/// A handshake state. use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
pub trait Handshake: Sized {
/// Resulting stream of this handshake.
type Stream;
/// Perform a single handshake round.
fn handshake(self) -> Result<HandshakeResult<Self>>;
/// Perform handshake to the end in a blocking mode.
fn handshake_wait(self) -> Result<Self::Stream> {
let mut hs = self;
loop {
hs = match hs.handshake()? {
HandshakeResult::Done(stream) => return Ok(stream),
HandshakeResult::Incomplete(s) => s,
}
}
}
}
/// A handshake result. /// A WebSocket handshake.
pub enum HandshakeResult<H: Handshake> { pub struct MidHandshake<Stream, Role> {
/// Handshake is done, a WebSocket stream is ready. role: Role,
Done(H::Stream), machine: HandshakeMachine<Stream>,
/// Handshake is not done, call handshake() again.
Incomplete(H),
} }
impl<H: Handshake> HandshakeResult<H> { impl<Stream, Role> MidHandshake<Stream, Role> {
pub fn map<R, F>(self, func: F) -> HandshakeResult<R> /// Returns a shared reference to the inner stream.
where R: Handshake<Stream = H::Stream>, pub fn get_ref(&self) -> &Stream {
F: FnOnce(H) -> R, self.machine.get_ref()
{
match self {
HandshakeResult::Done(s) => HandshakeResult::Done(s),
HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)),
} }
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
} }
} }
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
fn convert_key(input: &[u8]) -> Result<String> { /// Restarts the handshake process.
// ... field is constructed by concatenating /key/ ... pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> {
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) let mut mach = self.machine;
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; loop {
let mut sha1 = Sha1::new(); mach = match mach.single_round()? {
sha1.update(input); RoundResult::WouldBlock(m) => {
sha1.update(WS_GUID); return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
Ok(base64::encode(&sha1.digest().bytes()))
}
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
} }
RoundResult::Incomplete(m) => m,
/// Iterate over all headers with the given name. RoundResult::StageFinished(s) => {
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { match self.role.stage_finished(s)? {
HeadersIter { ProcessingResult::Continue(m) => m,
name: name, ProcessingResult::Done(ws) => return Ok(ws),
iter: self.data.iter()
} }
} }
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
} }
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
} }
/// Try to parse data and return headers, if any.
fn parse<B: Buf>(input: &mut B) -> Result<Option<Headers>> {
Headers::parse_http(input)
} }
} }
/// The iterator over headers. /// A handshake result.
pub struct HeadersIter<'name, 'headers> { pub enum HandshakeError<Stream, Role> {
name: &'name str, /// Handshake was interrupted (would block).
iter: slice::Iter<'headers, (String, Box<[u8]>)>, Interrupted(MidHandshake<Stream, Role>),
} /// Handshake failed.
Failure(Error),
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
}
}
None
}
} }
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> {
/// Trait to read HTTP parseable objects. fn from(err: Error) -> Self {
trait Httparse: Sized { HandshakeError::Failure(err)
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>>;
fn parse_http<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Ok(match Self::httparse(input.bytes())? {
Some((size, obj)) => {
input.advance(size);
Some(obj)
},
None => None,
})
} }
} }
/// Trait to convert raw objects into HTTP parseables. /// Handshake role.
trait FromHttparse<T>: Sized { pub trait HandshakeRole {
fn from_httparse(raw: T) -> Result<Self>; #[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>;
} }
impl Httparse for Headers { /// Stage processing result.
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> { #[doc(hidden)]
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; pub enum ProcessingResult<Stream> {
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Continue(HandshakeMachine<Stream>),
Status::Partial => None, Done(WebSocket<Stream>),
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
} }
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> { fn convert_key(input: &[u8]) -> Result<String, Error> {
Ok(Headers { // ... field is constructed by concatenating /key/ ...
data: raw.iter() // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
.collect(), let mut sha1 = Sha1::new();
}) sha1.update(input);
} sha1.update(WS_GUID);
Ok(base64::encode(&sha1.digest().bytes()))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Headers, convert_key}; use super::convert_key;
use std::io::Cursor;
#[test] #[test]
fn key_conversion() { fn key_conversion() {
@ -182,50 +105,4 @@ mod tests {
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
} }
#[test]
fn headers() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
}
#[test]
fn headers_iter() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
assert_eq!(iter.next(), None);
}
#[test]
fn headers_incomplete() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap();
assert!(hdr.is_none());
}
} }

@ -1,33 +1,20 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
use input_buffer::{InputBuffer, MIN_READ}; //use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, Role};
use super::{ use super::headers::{Headers, FromHttparse, MAX_HEADERS};
Handshake, use super::machine::{HandshakeMachine, StageResult, TryParse};
HandshakeResult, use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
Headers,
Httparse,
FromHttparse,
convert_key,
MAX_HEADERS
};
use util::NonBlockingResult;
/// Request from the client. /// Request from the client.
pub struct Request { pub struct Request {
path: String, pub path: String,
headers: Headers, pub headers: Headers,
} }
impl Request { impl Request {
/// Parse the request from a stream.
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Request::parse_http(input)
}
/// Reply to the response. /// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> { pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key") let key = self.headers.find_first("Sec-WebSocket-Key")
@ -42,8 +29,8 @@ impl Request {
} }
} }
impl Httparse for Request { impl TryParse for Request {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut hbuffer); let mut req = httparse::Request::new(&mut hbuffer);
Ok(match req.parse(buf)? { Ok(match req.parse(buf)? {
@ -68,76 +55,50 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
} }
/// Server handshake /// Server handshake role.
pub struct ServerHandshake<Stream> { #[allow(missing_copy_implementations)]
stream: Stream, pub struct ServerHandshake;
state: HandshakeState,
}
impl<Stream: Read + Write> ServerHandshake<Stream> { impl ServerHandshake {
/// Start a new server handshake on top of given stream. /// Start server handshake.
pub fn new(stream: Stream) -> Self { pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> {
ServerHandshake { MidHandshake {
stream: stream, machine: HandshakeMachine::start_read(stream),
state: HandshakeState::ReceivingRequest(InputBuffer::with_capacity(MIN_READ)), role: ServerHandshake,
} }
} }
} }
impl<Stream: Read + Write> Handshake for ServerHandshake<Stream> { impl HandshakeRole for ServerHandshake {
type Stream = WebSocket<Stream>; type IncomingData = Request;
fn handshake(mut self) -> Result<HandshakeResult<Self>> { fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
debug!("Performing server handshake..."); -> Result<ProcessingResult<Stream>>
match self.state { {
HandshakeState::ReceivingRequest(mut req_buf) => { Ok(match finish {
req_buf.reserve(MIN_READ, usize::max_value()) StageResult::DoneReading { stream, result, tail } => {
.map_err(|_| Error::Capacity("Header too long".into()))?; if ! tail.is_empty() {
req_buf.read_from(&mut self.stream).no_block()?; return Err(Error::Protocol("Junk after client request".into()))
let state = if let Some(req) = Request::parse(&mut req_buf)? {
let resp = req.reply()?;
HandshakeState::SendingResponse(Cursor::new(resp))
} else {
HandshakeState::ReceivingRequest(req_buf)
};
Ok(HandshakeResult::Incomplete(ServerHandshake {
state: state,
..self
}))
}
HandshakeState::SendingResponse(mut resp) => {
let size = self.stream.write(Buf::bytes(&resp)).no_block()?.unwrap_or(0);
Buf::advance(&mut resp, size);
if resp.has_remaining() {
Ok(HandshakeResult::Incomplete(ServerHandshake {
state: HandshakeState::SendingResponse(resp),
..self
}))
} else {
let ws = WebSocket::from_raw_socket(self.stream, Role::Server);
Ok(HandshakeResult::Done(ws))
} }
let response = result.reply()?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
StageResult::DoneWriting(stream) => {
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
} }
})
} }
} }
enum HandshakeState {
ReceivingRequest(InputBuffer),
SendingResponse(Cursor<Vec<u8>>),
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Request; use super::Request;
use super::super::machine::TryParse;
use std::io::Cursor;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let mut inp = Cursor::new(data); let (_, req) = Request::try_parse(data).unwrap().unwrap();
let req = Request::parse(&mut inp).unwrap().unwrap();
assert_eq!(req.path, "/script.ws"); assert_eq!(req.path, "/script.ws");
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..]));
} }
@ -152,9 +113,8 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let mut inp = Cursor::new(data); let (_, req) = Request::try_parse(data).unwrap().unwrap();
let req = Request::parse(&mut inp).unwrap().unwrap(); let _ = req.reply().unwrap();
let reply = req.reply().unwrap();
} }
} }

@ -1,14 +0,0 @@
use native_tls;
use stream::Stream;
use super::{Handshake, HandshakeResult};
pub struct TlsHandshake {
}
impl Handshale for TlsHandshake {
type Stream = Stream;
fn handshake(self) -> Result<HandshakeResult<Self>> {
}
}

@ -40,10 +40,7 @@ pub struct WebSocket<Stream> {
pong: Option<Frame>, pong: Option<Frame>,
} }
impl<Stream> WebSocket<Stream> impl<Stream> WebSocket<Stream> {
where Stream: Read + Write
{
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket(stream: Stream, role: Role) -> Self { pub fn from_raw_socket(stream: Stream, role: Role) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role) WebSocket::from_frame_socket(FrameSocket::new(stream), role)
@ -54,6 +51,20 @@ impl<Stream> WebSocket<Stream>
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role)
} }
/// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self {
WebSocket {
role: role,
socket: socket,
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
}
}
}
impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
/// This function sends pong and close responses automatically. /// This function sends pong and close responses automatically.
@ -141,18 +152,6 @@ impl<Stream> WebSocket<Stream>
} }
} }
/// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self {
WebSocket {
role: role,
socket: socket,
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
}
}
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
fn read_message_frame(&mut self) -> Result<Option<Message>> { fn read_message_frame(&mut self) -> Result<Option<Message>> {
if let Some(mut frame) = self.socket.read_frame()? { if let Some(mut frame) = self.socket.read_frame()? {

@ -1,8 +1,13 @@
use std::net::TcpStream; pub use handshake::server::ServerHandshake;
use handshake::server::ServerHandshake; use handshake::HandshakeError;
use protocol::WebSocket;
use std::io::{Read, Write};
/// Accept the given TcpStream as a WebSocket. /// Accept the given TcpStream as a WebSocket.
pub fn accept(stream: TcpStream) -> ServerHandshake<TcpStream> { pub fn accept<Stream: Read + Write>(stream: Stream)
ServerHandshake::new(stream) -> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>>
{
ServerHandshake::start(stream).handshake()
} }

@ -1,38 +1,37 @@
#[cfg(feature="tls")]
use native_tls::TlsStream;
use std::net::TcpStream;
use std::io::{Read, Write, Result as IoResult}; use std::io::{Read, Write, Result as IoResult};
/// Stream mode, either plain TCP or TLS.
#[derive(Clone, Copy)]
pub enum Mode {
Plain,
Tls,
}
/// Stream, either plain TCP or TLS. /// Stream, either plain TCP or TLS.
pub enum Stream { pub enum Stream<S, T> {
Plain(TcpStream), Plain(S),
#[cfg(feature="tls")] Tls(T),
Tls(TlsStream<TcpStream>),
} }
impl Read for Stream { impl<S: Read, T: Read> Read for Stream<S, T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self { match *self {
Stream::Plain(ref mut s) => s.read(buf), Stream::Plain(ref mut s) => s.read(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.read(buf), Stream::Tls(ref mut s) => s.read(buf),
} }
} }
} }
impl Write for Stream { impl<S: Write, T: Write> Write for Stream<S, T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> { fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self { match *self {
Stream::Plain(ref mut s) => s.write(buf), Stream::Plain(ref mut s) => s.write(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.write(buf), Stream::Tls(ref mut s) => s.write(buf),
} }
} }
fn flush(&mut self) -> IoResult<()> { fn flush(&mut self) -> IoResult<()> {
match *self { match *self {
Stream::Plain(ref mut s) => s.flush(), Stream::Plain(ref mut s) => s.flush(),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.flush(), Stream::Tls(ref mut s) => s.flush(),
} }
} }

Loading…
Cancel
Save