Remove custom Request/Response types from server code

Fixes https://github.com/snapview/tungstenite-rs/issues/92
pull/93/head
Sebastian Dröge 5 years ago
parent 9020840f84
commit 09a9b7ceef
  1. 144
      src/handshake/server.rs

@ -5,7 +5,7 @@ use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::{HeaderMap, StatusCode}; use http::{HeaderMap, Request, Response, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
@ -15,41 +15,28 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Request from the client. /// Reply to the response.
#[derive(Debug)] fn reply(request: &Request<()>, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> {
pub struct Request { let key = request
/// Path part of the URL. .headers()
pub path: String, .get("Sec-WebSocket-Key")
/// HTTP headers. .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
pub headers: HeaderMap, let mut reply = format!(
} "\
HTTP/1.1 101 Switching Protocols\r\n\
impl Request { Connection: Upgrade\r\n\
/// Reply to the response. Upgrade: websocket\r\n\
pub fn reply(&self, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> { Sec-WebSocket-Accept: {}\r\n",
let key = self convert_key(key.as_bytes())?
.headers );
.get("Sec-WebSocket-Key") add_headers(&mut reply, extra_headers.as_ref())?;
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; Ok(reply.into())
let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
convert_key(key.as_bytes())?
);
add_headers(&mut reply, extra_headers)?;
Ok(reply.into())
}
} }
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> { fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> {
if let Some(eh) = extra_headers { if let Some(eh) = extra_headers {
for (k, v) in eh { for (k, v) in eh {
if let Some(k) = k { writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
}
} }
} }
writeln!(reply, "\r").unwrap(); writeln!(reply, "\r").unwrap();
@ -57,7 +44,7 @@ fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> R
Ok(()) Ok(())
} }
impl TryParse for Request { impl TryParse for Request<()> {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut hbuffer); let mut req = httparse::Request::new(&mut hbuffer);
@ -68,41 +55,29 @@ impl TryParse for Request {
} }
} }
impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request<()> {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method.expect("Bug: no method in header") != "GET" { if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol("Method is not GET".into()));
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(), "HTTP version should be 1.1 or higher".into(),
)); ));
} }
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: HeaderMap::from_httparse(raw.headers)?,
})
}
}
/// An error response sent to the client. let headers = HeaderMap::from_httparse(raw.headers)?;
#[derive(Debug)]
pub struct ErrorResponse {
/// HTTP error code.
pub error_code: StatusCode,
/// Extra response headers, if any.
pub headers: Option<HeaderMap>,
/// Response body, if any.
pub body: Option<String>,
}
impl From<StatusCode> for ErrorResponse { let mut request = Request::new(());
fn from(error_code: StatusCode) -> Self { *request.method_mut() = http::Method::GET;
ErrorResponse { *request.headers_mut() = headers;
error_code, *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
headers: None, // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
body: None, // so the only valid value we could get in the response would be 1.1.
} *request.version_mut() = http::Version::HTTP_11;
Ok(request)
} }
} }
@ -116,14 +91,20 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>; fn on_request(
self,
request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>, F: FnOnce(&Request<()>) -> StdResult<Option<HeaderMap>, Response<Option<String>>>,
{ {
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { fn on_request(
self,
request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> {
self(request) self(request)
} }
} }
@ -133,7 +114,10 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> { fn on_request(
self,
_request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> {
Ok(None) Ok(None)
} }
} }
@ -174,7 +158,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
} }
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type IncomingData = Request; type IncomingData = Request<()>;
type InternalStream = S; type InternalStream = S;
type FinalResult = WebSocket<S>; type FinalResult = WebSocket<S>;
@ -200,23 +184,26 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
match callback_result { match callback_result {
Ok(extra_headers) => { Ok(extra_headers) => {
let response = result.reply(extra_headers)?; let response = reply(&result, extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
Err(ErrorResponse { Err(resp) => {
error_code, if resp.status().is_success() {
headers, return Err(Error::Protocol(
body, "Custom response must not be successful".into(),
}) => { ));
self.error_code = Some(error_code.as_u16()); }
self.error_code = Some(resp.status().as_u16());
let mut response = format!( let mut response = format!(
"HTTP/1.1 {} {}\r\n", "{version:?} {status} {reason}\r\n",
error_code.as_str(), version = resp.version(),
error_code.canonical_reason().unwrap_or("") status = resp.status().as_u16(),
reason = resp.status().canonical_reason().unwrap_or("")
); );
add_headers(&mut response, headers)?; add_headers(&mut response, Some(resp.headers()))?;
if let Some(body) = body { if let Some(body) = resp.body() {
response += &body; response += &body;
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
@ -241,6 +228,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::reply;
use super::{HeaderMap, Request}; use super::{HeaderMap, Request};
use http::header::HeaderName; use http::header::HeaderName;
use http::Response; use http::Response;
@ -249,8 +237,8 @@ mod tests {
fn request_parsing() { fn request_parsing() {
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
assert_eq!(req.path, "/script.ws"); assert_eq!(req.uri().path(), "/script.ws");
assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
} }
#[test] #[test]
@ -264,7 +252,7 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap(); let _ = reply(&req, None).unwrap();
let extra_headers = { let extra_headers = {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
@ -279,7 +267,7 @@ mod tests {
headers headers
}; };
let reply = req.reply(Some(extra_headers)).unwrap(); let reply = reply(&req, Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!( assert_eq!(
req.headers().get("MyCustomHeader").unwrap(), req.headers().get("MyCustomHeader").unwrap(),

Loading…
Cancel
Save