Add basic support for examining headers (#6)

pull/18/head
Daniel Abramov 8 years ago
parent 147ca9e4d3
commit f34c488217
  1. 6
      examples/autobahn-client.rs
  2. 2
      examples/autobahn-server.rs
  3. 2
      examples/client.rs
  4. 8
      src/client.rs
  5. 5
      src/handshake/client.rs
  6. 9
      src/handshake/mod.rs
  7. 13
      src/handshake/server.rs
  8. 3
      src/server.rs

@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message};
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let mut socket = connect( let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap() Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?; )?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
@ -19,7 +19,7 @@ fn get_case_count() -> Result<u32> {
} }
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let mut socket = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?; )?;
socket.close(None)?; socket.close(None)?;
@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse( let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap(); ).unwrap();
let mut socket = connect(case_url)?; let (mut socket, _) = connect(case_url)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -15,7 +15,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let (mut socket, _) = accept(stream).map_err(must_not_block)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -8,7 +8,7 @@ use tungstenite::{Message, connect};
fn main() { fn main() {
env_logger::init().unwrap(); env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) let (mut socket, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect"); .expect("Can't connect");
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();

@ -6,6 +6,8 @@ use std::io::{Read, Write};
use url::Url; use url::Url;
use handshake::headers::Headers;
#[cfg(feature="tls")] #[cfg(feature="tls")]
mod encryption { mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
@ -75,7 +77,9 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) -> Result<WebSocket<AutoStream>> { pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Headers)>
{
let request: Request = request.into(); let request: Request = request.into();
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?; let addrs = request.url.to_socket_addrs()?;
@ -121,7 +125,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream) pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>> -> StdResult<(WebSocket<Stream>, Headers), HandshakeError<Stream, ClientHandshake>>
where Stream: Read + Write, where Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,
{ {

@ -96,7 +96,7 @@ impl ClientHandshake {
impl HandshakeRole for ClientHandshake { impl HandshakeRole for ClientHandshake {
type IncomingData = Response; type IncomingData = Response;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>> -> Result<ProcessingResult<Stream>>
{ {
Ok(match finish { Ok(match finish {
@ -106,7 +106,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client),
result.headers)
} }
}) })
} }

@ -16,6 +16,7 @@ use sha1::Sha1;
use error::Error; use error::Error;
use protocol::WebSocket; use protocol::WebSocket;
use self::headers::Headers;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake. /// A WebSocket handshake.
@ -37,7 +38,7 @@ impl<Stream, Role> MidHandshake<Stream, Role> {
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> { impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
/// Restarts the handshake process. /// Restarts the handshake process.
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> { pub fn handshake(mut self) -> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, Role>> {
let mut mach = self.machine; let mut mach = self.machine;
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
@ -48,7 +49,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => { RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? { match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m, ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws) => return Ok(ws), ProcessingResult::Done(ws, headers) => return Ok((ws, headers)),
} }
} }
} }
@ -102,7 +103,7 @@ pub trait HandshakeRole {
#[doc(hidden)] #[doc(hidden)]
type IncomingData: TryParse; type IncomingData: TryParse;
#[doc(hidden)] #[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>; -> Result<ProcessingResult<Stream>, Error>;
} }
@ -110,7 +111,7 @@ pub trait HandshakeRole {
#[doc(hidden)] #[doc(hidden)]
pub enum ProcessingResult<Stream> { pub enum ProcessingResult<Stream> {
Continue(HandshakeMachine<Stream>), Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>), Done(WebSocket<Stream>, Headers),
} }
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.

@ -60,7 +60,10 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
/// Server handshake role. /// Server handshake role.
#[allow(missing_copy_implementations)] #[allow(missing_copy_implementations)]
pub struct ServerHandshake; pub struct ServerHandshake {
/// Incoming headers, received from the client.
headers: Option<Headers>,
}
impl ServerHandshake { impl ServerHandshake {
/// Start server handshake. /// Start server handshake.
@ -68,14 +71,14 @@ impl ServerHandshake {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake, role: ServerHandshake { headers: None },
} }
} }
} }
impl HandshakeRole for ServerHandshake { impl HandshakeRole for ServerHandshake {
type IncomingData = Request; type IncomingData = Request;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>) fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>> -> Result<ProcessingResult<Stream>>
{ {
Ok(match finish { Ok(match finish {
@ -84,11 +87,13 @@ impl HandshakeRole for ServerHandshake {
return Err(Error::Protocol("Junk after client request".into())) return Err(Error::Protocol("Junk after client request".into()))
} }
let response = result.reply()?; let response = result.reply()?;
self.headers = Some(result.headers);
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
debug!("Server handshake done."); debug!("Server handshake done.");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) let headers = self.headers.take().expect("Bug: accepted client without headers");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers)
} }
}) })
} }

@ -3,6 +3,7 @@
pub use handshake::server::ServerHandshake; pub use handshake::server::ServerHandshake;
use handshake::HandshakeError; use handshake::HandshakeError;
use handshake::headers::Headers;
use protocol::WebSocket; use protocol::WebSocket;
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -14,7 +15,7 @@ use std::io::{Read, Write};
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others.
pub fn accept<Stream: Read + Write>(stream: Stream) pub fn accept<Stream: Read + Write>(stream: Stream)
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>> -> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, ServerHandshake>>
{ {
ServerHandshake::start(stream).handshake() ServerHandshake::start(stream).handshake()
} }

Loading…
Cancel
Save