Add basic support for examining headers (#6)

pull/18/head
Daniel Abramov 7 years ago
parent 147ca9e4d3
commit f34c488217
  1. 6
      examples/autobahn-client.rs
  2. 2
      examples/autobahn-server.rs
  3. 2
      examples/client.rs
  4. 8
      src/client.rs
  5. 5
      src/handshake/client.rs
  6. 9
      src/handshake/mod.rs
  7. 13
      src/handshake/server.rs
  8. 3
      src/server.rs

@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message};
const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let mut socket = connect(
let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?;
let msg = socket.read_message()?;
@ -19,7 +19,7 @@ fn get_case_count() -> Result<u32> {
}
fn update_reports() -> Result<()> {
let mut socket = connect(
let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?;
socket.close(None)?;
@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let mut socket = connect(case_url)?;
let (mut socket, _) = connect(case_url)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -15,7 +15,7 @@ fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let (mut socket, _) = accept(stream).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -8,7 +8,7 @@ use tungstenite::{Message, connect};
fn main() {
env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
let (mut socket, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect");
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();

@ -6,6 +6,8 @@ use std::io::{Read, Write};
use url::Url;
use handshake::headers::Headers;
#[cfg(feature="tls")]
mod encryption {
use std::net::TcpStream;
@ -75,7 +77,9 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) -> Result<WebSocket<AutoStream>> {
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Headers)>
{
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?;
@ -121,7 +125,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>>
-> StdResult<(WebSocket<Stream>, Headers), HandshakeError<Stream, ClientHandshake>>
where Stream: Read + Write,
Req: Into<Request<'t>>,
{

@ -96,7 +96,7 @@ impl ClientHandshake {
impl HandshakeRole for ClientHandshake {
type IncomingData = Response;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
{
Ok(match finish {
@ -106,7 +106,8 @@ impl HandshakeRole for ClientHandshake {
StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client))
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client),
result.headers)
}
})
}

@ -16,6 +16,7 @@ use sha1::Sha1;
use error::Error;
use protocol::WebSocket;
use self::headers::Headers;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A WebSocket handshake.
@ -37,7 +38,7 @@ impl<Stream, Role> MidHandshake<Stream, Role> {
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
/// Restarts the handshake process.
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> {
pub fn handshake(mut self) -> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, Role>> {
let mut mach = self.machine;
loop {
mach = match mach.single_round()? {
@ -48,7 +49,7 @@ impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws) => return Ok(ws),
ProcessingResult::Done(ws, headers) => return Ok((ws, headers)),
}
}
}
@ -102,7 +103,7 @@ pub trait HandshakeRole {
#[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>;
}
@ -110,7 +111,7 @@ pub trait HandshakeRole {
#[doc(hidden)]
pub enum ProcessingResult<Stream> {
Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>),
Done(WebSocket<Stream>, Headers),
}
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.

@ -60,7 +60,10 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
/// Server handshake role.
#[allow(missing_copy_implementations)]
pub struct ServerHandshake;
pub struct ServerHandshake {
/// Incoming headers, received from the client.
headers: Option<Headers>,
}
impl ServerHandshake {
/// Start server handshake.
@ -68,14 +71,14 @@ impl ServerHandshake {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake,
role: ServerHandshake { headers: None },
}
}
}
impl HandshakeRole for ServerHandshake {
type IncomingData = Request;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
fn stage_finished<Stream>(&mut self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
{
Ok(match finish {
@ -84,11 +87,13 @@ impl HandshakeRole for ServerHandshake {
return Err(Error::Protocol("Junk after client request".into()))
}
let response = result.reply()?;
self.headers = Some(result.headers);
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
StageResult::DoneWriting(stream) => {
debug!("Server handshake done.");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
let headers = self.headers.take().expect("Bug: accepted client without headers");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers)
}
})
}

@ -3,6 +3,7 @@
pub use handshake::server::ServerHandshake;
use handshake::HandshakeError;
use handshake::headers::Headers;
use protocol::WebSocket;
use std::io::{Read, Write};
@ -14,7 +15,7 @@ use std::io::{Read, Write};
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<Stream: Read + Write>(stream: Stream)
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>>
-> Result<(WebSocket<Stream>, Headers), HandshakeError<Stream, ServerHandshake>>
{
ServerHandshake::start(stream).handshake()
}

Loading…
Cancel
Save