diff --git a/src/client.rs b/src/client.rs index cba6109..1b20980 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; -use http::Uri; +use http::{Uri, request::Parts}; use log::*; use url::Url; @@ -89,25 +89,62 @@ use crate::stream::{Mode, NoDelay}; pub fn connect_with_config( request: Req, config: Option, + max_redirects: u8, ) -> Result<(WebSocket, Response)> { - let request: Request = request.into_client_request()?; - let uri = request.uri(); - let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = uri.port_u16().unwrap_or(match mode { - Mode::Plain => 80, - Mode::Tls => 443, - }); - let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; - NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + + fn try_client_handshake(request: Request, config: Option) + -> Result<(WebSocket, Response)> + { + let uri = request.uri(); + let mode = uri_mode(uri)?; + let host = request + .uri() + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + NoDelay::set_nodelay(&mut stream, true)?; + client_with_config(request, stream, config).map_err(|e| match e { + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) + } + + fn create_request(parts: &Parts, uri: &Uri) -> Request { + let mut builder = Request::builder() + .uri(uri.clone()) + .method(parts.method.clone()) + .version(parts.version.clone()); + *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); + builder.body(()).expect("Failed to create `Request`") + } + + let (parts, _) = request.into_client_request()?.into_parts(); + let mut uri = parts.uri.clone(); + + for attempt in 0..(max_redirects + 1) { + let request = create_request(&parts, &uri); + + match try_client_handshake(request, config) { + Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { + if let Some(location) = res.headers().get("Location") { + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; + } else { + warn!("No `Location` found in redirect"); + return Err(Error::Http(res)); + } + } + other => return other, + } + } + + unreachable!("Bug in a redirect handling logic") } /// Connect to the given WebSocket in blocking mode. @@ -123,7 +160,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) + connect_with_config(request, None, 3) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { diff --git a/src/error.rs b/src/error.rs index 5c96dc7..b2657cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; +use http::Response; #[cfg(feature = "tls")] pub mod tls { @@ -61,7 +62,7 @@ pub enum Error { /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. - Http(http::StatusCode), + Http(Response>), /// HTTP format error. HttpFormat(http::Error), } @@ -79,7 +80,7 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..bb159d7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,7 +90,7 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { - self.verify_data.verify_response(&result)?; + let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, self.config); @@ -156,12 +156,13 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: Response) -> Result { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); + return Err(Error::Http(response.map(|_| None))); } + let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| @@ -219,7 +220,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok(()) + Ok(response) } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 9412a6f..15f6b14 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -195,7 +195,7 @@ pub struct ServerHandshake { /// WebSocket configuration. config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. - error_code: Option, + error_response: Option, /// Internal stream type. _marker: PhantomData, } @@ -212,7 +212,7 @@ impl ServerHandshake { role: ServerHandshake { callback: Some(callback), config, - error_code: None, + error_response: None, _marker: PhantomData, }, } @@ -259,22 +259,25 @@ impl HandshakeRole for ServerHandshake { )); } - self.error_code = Some(resp.status().as_u16()); + self.error_response = Some(resp); + let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } StageResult::DoneWriting(stream) => { - if let Some(err) = self.error_code.take() { + if let Some(err) = self.error_response.take() { debug!("Server handshake failed."); - return Err(Error::Http(StatusCode::from_u16(err)?)); + return Err(Error::Http(err)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);