|
|
|
@ -5,11 +5,11 @@ use std::io::{Read, Write}; |
|
|
|
|
use std::marker::PhantomData; |
|
|
|
|
use std::result::Result as StdResult; |
|
|
|
|
|
|
|
|
|
use http::StatusCode; |
|
|
|
|
use http::{HeaderMap, StatusCode}; |
|
|
|
|
use httparse::Status; |
|
|
|
|
use log::*; |
|
|
|
|
|
|
|
|
|
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; |
|
|
|
|
use super::headers::{FromHttparse, MAX_HEADERS}; |
|
|
|
|
use super::machine::{HandshakeMachine, StageResult, TryParse}; |
|
|
|
|
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; |
|
|
|
|
use crate::error::{Error, Result}; |
|
|
|
@ -21,15 +21,15 @@ pub struct Request { |
|
|
|
|
/// Path part of the URL.
|
|
|
|
|
pub path: String, |
|
|
|
|
/// HTTP headers.
|
|
|
|
|
pub headers: Headers, |
|
|
|
|
pub headers: HeaderMap, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl Request { |
|
|
|
|
/// Reply to the response.
|
|
|
|
|
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> { |
|
|
|
|
pub fn reply(&self, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> { |
|
|
|
|
let key = self |
|
|
|
|
.headers |
|
|
|
|
.find_first("Sec-WebSocket-Key") |
|
|
|
|
.get("Sec-WebSocket-Key") |
|
|
|
|
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; |
|
|
|
|
let mut reply = format!( |
|
|
|
|
"\ |
|
|
|
@ -37,20 +37,24 @@ impl Request { |
|
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
|
Sec-WebSocket-Accept: {}\r\n", |
|
|
|
|
convert_key(key)? |
|
|
|
|
convert_key(key.as_bytes())? |
|
|
|
|
); |
|
|
|
|
add_headers(&mut reply, extra_headers); |
|
|
|
|
add_headers(&mut reply, extra_headers)?; |
|
|
|
|
Ok(reply.into()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { |
|
|
|
|
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> { |
|
|
|
|
if let Some(eh) = extra_headers { |
|
|
|
|
for (k, v) in eh { |
|
|
|
|
writeln!(reply, "{}: {}\r", k, v).unwrap(); |
|
|
|
|
if let Some(k) = k { |
|
|
|
|
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
writeln!(reply, "\r").unwrap(); |
|
|
|
|
|
|
|
|
|
Ok(()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl TryParse for Request { |
|
|
|
@ -76,21 +80,18 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { |
|
|
|
|
} |
|
|
|
|
Ok(Request { |
|
|
|
|
path: raw.path.expect("Bug: no path in header").into(), |
|
|
|
|
headers: Headers::from_httparse(raw.headers)?, |
|
|
|
|
headers: HeaderMap::from_httparse(raw.headers)?, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Extra headers for responses.
|
|
|
|
|
pub type ExtraHeaders = Vec<(String, String)>; |
|
|
|
|
|
|
|
|
|
/// An error response sent to the client.
|
|
|
|
|
#[derive(Debug)] |
|
|
|
|
pub struct ErrorResponse { |
|
|
|
|
/// HTTP error code.
|
|
|
|
|
pub error_code: StatusCode, |
|
|
|
|
/// Extra response headers, if any.
|
|
|
|
|
pub headers: Option<ExtraHeaders>, |
|
|
|
|
pub headers: Option<HeaderMap>, |
|
|
|
|
/// Response body, if any.
|
|
|
|
|
pub body: Option<String>, |
|
|
|
|
} |
|
|
|
@ -115,14 +116,14 @@ pub trait Callback: Sized { |
|
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
|
/// May return additional reply headers.
|
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; |
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<F> Callback for F |
|
|
|
|
where |
|
|
|
|
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>, |
|
|
|
|
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>, |
|
|
|
|
{ |
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { |
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { |
|
|
|
|
self(request) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -132,7 +133,7 @@ where |
|
|
|
|
pub struct NoCallback; |
|
|
|
|
|
|
|
|
|
impl Callback for NoCallback { |
|
|
|
|
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { |
|
|
|
|
fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { |
|
|
|
|
Ok(None) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -214,7 +215,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
error_code.as_str(), |
|
|
|
|
error_code.canonical_reason().unwrap_or("") |
|
|
|
|
); |
|
|
|
|
add_headers(&mut response, headers); |
|
|
|
|
add_headers(&mut response, headers)?; |
|
|
|
|
if let Some(body) = body { |
|
|
|
|
response += &body; |
|
|
|
|
} |
|
|
|
@ -241,14 +242,15 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
mod tests { |
|
|
|
|
use super::super::client::Response; |
|
|
|
|
use super::super::machine::TryParse; |
|
|
|
|
use super::Request; |
|
|
|
|
use super::{HeaderMap, Request}; |
|
|
|
|
use http::header::HeaderName; |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn request_parsing() { |
|
|
|
|
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; |
|
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
|
assert_eq!(req.path, "/script.ws"); |
|
|
|
|
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); |
|
|
|
|
assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
@ -264,19 +266,25 @@ mod tests { |
|
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
|
let _ = req.reply(None).unwrap(); |
|
|
|
|
|
|
|
|
|
let extra_headers = Some(vec![ |
|
|
|
|
( |
|
|
|
|
String::from("MyCustomHeader"), |
|
|
|
|
String::from("MyCustomValue"), |
|
|
|
|
), |
|
|
|
|
(String::from("MyVersion"), String::from("LOL")), |
|
|
|
|
]); |
|
|
|
|
let reply = req.reply(extra_headers).unwrap(); |
|
|
|
|
let extra_headers = { |
|
|
|
|
let mut headers = HeaderMap::new(); |
|
|
|
|
headers.insert( |
|
|
|
|
HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(), |
|
|
|
|
"MyCustomValue".parse().unwrap(), |
|
|
|
|
); |
|
|
|
|
headers.insert( |
|
|
|
|
HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(), |
|
|
|
|
"LOL".parse().unwrap(), |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
headers |
|
|
|
|
}; |
|
|
|
|
let reply = req.reply(Some(extra_headers)).unwrap(); |
|
|
|
|
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); |
|
|
|
|
assert_eq!( |
|
|
|
|
req.headers.find_first("MyCustomHeader"), |
|
|
|
|
Some(b"MyCustomValue".as_ref()) |
|
|
|
|
req.headers.get("MyCustomHeader").unwrap(), |
|
|
|
|
b"MyCustomValue".as_ref() |
|
|
|
|
); |
|
|
|
|
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); |
|
|
|
|
assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|