Merge pull request #18 from snapview/headers

Add support for altering request/response headers during the websocket handshake
pull/19/head
Alexey Galakhov 8 years ago committed by GitHub
commit 13c382ae89
  1. 2
      Cargo.toml
  2. 8
      README.md
  3. 6
      examples/autobahn-client.rs
  4. 5
      examples/autobahn-server.rs
  5. 9
      examples/client.rs
  6. 38
      examples/server.rs
  7. 15
      src/client.rs
  8. 42
      src/handshake/client.rs
  9. 5
      src/handshake/headers.rs
  10. 48
      src/handshake/mod.rs
  11. 86
      src/handshake/server.rs
  12. 10
      src/server.rs

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.2.4"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.3.0"
version = "0.4.0"
[features]
default = ["tls"]

@ -8,15 +8,21 @@ Lightweight stream-based WebSocket implementation for [Rust](http://www.rust-lan
let server = TcpListener::bind("127.0.0.1:9001").unwrap();
for stream in server.incoming() {
spawn (move || {
let mut websocket = accept(stream.unwrap()).unwrap();
let mut websocket = accept(stream.unwrap(), None).unwrap();
loop {
let msg = websocket.read_message().unwrap();
// We do not want to send back ping/pong messages.
if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap();
}
}
});
}
```
Take a look at the examples section to see how to write a simple client/server.
[![MIT licensed](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE-MIT)
[![Apache-2.0 licensed](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE-APACHE)
[![Crates.io](https://img.shields.io/crates/v/tungstenite.svg?maxAge=2592000)](https://crates.io/crates/tungstenite)

@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message};
const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let mut socket = connect(
let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?;
let msg = socket.read_message()?;
@ -19,7 +19,7 @@ fn get_case_count() -> Result<u32> {
}
fn update_reports() -> Result<()> {
let mut socket = connect(
let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?;
socket.close(None)?;
@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let mut socket = connect(case_url)?;
let (mut socket, _) = connect(case_url)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -6,8 +6,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message};
use tungstenite::handshake::HandshakeRole;
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f,
@ -15,7 +16,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut socket = accept(stream, None).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -8,9 +8,16 @@ use tungstenite::{Message, connect};
fn main() {
env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect");
println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
println!("Response contains the following headers:");
for &(ref header, _ /*value*/) in response.headers.iter() {
println!("* {}", header);
}
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop {
let msg = socket.read_message().expect("Error reading message");

@ -0,0 +1,38 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener;
use tungstenite::accept;
use tungstenite::handshake::server::Request;
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |req: &Request| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.path);
println!("The request's headers are:");
for &(ref header, _ /* value */) in req.headers.iter() {
println!("* {}", header);
}
// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")),
];
Ok(Some(extra_headers))
};
let mut websocket = accept(stream.unwrap(), Some(Box::new(callback))).unwrap();
loop {
let msg = websocket.read_message().unwrap();
if msg.is_binary() || msg.is_text() {
websocket.write_message(msg).unwrap();
}
}
});
}
}

@ -6,6 +6,8 @@ use std::io::{Read, Write};
use url::Url;
use handshake::client::Response;
#[cfg(feature="tls")]
mod encryption {
use std::net::TcpStream;
@ -75,7 +77,9 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) -> Result<WebSocket<AutoStream>> {
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Response)>
{
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?;
@ -120,9 +124,12 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>>
where Stream: Read + Write,
pub fn client<'t, Stream, Req>(
request: Req,
stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
{
ClientHandshake::start(stream, request.into()).handshake()

@ -1,18 +1,19 @@
//! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use base64;
use rand;
use httparse;
use httparse::Status;
use std::io::Write;
use httparse;
use rand;
use url::Url;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
/// Client request.
pub struct Request<'t> {
@ -52,13 +53,14 @@ impl From<Url> for Request<'static> {
}
/// Client handshake role.
pub struct ClientHandshake {
pub struct ClientHandshake<S> {
verify_data: VerifyData,
_marker: PhantomData<S>,
}
impl ClientHandshake {
impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake.
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> {
pub fn start(stream: S, request: Request) -> MidHandshake<Self> {
let key = generate_key();
let machine = {
@ -83,9 +85,8 @@ impl ClientHandshake {
let client = {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData {
accept_key: accept_key,
},
verify_data: VerifyData { accept_key },
_marker: PhantomData,
}
};
@ -94,10 +95,12 @@ impl ClientHandshake {
}
}
impl HandshakeRole for ClientHandshake {
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
Ok(match finish {
StageResult::DoneWriting(stream) => {
@ -106,7 +109,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client))
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result))
}
})
}
@ -166,8 +170,10 @@ impl VerifyData {
/// Server response.
pub struct Response {
code: u16,
headers: Headers,
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: Headers,
}
impl TryParse for Response {
@ -203,7 +209,6 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{Response, generate_key};
use super::super::machine::TryParse;
@ -230,5 +235,4 @@ mod tests {
assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
}
}

@ -49,6 +49,11 @@ impl Headers {
.unwrap_or(false)
}
/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers.

@ -14,30 +14,17 @@ use base64;
use sha1::Sha1;
use error::Error;
use protocol::WebSocket;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake.
pub struct MidHandshake<Stream, Role> {
pub struct MidHandshake<Role: HandshakeRole> {
role: Role,
machine: HandshakeMachine<Stream>,
}
impl<Stream, Role> MidHandshake<Stream, Role> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
self.machine.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
}
machine: HandshakeMachine<Role::InternalStream>,
}
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
impl<Role: HandshakeRole> MidHandshake<Role> {
/// Restarts the handshake process.
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> {
pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
let mut mach = self.machine;
loop {
mach = match mach.single_round()? {
@ -48,7 +35,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws) => return Ok(ws),
ProcessingResult::Done(result) => return Ok(result),
}
}
}
@ -57,14 +44,14 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
}
/// A handshake result.
pub enum HandshakeError<Stream, Role> {
pub enum HandshakeError<Role: HandshakeRole> {
/// Handshake was interrupted (would block).
Interrupted(MidHandshake<Stream, Role>),
Interrupted(MidHandshake<Role>),
/// Handshake failed.
Failure(Error),
}
impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
@ -73,7 +60,7 @@ impl<Stream, Role> fmt::Debug for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
@ -82,7 +69,7 @@ impl<Stream, Role> fmt::Display for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {
fn description(&self) -> &str {
match *self {
HandshakeError::Interrupted(_) => "Interrupted handshake",
@ -91,7 +78,7 @@ impl<Stream, Role> ErrorTrait for HandshakeError<Stream, Role> {
}
}
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> {
impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
fn from(err: Error) -> Self {
HandshakeError::Failure(err)
}
@ -102,15 +89,19 @@ pub trait HandshakeRole {
#[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>;
type InternalStream: Read + Write;
#[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
}
/// Stage processing result.
#[doc(hidden)]
pub enum ProcessingResult<Stream> {
pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>),
Done(FinalResult),
}
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
@ -126,7 +117,6 @@ fn convert_key(input: &[u8]) -> Result<String, Error> {
#[cfg(test)]
mod tests {
use super::convert_key;
#[test]

@ -1,5 +1,10 @@
//! Server handshake machine.
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem::replace;
use httparse;
use httparse::Status;
@ -19,15 +24,23 @@ pub struct Request {
impl Request {
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\
let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n", convert_key(key)?);
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
);
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
Ok(reply.into())
}
}
@ -58,32 +71,63 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
}
/// The callback type, the callback is called when the server receives an incoming WebSocket
/// handshake request from the client, specifying a callback allows you to analyze incoming headers
/// and add additional headers to the response that server sends to the client and/or reject the
/// connection based on the incoming headers. Due to usability problems which are caused by a
/// static dispatch when using callbacks in such places, the callback is boxed.
///
/// The type uses `FnMut` instead of `FnOnce` as it is impossible to box `FnOnce` in the current
/// Rust version, `FnBox` is still unstable, this code has to be updated for `FnBox` when it gets
/// stable.
pub type Callback = Box<FnMut(&Request) -> Result<Option<Vec<(String, String)>>>>;
/// Server handshake role.
#[allow(missing_copy_implementations)]
pub struct ServerHandshake;
pub struct ServerHandshake<S> {
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<Callback>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl ServerHandshake {
/// Start server handshake.
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> {
impl<S: Read + Write> ServerHandshake<S> {
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: Option<Callback>) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake,
role: ServerHandshake { callback, _marker: PhantomData },
}
}
}
impl HandshakeRole for ServerHandshake {
impl<S: Read + Write> HandshakeRole for ServerHandshake<S> {
type IncomingData = Request;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
type InternalStream = S;
type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
if ! tail.is_empty() {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
}
let response = result.reply()?;
let extra_headers = {
if let Some(mut callback) = replace(&mut self.callback, None) {
callback(&result)?
} else {
None
}
};
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
StageResult::DoneWriting(stream) => {
@ -96,9 +140,9 @@ impl HandshakeRole for ServerHandshake {
#[cfg(test)]
mod tests {
use super::Request;
use super::super::machine::TryParse;
use super::super::client::Response;
#[test]
fn request_parsing() {
@ -119,7 +163,15 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply().unwrap();
}
let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
}

@ -3,6 +3,7 @@
pub use handshake::server::ServerHandshake;
use handshake::HandshakeError;
use handshake::server::Callback;
use protocol::WebSocket;
use std::io::{Read, Write};
@ -12,9 +13,10 @@ use std::io::{Read, Write};
/// This function starts a server WebSocket handshake over the given stream.
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<Stream: Read + Write>(stream: Stream)
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>>
/// those from `Mio` and others. You can also pass an optional `callback` which will
/// be called when the websocket request is received from an incoming client.
pub fn accept<S: Read + Write>(stream: S, callback: Option<Callback>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S>>>
{
ServerHandshake::start(stream).handshake()
ServerHandshake::start(stream, callback).handshake()
}

Loading…
Cancel
Save