commit
087636c3ad
@ -1,75 +1,97 @@ |
|||||||
use std::net::{TcpStream, ToSocketAddrs}; |
use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; |
||||||
use url::{Url, SocketAddrs}; |
use std::result::Result as StdResult; |
||||||
|
use std::io::{Read, Write}; |
||||||
|
|
||||||
|
use url::Url; |
||||||
|
|
||||||
|
#[cfg(feature="tls")] |
||||||
|
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError}; |
||||||
|
|
||||||
use protocol::WebSocket; |
use protocol::WebSocket; |
||||||
use handshake::{Handshake as HandshakeTrait, HandshakeResult}; |
use handshake::HandshakeError; |
||||||
use handshake::client::{ClientHandshake, Request}; |
use handshake::client::{ClientHandshake, Request}; |
||||||
|
use stream::Mode; |
||||||
use error::{Error, Result}; |
use error::{Error, Result}; |
||||||
|
|
||||||
/// Connect to the given WebSocket.
|
#[cfg(feature="tls")] |
||||||
///
|
use stream::Stream as StreamSwitcher; |
||||||
/// Note that this function may block the current thread while DNS resolution is performed.
|
|
||||||
pub fn connect(url: Url) -> Result<Handshake> { |
|
||||||
let mode = match url.scheme() { |
|
||||||
"ws" => Mode::Plain, |
|
||||||
#[cfg(feature="tls")] |
|
||||||
"wss" => Mode::Tls, |
|
||||||
_ => return Err(Error::Url("URL scheme not supported".into())) |
|
||||||
}; |
|
||||||
|
|
||||||
// Note that this function may block the current thread while resolution is performed.
|
#[cfg(feature="tls")] |
||||||
let addrs = url.to_socket_addrs()?; |
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>; |
||||||
Ok(Handshake { |
#[cfg(not(feature="tls"))] |
||||||
state: HandshakeState::Nothing(url), |
pub type AutoStream = TcpStream; |
||||||
alt_addresses: addrs, |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
enum Mode { |
/// Connect to the given WebSocket in blocking mode.
|
||||||
Plain, |
///
|
||||||
Tls, |
/// The URL may be either ws:// or wss://.
|
||||||
|
/// To support wss:// URLs, feature "tls" must be turned on.
|
||||||
|
pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> { |
||||||
|
let mode = url_mode(&url)?; |
||||||
|
let addrs = url.to_socket_addrs()?; |
||||||
|
let stream = connect_to_some(addrs, &url, mode)?; |
||||||
|
client(url.clone(), stream) |
||||||
|
.map_err(|e| match e { |
||||||
|
HandshakeError::Failure(f) => f, |
||||||
|
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), |
||||||
|
}) |
||||||
} |
} |
||||||
|
|
||||||
enum HandshakeState { |
#[cfg(feature="tls")] |
||||||
Nothing(Url), |
fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> { |
||||||
WebSocket(ClientHandshake<TcpStream>), |
match mode { |
||||||
|
Mode::Plain => Ok(StreamSwitcher::Plain(stream)), |
||||||
|
Mode::Tls => { |
||||||
|
let connector = TlsConnector::builder()?.build()?; |
||||||
|
connector.connect(domain, stream) |
||||||
|
.map_err(|e| match e { |
||||||
|
TlsHandshakeError::Failure(f) => f.into(), |
||||||
|
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), |
||||||
|
}) |
||||||
|
.map(|s| StreamSwitcher::Tls(s)) |
||||||
|
} |
||||||
|
} |
||||||
} |
} |
||||||
|
|
||||||
pub struct Handshake { |
#[cfg(not(feature="tls"))] |
||||||
state: HandshakeState, |
fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> { |
||||||
alt_addresses: SocketAddrs, |
match mode { |
||||||
|
Mode::Plain => Ok(stream), |
||||||
|
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), |
||||||
|
} |
||||||
} |
} |
||||||
|
|
||||||
impl HandshakeTrait for Handshake { |
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream> |
||||||
type Stream = WebSocket<TcpStream>; |
where A: Iterator<Item=SocketAddr> |
||||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> { |
{ |
||||||
match self.state { |
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?; |
||||||
HandshakeState::Nothing(url) => { |
for addr in addrs { |
||||||
if let Some(addr) = self.alt_addresses.next() { |
debug!("Trying to contact {} at {}...", url, addr); |
||||||
debug!("Trying to contact {} at {}...", url, addr); |
if let Ok(raw_stream) = TcpStream::connect(addr) { |
||||||
let state = { |
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { |
||||||
if let Ok(stream) = TcpStream::connect(addr) { |
return Ok(stream) |
||||||
let hs = ClientHandshake::new(stream, Request { url: url }); |
|
||||||
HandshakeState::WebSocket(hs) |
|
||||||
} else { |
|
||||||
HandshakeState::Nothing(url) |
|
||||||
} |
|
||||||
}; |
|
||||||
Ok(HandshakeResult::Incomplete(Handshake { |
|
||||||
state: state, |
|
||||||
..self |
|
||||||
})) |
|
||||||
} else { |
|
||||||
Err(Error::Url(format!("Unable to resolve {}", url).into())) |
|
||||||
} |
|
||||||
} |
|
||||||
HandshakeState::WebSocket(ws) => { |
|
||||||
let alt_addresses = self.alt_addresses; |
|
||||||
ws.handshake().map(move |r| r.map(move |s| Handshake { |
|
||||||
state: HandshakeState::WebSocket(s), |
|
||||||
alt_addresses: alt_addresses, |
|
||||||
})) |
|
||||||
} |
} |
||||||
} |
} |
||||||
} |
} |
||||||
|
Err(Error::Url(format!("Unable to connect to {}", url).into())) |
||||||
|
} |
||||||
|
|
||||||
|
/// Get the mode of the given URL.
|
||||||
|
///
|
||||||
|
/// This function may be used in non-blocking implementations.
|
||||||
|
pub fn url_mode(url: &Url) -> Result<Mode> { |
||||||
|
match url.scheme() { |
||||||
|
"ws" => Ok(Mode::Plain), |
||||||
|
"wss" => Ok(Mode::Tls), |
||||||
|
_ => Err(Error::Url("URL scheme not supported".into())) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// Do the client handshake over the given stream.
|
||||||
|
///
|
||||||
|
/// Use this function if you need a nonblocking handshake support.
|
||||||
|
pub fn client<Stream: Read + Write>(url: Url, stream: Stream) |
||||||
|
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>> |
||||||
|
{ |
||||||
|
let request = Request { url: url }; |
||||||
|
ClientHandshake::start(stream, request).handshake() |
||||||
} |
} |
||||||
|
@ -0,0 +1,145 @@ |
|||||||
|
use std::ascii::AsciiExt; |
||||||
|
use std::str::from_utf8; |
||||||
|
use std::slice; |
||||||
|
|
||||||
|
use httparse; |
||||||
|
use httparse::Status; |
||||||
|
|
||||||
|
use error::Result; |
||||||
|
use super::machine::TryParse; |
||||||
|
|
||||||
|
// Limit the number of header lines.
|
||||||
|
pub const MAX_HEADERS: usize = 124; |
||||||
|
|
||||||
|
/// HTTP request or response headers.
|
||||||
|
#[derive(Debug)] |
||||||
|
pub struct Headers { |
||||||
|
data: Vec<(String, Box<[u8]>)>, |
||||||
|
} |
||||||
|
|
||||||
|
impl Headers { |
||||||
|
|
||||||
|
/// Get first header with the given name, if any.
|
||||||
|
pub fn find_first(&self, name: &str) -> Option<&[u8]> { |
||||||
|
self.find(name).next() |
||||||
|
} |
||||||
|
|
||||||
|
/// Iterate over all headers with the given name.
|
||||||
|
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { |
||||||
|
HeadersIter { |
||||||
|
name: name, |
||||||
|
iter: self.data.iter() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// Check if the given header has the given value.
|
||||||
|
pub fn header_is(&self, name: &str, value: &str) -> bool { |
||||||
|
self.find_first(name) |
||||||
|
.map(|v| v == value.as_bytes()) |
||||||
|
.unwrap_or(false) |
||||||
|
} |
||||||
|
|
||||||
|
/// Check if the given header has the given value (case-insensitive).
|
||||||
|
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { |
||||||
|
self.find_first(name).ok_or(()) |
||||||
|
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) |
||||||
|
.map(|val| val.eq_ignore_ascii_case(value)) |
||||||
|
.unwrap_or(false) |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
/// The iterator over headers.
|
||||||
|
pub struct HeadersIter<'name, 'headers> { |
||||||
|
name: &'name str, |
||||||
|
iter: slice::Iter<'headers, (String, Box<[u8]>)>, |
||||||
|
} |
||||||
|
|
||||||
|
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { |
||||||
|
type Item = &'headers [u8]; |
||||||
|
fn next(&mut self) -> Option<Self::Item> { |
||||||
|
while let Some(&(ref name, ref value)) = self.iter.next() { |
||||||
|
if name.eq_ignore_ascii_case(self.name) { |
||||||
|
return Some(value) |
||||||
|
} |
||||||
|
} |
||||||
|
None |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
/// Trait to convert raw objects into HTTP parseables.
|
||||||
|
pub trait FromHttparse<T>: Sized { |
||||||
|
fn from_httparse(raw: T) -> Result<Self>; |
||||||
|
} |
||||||
|
|
||||||
|
impl TryParse for Headers { |
||||||
|
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
||||||
|
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
||||||
|
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { |
||||||
|
Status::Partial => None, |
||||||
|
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { |
||||||
|
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> { |
||||||
|
Ok(Headers { |
||||||
|
data: raw.iter() |
||||||
|
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) |
||||||
|
.collect(), |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
#[cfg(test)] |
||||||
|
mod tests { |
||||||
|
|
||||||
|
use super::Headers; |
||||||
|
use super::super::machine::TryParse; |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn headers() { |
||||||
|
const data: &'static [u8] = |
||||||
|
b"Host: foo.com\r\n\ |
||||||
|
Connection: Upgrade\r\n\ |
||||||
|
Upgrade: websocket\r\n\ |
||||||
|
\r\n"; |
||||||
|
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); |
||||||
|
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); |
||||||
|
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); |
||||||
|
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); |
||||||
|
|
||||||
|
assert!(hdr.header_is("upgrade", "websocket")); |
||||||
|
assert!(!hdr.header_is("upgrade", "Websocket")); |
||||||
|
assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn headers_iter() { |
||||||
|
const data: &'static [u8] = |
||||||
|
b"Host: foo.com\r\n\ |
||||||
|
Sec-WebSocket-Extensions: permessage-deflate\r\n\ |
||||||
|
Connection: Upgrade\r\n\ |
||||||
|
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ |
||||||
|
Upgrade: websocket\r\n\ |
||||||
|
\r\n"; |
||||||
|
let (_, hdr) = Headers::try_parse(data).unwrap().unwrap(); |
||||||
|
let mut iter = hdr.find("Sec-WebSocket-Extensions"); |
||||||
|
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); |
||||||
|
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); |
||||||
|
assert_eq!(iter.next(), None); |
||||||
|
} |
||||||
|
|
||||||
|
#[test] |
||||||
|
fn headers_incomplete() { |
||||||
|
const data: &'static [u8] = |
||||||
|
b"Host: foo.com\r\n\ |
||||||
|
Connection: Upgrade\r\n\ |
||||||
|
Upgrade: websocket\r\n"; |
||||||
|
let hdr = Headers::try_parse(data).unwrap(); |
||||||
|
assert!(hdr.is_none()); |
||||||
|
} |
||||||
|
|
||||||
|
} |
@ -0,0 +1,119 @@ |
|||||||
|
use std::io::{Cursor, Read, Write}; |
||||||
|
use bytes::Buf; |
||||||
|
|
||||||
|
use input_buffer::{InputBuffer, MIN_READ}; |
||||||
|
use error::{Error, Result}; |
||||||
|
use util::NonBlockingResult; |
||||||
|
|
||||||
|
/// A generic handshake state machine.
|
||||||
|
pub struct HandshakeMachine<Stream> { |
||||||
|
stream: Stream, |
||||||
|
state: HandshakeState, |
||||||
|
} |
||||||
|
|
||||||
|
impl<Stream> HandshakeMachine<Stream> { |
||||||
|
/// Start reading data from the peer.
|
||||||
|
pub fn start_read(stream: Stream) -> Self { |
||||||
|
HandshakeMachine { |
||||||
|
stream: stream, |
||||||
|
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), |
||||||
|
} |
||||||
|
} |
||||||
|
/// Start writing data to the peer.
|
||||||
|
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { |
||||||
|
HandshakeMachine { |
||||||
|
stream: stream, |
||||||
|
state: HandshakeState::Writing(Cursor::new(data.into())), |
||||||
|
} |
||||||
|
} |
||||||
|
/// Returns a shared reference to the inner stream.
|
||||||
|
pub fn get_ref(&self) -> &Stream { |
||||||
|
&self.stream |
||||||
|
} |
||||||
|
/// Returns a mutable reference to the inner stream.
|
||||||
|
pub fn get_mut(&mut self) -> &mut Stream { |
||||||
|
&mut self.stream |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
impl<Stream: Read + Write> HandshakeMachine<Stream> { |
||||||
|
/// Perform a single handshake round.
|
||||||
|
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { |
||||||
|
Ok(match self.state { |
||||||
|
HandshakeState::Reading(mut buf) => { |
||||||
|
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
|
||||||
|
.map_err(|_| Error::Capacity("Header too long".into()))?; |
||||||
|
if let Some(_) = buf.read_from(&mut self.stream).no_block()? { |
||||||
|
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { |
||||||
|
buf.advance(size); |
||||||
|
RoundResult::StageFinished(StageResult::DoneReading { |
||||||
|
result: obj, |
||||||
|
stream: self.stream, |
||||||
|
tail: buf.into_vec(), |
||||||
|
}) |
||||||
|
} else { |
||||||
|
RoundResult::Incomplete(HandshakeMachine { |
||||||
|
state: HandshakeState::Reading(buf), |
||||||
|
..self |
||||||
|
}) |
||||||
|
} |
||||||
|
} else { |
||||||
|
RoundResult::WouldBlock(HandshakeMachine { |
||||||
|
state: HandshakeState::Reading(buf), |
||||||
|
..self |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
HandshakeState::Writing(mut buf) => { |
||||||
|
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? { |
||||||
|
buf.advance(size); |
||||||
|
if buf.has_remaining() { |
||||||
|
RoundResult::Incomplete(HandshakeMachine { |
||||||
|
state: HandshakeState::Writing(buf), |
||||||
|
..self |
||||||
|
}) |
||||||
|
} else { |
||||||
|
RoundResult::StageFinished(StageResult::DoneWriting(self.stream)) |
||||||
|
} |
||||||
|
} else { |
||||||
|
RoundResult::WouldBlock(HandshakeMachine { |
||||||
|
state: HandshakeState::Writing(buf), |
||||||
|
..self |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
/// The result of the round.
|
||||||
|
pub enum RoundResult<Obj, Stream> { |
||||||
|
/// Round not done, I/O would block.
|
||||||
|
WouldBlock(HandshakeMachine<Stream>), |
||||||
|
/// Round done, state unchanged.
|
||||||
|
Incomplete(HandshakeMachine<Stream>), |
||||||
|
/// Stage complete.
|
||||||
|
StageFinished(StageResult<Obj, Stream>), |
||||||
|
} |
||||||
|
|
||||||
|
/// The result of the stage.
|
||||||
|
pub enum StageResult<Obj, Stream> { |
||||||
|
/// Reading round finished.
|
||||||
|
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> }, |
||||||
|
/// Writing round finished.
|
||||||
|
DoneWriting(Stream), |
||||||
|
} |
||||||
|
|
||||||
|
/// The parseable object.
|
||||||
|
pub trait TryParse: Sized { |
||||||
|
/// Return Ok(None) if incomplete, Err on syntax error.
|
||||||
|
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>; |
||||||
|
} |
||||||
|
|
||||||
|
/// The handshake state.
|
||||||
|
enum HandshakeState { |
||||||
|
/// Reading data from the peer.
|
||||||
|
Reading(InputBuffer), |
||||||
|
/// Sending data to the peer.
|
||||||
|
Writing(Cursor<Vec<u8>>), |
||||||
|
} |
@ -1,14 +0,0 @@ |
|||||||
use native_tls; |
|
||||||
|
|
||||||
use stream::Stream; |
|
||||||
use super::{Handshake, HandshakeResult}; |
|
||||||
|
|
||||||
pub struct TlsHandshake { |
|
||||||
|
|
||||||
} |
|
||||||
|
|
||||||
impl Handshale for TlsHandshake { |
|
||||||
type Stream = Stream; |
|
||||||
fn handshake(self) -> Result<HandshakeResult<Self>> { |
|
||||||
} |
|
||||||
} |
|
@ -1,8 +1,13 @@ |
|||||||
use std::net::TcpStream; |
pub use handshake::server::ServerHandshake; |
||||||
|
|
||||||
use handshake::server::ServerHandshake; |
use handshake::HandshakeError; |
||||||
|
use protocol::WebSocket; |
||||||
|
|
||||||
|
use std::io::{Read, Write}; |
||||||
|
|
||||||
/// Accept the given TcpStream as a WebSocket.
|
/// Accept the given TcpStream as a WebSocket.
|
||||||
pub fn accept(stream: TcpStream) -> ServerHandshake<TcpStream> { |
pub fn accept<Stream: Read + Write>(stream: Stream) |
||||||
ServerHandshake::new(stream) |
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>> |
||||||
|
{ |
||||||
|
ServerHandshake::start(stream).handshake() |
||||||
} |
} |
||||||
|
Loading…
Reference in new issue