clean up http error handling

pull/148/head
Daniel Abramov 4 years ago
parent 521f1a0767
commit a8e06d2b39
  1. 6
      src/client.rs
  2. 10
      src/error.rs
  3. 17
      src/handshake/client.rs
  4. 13
      src/handshake/server.rs

@ -131,10 +131,14 @@ pub fn connect_with_config<Req: IntoClientRequest>(
match try_client_handshake(request, config) { match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
let location = res.headers().get("Location").ok_or(Error::NoLocation)?; if let Some(location) = res.headers().get("Location") {
uri = location.to_str()?.parse::<Uri>()?; uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri); debug!("Redirecting to {:?}", uri);
continue; continue;
} else {
warn!("No `Location` found in redirect");
return Err(Error::Http(res));
}
} }
other => return other, other => return other,
} }

@ -9,7 +9,7 @@ use std::str;
use std::string; use std::string;
use crate::protocol::Message; use crate::protocol::Message;
use http::{Response, StatusCode}; use http::Response;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
@ -61,12 +61,8 @@ pub enum Error {
Utf8, Utf8,
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error (status only).
HttpStatus(StatusCode),
/// HTTP error. /// HTTP error.
Http(Response<()>), Http(Response<Option<String>>),
/// No Location header in 3xx response
NoLocation,
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), HttpFormat(http::Error),
} }
@ -84,8 +80,6 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::NoLocation => write!(f, "No Location header specified"),
Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
} }

@ -90,12 +90,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result, result,
tail, tail,
} => { } => {
// If the status code received from the server is not 101, the let result = self.verify_data.verify_response(result)?;
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if result.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(result));
}
self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config); WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
@ -161,7 +156,13 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> { pub fn verify_response(&self, response: Response) -> Result<Response> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.map(|_| None)));
}
let headers = response.headers(); let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
@ -219,7 +220,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok(()) Ok(response)
} }
} }

@ -195,7 +195,7 @@ pub struct ServerHandshake<S, C> {
/// WebSocket configuration. /// WebSocket configuration.
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>, error_response: Option<ErrorResponse>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -212,7 +212,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
role: ServerHandshake { role: ServerHandshake {
callback: Some(callback), callback: Some(callback),
config, config,
error_code: None, error_response: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -259,22 +259,25 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
)); ));
} }
self.error_code = Some(resp.status().as_u16()); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap();
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &resp)?; write_response(&mut output, &resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend_from_slice(body.as_bytes());
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
} }
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_response.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); return Err(Error::Http(err));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);

Loading…
Cancel
Save