diff --git a/src/client.rs b/src/client.rs index 5300ab2..7a6e170 100644 --- a/src/client.rs +++ b/src/client.rs @@ -119,6 +119,6 @@ pub fn url_mode(url: &Url) -> Result { pub fn client(url: Url, stream: Stream) -> StdResult, HandshakeError> { - let request = Request { url: url }; + let request = Request { url: url, extra_headers: None }; ClientHandshake::start(stream, request).handshake() } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 9beb5eb..157f309 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -13,12 +13,12 @@ use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::machine::{HandshakeMachine, StageResult, TryParse}; /// Client request. -pub struct Request { +pub struct Request<'t> { pub url: Url, - // TODO extra headers + pub extra_headers: Option<&'t [(&'t str, &'t str)]>, } -impl Request { +impl<'t> Request<'t> { /// The GET part of the request. fn get_path(&self) -> String { if let Some(query) = self.url.query() { @@ -56,9 +56,14 @@ impl ClientHandshake { Connection: upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n\ - \r\n", host = request.get_host(), path = request.get_path(), key = key) - .unwrap(); + Sec-WebSocket-Key: {key}\r\n", + host = request.get_host(), path = request.get_path(), key = key).unwrap(); + if let Some(eh) = request.extra_headers { + for &(k, v) in eh { + write!(req, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(req, "\r\n").unwrap(); HandshakeMachine::start_write(stream, req) };