diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 700e2d7..8a08e4b 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -10,7 +10,7 @@ use tungstenite::{connect, Error, Result, Message}; const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse("ws://localhost:9001/getCaseCount").unwrap() )?; let msg = socket.read_message()?; @@ -19,7 +19,7 @@ fn get_case_count() -> Result { } fn update_reports() -> Result<()> { - let mut socket = connect( + let (mut socket, _) = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() )?; socket.close(None)?; @@ -31,7 +31,7 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let mut socket = connect(case_url)?; + let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 458fa26..3492c7b 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -15,7 +15,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream).map_err(must_not_block)?; + let (mut socket, _) = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/client.rs b/examples/client.rs index 1c97aba..ca35805 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,7 +8,7 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init().unwrap(); - let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + let (mut socket, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) .expect("Can't connect"); socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); diff --git a/src/client.rs b/src/client.rs index 3274ad0..33823ce 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,6 +6,8 @@ use std::io::{Read, Write}; use url::Url; +use handshake::headers::Headers; + #[cfg(feature="tls")] mod encryption { use std::net::TcpStream; @@ -75,7 +77,9 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) -> Result> { +pub fn connect<'t, Req: Into>>(request: Req) + -> Result<(WebSocket, Headers)> +{ let request: Request = request.into(); let mode = url_mode(&request.url)?; let addrs = request.url.to_socket_addrs()?; @@ -121,7 +125,7 @@ pub fn url_mode(url: &Url) -> Result { /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. pub fn client<'t, Stream, Req>(request: Req, stream: Stream) - -> StdResult, HandshakeError> + -> StdResult<(WebSocket, Headers), HandshakeError> where Stream: Read + Write, Req: Into>, { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index ff07d37..8fa9b82 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -96,7 +96,7 @@ impl ClientHandshake { impl HandshakeRole for ClientHandshake { type IncomingData = Response; - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result> { Ok(match finish { @@ -106,7 +106,8 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client)) + ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client), + result.headers) } }) } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 64d31f1..28aa0f6 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -16,6 +16,7 @@ use sha1::Sha1; use error::Error; use protocol::WebSocket; +use self::headers::Headers; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; /// A WebSocket handshake. @@ -37,7 +38,7 @@ impl MidHandshake { impl MidHandshake { /// Restarts the handshake process. - pub fn handshake(self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result<(WebSocket, Headers), HandshakeError> { let mut mach = self.machine; loop { mach = match mach.single_round()? { @@ -48,7 +49,7 @@ impl MidHandshake { RoundResult::StageFinished(s) => { match self.role.stage_finished(s)? { ProcessingResult::Continue(m) => m, - ProcessingResult::Done(ws) => return Ok(ws), + ProcessingResult::Done(ws, headers) => return Ok((ws, headers)), } } } @@ -102,7 +103,7 @@ pub trait HandshakeRole { #[doc(hidden)] type IncomingData: TryParse; #[doc(hidden)] - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result, Error>; } @@ -110,7 +111,7 @@ pub trait HandshakeRole { #[doc(hidden)] pub enum ProcessingResult { Continue(HandshakeMachine), - Done(WebSocket), + Done(WebSocket, Headers), } /// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept. diff --git a/src/handshake/server.rs b/src/handshake/server.rs index a3a5fab..a09ce65 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -60,7 +60,10 @@ impl<'h, 'b: 'h> FromHttparse> for Request { /// Server handshake role. #[allow(missing_copy_implementations)] -pub struct ServerHandshake; +pub struct ServerHandshake { + /// Incoming headers, received from the client. + headers: Option, +} impl ServerHandshake { /// Start server handshake. @@ -68,14 +71,14 @@ impl ServerHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake, + role: ServerHandshake { headers: None }, } } } impl HandshakeRole for ServerHandshake { type IncomingData = Request; - fn stage_finished(&self, finish: StageResult) + fn stage_finished(&mut self, finish: StageResult) -> Result> { Ok(match finish { @@ -84,11 +87,13 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol("Junk after client request".into())) } let response = result.reply()?; + self.headers = Some(result.headers); ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } StageResult::DoneWriting(stream) => { debug!("Server handshake done."); - ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) + let headers = self.headers.take().expect("Bug: accepted client without headers"); + ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server), headers) } }) } diff --git a/src/server.rs b/src/server.rs index 3217006..04083f5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; +use handshake::headers::Headers; use protocol::WebSocket; use std::io::{Read, Write}; @@ -14,7 +15,7 @@ use std::io::{Read, Write}; /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. pub fn accept(stream: Stream) - -> Result, HandshakeError> + -> Result<(WebSocket, Headers), HandshakeError> { ServerHandshake::start(stream).handshake() }