Implements #6 (client/server headers access)

pull/18/head
Daniel Abramov 8 years ago
parent f34c488217
commit 44a15c9eab
  1. 5
      examples/autobahn-server.rs
  2. 13
      src/client.rs
  3. 43
      src/handshake/client.rs
  4. 49
      src/handshake/mod.rs
  5. 91
      src/handshake/server.rs
  6. 11
      src/server.rs

@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn; use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message}; use tungstenite::{accept, HandshakeError, Error, Result, Message};
use tungstenite::handshake::HandshakeRole;
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
@ -15,7 +16,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let (mut socket, _) = accept(stream).map_err(must_not_block)?; let mut socket = accept(stream, None).map_err(must_not_block)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -6,7 +6,7 @@ use std::io::{Read, Write};
use url::Url; use url::Url;
use handshake::headers::Headers; use handshake::client::Response;
#[cfg(feature="tls")] #[cfg(feature="tls")]
mod encryption { mod encryption {
@ -78,7 +78,7 @@ use error::{Error, Result};
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Headers)> -> Result<(WebSocket<AutoStream>, Response)>
{ {
let request: Request = request.into(); let request: Request = request.into();
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
@ -124,9 +124,12 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream) pub fn client<'t, Stream, Req>(
-> StdResult<(WebSocket<Stream>, Headers), HandshakeError<Stream, ClientHandshake>> request: Req,
where Stream: Read + Write, stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,
{ {
ClientHandshake::start(stream, request.into()).handshake() ClientHandshake::start(stream, request.into()).handshake()

@ -1,18 +1,19 @@
//! Client handshake machine. //! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use base64; use base64;
use rand;
use httparse;
use httparse::Status; use httparse::Status;
use std::io::Write; use httparse;
use rand;
use url::Url; use url::Url;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
/// Client request. /// Client request.
pub struct Request<'t> { pub struct Request<'t> {
@ -52,13 +53,14 @@ impl From<Url> for Request<'static> {
} }
/// Client handshake role. /// Client handshake role.
pub struct ClientHandshake { pub struct ClientHandshake<S> {
verify_data: VerifyData, verify_data: VerifyData,
_marker: PhantomData<S>,
} }
impl ClientHandshake { impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> { pub fn start(stream: S, request: Request) -> MidHandshake<Self> {
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
@ -83,9 +85,8 @@ impl ClientHandshake {
let client = { let client = {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake {
verify_data: VerifyData { verify_data: VerifyData { accept_key },
accept_key: accept_key, _marker: PhantomData,
},
} }
}; };
@ -94,10 +95,12 @@ impl ClientHandshake {
} }
} }
impl HandshakeRole for ClientHandshake { impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response; type IncomingData = Response;
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream = S;
-> Result<ProcessingResult<Stream>> type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{ {
Ok(match finish { Ok(match finish {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
@ -106,8 +109,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client), ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result.headers) result))
} }
}) })
} }
@ -167,8 +170,10 @@ impl VerifyData {
/// Server response. /// Server response.
pub struct Response { pub struct Response {
code: u16, /// HTTP response code of the response.
headers: Headers, pub code: u16,
/// Received headers.
pub headers: Headers,
} }
impl TryParse for Response { impl TryParse for Response {
@ -204,7 +209,6 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Response, generate_key}; use super::{Response, generate_key};
use super::super::machine::TryParse; use super::super::machine::TryParse;
@ -231,5 +235,4 @@ mod tests {
assert_eq!(resp.code, 200); assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
} }
} }

@ -14,31 +14,17 @@ use base64;
use sha1::Sha1; use sha1::Sha1;
use error::Error; use error::Error;
use protocol::WebSocket;
use self::headers::Headers;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake. /// A WebSocket handshake.
pub struct MidHandshake<Stream, Role> { pub struct MidHandshake<Role: HandshakeRole> {
role: Role, role: Role,
machine: HandshakeMachine<Stream>, machine: HandshakeMachine<Role::InternalStream>,
}
impl<Stream, Role> MidHandshake<Stream, Role> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
self.machine.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
}
} }
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> { impl<Role: HandshakeRole> MidHandshake<Role> {
/// Restarts the handshake process. /// Restarts the handshake process.
pub fn handshake(mut self) -> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, Role>> { pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
let mut mach = self.machine; let mut mach = self.machine;
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
@ -49,7 +35,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => { RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? { match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m, ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws, headers) => return Ok((ws, headers)), ProcessingResult::Done(result) => return Ok(result),
} }
} }
} }
@ -58,14 +44,14 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
} }
/// A handshake result. /// A handshake result.
pub enum HandshakeError<Stream, Role> { pub enum HandshakeError<Role: HandshakeRole> {
/// Handshake was interrupted (would block). /// Handshake was interrupted (would block).
Interrupted(MidHandshake<Stream, Role>), Interrupted(MidHandshake<Role>),
/// Handshake failed. /// Handshake failed.
Failure(Error), Failure(Error),
} }
impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"), HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
@ -74,7 +60,7 @@ impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"), HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
@ -83,7 +69,7 @@ impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {
fn description(&self) -> &str { fn description(&self) -> &str {
match *self { match *self {
HandshakeError::Interrupted(_) => "Interrupted handshake", HandshakeError::Interrupted(_) => "Interrupted handshake",
@ -92,7 +78,7 @@ impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
fn from(err: Error) -> Self { fn from(err: Error) -> Self {
HandshakeError::Failure(err) HandshakeError::Failure(err)
} }
@ -103,15 +89,19 @@ pub trait HandshakeRole {
#[doc(hidden)] #[doc(hidden)]
type IncomingData: TryParse; type IncomingData: TryParse;
#[doc(hidden)] #[doc(hidden)]
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream: Read + Write;
-> Result<ProcessingResult<Stream>, Error>; #[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
} }
/// Stage processing result. /// Stage processing result.
#[doc(hidden)] #[doc(hidden)]
pub enum ProcessingResult<Stream> { pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>), Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>, Headers), Done(FinalResult),
} }
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
@ -127,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result<String, Error> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::convert_key; use super::convert_key;
#[test] #[test]

@ -1,5 +1,10 @@
//! Server handshake machine. //! Server handshake machine.
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem::replace;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
@ -19,15 +24,23 @@ pub struct Request {
impl Request { impl Request {
/// Reply to the response. /// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> { pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key") let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\ let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\ HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\ Sec-WebSocket-Accept: {}\r\n",
\r\n", convert_key(key)?); convert_key(key)?
);
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
Ok(reply.into()) Ok(reply.into())
} }
} }
@ -58,42 +71,68 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
} }
/// The callback type, the callback is called when the server receives an incoming WebSocket
/// handshake request from the client, specifying a callback allows you to analyze incoming headers
/// and add additional headers to the response that server sends to the client and/or reject the
/// connection based on the incoming headers. Due to usability problems which are caused by a
/// static dispatch when using callbacks in such places, the callback is boxed.
///
/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current
/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets
/// stable.
pub type Callback = Box<FnMut(&Request) -> Result<Option<Vec<(String, String)>>>>;
/// Server handshake role. /// Server handshake role.
#[allow(missing_copy_implementations)] #[allow(missing_copy_implementations)]
pub struct ServerHandshake { pub struct ServerHandshake<S> {
/// Incoming headers, received from the client. /// Callback which is called whenever the server read the request from the client and is ready
headers: Option<Headers>, /// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<Callback>,
/// Internal stream type.
_marker: PhantomData<S>,
} }
impl ServerHandshake { impl<S: Read + Write> ServerHandshake<S> {
/// Start server handshake. /// Start server handshake. `callback` specifies a custom callback which the user can pass to
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> { /// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: Option<Callback>) -> MidHandshake<Self> {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { headers: None }, role: ServerHandshake { callback, _marker: PhantomData },
} }
} }
} }
impl HandshakeRole for ServerHandshake { impl<S: Read + Write> HandshakeRole for ServerHandshake<S> {
type IncomingData = Request; type IncomingData = Request;
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream = S;
-> Result<ProcessingResult<Stream>> type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{ {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
if ! tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())) return Err(Error::Protocol("Junk after client request".into()))
} }
let response = result.reply()?; let extra_headers = {
self.headers = Some(result.headers); if let Some(mut callback) = replace(&mut self.callback, None) {
callback(&result)?
} else {
None
}
};
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
debug!("Server handshake done."); debug!("Server handshake done.");
let headers = self.headers.take().expect("Bug: accepted client without headers"); ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers)
} }
}) })
} }
@ -101,9 +140,9 @@ impl HandshakeRole for ServerHandshake {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Request; use super::Request;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::super::client::Response;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
@ -124,7 +163,15 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply().unwrap(); let _ = req.reply(None).unwrap();
}
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
} }

@ -3,7 +3,7 @@
pub use handshake::server::ServerHandshake; pub use handshake::server::ServerHandshake;
use handshake::HandshakeError; use handshake::HandshakeError;
use handshake::headers::Headers; use handshake::server::Callback;
use protocol::WebSocket; use protocol::WebSocket;
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -13,9 +13,10 @@ use std::io::{Read, Write};
/// This function starts a server WebSocket handshake over the given stream. /// This function starts a server WebSocket handshake over the given stream.
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others. You can also pass an optional `callback` which will
pub fn accept<Stream: Read + Write>(stream: Stream) /// be called when the websocket request is received from an incoming client.
-> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, ServerHandshake>> pub fn accept<S: Read + Write>(stream: S, callback: Option<Callback>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S>>>
{ {
ServerHandshake::start(stream).handshake() ServerHandshake::start(stream, callback).handshake()
} }

Loading…
Cancel
Save