From a8e06d2b39c6c1f9091ce7af2adf2510575ab9d3 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 16 Nov 2020 19:49:17 +0100 Subject: [PATCH] clean up http error handling --- src/client.rs | 12 ++++++++---- src/error.rs | 10 ++-------- src/handshake/client.rs | 17 +++++++++-------- src/handshake/server.rs | 13 ++++++++----- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/client.rs b/src/client.rs index b28cf3c..0b70af3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -131,10 +131,14 @@ pub fn connect_with_config( match try_client_handshake(request, config) { Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { - let location = res.headers().get("Location").ok_or(Error::NoLocation)?; - uri = location.to_str()?.parse::()?; - debug!("Redirecting to {:?}", uri); - continue; + if let Some(location) = res.headers().get("Location") { + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; + } else { + warn!("No `Location` found in redirect"); + return Err(Error::Http(res)); + } } other => return other, } diff --git a/src/error.rs b/src/error.rs index 01edcb0..b2657cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::{Response, StatusCode}; +use http::Response; #[cfg(feature = "tls")] pub mod tls { @@ -61,12 +61,8 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), - /// HTTP error (status only). - HttpStatus(StatusCode), /// HTTP error. - Http(Response<()>), - /// No Location header in 3xx response - NoLocation, + Http(Response>), /// HTTP format error. HttpFormat(http::Error), } @@ -84,8 +80,6 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::NoLocation => write!(f, "No Location header specified"), - Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e2ca308..bb159d7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,12 +90,7 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { - // If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if result.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(result)); - } - self.verify_data.verify_response(&result)?; + let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, self.config); @@ -161,7 +156,13 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: Response) -> Result { + // 1. If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.map(|_| None))); + } + let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| @@ -219,7 +220,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok(()) + Ok(response) } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 406dc24..15f6b14 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -195,7 +195,7 @@ pub struct ServerHandshake { /// WebSocket configuration. config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. - error_code: Option, + error_response: Option, /// Internal stream type. _marker: PhantomData, } @@ -212,7 +212,7 @@ impl ServerHandshake { role: ServerHandshake { callback: Some(callback), config, - error_code: None, + error_response: None, _marker: PhantomData, }, } @@ -259,22 +259,25 @@ impl HandshakeRole for ServerHandshake { )); } - self.error_code = Some(resp.status().as_u16()); + self.error_response = Some(resp); + let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } StageResult::DoneWriting(stream) => { - if let Some(err) = self.error_code.take() { + if let Some(err) = self.error_response.take() { debug!("Server handshake failed."); - return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); + return Err(Error::Http(err)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);