From 09a9b7ceef30e9b7c4624c04b3f3cfb320440ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sat, 23 Nov 2019 13:17:51 +0200 Subject: [PATCH] Remove custom Request/Response types from server code Fixes https://github.com/snapview/tungstenite-rs/issues/92 --- src/handshake/server.rs | 144 ++++++++++++++++++---------------------- 1 file changed, 66 insertions(+), 78 deletions(-) diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 5a5890c..3293b80 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,7 +5,7 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::{HeaderMap, StatusCode}; +use http::{HeaderMap, Request, Response, StatusCode}; use httparse::Status; use log::*; @@ -15,41 +15,28 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Request from the client. -#[derive(Debug)] -pub struct Request { - /// Path part of the URL. - pub path: String, - /// HTTP headers. - pub headers: HeaderMap, -} - -impl Request { - /// Reply to the response. - pub fn reply(&self, extra_headers: Option) -> Result> { - let key = self - .headers - .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let mut reply = format!( - "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", - convert_key(key.as_bytes())? - ); - add_headers(&mut reply, extra_headers)?; - Ok(reply.into()) - } +/// Reply to the response. +fn reply(request: &Request<()>, extra_headers: Option) -> Result> { + let key = request + .headers() + .get("Sec-WebSocket-Key") + .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + let mut reply = format!( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", + convert_key(key.as_bytes())? + ); + add_headers(&mut reply, extra_headers.as_ref())?; + Ok(reply.into()) } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> Result<()> { +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { if let Some(eh) = extra_headers { for (k, v) in eh { - if let Some(k) = k { - writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); - } + writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); } } writeln!(reply, "\r").unwrap(); @@ -57,7 +44,7 @@ fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> R Ok(()) } -impl TryParse for Request { +impl TryParse for Request<()> { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Request::new(&mut hbuffer); @@ -68,41 +55,29 @@ impl TryParse for Request { } } -impl<'h, 'b: 'h> FromHttparse> for Request { +impl<'h, 'b: 'h> FromHttparse> for Request<()> { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { return Err(Error::Protocol("Method is not GET".into())); } + if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( "HTTP version should be 1.1 or higher".into(), )); } - Ok(Request { - path: raw.path.expect("Bug: no path in header").into(), - headers: HeaderMap::from_httparse(raw.headers)?, - }) - } -} -/// An error response sent to the client. -#[derive(Debug)] -pub struct ErrorResponse { - /// HTTP error code. - pub error_code: StatusCode, - /// Extra response headers, if any. - pub headers: Option, - /// Response body, if any. - pub body: Option, -} + let headers = HeaderMap::from_httparse(raw.headers)?; -impl From for ErrorResponse { - fn from(error_code: StatusCode) -> Self { - ErrorResponse { - error_code, - headers: None, - body: None, - } + let mut request = Request::new(()); + *request.method_mut() = http::Method::GET; + *request.headers_mut() = headers; + *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *request.version_mut() = http::Version::HTTP_11; + + Ok(request) } } @@ -116,14 +91,20 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request( + self, + request: &Request<()>, + ) -> StdResult, Response>>; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request<()>) -> StdResult, Response>>, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + request: &Request<()>, + ) -> StdResult, Response>> { self(request) } } @@ -133,7 +114,10 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + _request: &Request<()>, + ) -> StdResult, Response>> { Ok(None) } } @@ -174,7 +158,7 @@ impl ServerHandshake { } impl HandshakeRole for ServerHandshake { - type IncomingData = Request; + type IncomingData = Request<()>; type InternalStream = S; type FinalResult = WebSocket; @@ -200,23 +184,26 @@ impl HandshakeRole for ServerHandshake { match callback_result { Ok(extra_headers) => { - let response = result.reply(extra_headers)?; + let response = reply(&result, extra_headers)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } - Err(ErrorResponse { - error_code, - headers, - body, - }) => { - self.error_code = Some(error_code.as_u16()); + Err(resp) => { + if resp.status().is_success() { + return Err(Error::Protocol( + "Custom response must not be successful".into(), + )); + } + + self.error_code = Some(resp.status().as_u16()); let mut response = format!( - "HTTP/1.1 {} {}\r\n", - error_code.as_str(), - error_code.canonical_reason().unwrap_or("") + "{version:?} {status} {reason}\r\n", + version = resp.version(), + status = resp.status().as_u16(), + reason = resp.status().canonical_reason().unwrap_or("") ); - add_headers(&mut response, headers)?; - if let Some(body) = body { + add_headers(&mut response, Some(resp.headers()))?; + if let Some(body) = resp.body() { response += &body; } ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) @@ -241,6 +228,7 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { use super::super::machine::TryParse; + use super::reply; use super::{HeaderMap, Request}; use http::header::HeaderName; use http::Response; @@ -249,8 +237,8 @@ mod tests { fn request_parsing() { const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - assert_eq!(req.path, "/script.ws"); - assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); + assert_eq!(req.uri().path(), "/script.ws"); + assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); } #[test] @@ -264,7 +252,7 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply(None).unwrap(); + let _ = reply(&req, None).unwrap(); let extra_headers = { let mut headers = HeaderMap::new(); @@ -279,7 +267,7 @@ mod tests { headers }; - let reply = req.reply(Some(extra_headers)).unwrap(); + let reply = reply(&req, Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( req.headers().get("MyCustomHeader").unwrap(),