Use Response for the server handshake callback too

And add a public create_response(&Request) function that creates an
initial response. This can be used to simplify integration into existing
HTTP libraries.
pull/93/head
Sebastian Dröge 5 years ago
parent 09a9b7ceef
commit 1ecc4f900d
  1. 161
      src/handshake/server.rs

@ -1,7 +1,6 @@
//! Server handshake machine. //! Server handshake machine.
use std::fmt::Write as FmtWrite; use std::io::{self, Read, Write};
use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult; use std::result::Result as StdResult;
@ -15,31 +14,84 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Reply to the response. /// Create a response for the request.
fn reply(request: &Request<()>, extra_headers: Option<HeaderMap>) -> Result<Vec<u8>> { pub fn create_response(request: &Request<()>) -> Result<Response<()>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol("Method is not GET".into()));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
if !request
.headers()
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in client request".into(),
));
}
if !request
.headers()
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in client request".into(),
));
}
if !request
.headers()
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
}
let key = request let key = request
.headers() .headers()
.get("Sec-WebSocket-Key") .get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!(
"\ let mut response = Response::builder();
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\ response.status(StatusCode::SWITCHING_PROTOCOLS);
Upgrade: websocket\r\n\ response.version(request.version());
Sec-WebSocket-Accept: {}\r\n", response.header("Connection", "Upgrade");
convert_key(key.as_bytes())? response.header("Upgrade", "websocket");
); response.header("Sec-WebSocket-Accept", convert_key(key.as_bytes())?);
add_headers(&mut reply, extra_headers.as_ref())?;
Ok(reply.into()) Ok(response.body(())?)
} }
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { // Assumes that this is a valid response
if let Some(eh) = extra_headers { fn write_response<T>(w: &mut dyn io::Write, response: &Response<T>) -> Result<()> {
for (k, v) in eh { writeln!(
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); w,
} "{version:?} {status} {reason}\r",
version = response.version(),
status = response.status(),
reason = response.status().canonical_reason().unwrap_or(""),
)?;
for (k, v) in response.headers() {
writeln!(w, "{}: {}\r", k, v.to_str()?).unwrap();
} }
writeln!(reply, "\r").unwrap();
writeln!(w, "\r")?;
Ok(()) Ok(())
} }
@ -94,18 +146,20 @@ pub trait Callback: Sized {
fn on_request( fn on_request(
self, self,
request: &Request<()>, request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>>; response: Response<()>,
) -> StdResult<Response<()>, Response<Option<String>>>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request<()>) -> StdResult<Option<HeaderMap>, Response<Option<String>>>, F: FnOnce(&Request<()>, Response<()>) -> StdResult<Response<()>, Response<Option<String>>>,
{ {
fn on_request( fn on_request(
self, self,
request: &Request<()>, request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> { response: Response<()>,
self(request) ) -> StdResult<Response<()>, Response<Option<String>>> {
self(request, response)
} }
} }
@ -117,8 +171,9 @@ impl Callback for NoCallback {
fn on_request( fn on_request(
self, self,
_request: &Request<()>, _request: &Request<()>,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> { response: Response<()>,
Ok(None) ) -> StdResult<Response<()>, Response<Option<String>>> {
Ok(response)
} }
} }
@ -176,16 +231,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol("Junk after client request".into()));
} }
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result) callback.on_request(&result, response)
} else { } else {
Ok(None) Ok(response)
}; };
match callback_result { match callback_result {
Ok(extra_headers) => { Ok(response) => {
let response = reply(&result, extra_headers)?; let mut output = vec![];
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) write_response(&mut output, &response)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
Err(resp) => { Err(resp) => {
@ -196,17 +253,13 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
} }
self.error_code = Some(resp.status().as_u16()); self.error_code = Some(resp.status().as_u16());
let mut response = format!(
"{version:?} {status} {reason}\r\n", let mut output = vec![];
version = resp.version(), write_response(&mut output, &resp)?;
status = resp.status().as_u16(),
reason = resp.status().canonical_reason().unwrap_or("")
);
add_headers(&mut response, Some(resp.headers()))?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
response += &body; output.extend_from_slice(body.as_bytes());
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
} }
} }
@ -228,10 +281,8 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::reply; use super::create_response;
use super::{HeaderMap, Request}; use super::Request;
use http::header::HeaderName;
use http::Response;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
@ -252,27 +303,11 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = reply(&req, None).unwrap(); let response = create_response(&req).unwrap();
let extra_headers = {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(),
"MyCustomValue".parse().unwrap(),
);
headers.insert(
HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(),
"LOL".parse().unwrap(),
);
headers
};
let reply = reply(&req, Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!( assert_eq!(
req.headers().get("MyCustomHeader").unwrap(), response.headers().get("Sec-WebSocket-Accept").unwrap(),
b"MyCustomValue".as_ref() b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
); );
assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref());
} }
} }

Loading…
Cancel
Save