Implements #6 (client/server headers access)

pull/18/head
Daniel Abramov 7 years ago
parent f34c488217
commit 44a15c9eab
  1. 5
      examples/autobahn-server.rs
  2. 15
      src/client.rs
  3. 43
      src/handshake/client.rs
  4. 49
      src/handshake/mod.rs
  5. 97
      src/handshake/server.rs
  6. 11
      src/server.rs

@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message};
use tungstenite::handshake::HandshakeRole;
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f,
@ -15,7 +16,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let (mut socket, _) = accept(stream).map_err(must_not_block)?;
let mut socket = accept(stream, None).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -6,7 +6,7 @@ use std::io::{Read, Write};
use url::Url;
use handshake::headers::Headers;
use handshake::client::Response;
#[cfg(feature="tls")]
mod encryption {
@ -78,7 +78,7 @@ use error::{Error, Result};
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Headers)>
-> Result<(WebSocket<AutoStream>, Response)>
{
let request: Request = request.into();
let mode = url_mode(&request.url)?;
@ -124,10 +124,13 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<(WebSocket<Stream>, Headers), HandshakeError<Stream, ClientHandshake>>
where Stream: Read + Write,
Req: Into<Request<'t>>,
pub fn client<'t, Stream, Req>(
request: Req,
stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
{
ClientHandshake::start(stream, request.into()).handshake()
}

@ -1,18 +1,19 @@
//! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use base64;
use rand;
use httparse;
use httparse::Status;
use std::io::Write;
use httparse;
use rand;
use url::Url;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
/// Client request.
pub struct Request<'t> {
@ -52,13 +53,14 @@ impl From<Url> for Request<'static> {
}
/// Client handshake role.
pub struct ClientHandshake {
pub struct ClientHandshake<S> {
verify_data: VerifyData,
_marker: PhantomData<S>,
}
impl ClientHandshake {
impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake.
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> {
pub fn start(stream: S, request: Request) -> MidHandshake<Self> {
let key = generate_key();
let machine = {
@ -83,9 +85,8 @@ impl ClientHandshake {
let client = {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData {
accept_key: accept_key,
},
verify_data: VerifyData { accept_key },
_marker: PhantomData,
}
};
@ -94,10 +95,12 @@ impl ClientHandshake {
}
}
impl HandshakeRole for ClientHandshake {
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
Ok(match finish {
StageResult::DoneWriting(stream) => {
@ -106,8 +109,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client),
result.headers)
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result))
}
})
}
@ -167,8 +170,10 @@ impl VerifyData {
/// Server response.
pub struct Response {
code: u16,
headers: Headers,
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: Headers,
}
impl TryParse for Response {
@ -204,7 +209,6 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{Response, generate_key};
use super::super::machine::TryParse;
@ -231,5 +235,4 @@ mod tests {
assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
}
}

@ -14,31 +14,17 @@ use base64;
use sha1::Sha1;
use error::Error;
use protocol::WebSocket;
use self::headers::Headers;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake.
pub struct MidHandshake<Stream, Role> {
pub struct MidHandshake<Role: HandshakeRole> {
role: Role,
machine: HandshakeMachine<Stream>,
}
impl<Stream, Role> MidHandshake<Stream, Role> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
self.machine.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
}
machine: HandshakeMachine<Role::InternalStream>,
}
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
impl<Role: HandshakeRole> MidHandshake<Role> {
/// Restarts the handshake process.
pub fn handshake(mut self) -> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, Role>> {
pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
let mut mach = self.machine;
loop {
mach = match mach.single_round()? {
@ -49,7 +35,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws, headers) => return Ok((ws, headers)),
ProcessingResult::Done(result) => return Ok(result),
}
}
}
@ -58,14 +44,14 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
}
/// A handshake result.
pub enum HandshakeError<Stream, Role> {
pub enum HandshakeError<Role: HandshakeRole> {
/// Handshake was interrupted (would block).
Interrupted(MidHandshake<Stream, Role>),
Interrupted(MidHandshake<Role>),
/// Handshake failed.
Failure(Error),
}
impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
@ -74,7 +60,7 @@ impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
@ -83,7 +69,7 @@ impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {
fn description(&self) -> &str {
match *self {
HandshakeError::Interrupted(_) => "Interrupted handshake",
@ -92,7 +78,7 @@ impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
fn from(err: Error) -> Self {
HandshakeError::Failure(err)
}
@ -103,15 +89,19 @@ pub trait HandshakeRole {
#[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>;
type InternalStream: Read + Write;
#[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
}
/// Stage processing result.
#[doc(hidden)]
pub enum ProcessingResult<Stream> {
pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>, Headers),
Done(FinalResult),
}
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
@ -127,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result<String, Error> {
#[cfg(test)]
mod tests {
use super::convert_key;
#[test]

@ -1,5 +1,10 @@
//! Server handshake machine.
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem::replace;
use httparse;
use httparse::Status;
@ -19,15 +24,23 @@ pub struct Request {
impl Request {
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n", convert_key(key)?);
let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
);
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
Ok(reply.into())
}
}
@ -58,42 +71,68 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
}
/// The callback type, the callback is called when the server receives an incoming WebSocket
/// handshake request from the client, specifying a callback allows you to analyze incoming headers
/// and add additional headers to the response that server sends to the client and/or reject the
/// connection based on the incoming headers. Due to usability problems which are caused by a
/// static dispatch when using callbacks in such places, the callback is boxed.
///
/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current
/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets
/// stable.
pub type Callback = Box<FnMut(&Request) -> Result<Option<Vec<(String, String)>>>>;
/// Server handshake role.
#[allow(missing_copy_implementations)]
pub struct ServerHandshake {
/// Incoming headers, received from the client.
headers: Option<Headers>,
pub struct ServerHandshake<S> {
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<Callback>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl ServerHandshake {
/// Start server handshake.
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> {
impl<S: Read + Write> ServerHandshake<S> {
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: Option<Callback>) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { headers: None },
role: ServerHandshake { callback, _marker: PhantomData },
}
}
}
impl HandshakeRole for ServerHandshake {
impl<S: Read + Write> HandshakeRole for ServerHandshake<S> {
type IncomingData = Request;
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
type InternalStream = S;
type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
if ! tail.is_empty() {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
}
let response = result.reply()?;
self.headers = Some(result.headers);
let extra_headers = {
if let Some(mut callback) = replace(&mut self.callback, None) {
callback(&result)?
} else {
None
}
};
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
StageResult::DoneWriting(stream) => {
debug!("Server handshake done.");
let headers = self.headers.take().expect("Bug: accepted client without headers");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers)
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
}
})
}
@ -101,9 +140,9 @@ impl HandshakeRole for ServerHandshake {
#[cfg(test)]
mod tests {
use super::Request;
use super::super::machine::TryParse;
use super::super::client::Response;
#[test]
fn request_parsing() {
@ -124,7 +163,15 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply().unwrap();
}
let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
}

@ -3,7 +3,7 @@
pub use handshake::server::ServerHandshake;
use handshake::HandshakeError;
use handshake::headers::Headers;
use handshake::server::Callback;
use protocol::WebSocket;
use std::io::{Read, Write};
@ -13,9 +13,10 @@ use std::io::{Read, Write};
/// This function starts a server WebSocket handshake over the given stream.
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<Stream: Read + Write>(stream: Stream)
-> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, ServerHandshake>>
/// those from `Mio` and others. You can also pass an optional `callback` which will
/// be called when the websocket request is received from an incoming client.
pub fn accept<S: Read + Write>(stream: S, callback: Option<Callback>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S>>>
{
ServerHandshake::start(stream).handshake()
ServerHandshake::start(stream, callback).handshake()
}

Loading…
Cancel
Save