Merge pull request #18 from snapview/headers

Add support for altering request/response headers during the websocket handshake
pull/19/head
Alexey Galakhov 7 years ago committed by GitHub
commit 13c382ae89
  1. 2
      Cargo.toml
  2. 8
      README.md
  3. 6
      examples/autobahn-client.rs
  4. 5
      examples/autobahn-server.rs
  5. 9
      examples/client.rs
  6. 38
      examples/server.rs
  7. 15
      src/client.rs
  8. 42
      src/handshake/client.rs
  9. 5
      src/handshake/headers.rs
  10. 48
      src/handshake/mod.rs
  11. 86
      src/handshake/server.rs
  12. 10
      src/server.rs

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.2.4" documentation = "https://docs.rs/tungstenite/0.2.4"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.3.0" version = "0.4.0"
[features] [features]
default = ["tls"] default = ["tls"]

@ -8,15 +8,21 @@ Lightweight stream-based WebSocket implementation for [Rust](http://www.rust-lan
let server = TcpListener::bind("127.0.0.1:9001").unwrap(); let server = TcpListener::bind("127.0.0.1:9001").unwrap();
for stream in server.incoming() { for stream in server.incoming() {
spawn (move || { spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap(); let mut websocket = accept(stream.unwrap(), None).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap(); websocket.write_message(msg).unwrap();
} }
}
}); });
} }
``` ```
Take a look at the examples section to see how to write a simple client/server.
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT) [![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE) [![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite) [![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)

@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message};
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let mut socket = connect( let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap() Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?; )?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
@ -19,7 +19,7 @@ fn get_case_count() -> Result<u32> {
} }
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let mut socket = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?; )?;
socket.close(None)?; socket.close(None)?;
@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse( let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap(); ).unwrap();
let mut socket = connect(case_url)?; let (mut socket, _) = connect(case_url)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn; use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message}; use tungstenite::{accept, HandshakeError, Error, Result, Message};
use tungstenite::handshake::HandshakeRole;
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"), HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
@ -15,7 +16,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept(stream, None).map_err(must_not_block)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -8,9 +8,16 @@ use tungstenite::{Message, connect};
fn main() { fn main() {
env_logger::init().unwrap(); env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect"); .expect("Can't connect");
println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
println!("Response contains the following headers:");
for &(ref header, _ /*value*/) in response.headers.iter() {
println!("* {}", header);
}
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop { loop {
let msg = socket.read_message().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");

@ -0,0 +1,38 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener;
use tungstenite::accept;
use tungstenite::handshake::server::Request;
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |req: &Request| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.path);
println!("The request's headers are:");
for &(ref header, _ /* value */) in req.headers.iter() {
println!("* {}", header);
}
// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")),
];
Ok(Some(extra_headers))
};
let mut websocket = accept(stream.unwrap(), Some(Box::new(callback))).unwrap();
loop {
let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap();
}
}
});
}
}

@ -6,6 +6,8 @@ use std::io::{Read, Write};
use url::Url; use url::Url;
use handshake::client::Response;
#[cfg(feature="tls")] #[cfg(feature="tls")]
mod encryption { mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
@ -75,7 +77,9 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) -> Result<WebSocket<AutoStream>> { pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Response)>
{
let request: Request = request.into(); let request: Request = request.into();
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?; let addrs = request.url.to_socket_addrs()?;
@ -120,9 +124,12 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream) pub fn client<'t, Stream, Req>(
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>> request: Req,
where Stream: Read + Write, stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,
{ {
ClientHandshake::start(stream, request.into()).handshake() ClientHandshake::start(stream, request.into()).handshake()

@ -1,18 +1,19 @@
//! Client handshake machine. //! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use base64; use base64;
use rand;
use httparse;
use httparse::Status; use httparse::Status;
use std::io::Write; use httparse;
use rand;
use url::Url; use url::Url;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
/// Client request. /// Client request.
pub struct Request<'t> { pub struct Request<'t> {
@ -52,13 +53,14 @@ impl From<Url> for Request<'static> {
} }
/// Client handshake role. /// Client handshake role.
pub struct ClientHandshake { pub struct ClientHandshake<S> {
verify_data: VerifyData, verify_data: VerifyData,
_marker: PhantomData<S>,
} }
impl ClientHandshake { impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> { pub fn start(stream: S, request: Request) -> MidHandshake<Self> {
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
@ -83,9 +85,8 @@ impl ClientHandshake {
let client = { let client = {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake {
verify_data: VerifyData { verify_data: VerifyData { accept_key },
accept_key: accept_key, _marker: PhantomData,
},
} }
}; };
@ -94,10 +95,12 @@ impl ClientHandshake {
} }
} }
impl HandshakeRole for ClientHandshake { impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response; type IncomingData = Response;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream = S;
-> Result<ProcessingResult<Stream>> type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{ {
Ok(match finish { Ok(match finish {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
@ -106,7 +109,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result))
} }
}) })
} }
@ -166,8 +170,10 @@ impl VerifyData {
/// Server response. /// Server response.
pub struct Response { pub struct Response {
code: u16, /// HTTP response code of the response.
headers: Headers, pub code: u16,
/// Received headers.
pub headers: Headers,
} }
impl TryParse for Response { impl TryParse for Response {
@ -203,7 +209,6 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Response, generate_key}; use super::{Response, generate_key};
use super::super::machine::TryParse; use super::super::machine::TryParse;
@ -230,5 +235,4 @@ mod tests {
assert_eq!(resp.code, 200); assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
} }
} }

@ -49,6 +49,11 @@ impl Headers {
.unwrap_or(false) .unwrap_or(false)
} }
/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
} }
/// The iterator over headers. /// The iterator over headers.

@ -14,30 +14,17 @@ use base64;
use sha1::Sha1; use sha1::Sha1;
use error::Error; use error::Error;
use protocol::WebSocket;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake. /// A WebSocket handshake.
pub struct MidHandshake<Stream, Role> { pub struct MidHandshake<Role: HandshakeRole> {
role: Role, role: Role,
machine: HandshakeMachine<Stream>, machine: HandshakeMachine<Role::InternalStream>,
}
impl<Stream, Role> MidHandshake<Stream, Role> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
self.machine.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
}
} }
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> { impl<Role: HandshakeRole> MidHandshake<Role> {
/// Restarts the handshake process. /// Restarts the handshake process.
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> { pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
let mut mach = self.machine; let mut mach = self.machine;
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
@ -48,7 +35,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => { RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? { match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m, ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws) => return Ok(ws), ProcessingResult::Done(result) => return Ok(result),
} }
} }
} }
@ -57,14 +44,14 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
} }
/// A handshake result. /// A handshake result.
pub enum HandshakeError<Stream, Role> { pub enum HandshakeError<Role: HandshakeRole> {
/// Handshake was interrupted (would block). /// Handshake was interrupted (would block).
Interrupted(MidHandshake<Stream, Role>), Interrupted(MidHandshake<Role>),
/// Handshake failed. /// Handshake failed.
Failure(Error), Failure(Error),
} }
impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"), HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
@ -73,7 +60,7 @@ impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"), HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
@ -82,7 +69,7 @@ impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {
fn description(&self) -> &str { fn description(&self) -> &str {
match *self { match *self {
HandshakeError::Interrupted(_) => "Interrupted handshake", HandshakeError::Interrupted(_) => "Interrupted handshake",
@ -91,7 +78,7 @@ impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
} }
} }
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> { impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
fn from(err: Error) -> Self { fn from(err: Error) -> Self {
HandshakeError::Failure(err) HandshakeError::Failure(err)
} }
@ -102,15 +89,19 @@ pub trait HandshakeRole {
#[doc(hidden)] #[doc(hidden)]
type IncomingData: TryParse; type IncomingData: TryParse;
#[doc(hidden)] #[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream: Read + Write;
-> Result<ProcessingResult<Stream>, Error>; #[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
} }
/// Stage processing result. /// Stage processing result.
#[doc(hidden)] #[doc(hidden)]
pub enum ProcessingResult<Stream> { pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>), Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>), Done(FinalResult),
} }
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
@ -126,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result<String, Error> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::convert_key; use super::convert_key;
#[test] #[test]

@ -1,5 +1,10 @@
//! Server handshake machine. //! Server handshake machine.
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem::replace;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
@ -19,15 +24,23 @@ pub struct Request {
impl Request { impl Request {
/// Reply to the response. /// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> { pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key") let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\ let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\ HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\ Sec-WebSocket-Accept: {}\r\n",
\r\n", convert_key(key)?); convert_key(key)?
);
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
Ok(reply.into()) Ok(reply.into())
} }
} }
@ -58,32 +71,63 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
} }
/// The callback type, the callback is called when the server receives an incoming WebSocket
/// handshake request from the client, specifying a callback allows you to analyze incoming headers
/// and add additional headers to the response that server sends to the client and/or reject the
/// connection based on the incoming headers. Due to usability problems which are caused by a
/// static dispatch when using callbacks in such places, the callback is boxed.
///
/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current
/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets
/// stable.
pub type Callback = Box<FnMut(&Request) -> Result<Option<Vec<(String, String)>>>>;
/// Server handshake role. /// Server handshake role.
#[allow(missing_copy_implementations)] #[allow(missing_copy_implementations)]
pub struct ServerHandshake; pub struct ServerHandshake<S> {
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<Callback>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl ServerHandshake { impl<S: Read + Write> ServerHandshake<S> {
/// Start server handshake. /// Start server handshake. `callback` specifies a custom callback which the user can pass to
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> { /// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: Option<Callback>) -> MidHandshake<Self> {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake, role: ServerHandshake { callback, _marker: PhantomData },
} }
} }
} }
impl HandshakeRole for ServerHandshake { impl<S: Read + Write> HandshakeRole for ServerHandshake<S> {
type IncomingData = Request; type IncomingData = Request;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) type InternalStream = S;
-> Result<ProcessingResult<Stream>> type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{ {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading { stream, result, tail } => {
if ! tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())) return Err(Error::Protocol("Junk after client request".into()))
} }
let response = result.reply()?; let extra_headers = {
if let Some(mut callback) = replace(&mut self.callback, None) {
callback(&result)?
} else {
None
}
};
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
@ -96,9 +140,9 @@ impl HandshakeRole for ServerHandshake {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Request; use super::Request;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::super::client::Response;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
@ -119,7 +163,15 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply().unwrap(); let _ = req.reply(None).unwrap();
}
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
} }

@ -3,6 +3,7 @@
pub use handshake::server::ServerHandshake; pub use handshake::server::ServerHandshake;
use handshake::HandshakeError; use handshake::HandshakeError;
use handshake::server::Callback;
use protocol::WebSocket; use protocol::WebSocket;
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -12,9 +13,10 @@ use std::io::{Read, Write};
/// This function starts a server WebSocket handshake over the given stream. /// This function starts a server WebSocket handshake over the given stream.
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others. You can also pass an optional `callback` which will
pub fn accept<Stream: Read + Write>(stream: Stream) /// be called when the websocket request is received from an incoming client.
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>> pub fn accept<S: Read + Write>(stream: S, callback: Option<Callback>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S>>>
{ {
ServerHandshake::start(stream).handshake() ServerHandshake::start(stream, callback).handshake()
} }

Loading…
Cancel
Save