|
|
@ -1,56 +1,108 @@ |
|
|
|
//! Server handshake machine.
|
|
|
|
//! Server handshake machine.
|
|
|
|
|
|
|
|
|
|
|
|
use std::fmt::Write as FmtWrite; |
|
|
|
use std::io::{self, Read, Write}; |
|
|
|
use std::io::{Read, Write}; |
|
|
|
|
|
|
|
use std::marker::PhantomData; |
|
|
|
use std::marker::PhantomData; |
|
|
|
use std::result::Result as StdResult; |
|
|
|
use std::result::Result as StdResult; |
|
|
|
|
|
|
|
|
|
|
|
use http::StatusCode; |
|
|
|
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; |
|
|
|
use httparse::Status; |
|
|
|
use httparse::Status; |
|
|
|
use log::*; |
|
|
|
use log::*; |
|
|
|
|
|
|
|
|
|
|
|
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; |
|
|
|
use super::headers::{FromHttparse, MAX_HEADERS}; |
|
|
|
use super::machine::{HandshakeMachine, StageResult, TryParse}; |
|
|
|
use super::machine::{HandshakeMachine, StageResult, TryParse}; |
|
|
|
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; |
|
|
|
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; |
|
|
|
use crate::error::{Error, Result}; |
|
|
|
use crate::error::{Error, Result}; |
|
|
|
use crate::protocol::{Role, WebSocket, WebSocketConfig}; |
|
|
|
use crate::protocol::{Role, WebSocket, WebSocketConfig}; |
|
|
|
|
|
|
|
|
|
|
|
/// Request from the client.
|
|
|
|
/// Server request type.
|
|
|
|
#[derive(Debug)] |
|
|
|
pub type Request = HttpRequest<()>; |
|
|
|
pub struct Request { |
|
|
|
|
|
|
|
/// Path part of the URL.
|
|
|
|
|
|
|
|
pub path: String, |
|
|
|
|
|
|
|
/// HTTP headers.
|
|
|
|
|
|
|
|
pub headers: Headers, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl Request { |
|
|
|
/// Server response type.
|
|
|
|
/// Reply to the response.
|
|
|
|
pub type Response = HttpResponse<()>; |
|
|
|
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> { |
|
|
|
|
|
|
|
let key = self |
|
|
|
/// Server error response type.
|
|
|
|
.headers |
|
|
|
pub type ErrorResponse = HttpResponse<Option<String>>; |
|
|
|
.find_first("Sec-WebSocket-Key") |
|
|
|
|
|
|
|
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; |
|
|
|
/// Create a response for the request.
|
|
|
|
let mut reply = format!( |
|
|
|
pub fn create_response(request: &Request) -> Result<Response> { |
|
|
|
"\ |
|
|
|
if request.method() != http::Method::GET { |
|
|
|
HTTP/1.1 101 Switching Protocols\r\n\ |
|
|
|
return Err(Error::Protocol("Method is not GET".into())); |
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
} |
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Accept: {}\r\n", |
|
|
|
if request.version() < http::Version::HTTP_11 { |
|
|
|
convert_key(key)? |
|
|
|
return Err(Error::Protocol( |
|
|
|
); |
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
add_headers(&mut reply, extra_headers); |
|
|
|
)); |
|
|
|
Ok(reply.into()) |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if !request |
|
|
|
|
|
|
|
.headers() |
|
|
|
|
|
|
|
.get("Connection") |
|
|
|
|
|
|
|
.and_then(|h| h.to_str().ok()) |
|
|
|
|
|
|
|
.map(|h| h.eq_ignore_ascii_case("Upgrade")) |
|
|
|
|
|
|
|
.unwrap_or(false) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
|
|
|
"No \"Connection: upgrade\" in client request".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { |
|
|
|
if !request |
|
|
|
if let Some(eh) = extra_headers { |
|
|
|
.headers() |
|
|
|
for (k, v) in eh { |
|
|
|
.get("Upgrade") |
|
|
|
writeln!(reply, "{}: {}\r", k, v).unwrap(); |
|
|
|
.and_then(|h| h.to_str().ok()) |
|
|
|
|
|
|
|
.map(|h| h.eq_ignore_ascii_case("websocket")) |
|
|
|
|
|
|
|
.unwrap_or(false) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
|
|
|
"No \"Upgrade: websocket\" in client request".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if !request |
|
|
|
|
|
|
|
.headers() |
|
|
|
|
|
|
|
.get("Sec-WebSocket-Version") |
|
|
|
|
|
|
|
.map(|h| h == "13") |
|
|
|
|
|
|
|
.unwrap_or(false) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
|
|
|
"No \"Sec-WebSocket-Version: 13\" in client request".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let key = request |
|
|
|
|
|
|
|
.headers() |
|
|
|
|
|
|
|
.get("Sec-WebSocket-Key") |
|
|
|
|
|
|
|
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut response = Response::builder(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response.status(StatusCode::SWITCHING_PROTOCOLS); |
|
|
|
|
|
|
|
response.version(request.version()); |
|
|
|
|
|
|
|
response.header("Connection", "Upgrade"); |
|
|
|
|
|
|
|
response.header("Upgrade", "websocket"); |
|
|
|
|
|
|
|
response.header("Sec-WebSocket-Accept", convert_key(key.as_bytes())?); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(response.body(())?) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Assumes that this is a valid response
|
|
|
|
|
|
|
|
fn write_response<T>(w: &mut dyn io::Write, response: &HttpResponse<T>) -> Result<()> { |
|
|
|
|
|
|
|
writeln!( |
|
|
|
|
|
|
|
w, |
|
|
|
|
|
|
|
"{version:?} {status} {reason}\r", |
|
|
|
|
|
|
|
version = response.version(), |
|
|
|
|
|
|
|
status = response.status(), |
|
|
|
|
|
|
|
reason = response.status().canonical_reason().unwrap_or(""), |
|
|
|
|
|
|
|
)?; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (k, v) in response.headers() { |
|
|
|
|
|
|
|
writeln!(w, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
} |
|
|
|
} |
|
|
|
writeln!(reply, "\r").unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
writeln!(w, "\r")?; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(()) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl TryParse for Request { |
|
|
|
impl TryParse for Request { |
|
|
@ -69,39 +121,24 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { |
|
|
|
if raw.method.expect("Bug: no method in header") != "GET" { |
|
|
|
if raw.method.expect("Bug: no method in header") != "GET" { |
|
|
|
return Err(Error::Protocol("Method is not GET".into())); |
|
|
|
return Err(Error::Protocol("Method is not GET".into())); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
|
|
|
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol( |
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
)); |
|
|
|
)); |
|
|
|
} |
|
|
|
} |
|
|
|
Ok(Request { |
|
|
|
|
|
|
|
path: raw.path.expect("Bug: no path in header").into(), |
|
|
|
|
|
|
|
headers: Headers::from_httparse(raw.headers)?, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Extra headers for responses.
|
|
|
|
let headers = HeaderMap::from_httparse(raw.headers)?; |
|
|
|
pub type ExtraHeaders = Vec<(String, String)>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// An error response sent to the client.
|
|
|
|
let mut request = Request::new(()); |
|
|
|
#[derive(Debug)] |
|
|
|
*request.method_mut() = http::Method::GET; |
|
|
|
pub struct ErrorResponse { |
|
|
|
*request.headers_mut() = headers; |
|
|
|
/// HTTP error code.
|
|
|
|
*request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?; |
|
|
|
pub error_code: StatusCode, |
|
|
|
// TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
|
|
|
|
/// Extra response headers, if any.
|
|
|
|
// so the only valid value we could get in the response would be 1.1.
|
|
|
|
pub headers: Option<ExtraHeaders>, |
|
|
|
*request.version_mut() = http::Version::HTTP_11; |
|
|
|
/// Response body, if any.
|
|
|
|
|
|
|
|
pub body: Option<String>, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl From<StatusCode> for ErrorResponse { |
|
|
|
Ok(request) |
|
|
|
fn from(error_code: StatusCode) -> Self { |
|
|
|
|
|
|
|
ErrorResponse { |
|
|
|
|
|
|
|
error_code, |
|
|
|
|
|
|
|
headers: None, |
|
|
|
|
|
|
|
body: None, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -115,15 +152,23 @@ pub trait Callback: Sized { |
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
/// May return additional reply headers.
|
|
|
|
/// May return additional reply headers.
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; |
|
|
|
fn on_request( |
|
|
|
|
|
|
|
self, |
|
|
|
|
|
|
|
request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse>; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl<F> Callback for F |
|
|
|
impl<F> Callback for F |
|
|
|
where |
|
|
|
where |
|
|
|
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>, |
|
|
|
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>, |
|
|
|
{ |
|
|
|
{ |
|
|
|
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { |
|
|
|
fn on_request( |
|
|
|
self(request) |
|
|
|
self, |
|
|
|
|
|
|
|
request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse> { |
|
|
|
|
|
|
|
self(request, response) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -132,8 +177,12 @@ where |
|
|
|
pub struct NoCallback; |
|
|
|
pub struct NoCallback; |
|
|
|
|
|
|
|
|
|
|
|
impl Callback for NoCallback { |
|
|
|
impl Callback for NoCallback { |
|
|
|
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { |
|
|
|
fn on_request( |
|
|
|
Ok(None) |
|
|
|
self, |
|
|
|
|
|
|
|
_request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse> { |
|
|
|
|
|
|
|
Ok(response) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -191,34 +240,35 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
return Err(Error::Protocol("Junk after client request".into())); |
|
|
|
return Err(Error::Protocol("Junk after client request".into())); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let response = create_response(&result)?; |
|
|
|
let callback_result = if let Some(callback) = self.callback.take() { |
|
|
|
let callback_result = if let Some(callback) = self.callback.take() { |
|
|
|
callback.on_request(&result) |
|
|
|
callback.on_request(&result, response) |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
Ok(None) |
|
|
|
Ok(response) |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
match callback_result { |
|
|
|
match callback_result { |
|
|
|
Ok(extra_headers) => { |
|
|
|
Ok(response) => { |
|
|
|
let response = result.reply(extra_headers)?; |
|
|
|
let mut output = vec![]; |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) |
|
|
|
write_response(&mut output, &response)?; |
|
|
|
|
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Err(ErrorResponse { |
|
|
|
Err(resp) => { |
|
|
|
error_code, |
|
|
|
if resp.status().is_success() { |
|
|
|
headers, |
|
|
|
return Err(Error::Protocol( |
|
|
|
body, |
|
|
|
"Custom response must not be successful".into(), |
|
|
|
}) => { |
|
|
|
)); |
|
|
|
self.error_code = Some(error_code.as_u16()); |
|
|
|
} |
|
|
|
let mut response = format!( |
|
|
|
|
|
|
|
"HTTP/1.1 {} {}\r\n", |
|
|
|
self.error_code = Some(resp.status().as_u16()); |
|
|
|
error_code.as_str(), |
|
|
|
|
|
|
|
error_code.canonical_reason().unwrap_or("") |
|
|
|
let mut output = vec![]; |
|
|
|
); |
|
|
|
write_response(&mut output, &resp)?; |
|
|
|
add_headers(&mut response, headers); |
|
|
|
if let Some(body) = resp.body() { |
|
|
|
if let Some(body) = body { |
|
|
|
output.extend_from_slice(body.as_bytes()); |
|
|
|
response += &body; |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -226,7 +276,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
StageResult::DoneWriting(stream) => { |
|
|
|
StageResult::DoneWriting(stream) => { |
|
|
|
if let Some(err) = self.error_code.take() { |
|
|
|
if let Some(err) = self.error_code.take() { |
|
|
|
debug!("Server handshake failed."); |
|
|
|
debug!("Server handshake failed."); |
|
|
|
return Err(Error::Http(err)); |
|
|
|
return Err(Error::Http(StatusCode::from_u16(err)?)); |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
debug!("Server handshake done."); |
|
|
|
debug!("Server handshake done."); |
|
|
|
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); |
|
|
|
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); |
|
|
@ -239,21 +289,21 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
|
|
|
|
|
|
|
|
#[cfg(test)] |
|
|
|
#[cfg(test)] |
|
|
|
mod tests { |
|
|
|
mod tests { |
|
|
|
use super::super::client::Response; |
|
|
|
|
|
|
|
use super::super::machine::TryParse; |
|
|
|
use super::super::machine::TryParse; |
|
|
|
|
|
|
|
use super::create_response; |
|
|
|
use super::Request; |
|
|
|
use super::Request; |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
|
fn request_parsing() { |
|
|
|
fn request_parsing() { |
|
|
|
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; |
|
|
|
const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; |
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
assert_eq!(req.path, "/script.ws"); |
|
|
|
assert_eq!(req.uri().path(), "/script.ws"); |
|
|
|
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); |
|
|
|
assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
|
fn request_replying() { |
|
|
|
fn request_replying() { |
|
|
|
const DATA: &'static [u8] = b"\ |
|
|
|
const DATA: &[u8] = b"\ |
|
|
|
GET /script.ws HTTP/1.1\r\n\ |
|
|
|
GET /script.ws HTTP/1.1\r\n\ |
|
|
|
Host: foo.com\r\n\ |
|
|
|
Host: foo.com\r\n\ |
|
|
|
Connection: upgrade\r\n\ |
|
|
|
Connection: upgrade\r\n\ |
|
|
@ -262,21 +312,11 @@ mod tests { |
|
|
|
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ |
|
|
|
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ |
|
|
|
\r\n"; |
|
|
|
\r\n"; |
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); |
|
|
|
let _ = req.reply(None).unwrap(); |
|
|
|
let response = create_response(&req).unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
let extra_headers = Some(vec![ |
|
|
|
|
|
|
|
( |
|
|
|
|
|
|
|
String::from("MyCustomHeader"), |
|
|
|
|
|
|
|
String::from("MyCustomValue"), |
|
|
|
|
|
|
|
), |
|
|
|
|
|
|
|
(String::from("MyVersion"), String::from("LOL")), |
|
|
|
|
|
|
|
]); |
|
|
|
|
|
|
|
let reply = req.reply(extra_headers).unwrap(); |
|
|
|
|
|
|
|
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); |
|
|
|
|
|
|
|
assert_eq!( |
|
|
|
assert_eq!( |
|
|
|
req.headers.find_first("MyCustomHeader"), |
|
|
|
response.headers().get("Sec-WebSocket-Accept").unwrap(), |
|
|
|
Some(b"MyCustomValue".as_ref()) |
|
|
|
b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref() |
|
|
|
); |
|
|
|
); |
|
|
|
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|