Remove custom Headers type and use http::HeaderMap instead

Fixes https://github.com/snapview/tungstenite-rs/issues/92
pull/93/head
Sebastian Dröge 5 years ago
parent bb801430c8
commit 38a7d1a375
  1. 46
      src/error.rs
  2. 26
      src/handshake/client.rs
  3. 118
      src/handshake/headers.rs
  4. 72
      src/handshake/server.rs

@ -45,7 +45,7 @@ pub enum Error {
/// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part.
AlreadyClosed,
/// Input-output error. Appart from WouldBlock, these are generally errors with the
/// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal.
Io(io::Error),
#[cfg(feature = "tls")]
@ -61,10 +61,12 @@ pub enum Error {
SendQueueFull(Message),
/// UTF coding error
Utf8,
/// Invlid URL.
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
/// HTTP format error.
HttpFormat(http::Error),
}
impl fmt::Display for Error {
@ -80,7 +82,8 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
}
@ -99,6 +102,7 @@ impl ErrorTrait for Error {
Error::Utf8 => "",
Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "",
Error::HttpFormat(ref err) => err.description(),
}
}
}
@ -121,6 +125,42 @@ impl From<string::FromUtf8Error> for Error {
}
}
impl From<http::header::InvalidHeaderValue> for Error {
fn from(err: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::InvalidHeaderName> for Error {
fn from(err: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::ToStrError> for Error {
fn from(_: http::header::ToStrError) -> Self {
Error::Utf8
}
}
impl From<http::uri::InvalidUri> for Error {
fn from(err: http::uri::InvalidUri) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::status::InvalidStatusCode> for Error {
fn from(err: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}
#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {

@ -4,11 +4,12 @@ use std::borrow::Cow;
use std::io::{Read, Write};
use std::marker::PhantomData;
use http::HeaderMap;
use httparse::Status;
use log::*;
use url::Url;
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
@ -171,7 +172,10 @@ impl VerifyData {
// _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
.header_is_ignore_case("Upgrade", "websocket")
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
@ -183,7 +187,10 @@ impl VerifyData {
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
.header_is_ignore_case("Connection", "Upgrade")
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
@ -195,7 +202,9 @@ impl VerifyData {
// Connection_. (RFC 6455)
if !response
.headers
.header_is("Sec-WebSocket-Accept", &self.accept_key)
.get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
@ -225,7 +234,7 @@ pub struct Response {
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: Headers,
pub headers: HeaderMap,
}
impl TryParse for Response {
@ -248,7 +257,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
headers: Headers::from_httparse(raw.headers)?,
headers: HeaderMap::from_httparse(raw.headers)?,
})
}
}
@ -287,9 +296,6 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200);
assert_eq!(
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],);
}
}

@ -1,8 +1,7 @@
//! HTTP Request and response header handling.
use std::slice;
use std::str::from_utf8;
use http;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse;
use httparse::Status;
@ -12,90 +11,31 @@ use crate::error::Result;
/// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124;
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
/// Trait to convert raw objects into HTTP parseables.
pub(crate) trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
}
/// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name,
iter: self.data.iter(),
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
let mut headers = HeaderMap::new();
for h in raw {
headers.append(
HeaderName::from_bytes(h.name.as_bytes())?,
HeaderValue::from_bytes(h.value)?,
);
}
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers.
#[derive(Debug)]
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value);
}
}
None
Ok(headers)
}
}
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}
impl TryParse for Headers {
impl TryParse for HeaderMap {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)),
})
}
}
@ -104,7 +44,7 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
mod tests {
use super::super::machine::TryParse;
use super::Headers;
use super::HeaderMap;
#[test]
fn headers() {
@ -112,14 +52,10 @@ mod tests {
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]);
assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]);
assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]);
}
#[test]
@ -130,10 +66,10 @@ mod tests {
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter();
assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]);
assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]);
assert_eq!(iter.next(), None);
}
@ -142,7 +78,7 @@ mod tests {
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap();
let hdr = HeaderMap::try_parse(DATA).unwrap();
assert!(hdr.is_none());
}
}

@ -5,11 +5,11 @@ use std::io::{Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;
use http::StatusCode;
use http::{HeaderMap, StatusCode};
use httparse::Status;
use log::*;
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
@ -21,15 +21,15 @@ pub struct Request {
/// Path part of the URL.
pub path: String,
/// HTTP headers.
pub headers: Headers,
pub headers: HeaderMap,
}
impl Request {
/// Reply to the response.
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
pub fn reply(&self, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> {
let key = self
.headers
.find_first("Sec-WebSocket-Key")
.get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!(
"\
@ -37,20 +37,24 @@ impl Request {
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
convert_key(key.as_bytes())?
);
add_headers(&mut reply, extra_headers);
add_headers(&mut reply, extra_headers)?;
Ok(reply.into())
}
}
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> {
if let Some(eh) = extra_headers {
for (k, v) in eh {
writeln!(reply, "{}: {}\r", k, v).unwrap();
if let Some(k) = k {
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
}
}
}
writeln!(reply, "\r").unwrap();
Ok(())
}
impl TryParse for Request {
@ -76,21 +80,18 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?,
headers: HeaderMap::from_httparse(raw.headers)?,
})
}
}
/// Extra headers for responses.
pub type ExtraHeaders = Vec<(String, String)>;
/// An error response sent to the client.
#[derive(Debug)]
pub struct ErrorResponse {
/// HTTP error code.
pub error_code: StatusCode,
/// Extra response headers, if any.
pub headers: Option<ExtraHeaders>,
pub headers: Option<HeaderMap>,
/// Response body, if any.
pub body: Option<String>,
}
@ -115,14 +116,14 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>;
}
impl<F> Callback for F
where
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>,
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>,
{
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
self(request)
}
}
@ -132,7 +133,7 @@ where
pub struct NoCallback;
impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
Ok(None)
}
}
@ -214,7 +215,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
error_code.as_str(),
error_code.canonical_reason().unwrap_or("")
);
add_headers(&mut response, headers);
add_headers(&mut response, headers)?;
if let Some(body) = body {
response += &body;
}
@ -241,14 +242,15 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
mod tests {
use super::super::client::Response;
use super::super::machine::TryParse;
use super::Request;
use super::{HeaderMap, Request};
use http::header::HeaderName;
#[test]
fn request_parsing() {
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
assert_eq!(req.path, "/script.ws");
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]);
}
#[test]
@ -264,19 +266,25 @@ mod tests {
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![
(
String::from("MyCustomHeader"),
String::from("MyCustomValue"),
),
(String::from("MyVersion"), String::from("LOL")),
]);
let reply = req.reply(extra_headers).unwrap();
let extra_headers = {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(),
"MyCustomValue".parse().unwrap(),
);
headers.insert(
HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(),
"LOL".parse().unwrap(),
);
headers
};
let reply = req.reply(Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(
req.headers.find_first("MyCustomHeader"),
Some(b"MyCustomValue".as_ref())
req.headers.get("MyCustomHeader").unwrap(),
b"MyCustomValue".as_ref()
);
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref());
}
}

Loading…
Cancel
Save