diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index db8ee9b..75c15e8 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -5,9 +5,7 @@ extern crate url; use url::Url; -use tungstenite::protocol::Message; use tungstenite::client::connect; -use tungstenite::handshake::Handshake; use tungstenite::error::{Error, Result}; const AGENT: &'static str = "Tungstenite"; @@ -15,17 +13,17 @@ const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { let mut socket = connect( Url::parse("ws://localhost:9001/getCaseCount").unwrap() - )?.handshake_wait()?; + )?; let msg = socket.read_message()?; - socket.close(); + socket.close()?; Ok(msg.into_text()?.parse::().unwrap()) } fn update_reports() -> Result<()> { let mut socket = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() - )?.handshake_wait()?; - socket.close(); + )?; + socket.close()?; Ok(()) } @@ -34,13 +32,11 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let mut socket = connect(case_url)?.handshake_wait()?; + let mut socket = connect(case_url)?; loop { let msg = socket.read_message()?; socket.write_message(msg)?; } - socket.close(); - Ok(()) } fn main() { diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 0e5aac2..4ff44bd 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -6,11 +6,18 @@ use std::net::{TcpListener, TcpStream}; use std::thread::spawn; use tungstenite::server::accept; -use tungstenite::error::Result; -use tungstenite::handshake::Handshake; +use tungstenite::handshake::HandshakeError; +use tungstenite::error::{Error, Result}; + +fn must_not_block(err: HandshakeError) -> Error { + match err { + HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), + HandshakeError::Failure(f) => f, + } +} fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream).handshake_wait()?; + let mut socket = accept(stream).map_err(must_not_block)?; loop { let msg = socket.read_message()?; socket.write_message(msg)?; diff --git a/examples/client.rs b/examples/client.rs index f5a7964..22ae38c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -5,15 +5,12 @@ extern crate env_logger; use url::Url; use tungstenite::protocol::Message; use tungstenite::client::connect; -use tungstenite::handshake::Handshake; fn main() { env_logger::init().unwrap(); let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) - .expect("Can't connect") - .handshake_wait() - .expect("Handshake error"); + .expect("Can't connect"); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); loop { diff --git a/src/client.rs b/src/client.rs index 0b18450..5dc0804 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,75 +1,97 @@ -use std::net::{TcpStream, ToSocketAddrs}; -use url::{Url, SocketAddrs}; +use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; +use std::result::Result as StdResult; +use std::io::{Read, Write}; + +use url::Url; + +#[cfg(feature="tls")] +use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError}; use protocol::WebSocket; -use handshake::{Handshake as HandshakeTrait, HandshakeResult}; +use handshake::HandshakeError; use handshake::client::{ClientHandshake, Request}; +use stream::Mode; use error::{Error, Result}; -/// Connect to the given WebSocket. -/// -/// Note that this function may block the current thread while DNS resolution is performed. -pub fn connect(url: Url) -> Result { - let mode = match url.scheme() { - "ws" => Mode::Plain, - #[cfg(feature="tls")] - "wss" => Mode::Tls, - _ => return Err(Error::Url("URL scheme not supported".into())) - }; +#[cfg(feature="tls")] +use stream::Stream as StreamSwitcher; - // Note that this function may block the current thread while resolution is performed. - let addrs = url.to_socket_addrs()?; - Ok(Handshake { - state: HandshakeState::Nothing(url), - alt_addresses: addrs, - }) -} +#[cfg(feature="tls")] +pub type AutoStream = StreamSwitcher>; +#[cfg(not(feature="tls"))] +pub type AutoStream = TcpStream; -enum Mode { - Plain, - Tls, +/// Connect to the given WebSocket in blocking mode. +/// +/// The URL may be either ws:// or wss://. +/// To support wss:// URLs, feature "tls" must be turned on. +pub fn connect(url: Url) -> Result> { + let mode = url_mode(&url)?; + let addrs = url.to_socket_addrs()?; + let stream = connect_to_some(addrs, &url, mode)?; + client(url.clone(), stream) + .map_err(|e| match e { + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) } -enum HandshakeState { - Nothing(Url), - WebSocket(ClientHandshake), +#[cfg(feature="tls")] +fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(stream)), + Mode::Tls => { + let connector = TlsConnector::builder()?.build()?; + connector.connect(domain, stream) + .map_err(|e| match e { + TlsHandshakeError::Failure(f) => f.into(), + TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), + }) + .map(|s| StreamSwitcher::Tls(s)) + } + } } -pub struct Handshake { - state: HandshakeState, - alt_addresses: SocketAddrs, +#[cfg(not(feature="tls"))] +fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(stream), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + } } -impl HandshakeTrait for Handshake { - type Stream = WebSocket; - fn handshake(mut self) -> Result> { - match self.state { - HandshakeState::Nothing(url) => { - if let Some(addr) = self.alt_addresses.next() { - debug!("Trying to contact {} at {}...", url, addr); - let state = { - if let Ok(stream) = TcpStream::connect(addr) { - let hs = ClientHandshake::new(stream, Request { url: url }); - HandshakeState::WebSocket(hs) - } else { - HandshakeState::Nothing(url) - } - }; - Ok(HandshakeResult::Incomplete(Handshake { - state: state, - ..self - })) - } else { - Err(Error::Url(format!("Unable to resolve {}", url).into())) - } - } - HandshakeState::WebSocket(ws) => { - let alt_addresses = self.alt_addresses; - ws.handshake().map(move |r| r.map(move |s| Handshake { - state: HandshakeState::WebSocket(s), - alt_addresses: alt_addresses, - })) +fn connect_to_some(addrs: A, url: &Url, mode: Mode) -> Result + where A: Iterator +{ + let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?; + for addr in addrs { + debug!("Trying to contact {} at {}...", url, addr); + if let Ok(raw_stream) = TcpStream::connect(addr) { + if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { + return Ok(stream) } } } + Err(Error::Url(format!("Unable to connect to {}", url).into())) +} + +/// Get the mode of the given URL. +/// +/// This function may be used in non-blocking implementations. +pub fn url_mode(url: &Url) -> Result { + match url.scheme() { + "ws" => Ok(Mode::Plain), + "wss" => Ok(Mode::Tls), + _ => Err(Error::Url("URL scheme not supported".into())) + } +} + +/// Do the client handshake over the given stream. +/// +/// Use this function if you need a nonblocking handshake support. +pub fn client(url: Url, stream: Stream) + -> StdResult, HandshakeError> +{ + let request = Request { url: url }; + ClientHandshake::start(stream, request).handshake() } diff --git a/src/error.rs b/src/error.rs index b2c446d..1a9c7f5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,9 @@ use std::string; use httparse; +#[cfg(feature="tls")] +use native_tls; + pub type Result = result::Result; /// Possible WebSocket errors @@ -20,6 +23,9 @@ pub enum Error { ConnectionClosed, /// Input-output error Io(io::Error), + #[cfg(feature="tls")] + /// TLS error + Tls(native_tls::Error), /// Buffer capacity exhausted Capacity(Cow<'static, str>), /// Protocol violation @@ -37,6 +43,8 @@ impl fmt::Display for Error { match *self { Error::ConnectionClosed => write!(f, "Connection closed"), Error::Io(ref err) => write!(f, "IO error: {}", err), + #[cfg(feature="tls")] + Error::Tls(ref err) => write!(f, "TLS error: {}", err), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), Error::Utf8 => write!(f, "UTF-8 encoding error"), @@ -51,6 +59,8 @@ impl ErrorTrait for Error { match *self { Error::ConnectionClosed => "", Error::Io(ref err) => err.description(), + #[cfg(feature="tls")] + Error::Tls(ref err) => err.description(), Error::Capacity(ref msg) => msg.borrow(), Error::Protocol(ref msg) => msg.borrow(), Error::Utf8 => "", @@ -78,6 +88,13 @@ impl From for Error { } } +#[cfg(feature="tls")] +impl From for Error { + fn from(err: native_tls::Error) -> Self { + Error::Tls(err) + } +} + impl From for Error { fn from(err: httparse::Error) -> Self { match err { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 86b2c60..6e54cfb 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,25 +1,16 @@ -use std::io::{Read, Write, Cursor}; - use base64; use rand; -use bytes::Buf; use httparse; use httparse::Status; +use std::io::Write; use url::Url; -use input_buffer::{InputBuffer, MIN_READ}; use error::{Error, Result}; -use protocol::{ - WebSocket, Role, -}; -use super::{ - Headers, - Httparse, FromHttparse, - Handshake, HandshakeResult, - convert_key, - MAX_HEADERS, -}; -use util::NonBlockingResult; +use protocol::{WebSocket, Role}; + +use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +use super::machine::{HandshakeMachine, StageResult, TryParse}; /// Client request. pub struct Request { @@ -47,77 +38,59 @@ impl Request { } } -/// Client handshake. -pub struct ClientHandshake { - stream: Stream, - state: HandshakeState, +/// Client handshake role. +pub struct ClientHandshake { verify_data: VerifyData, } -impl ClientHandshake { - /// Initiate a WebSocket handshake over the given stream. - pub fn new(stream: Stream, request: Request) -> Self { +impl ClientHandshake { + /// Initiate a client handshake. + pub fn start(stream: Stream, request: Request) -> MidHandshake { let key = generate_key(); - let mut req = Vec::new(); - write!(req, "\ - GET {path} HTTP/1.1\r\n\ - Host: {host}\r\n\ - Connection: upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n\ - \r\n", host = request.get_host(), path = request.get_path(), key = key) - .unwrap(); - - let accept_key = convert_key(key.as_ref()).unwrap(); + let machine = { + let mut req = Vec::new(); + write!(req, "\ + GET {path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Connection: upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: {key}\r\n\ + \r\n", host = request.get_host(), path = request.get_path(), key = key) + .unwrap(); + HandshakeMachine::start_write(stream, req) + }; + + let client = { + let accept_key = convert_key(key.as_ref()).unwrap(); + ClientHandshake { + verify_data: VerifyData { + accept_key: accept_key, + }, + } + }; - ClientHandshake { - stream: stream, - state: HandshakeState::SendingRequest(Cursor::new(req)), - verify_data: VerifyData { - accept_key: accept_key, - }, - } + debug!("Client handshake initiated."); + MidHandshake { role: client, machine: machine } } } -impl Handshake for ClientHandshake { - type Stream = WebSocket; - fn handshake(mut self) -> Result> { - debug!("Performing client handshake..."); - match self.state { - HandshakeState::SendingRequest(mut req) => { - let size = self.stream.write(Buf::bytes(&req)).no_block()?.unwrap_or(0); - Buf::advance(&mut req, size); - let state = if req.has_remaining() { - HandshakeState::SendingRequest(req) - } else { - HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ)) - }; - Ok(HandshakeResult::Incomplete(ClientHandshake { - state: state, - ..self - })) +impl HandshakeRole for ClientHandshake { + type IncomingData = Response; + fn stage_finished(&self, finish: StageResult) + -> Result> + { + Ok(match finish { + StageResult::DoneWriting(stream) => { + ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } - HandshakeState::ReceivingResponse(mut resp_buf) => { - resp_buf.reserve(MIN_READ, usize::max_value()) - .map_err(|_| Error::Capacity("Header too long".into()))?; - resp_buf.read_from(&mut self.stream).no_block()?; - if let Some(resp) = Response::parse(&mut resp_buf)? { - self.verify_data.verify_response(&resp)?; - let ws = WebSocket::from_partially_read(self.stream, - resp_buf.into_vec(), Role::Client); - debug!("Client handshake done."); - Ok(HandshakeResult::Done(ws)) - } else { - Ok(HandshakeResult::Incomplete(ClientHandshake { - state: HandshakeState::ReceivingResponse(resp_buf), - ..self - })) - } + StageResult::DoneReading { stream, result, tail, } => { + self.verify_data.verify_response(&result)?; + debug!("Client handshake done."); + ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) } - } + }) } } @@ -173,27 +146,14 @@ impl VerifyData { } } -/// Internal state of the client handshake. -enum HandshakeState { - SendingRequest(Cursor>), - ReceivingResponse(InputBuffer), -} - /// Server response. pub struct Response { code: u16, headers: Headers, } -impl Response { - /// Parse the response from a stream. - pub fn parse(input: &mut B) -> Result> { - Response::parse_http(input) - } -} - -impl Httparse for Response { - fn httparse(buf: &[u8]) -> Result> { +impl TryParse for Response { + fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Response::new(&mut hbuffer); Ok(match req.parse(buf)? { @@ -227,8 +187,7 @@ fn generate_key() -> String { mod tests { use super::{Response, generate_key}; - - use std::io::Cursor; + use super::super::machine::TryParse; #[test] fn random_keys() { @@ -249,10 +208,9 @@ mod tests { #[test] fn response_parsing() { const data: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; - let mut inp = Cursor::new(data); - let req = Response::parse(&mut inp).unwrap().unwrap(); - assert_eq!(req.code, 200); - assert_eq!(req.headers.find_first("Content-Type"), Some(&b"text/html"[..])); + let (_, resp) = Response::try_parse(data).unwrap().unwrap(); + assert_eq!(resp.code, 200); + assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs new file mode 100644 index 0000000..c765caf --- /dev/null +++ b/src/handshake/headers.rs @@ -0,0 +1,145 @@ +use std::ascii::AsciiExt; +use std::str::from_utf8; +use std::slice; + +use httparse; +use httparse::Status; + +use error::Result; +use super::machine::TryParse; + +// Limit the number of header lines. +pub const MAX_HEADERS: usize = 124; + +/// HTTP request or response headers. +#[derive(Debug)] +pub struct Headers { + data: Vec<(String, Box<[u8]>)>, +} + +impl Headers { + + /// Get first header with the given name, if any. + pub fn find_first(&self, name: &str) -> Option<&[u8]> { + self.find(name).next() + } + + /// Iterate over all headers with the given name. + pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { + HeadersIter { + name: name, + iter: self.data.iter() + } + } + + /// Check if the given header has the given value. + pub fn header_is(&self, name: &str, value: &str) -> bool { + self.find_first(name) + .map(|v| v == value.as_bytes()) + .unwrap_or(false) + } + + /// Check if the given header has the given value (case-insensitive). + pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { + self.find_first(name).ok_or(()) + .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) + .map(|val| val.eq_ignore_ascii_case(value)) + .unwrap_or(false) + } + +} + +/// The iterator over headers. +pub struct HeadersIter<'name, 'headers> { + name: &'name str, + iter: slice::Iter<'headers, (String, Box<[u8]>)>, +} + +impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { + type Item = &'headers [u8]; + fn next(&mut self) -> Option { + while let Some(&(ref name, ref value)) = self.iter.next() { + if name.eq_ignore_ascii_case(self.name) { + return Some(value) + } + } + None + } +} + + +/// Trait to convert raw objects into HTTP parseables. +pub trait FromHttparse: Sized { + fn from_httparse(raw: T) -> Result; +} + +impl TryParse for Headers { + fn try_parse(buf: &[u8]) -> Result> { + let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; + Ok(match httparse::parse_headers(buf, &mut hbuffer)? { + Status::Partial => None, + Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), + }) + } +} + +impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { + fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { + Ok(Headers { + data: raw.iter() + .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) + .collect(), + }) + } +} + +#[cfg(test)] +mod tests { + + use super::Headers; + use super::super::machine::TryParse; + + #[test] + fn headers() { + const data: &'static [u8] = + b"Host: foo.com\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + \r\n"; + let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); + assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); + assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); + assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); + + assert!(hdr.header_is("upgrade", "websocket")); + assert!(!hdr.header_is("upgrade", "Websocket")); + assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); + } + + #[test] + fn headers_iter() { + const data: &'static [u8] = + b"Host: foo.com\r\n\ + Sec-WebSocket-Extensions: permessage-deflate\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ + Upgrade: websocket\r\n\ + \r\n"; + let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); + let mut iter = hdr.find("Sec-WebSocket-Extensions"); + assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); + assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); + assert_eq!(iter.next(), None); + } + + #[test] + fn headers_incomplete() { + const data: &'static [u8] = + b"Host: foo.com\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n"; + let hdr = Headers::try_parse(data).unwrap(); + assert!(hdr.is_none()); + } + +} diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs new file mode 100644 index 0000000..36c3e42 --- /dev/null +++ b/src/handshake/machine.rs @@ -0,0 +1,119 @@ +use std::io::{Cursor, Read, Write}; +use bytes::Buf; + +use input_buffer::{InputBuffer, MIN_READ}; +use error::{Error, Result}; +use util::NonBlockingResult; + +/// A generic handshake state machine. +pub struct HandshakeMachine { + stream: Stream, + state: HandshakeState, +} + +impl HandshakeMachine { + /// Start reading data from the peer. + pub fn start_read(stream: Stream) -> Self { + HandshakeMachine { + stream: stream, + state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), + } + } + /// Start writing data to the peer. + pub fn start_write>>(stream: Stream, data: D) -> Self { + HandshakeMachine { + stream: stream, + state: HandshakeState::Writing(Cursor::new(data.into())), + } + } + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &Stream { + &self.stream + } + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut Stream { + &mut self.stream + } +} + +impl HandshakeMachine { + /// Perform a single handshake round. + pub fn single_round(mut self) -> Result> { + Ok(match self.state { + HandshakeState::Reading(mut buf) => { + buf.reserve(MIN_READ, usize::max_value()) // TODO limit size + .map_err(|_| Error::Capacity("Header too long".into()))?; + if let Some(_) = buf.read_from(&mut self.stream).no_block()? { + if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { + buf.advance(size); + RoundResult::StageFinished(StageResult::DoneReading { + result: obj, + stream: self.stream, + tail: buf.into_vec(), + }) + } else { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + }) + } + } else { + RoundResult::WouldBlock(HandshakeMachine { + state: HandshakeState::Reading(buf), + ..self + }) + } + } + HandshakeState::Writing(mut buf) => { + if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? { + buf.advance(size); + if buf.has_remaining() { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Writing(buf), + ..self + }) + } else { + RoundResult::StageFinished(StageResult::DoneWriting(self.stream)) + } + } else { + RoundResult::WouldBlock(HandshakeMachine { + state: HandshakeState::Writing(buf), + ..self + }) + } + } + }) + } +} + +/// The result of the round. +pub enum RoundResult { + /// Round not done, I/O would block. + WouldBlock(HandshakeMachine), + /// Round done, state unchanged. + Incomplete(HandshakeMachine), + /// Stage complete. + StageFinished(StageResult), +} + +/// The result of the stage. +pub enum StageResult { + /// Reading round finished. + DoneReading { result: Obj, stream: Stream, tail: Vec }, + /// Writing round finished. + DoneWriting(Stream), +} + +/// The parseable object. +pub trait TryParse: Sized { + /// Return Ok(None) if incomplete, Err on syntax error. + fn try_parse(data: &[u8]) -> Result>; +} + +/// The handshake state. +enum HandshakeState { + /// Reading data from the peer. + Reading(InputBuffer), + /// Sending data to the peer. + Writing(Cursor>), +} diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 61220d3..859ff4d 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -1,63 +1,89 @@ +pub mod headers; pub mod client; pub mod server; -#[cfg(feature="tls")] -pub mod tls; -use std::ascii::AsciiExt; -use std::str::from_utf8; -use std::slice; +mod machine; + +use std::io::{Read, Write}; use base64; -use bytes::Buf; -use httparse; -use httparse::Status; use sha1::Sha1; -use error::Result; +use error::Error; +use protocol::WebSocket; + +use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; + +/// A WebSocket handshake. +pub struct MidHandshake { + role: Role, + machine: HandshakeMachine, +} -// Limit the number of header lines. -const MAX_HEADERS: usize = 124; +impl MidHandshake { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &Stream { + self.machine.get_ref() + } + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut Stream { + self.machine.get_mut() + } +} -/// A handshake state. -pub trait Handshake: Sized { - /// Resulting stream of this handshake. - type Stream; - /// Perform a single handshake round. - fn handshake(self) -> Result>; - /// Perform handshake to the end in a blocking mode. - fn handshake_wait(self) -> Result { - let mut hs = self; +impl MidHandshake { + /// Restarts the handshake process. + pub fn handshake(self) -> Result, HandshakeError> { + let mut mach = self.machine; loop { - hs = match hs.handshake()? { - HandshakeResult::Done(stream) => return Ok(stream), - HandshakeResult::Incomplete(s) => s, + mach = match mach.single_round()? { + RoundResult::WouldBlock(m) => { + return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) + } + RoundResult::Incomplete(m) => m, + RoundResult::StageFinished(s) => { + match self.role.stage_finished(s)? { + ProcessingResult::Continue(m) => m, + ProcessingResult::Done(ws) => return Ok(ws), + } + } } } } } /// A handshake result. -pub enum HandshakeResult { - /// Handshake is done, a WebSocket stream is ready. - Done(H::Stream), - /// Handshake is not done, call handshake() again. - Incomplete(H), +pub enum HandshakeError { + /// Handshake was interrupted (would block). + Interrupted(MidHandshake), + /// Handshake failed. + Failure(Error), } -impl HandshakeResult { - pub fn map(self, func: F) -> HandshakeResult - where R: Handshake, - F: FnOnce(H) -> R, - { - match self { - HandshakeResult::Done(s) => HandshakeResult::Done(s), - HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)), - } +impl From for HandshakeError { + fn from(err: Error) -> Self { + HandshakeError::Failure(err) } } +/// Handshake role. +pub trait HandshakeRole { + #[doc(hidden)] + type IncomingData: TryParse; + #[doc(hidden)] + fn stage_finished(&self, finish: StageResult) + -> Result, Error>; +} + +/// Stage processing result. +#[doc(hidden)] +pub enum ProcessingResult { + Continue(HandshakeMachine), + Done(WebSocket), +} + /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. -fn convert_key(input: &[u8]) -> Result { +fn convert_key(input: &[u8]) -> Result { // ... field is constructed by concatenating /key/ ... // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -67,113 +93,10 @@ fn convert_key(input: &[u8]) -> Result { Ok(base64::encode(&sha1.digest().bytes())) } -/// HTTP request or response headers. -#[derive(Debug)] -pub struct Headers { - data: Vec<(String, Box<[u8]>)>, -} - -impl Headers { - - /// Get first header with the given name, if any. - pub fn find_first(&self, name: &str) -> Option<&[u8]> { - self.find(name).next() - } - - /// Iterate over all headers with the given name. - pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { - HeadersIter { - name: name, - iter: self.data.iter() - } - } - - /// Check if the given header has the given value. - pub fn header_is(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .map(|v| v == value.as_bytes()) - .unwrap_or(false) - } - - /// Check if the given header has the given value (case-insensitive). - pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name).ok_or(()) - .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) - .map(|val| val.eq_ignore_ascii_case(value)) - .unwrap_or(false) - } - - /// Try to parse data and return headers, if any. - fn parse(input: &mut B) -> Result> { - Headers::parse_http(input) - } - -} - -/// The iterator over headers. -pub struct HeadersIter<'name, 'headers> { - name: &'name str, - iter: slice::Iter<'headers, (String, Box<[u8]>)>, -} - -impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { - type Item = &'headers [u8]; - fn next(&mut self) -> Option { - while let Some(&(ref name, ref value)) = self.iter.next() { - if name.eq_ignore_ascii_case(self.name) { - return Some(value) - } - } - None - } -} - - -/// Trait to read HTTP parseable objects. -trait Httparse: Sized { - fn httparse(buf: &[u8]) -> Result>; - fn parse_http(input: &mut B) -> Result> { - Ok(match Self::httparse(input.bytes())? { - Some((size, obj)) => { - input.advance(size); - Some(obj) - }, - None => None, - }) - } -} - -/// Trait to convert raw objects into HTTP parseables. -trait FromHttparse: Sized { - fn from_httparse(raw: T) -> Result; -} - -impl Httparse for Headers { - fn httparse(buf: &[u8]) -> Result> { - let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; - Ok(match httparse::parse_headers(buf, &mut hbuffer)? { - Status::Partial => None, - Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), - }) - } -} - -impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { - fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { - Ok(Headers { - data: raw.iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), - }) - } -} - #[cfg(test)] mod tests { - use super::{Headers, convert_key}; - - use std::io::Cursor; + use super::convert_key; #[test] fn key_conversion() { @@ -182,50 +105,4 @@ mod tests { "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } - #[test] - fn headers() { - const data: &'static [u8] = - b"Host: foo.com\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - \r\n"; - let mut inp = Cursor::new(data); - let hdr = Headers::parse(&mut inp).unwrap().unwrap(); - assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); - assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); - assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); - - assert!(hdr.header_is("upgrade", "websocket")); - assert!(!hdr.header_is("upgrade", "Websocket")); - assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); - } - - #[test] - fn headers_iter() { - const data: &'static [u8] = - b"Host: foo.com\r\n\ - Sec-WebSocket-Extensions: permessage-deflate\r\n\ - Connection: Upgrade\r\n\ - Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ - Upgrade: websocket\r\n\ - \r\n"; - let mut inp = Cursor::new(data); - let hdr = Headers::parse(&mut inp).unwrap().unwrap(); - let mut iter = hdr.find("Sec-WebSocket-Extensions"); - assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); - assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); - assert_eq!(iter.next(), None); - } - - #[test] - fn headers_incomplete() { - const data: &'static [u8] = - b"Host: foo.com\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n"; - let mut inp = Cursor::new(data); - let hdr = Headers::parse(&mut inp).unwrap(); - assert!(hdr.is_none()); - } - } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index f94f3b3..3615bf5 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,33 +1,20 @@ -use std::io::{Cursor, Read, Write}; -use bytes::Buf; use httparse; use httparse::Status; -use input_buffer::{InputBuffer, MIN_READ}; +//use input_buffer::{InputBuffer, MIN_READ}; use error::{Error, Result}; use protocol::{WebSocket, Role}; -use super::{ - Handshake, - HandshakeResult, - Headers, - Httparse, - FromHttparse, - convert_key, - MAX_HEADERS -}; -use util::NonBlockingResult; +use super::headers::{Headers, FromHttparse, MAX_HEADERS}; +use super::machine::{HandshakeMachine, StageResult, TryParse}; +use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; /// Request from the client. pub struct Request { - path: String, - headers: Headers, + pub path: String, + pub headers: Headers, } impl Request { - /// Parse the request from a stream. - pub fn parse(input: &mut B) -> Result> { - Request::parse_http(input) - } /// Reply to the response. pub fn reply(&self) -> Result> { let key = self.headers.find_first("Sec-WebSocket-Key") @@ -42,8 +29,8 @@ impl Request { } } -impl Httparse for Request { - fn httparse(buf: &[u8]) -> Result> { +impl TryParse for Request { + fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Request::new(&mut hbuffer); Ok(match req.parse(buf)? { @@ -68,76 +55,50 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } -/// Server handshake -pub struct ServerHandshake { - stream: Stream, - state: HandshakeState, -} +/// Server handshake role. +#[allow(missing_copy_implementations)] +pub struct ServerHandshake; -impl ServerHandshake { - /// Start a new server handshake on top of given stream. - pub fn new(stream: Stream) -> Self { - ServerHandshake { - stream: stream, - state: HandshakeState::ReceivingRequest(InputBuffer::with_capacity(MIN_READ)), +impl ServerHandshake { + /// Start server handshake. + pub fn start(stream: Stream) -> MidHandshake { + MidHandshake { + machine: HandshakeMachine::start_read(stream), + role: ServerHandshake, } } } -impl Handshake for ServerHandshake { - type Stream = WebSocket; - fn handshake(mut self) -> Result> { - debug!("Performing server handshake..."); - match self.state { - HandshakeState::ReceivingRequest(mut req_buf) => { - req_buf.reserve(MIN_READ, usize::max_value()) - .map_err(|_| Error::Capacity("Header too long".into()))?; - req_buf.read_from(&mut self.stream).no_block()?; - let state = if let Some(req) = Request::parse(&mut req_buf)? { - let resp = req.reply()?; - HandshakeState::SendingResponse(Cursor::new(resp)) - } else { - HandshakeState::ReceivingRequest(req_buf) - }; - Ok(HandshakeResult::Incomplete(ServerHandshake { - state: state, - ..self - })) - } - HandshakeState::SendingResponse(mut resp) => { - let size = self.stream.write(Buf::bytes(&resp)).no_block()?.unwrap_or(0); - Buf::advance(&mut resp, size); - if resp.has_remaining() { - Ok(HandshakeResult::Incomplete(ServerHandshake { - state: HandshakeState::SendingResponse(resp), - ..self - })) - } else { - let ws = WebSocket::from_raw_socket(self.stream, Role::Server); - Ok(HandshakeResult::Done(ws)) +impl HandshakeRole for ServerHandshake { + type IncomingData = Request; + fn stage_finished(&self, finish: StageResult) + -> Result> + { + Ok(match finish { + StageResult::DoneReading { stream, result, tail } => { + if ! tail.is_empty() { + return Err(Error::Protocol("Junk after client request".into())) } + let response = result.reply()?; + ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } - } + StageResult::DoneWriting(stream) => { + ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) + } + }) } } -enum HandshakeState { - ReceivingRequest(InputBuffer), - SendingResponse(Cursor>), -} - #[cfg(test)] mod tests { use super::Request; - - use std::io::Cursor; + use super::super::machine::TryParse; #[test] fn request_parsing() { const data: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; - let mut inp = Cursor::new(data); - let req = Request::parse(&mut inp).unwrap().unwrap(); + let (_, req) = Request::try_parse(data).unwrap().unwrap(); assert_eq!(req.path, "/script.ws"); assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); } @@ -152,9 +113,8 @@ mod tests { Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; - let mut inp = Cursor::new(data); - let req = Request::parse(&mut inp).unwrap().unwrap(); - let reply = req.reply().unwrap(); + let (_, req) = Request::try_parse(data).unwrap().unwrap(); + let _ = req.reply().unwrap(); } } diff --git a/src/handshake/tls.rs b/src/handshake/tls.rs deleted file mode 100644 index c093f2f..0000000 --- a/src/handshake/tls.rs +++ /dev/null @@ -1,14 +0,0 @@ -use native_tls; - -use stream::Stream; -use super::{Handshake, HandshakeResult}; - -pub struct TlsHandshake { - -} - -impl Handshale for TlsHandshake { - type Stream = Stream; - fn handshake(self) -> Result> { - } -} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index c8ddccf..efe0d40 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -40,10 +40,7 @@ pub struct WebSocket { pong: Option, } -impl WebSocket - where Stream: Read + Write -{ - +impl WebSocket { /// Convert a raw socket into a WebSocket without performing a handshake. pub fn from_raw_socket(stream: Stream, role: Role) -> Self { WebSocket::from_frame_socket(FrameSocket::new(stream), role) @@ -54,6 +51,20 @@ impl WebSocket WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) } + /// Convert a frame socket into a WebSocket. + fn from_frame_socket(socket: FrameSocket, role: Role) -> Self { + WebSocket { + role: role, + socket: socket, + state: WebSocketState::Active, + incomplete: None, + send_queue: VecDeque::new(), + pong: None, + } + } +} + +impl WebSocket { /// Read a message from stream, if possible. /// /// This function sends pong and close responses automatically. @@ -141,18 +152,6 @@ impl WebSocket } } - /// Convert a frame socket into a WebSocket. - fn from_frame_socket(socket: FrameSocket, role: Role) -> Self { - WebSocket { - role: role, - socket: socket, - state: WebSocketState::Active, - incomplete: None, - send_queue: VecDeque::new(), - pong: None, - } - } - /// Try to decode one message frame. May return None. fn read_message_frame(&mut self) -> Result> { if let Some(mut frame) = self.socket.read_frame()? { diff --git a/src/server.rs b/src/server.rs index fca1368..9c5e4f0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,13 @@ -use std::net::TcpStream; +pub use handshake::server::ServerHandshake; -use handshake::server::ServerHandshake; +use handshake::HandshakeError; +use protocol::WebSocket; + +use std::io::{Read, Write}; /// Accept the given TcpStream as a WebSocket. -pub fn accept(stream: TcpStream) -> ServerHandshake { - ServerHandshake::new(stream) +pub fn accept(stream: Stream) + -> Result, HandshakeError> +{ + ServerHandshake::start(stream).handshake() } diff --git a/src/stream.rs b/src/stream.rs index 3818bb3..c8c88d0 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,38 +1,37 @@ -#[cfg(feature="tls")] -use native_tls::TlsStream; - -use std::net::TcpStream; use std::io::{Read, Write, Result as IoResult}; +/// Stream mode, either plain TCP or TLS. +#[derive(Clone, Copy)] +pub enum Mode { + Plain, + Tls, +} + /// Stream, either plain TCP or TLS. -pub enum Stream { - Plain(TcpStream), - #[cfg(feature="tls")] - Tls(TlsStream), +pub enum Stream { + Plain(S), + Tls(T), } -impl Read for Stream { +impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> IoResult { match *self { Stream::Plain(ref mut s) => s.read(buf), - #[cfg(feature="tls")] Stream::Tls(ref mut s) => s.read(buf), } } } -impl Write for Stream { +impl Write for Stream { fn write(&mut self, buf: &[u8]) -> IoResult { match *self { Stream::Plain(ref mut s) => s.write(buf), - #[cfg(feature="tls")] Stream::Tls(ref mut s) => s.write(buf), } } fn flush(&mut self) -> IoResult<()> { match *self { Stream::Plain(ref mut s) => s.flush(), - #[cfg(feature="tls")] Stream::Tls(ref mut s) => s.flush(), } }