|
|
|
@ -5,7 +5,7 @@ use std::io::{Read, Write}; |
|
|
|
|
use std::marker::PhantomData; |
|
|
|
|
use std::result::Result as StdResult; |
|
|
|
|
|
|
|
|
|
use http::{HeaderMap, StatusCode}; |
|
|
|
|
use http::{HeaderMap, Request, Response, StatusCode}; |
|
|
|
|
use httparse::Status; |
|
|
|
|
use log::*; |
|
|
|
|
|
|
|
|
@ -15,20 +15,10 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; |
|
|
|
|
use crate::error::{Error, Result}; |
|
|
|
|
use crate::protocol::{Role, WebSocket, WebSocketConfig}; |
|
|
|
|
|
|
|
|
|
/// Request from the client.
|
|
|
|
|
#[derive(Debug)] |
|
|
|
|
pub struct Request { |
|
|
|
|
/// Path part of the URL.
|
|
|
|
|
pub path: String, |
|
|
|
|
/// HTTP headers.
|
|
|
|
|
pub headers: HeaderMap, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl Request { |
|
|
|
|
/// Reply to the response.
|
|
|
|
|
pub fn reply(&self, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> { |
|
|
|
|
let key = self |
|
|
|
|
.headers |
|
|
|
|
/// Reply to the response.
|
|
|
|
|
fn reply(request: &Request<()>, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> { |
|
|
|
|
let key = request |
|
|
|
|
.headers() |
|
|
|
|
.get("Sec-WebSocket-Key") |
|
|
|
|
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; |
|
|
|
|
let mut reply = format!( |
|
|
|
@ -39,25 +29,22 @@ impl Request { |
|
|
|
|
Sec-WebSocket-Accept: {}\r\n", |
|
|
|
|
convert_key(key.as_bytes())? |
|
|
|
|
); |
|
|
|
|
add_headers(&mut reply, extra_headers)?; |
|
|
|
|
add_headers(&mut reply, extra_headers.as_ref())?; |
|
|
|
|
Ok(reply.into()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> { |
|
|
|
|
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { |
|
|
|
|
if let Some(eh) = extra_headers { |
|
|
|
|
for (k, v) in eh { |
|
|
|
|
if let Some(k) = k { |
|
|
|
|
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
writeln!(reply, "\r").unwrap(); |
|
|
|
|
|
|
|
|
|
Ok(()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl TryParse for Request { |
|
|
|
|
impl TryParse for Request<()> { |
|
|
|
|
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { |
|
|
|
|
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; |
|
|
|
|
let mut req = httparse::Request::new(&mut hbuffer); |
|
|
|
@ -68,41 +55,29 @@ impl TryParse for Request { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { |
|
|
|
|
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request<()> { |
|
|
|
|
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> { |
|
|
|
|
if raw.method.expect("Bug: no method in header") != "GET" { |
|
|
|
|
return Err(Error::Protocol("Method is not GET".into())); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
Ok(Request { |
|
|
|
|
path: raw.path.expect("Bug: no path in header").into(), |
|
|
|
|
headers: HeaderMap::from_httparse(raw.headers)?, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// An error response sent to the client.
|
|
|
|
|
#[derive(Debug)] |
|
|
|
|
pub struct ErrorResponse { |
|
|
|
|
/// HTTP error code.
|
|
|
|
|
pub error_code: StatusCode, |
|
|
|
|
/// Extra response headers, if any.
|
|
|
|
|
pub headers: Option<HeaderMap>, |
|
|
|
|
/// Response body, if any.
|
|
|
|
|
pub body: Option<String>, |
|
|
|
|
} |
|
|
|
|
let headers = HeaderMap::from_httparse(raw.headers)?; |
|
|
|
|
|
|
|
|
|
impl From<StatusCode> for ErrorResponse { |
|
|
|
|
fn from(error_code: StatusCode) -> Self { |
|
|
|
|
ErrorResponse { |
|
|
|
|
error_code, |
|
|
|
|
headers: None, |
|
|
|
|
body: None, |
|
|
|
|
} |
|
|
|
|
let mut request = Request::new(()); |
|
|
|
|
*request.method_mut() = http::Method::GET; |
|
|
|
|
*request.headers_mut() = headers; |
|
|
|
|
*request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?; |
|
|
|
|
// TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
|
|
|
|
|
// so the only valid value we could get in the response would be 1.1.
|
|
|
|
|
*request.version_mut() = http::Version::HTTP_11; |
|
|
|
|
|
|
|
|
|
Ok(request) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -116,14 +91,20 @@ pub trait Callback: Sized { |
|
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
|
/// May return additional reply headers.
|
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>; |
|
|
|
|
fn on_request( |
|
|
|
|
self, |
|
|
|
|
request: &Request<()>, |
|
|
|
|
) -> StdResult<Option<HeaderMap>, Response<Option<String>>>; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<F> Callback for F |
|
|
|
|
where |
|
|
|
|
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>, |
|
|
|
|
F: FnOnce(&Request<()>) -> StdResult<Option<HeaderMap>, Response<Option<String>>>, |
|
|
|
|
{ |
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { |
|
|
|
|
fn on_request( |
|
|
|
|
self, |
|
|
|
|
request: &Request<()>, |
|
|
|
|
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> { |
|
|
|
|
self(request) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -133,7 +114,10 @@ where |
|
|
|
|
pub struct NoCallback; |
|
|
|
|
|
|
|
|
|
impl Callback for NoCallback { |
|
|
|
|
fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { |
|
|
|
|
fn on_request( |
|
|
|
|
self, |
|
|
|
|
_request: &Request<()>, |
|
|
|
|
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> { |
|
|
|
|
Ok(None) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -174,7 +158,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
type IncomingData = Request; |
|
|
|
|
type IncomingData = Request<()>; |
|
|
|
|
type InternalStream = S; |
|
|
|
|
type FinalResult = WebSocket<S>; |
|
|
|
|
|
|
|
|
@ -200,23 +184,26 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
|
|
|
|
|
match callback_result { |
|
|
|
|
Ok(extra_headers) => { |
|
|
|
|
let response = result.reply(extra_headers)?; |
|
|
|
|
let response = reply(&result, extra_headers)?; |
|
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Err(ErrorResponse { |
|
|
|
|
error_code, |
|
|
|
|
headers, |
|
|
|
|
body, |
|
|
|
|
}) => { |
|
|
|
|
self.error_code = Some(error_code.as_u16()); |
|
|
|
|
Err(resp) => { |
|
|
|
|
if resp.status().is_success() { |
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
"Custom response must not be successful".into(), |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
self.error_code = Some(resp.status().as_u16()); |
|
|
|
|
let mut response = format!( |
|
|
|
|
"HTTP/1.1 {} {}\r\n", |
|
|
|
|
error_code.as_str(), |
|
|
|
|
error_code.canonical_reason().unwrap_or("") |
|
|
|
|
"{version:?} {status} {reason}\r\n", |
|
|
|
|
version = resp.version(), |
|
|
|
|
status = resp.status().as_u16(), |
|
|
|
|
reason = resp.status().canonical_reason().unwrap_or("") |
|
|
|
|
); |
|
|
|
|
add_headers(&mut response, headers)?; |
|
|
|
|
if let Some(body) = body { |
|
|
|
|
add_headers(&mut response, Some(resp.headers()))?; |
|
|
|
|
if let Some(body) = resp.body() { |
|
|
|
|
response += &body; |
|
|
|
|
} |
|
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) |
|
|
|
@ -241,6 +228,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
#[cfg(test)] |
|
|
|
|
mod tests { |
|
|
|
|
use super::super::machine::TryParse; |
|
|
|
|
use super::reply; |
|
|
|
|
use super::{HeaderMap, Request}; |
|
|
|
|
use http::header::HeaderName; |
|
|
|
|
use http::Response; |
|
|
|
@ -249,8 +237,8 @@ mod tests { |
|
|
|
|
fn request_parsing() { |
|
|
|
|
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; |
|
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
|
assert_eq!(req.path, "/script.ws"); |
|
|
|
|
assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); |
|
|
|
|
assert_eq!(req.uri().path(), "/script.ws"); |
|
|
|
|
assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
@ -264,7 +252,7 @@ mod tests { |
|
|
|
|
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ |
|
|
|
|
\r\n"; |
|
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
|
let _ = req.reply(None).unwrap(); |
|
|
|
|
let _ = reply(&req, None).unwrap(); |
|
|
|
|
|
|
|
|
|
let extra_headers = { |
|
|
|
|
let mut headers = HeaderMap::new(); |
|
|
|
@ -279,7 +267,7 @@ mod tests { |
|
|
|
|
|
|
|
|
|
headers |
|
|
|
|
}; |
|
|
|
|
let reply = req.reply(Some(extra_headers)).unwrap(); |
|
|
|
|
let reply = reply(&req, Some(extra_headers)).unwrap(); |
|
|
|
|
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); |
|
|
|
|
assert_eq!( |
|
|
|
|
req.headers().get("MyCustomHeader").unwrap(), |
|
|
|
|