Merge pull request #20 from snapview/request_minor

Minor improvements in `Request`
pull/24/merge
Alexey Galakhov 7 years ago committed by GitHub
commit 3a1e5dfb1f
  1. 2
      Cargo.toml
  2. 22
      src/handshake/client.rs

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.4.0"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.4.0"
version = "0.5.0"
[features]
default = ["tls"]

@ -1,5 +1,6 @@
//! Client handshake machine.
use std::borrow::Cow;
use std::io::{Read, Write};
use std::marker::PhantomData;
@ -20,11 +21,11 @@ pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to.
pub url: Url,
/// Extra HTTP headers to append to the request.
pub extra_headers: Option<&'t [(&'t str, &'t str)]>,
pub extra_headers: Option<Vec<(Cow<'t, str>, Cow<'t, str>)>>,
}
impl<'t> Request<'t> {
/// The GET part of the request.
/// Returns the GET part of the request.
fn get_path(&self) -> String {
if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query)
@ -32,7 +33,8 @@ impl<'t> Request<'t> {
self.url.path().into()
}
}
/// The Host: part of the request.
/// Returns the host part of the request.
fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() {
@ -41,6 +43,18 @@ impl<'t> Request<'t> {
host.into()
}
}
/// Adds a WebSocket protocol to the request.
pub fn add_protocol(&mut self, protocol: Cow<'t, str>) {
self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol);
}
/// Adds a custom header to the request.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) {
let mut headers = self.extra_headers.take().unwrap_or(vec![]);
headers.push((name, value));
self.extra_headers = Some(headers);
}
}
impl From<Url> for Request<'static> {
@ -74,7 +88,7 @@ impl<S: Read + Write> ClientHandshake<S> {
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), path = request.get_path(), key = key).unwrap();
if let Some(eh) = request.extra_headers {
for &(k, v) in eh {
for (k, v) in eh {
write!(req, "{}: {}\r\n", k, v).unwrap();
}
}

Loading…
Cancel
Save