commit
087636c3ad
@ -1,75 +1,97 @@ |
||||
use std::net::{TcpStream, ToSocketAddrs}; |
||||
use url::{Url, SocketAddrs}; |
||||
use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; |
||||
use std::result::Result as StdResult; |
||||
use std::io::{Read, Write}; |
||||
|
||||
use url::Url; |
||||
|
||||
#[cfg(feature="tls")] |
||||
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError}; |
||||
|
||||
use protocol::WebSocket; |
||||
use handshake::{Handshake as HandshakeTrait, HandshakeResult}; |
||||
use handshake::HandshakeError; |
||||
use handshake::client::{ClientHandshake, Request}; |
||||
use stream::Mode; |
||||
use error::{Error, Result}; |
||||
|
||||
/// Connect to the given WebSocket.
|
||||
///
|
||||
/// Note that this function may block the current thread while DNS resolution is performed.
|
||||
pub fn connect(url: Url) -> Result<Handshake> { |
||||
let mode = match url.scheme() { |
||||
"ws" => Mode::Plain, |
||||
#[cfg(feature="tls")] |
||||
"wss" => Mode::Tls, |
||||
_ => return Err(Error::Url("URL scheme not supported".into())) |
||||
}; |
||||
#[cfg(feature="tls")] |
||||
use stream::Stream as StreamSwitcher; |
||||
|
||||
// Note that this function may block the current thread while resolution is performed.
|
||||
let addrs = url.to_socket_addrs()?; |
||||
Ok(Handshake { |
||||
state: HandshakeState::Nothing(url), |
||||
alt_addresses: addrs, |
||||
}) |
||||
} |
||||
#[cfg(feature="tls")] |
||||
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>; |
||||
#[cfg(not(feature="tls"))] |
||||
pub type AutoStream = TcpStream; |
||||
|
||||
enum Mode { |
||||
Plain, |
||||
Tls, |
||||
/// Connect to the given WebSocket in blocking mode.
|
||||
///
|
||||
/// The URL may be either ws:// or wss://.
|
||||
/// To support wss:// URLs, feature "tls" must be turned on.
|
||||
pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> { |
||||
let mode = url_mode(&url)?; |
||||
let addrs = url.to_socket_addrs()?; |
||||
let stream = connect_to_some(addrs, &url, mode)?; |
||||
client(url.clone(), stream) |
||||
.map_err(|e| match e { |
||||
HandshakeError::Failure(f) => f, |
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), |
||||
}) |
||||
} |
||||
|
||||
enum HandshakeState { |
||||
Nothing(Url), |
||||
WebSocket(ClientHandshake<TcpStream>), |
||||
#[cfg(feature="tls")] |
||||
fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> { |
||||
match mode { |
||||
Mode::Plain => Ok(StreamSwitcher::Plain(stream)), |
||||
Mode::Tls => { |
||||
let connector = TlsConnector::builder()?.build()?; |
||||
connector.connect(domain, stream) |
||||
.map_err(|e| match e { |
||||
TlsHandshakeError::Failure(f) => f.into(), |
||||
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), |
||||
}) |
||||
.map(|s| StreamSwitcher::Tls(s)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub struct Handshake { |
||||
state: HandshakeState, |
||||
alt_addresses: SocketAddrs, |
||||
#[cfg(not(feature="tls"))] |
||||
fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> { |
||||
match mode { |
||||
Mode::Plain => Ok(stream), |
||||
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), |
||||
} |
||||
} |
||||
|
||||
impl HandshakeTrait for Handshake { |
||||
type Stream = WebSocket<TcpStream>; |
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> { |
||||
match self.state { |
||||
HandshakeState::Nothing(url) => { |
||||
if let Some(addr) = self.alt_addresses.next() { |
||||
debug!("Trying to contact {} at {}...", url, addr); |
||||
let state = { |
||||
if let Ok(stream) = TcpStream::connect(addr) { |
||||
let hs = ClientHandshake::new(stream, Request { url: url }); |
||||
HandshakeState::WebSocket(hs) |
||||
} else { |
||||
HandshakeState::Nothing(url) |
||||
} |
||||
}; |
||||
Ok(HandshakeResult::Incomplete(Handshake { |
||||
state: state, |
||||
..self |
||||
})) |
||||
} else { |
||||
Err(Error::Url(format!("Unable to resolve {}", url).into())) |
||||
} |
||||
} |
||||
HandshakeState::WebSocket(ws) => { |
||||
let alt_addresses = self.alt_addresses; |
||||
ws.handshake().map(move |r| r.map(move |s| Handshake { |
||||
state: HandshakeState::WebSocket(s), |
||||
alt_addresses: alt_addresses, |
||||
})) |
||||
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream> |
||||
where A: Iterator<Item=SocketAddr> |
||||
{ |
||||
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?; |
||||
for addr in addrs { |
||||
debug!("Trying to contact {} at {}...", url, addr); |
||||
if let Ok(raw_stream) = TcpStream::connect(addr) { |
||||
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { |
||||
return Ok(stream) |
||||
} |
||||
} |
||||
} |
||||
Err(Error::Url(format!("Unable to connect to {}", url).into())) |
||||
} |
||||
|
||||
/// Get the mode of the given URL.
|
||||
///
|
||||
/// This function may be used in non-blocking implementations.
|
||||
pub fn url_mode(url: &Url) -> Result<Mode> { |
||||
match url.scheme() { |
||||
"ws" => Ok(Mode::Plain), |
||||
"wss" => Ok(Mode::Tls), |
||||
_ => Err(Error::Url("URL scheme not supported".into())) |
||||
} |
||||
} |
||||
|
||||
/// Do the client handshake over the given stream.
|
||||
///
|
||||
/// Use this function if you need a nonblocking handshake support.
|
||||
pub fn client<Stream: Read + Write>(url: Url, stream: Stream) |
||||
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>> |
||||
{ |
||||
let request = Request { url: url }; |
||||
ClientHandshake::start(stream, request).handshake() |
||||
} |
||||
|
@ -0,0 +1,145 @@ |
||||
use std::ascii::AsciiExt; |
||||
use std::str::from_utf8; |
||||
use std::slice; |
||||
|
||||
use httparse; |
||||
use httparse::Status; |
||||
|
||||
use error::Result; |
||||
use super::machine::TryParse; |
||||
|
||||
// Limit the number of header lines.
|
||||
pub const MAX_HEADERS: usize = 124; |
||||
|
||||
/// HTTP request or response headers.
|
||||
#[derive(Debug)] |
||||
pub struct Headers { |
||||
data: Vec<(String, Box<[u8]>)>, |
||||
} |
||||
|
||||
impl Headers { |
||||
|
||||
/// Get first header with the given name, if any.
|
||||
pub fn find_first(&self, name: &str) -> Option<&[u8]> { |
||||
self.find(name).next() |
||||
} |
||||
|
||||
/// Iterate over all headers with the given name.
|
||||
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { |
||||
HeadersIter { |
||||
name: name, |
||||
iter: self.data.iter() |
||||
} |
||||
} |
||||
|
||||
/// Check if the given header has the given value.
|
||||
pub fn header_is(&self, name: &str, value: &str) -> bool { |
||||
self.find_first(name) |
||||
.map(|v| v == value.as_bytes()) |
||||
.unwrap_or(false) |
||||
} |
||||
|
||||
/// Check if the given header has the given value (case-insensitive).
|
||||
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { |
||||
self.find_first(name).ok_or(()) |
||||
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) |
||||
.map(|val| val.eq_ignore_ascii_case(value)) |
||||
.unwrap_or(false) |
||||
} |
||||
|
||||
} |
||||
|
||||
/// The iterator over headers.
|
||||
pub struct HeadersIter<'name, 'headers> { |
||||
name: &'name str, |
||||
iter: slice::Iter<'headers, (String, Box<[u8]>)>, |
||||
} |
||||
|
||||
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { |
||||
type Item = &'headers [u8]; |
||||
fn next(&mut self) -> Option<Self::Item> { |
||||
while let Some(&(ref name, ref value)) = self.iter.next() { |
||||
if name.eq_ignore_ascii_case(self.name) { |
||||
return Some(value) |
||||
} |
||||
} |
||||
None |
||||
} |
||||
} |
||||
|
||||
|
||||
/// Trait to convert raw objects into HTTP parseables.
|
||||
pub trait FromHttparse<T>: Sized { |
||||
fn from_httparse(raw: T) -> Result<Self>; |
||||
} |
||||
|
||||
impl TryParse for Headers { |
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
||||
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { |
||||
Status::Partial => None, |
||||
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { |
||||
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> { |
||||
Ok(Headers { |
||||
data: raw.iter() |
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) |
||||
.collect(), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
|
||||
use super::Headers; |
||||
use super::super::machine::TryParse; |
||||
|
||||
#[test] |
||||
fn headers() { |
||||
const data: &'static [u8] = |
||||
b"Host: foo.com\r\n\ |
||||
Connection: Upgrade\r\n\ |
||||
Upgrade: websocket\r\n\ |
||||
\r\n"; |
||||
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); |
||||
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); |
||||
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); |
||||
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); |
||||
|
||||
assert!(hdr.header_is("upgrade", "websocket")); |
||||
assert!(!hdr.header_is("upgrade", "Websocket")); |
||||
assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); |
||||
} |
||||
|
||||
#[test] |
||||
fn headers_iter() { |
||||
const data: &'static [u8] = |
||||
b"Host: foo.com\r\n\ |
||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\ |
||||
Connection: Upgrade\r\n\ |
||||
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ |
||||
Upgrade: websocket\r\n\ |
||||
\r\n"; |
||||
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); |
||||
let mut iter = hdr.find("Sec-WebSocket-Extensions"); |
||||
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); |
||||
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); |
||||
assert_eq!(iter.next(), None); |
||||
} |
||||
|
||||
#[test] |
||||
fn headers_incomplete() { |
||||
const data: &'static [u8] = |
||||
b"Host: foo.com\r\n\ |
||||
Connection: Upgrade\r\n\ |
||||
Upgrade: websocket\r\n"; |
||||
let hdr = Headers::try_parse(data).unwrap(); |
||||
assert!(hdr.is_none()); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,119 @@ |
||||
use std::io::{Cursor, Read, Write}; |
||||
use bytes::Buf; |
||||
|
||||
use input_buffer::{InputBuffer, MIN_READ}; |
||||
use error::{Error, Result}; |
||||
use util::NonBlockingResult; |
||||
|
||||
/// A generic handshake state machine.
|
||||
pub struct HandshakeMachine<Stream> { |
||||
stream: Stream, |
||||
state: HandshakeState, |
||||
} |
||||
|
||||
impl<Stream> HandshakeMachine<Stream> { |
||||
/// Start reading data from the peer.
|
||||
pub fn start_read(stream: Stream) -> Self { |
||||
HandshakeMachine { |
||||
stream: stream, |
||||
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), |
||||
} |
||||
} |
||||
/// Start writing data to the peer.
|
||||
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { |
||||
HandshakeMachine { |
||||
stream: stream, |
||||
state: HandshakeState::Writing(Cursor::new(data.into())), |
||||
} |
||||
} |
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &Stream { |
||||
&self.stream |
||||
} |
||||
/// Returns a mutable reference to the inner stream.
|
||||
pub fn get_mut(&mut self) -> &mut Stream { |
||||
&mut self.stream |
||||
} |
||||
} |
||||
|
||||
impl<Stream: Read + Write> HandshakeMachine<Stream> { |
||||
/// Perform a single handshake round.
|
||||
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { |
||||
Ok(match self.state { |
||||
HandshakeState::Reading(mut buf) => { |
||||
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
|
||||
.map_err(|_| Error::Capacity("Header too long".into()))?; |
||||
if let Some(_) = buf.read_from(&mut self.stream).no_block()? { |
||||
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { |
||||
buf.advance(size); |
||||
RoundResult::StageFinished(StageResult::DoneReading { |
||||
result: obj, |
||||
stream: self.stream, |
||||
tail: buf.into_vec(), |
||||
}) |
||||
} else { |
||||
RoundResult::Incomplete(HandshakeMachine { |
||||
state: HandshakeState::Reading(buf), |
||||
..self |
||||
}) |
||||
} |
||||
} else { |
||||
RoundResult::WouldBlock(HandshakeMachine { |
||||
state: HandshakeState::Reading(buf), |
||||
..self |
||||
}) |
||||
} |
||||
} |
||||
HandshakeState::Writing(mut buf) => { |
||||
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? { |
||||
buf.advance(size); |
||||
if buf.has_remaining() { |
||||
RoundResult::Incomplete(HandshakeMachine { |
||||
state: HandshakeState::Writing(buf), |
||||
..self |
||||
}) |
||||
} else { |
||||
RoundResult::StageFinished(StageResult::DoneWriting(self.stream)) |
||||
} |
||||
} else { |
||||
RoundResult::WouldBlock(HandshakeMachine { |
||||
state: HandshakeState::Writing(buf), |
||||
..self |
||||
}) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
/// The result of the round.
|
||||
pub enum RoundResult<Obj, Stream> { |
||||
/// Round not done, I/O would block.
|
||||
WouldBlock(HandshakeMachine<Stream>), |
||||
/// Round done, state unchanged.
|
||||
Incomplete(HandshakeMachine<Stream>), |
||||
/// Stage complete.
|
||||
StageFinished(StageResult<Obj, Stream>), |
||||
} |
||||
|
||||
/// The result of the stage.
|
||||
pub enum StageResult<Obj, Stream> { |
||||
/// Reading round finished.
|
||||
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> }, |
||||
/// Writing round finished.
|
||||
DoneWriting(Stream), |
||||
} |
||||
|
||||
/// The parseable object.
|
||||
pub trait TryParse: Sized { |
||||
/// Return Ok(None) if incomplete, Err on syntax error.
|
||||
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>; |
||||
} |
||||
|
||||
/// The handshake state.
|
||||
enum HandshakeState { |
||||
/// Reading data from the peer.
|
||||
Reading(InputBuffer), |
||||
/// Sending data to the peer.
|
||||
Writing(Cursor<Vec<u8>>), |
||||
} |
@ -1,14 +0,0 @@ |
||||
use native_tls; |
||||
|
||||
use stream::Stream; |
||||
use super::{Handshake, HandshakeResult}; |
||||
|
||||
pub struct TlsHandshake { |
||||
|
||||
} |
||||
|
||||
impl Handshale for TlsHandshake { |
||||
type Stream = Stream; |
||||
fn handshake(self) -> Result<HandshakeResult<Self>> { |
||||
} |
||||
} |
@ -1,8 +1,13 @@ |
||||
use std::net::TcpStream; |
||||
pub use handshake::server::ServerHandshake; |
||||
|
||||
use handshake::server::ServerHandshake; |
||||
use handshake::HandshakeError; |
||||
use protocol::WebSocket; |
||||
|
||||
use std::io::{Read, Write}; |
||||
|
||||
/// Accept the given TcpStream as a WebSocket.
|
||||
pub fn accept(stream: TcpStream) -> ServerHandshake<TcpStream> { |
||||
ServerHandshake::new(stream) |
||||
pub fn accept<Stream: Read + Write>(stream: Stream) |
||||
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>> |
||||
{ |
||||
ServerHandshake::start(stream).handshake() |
||||
} |
||||
|
Loading…
Reference in new issue