Improve the `handshake::client::Request` structure

pull/20/head
Daniel Abramov 7 years ago
parent daa7fc1d45
commit c4013ccad3
  1. 22
      src/handshake/client.rs

@ -1,5 +1,6 @@
//! Client handshake machine. //! Client handshake machine.
use std::borrow::Cow;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
@ -20,11 +21,11 @@ pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to. /// `ws://` or `wss://` URL to connect to.
pub url: Url, pub url: Url,
/// Extra HTTP headers to append to the request. /// Extra HTTP headers to append to the request.
pub extra_headers: Option<&'t [(&'t str, &'t str)]>, pub extra_headers: Option<Vec<(Cow<'t, str>, Cow<'t, str>)>>,
} }
impl<'t> Request<'t> { impl<'t> Request<'t> {
/// The GET part of the request. /// Returns the GET part of the request.
fn get_path(&self) -> String { fn get_path(&self) -> String {
if let Some(query) = self.url.query() { if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query) format!("{path}?{query}", path = self.url.path(), query = query)
@ -32,7 +33,8 @@ impl<'t> Request<'t> {
self.url.path().into() self.url.path().into()
} }
} }
/// The Host: part of the request.
/// Returns the host part of the request.
fn get_host(&self) -> String { fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host"); let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() { if let Some(port) = self.url.port() {
@ -41,6 +43,18 @@ impl<'t> Request<'t> {
host.into() host.into()
} }
} }
/// Adds a WebSocket protocol to the request.
pub fn add_protocol(&mut self, protocol: Cow<'t, str>) {
self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol);
}
/// Adds a custom header to the request.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) {
let mut headers = self.extra_headers.take().unwrap_or(vec![]);
headers.push((name, value));
self.extra_headers = Some(headers);
}
} }
impl From<Url> for Request<'static> { impl From<Url> for Request<'static> {
@ -74,7 +88,7 @@ impl<S: Read + Write> ClientHandshake<S> {
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), path = request.get_path(), key = key).unwrap(); host = request.get_host(), path = request.get_path(), key = key).unwrap();
if let Some(eh) = request.extra_headers { if let Some(eh) = request.extra_headers {
for &(k, v) in eh { for (k, v) in eh {
write!(req, "{}: {}\r\n", k, v).unwrap(); write!(req, "{}: {}\r\n", k, v).unwrap();
} }
} }

Loading…
Cancel
Save