Remove custom Headers type and use http::HeaderMap instead

Fixes https://github.com/snapview/tungstenite-rs/issues/92
pull/93/head
Sebastian Dröge 5 years ago
parent bb801430c8
commit 38a7d1a375
  1. 46
      src/error.rs
  2. 26
      src/handshake/client.rs
  3. 118
      src/handshake/headers.rs
  4. 72
      src/handshake/server.rs

@ -45,7 +45,7 @@ pub enum Error {
/// connection when it really shouldn't anymore, so this really indicates a programmer /// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part. /// error on your part.
AlreadyClosed, AlreadyClosed,
/// Input-output error. Appart from WouldBlock, these are generally errors with the /// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal. /// underlying connection and you should probably consider them fatal.
Io(io::Error), Io(io::Error),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
@ -61,10 +61,12 @@ pub enum Error {
SendQueueFull(Message), SendQueueFull(Message),
/// UTF coding error /// UTF coding error
Utf8, Utf8,
/// Invlid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error. /// HTTP error.
Http(u16), Http(u16),
/// HTTP format error.
HttpFormat(http::Error),
} }
impl fmt::Display for Error { impl fmt::Display for Error {
@ -80,7 +82,8 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code), Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
} }
} }
} }
@ -99,6 +102,7 @@ impl ErrorTrait for Error {
Error::Utf8 => "", Error::Utf8 => "",
Error::Url(ref msg) => msg.borrow(), Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "", Error::Http(_) => "",
Error::HttpFormat(ref err) => err.description(),
} }
} }
} }
@ -121,6 +125,42 @@ impl From<string::FromUtf8Error> for Error {
} }
} }
impl From<http::header::InvalidHeaderValue> for Error {
fn from(err: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::InvalidHeaderName> for Error {
fn from(err: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::ToStrError> for Error {
fn from(_: http::header::ToStrError) -> Self {
Error::Utf8
}
}
impl From<http::uri::InvalidUri> for Error {
fn from(err: http::uri::InvalidUri) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::status::InvalidStatusCode> for Error {
fn from(err: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
impl From<tls::Error> for Error { impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self { fn from(err: tls::Error) -> Self {

@ -4,11 +4,12 @@ use std::borrow::Cow;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use http::HeaderMap;
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use url::Url; use url::Url;
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -171,7 +172,10 @@ impl VerifyData {
// _Fail the WebSocket Connection_. (RFC 6455) // _Fail the WebSocket Connection_. (RFC 6455)
if !response if !response
.headers .headers
.header_is_ignore_case("Upgrade", "websocket") .get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(), "No \"Upgrade: websocket\" in server reply".into(),
@ -183,7 +187,10 @@ impl VerifyData {
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response if !response
.headers .headers
.header_is_ignore_case("Connection", "Upgrade") .get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(), "No \"Connection: upgrade\" in server reply".into(),
@ -195,7 +202,9 @@ impl VerifyData {
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !response if !response
.headers .headers
.header_is("Sec-WebSocket-Accept", &self.accept_key) .get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(), "Key mismatch in Sec-WebSocket-Accept".into(),
@ -225,7 +234,7 @@ pub struct Response {
/// HTTP response code of the response. /// HTTP response code of the response.
pub code: u16, pub code: u16,
/// Received headers. /// Received headers.
pub headers: Headers, pub headers: HeaderMap,
} }
impl TryParse for Response { impl TryParse for Response {
@ -248,7 +257,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
} }
Ok(Response { Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"), code: raw.code.expect("Bug: no HTTP response code"),
headers: Headers::from_httparse(raw.headers)?, headers: HeaderMap::from_httparse(raw.headers)?,
}) })
} }
} }
@ -287,9 +296,6 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200); assert_eq!(resp.code, 200);
assert_eq!( assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],);
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
} }
} }

@ -1,8 +1,7 @@
//! HTTP Request and response header handling. //! HTTP Request and response header handling.
use std::slice; use http;
use std::str::from_utf8; use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse; use httparse;
use httparse::Status; use httparse::Status;
@ -12,90 +11,31 @@ use crate::error::Result;
/// Limit for the number of header lines. /// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124; pub const MAX_HEADERS: usize = 124;
/// HTTP request or response headers. /// Trait to convert raw objects into HTTP parseables.
#[derive(Debug)] pub(crate) trait FromHttparse<T>: Sized {
pub struct Headers { /// Convert raw object into parsed HTTP headers.
data: Vec<(String, Box<[u8]>)>, fn from_httparse(raw: T) -> Result<Self>;
} }
impl Headers { impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap {
/// Get first header with the given name, if any. fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
pub fn find_first(&self, name: &str) -> Option<&[u8]> { let mut headers = HeaderMap::new();
self.find(name).next() for h in raw {
} headers.append(
HeaderName::from_bytes(h.name.as_bytes())?,
/// Iterate over all headers with the given name. HeaderValue::from_bytes(h.value)?,
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { );
HeadersIter {
name,
iter: self.data.iter(),
} }
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers. Ok(headers)
#[derive(Debug)]
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value);
}
}
None
} }
} }
impl TryParse for HeaderMap {
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}
impl TryParse for Headers {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None, Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
}) })
} }
} }
@ -104,7 +44,7 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
mod tests { mod tests {
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::Headers; use super::HeaderMap;
#[test] #[test]
fn headers() { fn headers() {
@ -112,14 +52,10 @@ mod tests {
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
\r\n"; \r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]);
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]);
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]);
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
} }
#[test] #[test]
@ -130,10 +66,10 @@ mod tests {
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
\r\n"; \r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions"); let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter();
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]);
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]);
assert_eq!(iter.next(), None); assert_eq!(iter.next(), None);
} }
@ -142,7 +78,7 @@ mod tests {
const DATA: &'static [u8] = b"Host: foo.com\r\n\ const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n"; Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap(); let hdr = HeaderMap::try_parse(DATA).unwrap();
assert!(hdr.is_none()); assert!(hdr.is_none());
} }
} }

@ -5,11 +5,11 @@ use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::StatusCode; use http::{HeaderMap, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -21,15 +21,15 @@ pub struct Request {
/// Path part of the URL. /// Path part of the URL.
pub path: String, pub path: String,
/// HTTP headers. /// HTTP headers.
pub headers: Headers, pub headers: HeaderMap,
} }
impl Request { impl Request {
/// Reply to the response. /// Reply to the response.
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> { pub fn reply(&self, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> {
let key = self let key = self
.headers .headers
.find_first("Sec-WebSocket-Key") .get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!( let mut reply = format!(
"\ "\
@ -37,20 +37,24 @@ impl Request {
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n", Sec-WebSocket-Accept: {}\r\n",
convert_key(key)? convert_key(key.as_bytes())?
); );
add_headers(&mut reply, extra_headers); add_headers(&mut reply, extra_headers)?;
Ok(reply.into()) Ok(reply.into())
} }
} }
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> {
if let Some(eh) = extra_headers { if let Some(eh) = extra_headers {
for (k, v) in eh { for (k, v) in eh {
writeln!(reply, "{}: {}\r", k, v).unwrap(); if let Some(k) = k {
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
}
} }
} }
writeln!(reply, "\r").unwrap(); writeln!(reply, "\r").unwrap();
Ok(())
} }
impl TryParse for Request { impl TryParse for Request {
@ -76,21 +80,18 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
Ok(Request { Ok(Request {
path: raw.path.expect("Bug: no path in header").into(), path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?, headers: HeaderMap::from_httparse(raw.headers)?,
}) })
} }
} }
/// Extra headers for responses.
pub type ExtraHeaders = Vec<(String, String)>;
/// An error response sent to the client. /// An error response sent to the client.
#[derive(Debug)] #[derive(Debug)]
pub struct ErrorResponse { pub struct ErrorResponse {
/// HTTP error code. /// HTTP error code.
pub error_code: StatusCode, pub error_code: StatusCode,
/// Extra response headers, if any. /// Extra response headers, if any.
pub headers: Option<ExtraHeaders>, pub headers: Option<HeaderMap>,
/// Response body, if any. /// Response body, if any.
pub body: Option<String>, pub body: Option<String>,
} }
@ -115,14 +116,14 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>, F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>,
{ {
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
self(request) self(request)
} }
} }
@ -132,7 +133,7 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
Ok(None) Ok(None)
} }
} }
@ -214,7 +215,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
error_code.as_str(), error_code.as_str(),
error_code.canonical_reason().unwrap_or("") error_code.canonical_reason().unwrap_or("")
); );
add_headers(&mut response, headers); add_headers(&mut response, headers)?;
if let Some(body) = body { if let Some(body) = body {
response += &body; response += &body;
} }
@ -241,14 +242,15 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
mod tests { mod tests {
use super::super::client::Response; use super::super::client::Response;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::Request; use super::{HeaderMap, Request};
use http::header::HeaderName;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
assert_eq!(req.path, "/script.ws"); assert_eq!(req.path, "/script.ws");
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]);
} }
#[test] #[test]
@ -264,19 +266,25 @@ mod tests {
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap(); let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![ let extra_headers = {
( let mut headers = HeaderMap::new();
String::from("MyCustomHeader"), headers.insert(
String::from("MyCustomValue"), HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(),
), "MyCustomValue".parse().unwrap(),
(String::from("MyVersion"), String::from("LOL")), );
]); headers.insert(
let reply = req.reply(extra_headers).unwrap(); HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(),
"LOL".parse().unwrap(),
);
headers
};
let reply = req.reply(Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!( assert_eq!(
req.headers.find_first("MyCustomHeader"), req.headers.get("MyCustomHeader").unwrap(),
Some(b"MyCustomValue".as_ref()) b"MyCustomValue".as_ref()
); );
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref());
} }
} }

Loading…
Cancel
Save