diff --git a/src/client.rs b/src/client.rs index cba6109..9f8516c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -91,6 +91,12 @@ pub fn connect_with_config( config: Option, ) -> Result<(WebSocket, Response)> { let request: Request = request.into_client_request()?; + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -104,10 +110,20 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { + match client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + }) { + Ok(r) => Ok(r), + Err(e) => match e { + Error::Redirection(uri) => { + debug!("Redirecting to {}", uri); + let request = request2.uri(uri).body(()).unwrap(); + connect_with_config(request, config) + } + _ => Err(e), + } + } } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 5c96dc7..9df1a45 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; +use http::Uri; #[cfg(feature = "tls")] pub mod tls { @@ -62,6 +63,8 @@ pub enum Error { Url(Cow<'static, str>), /// HTTP error. Http(http::StatusCode), + /// HTTP 3xx redirection response + Redirection(Uri), /// HTTP format error. HttpFormat(http::Error), } @@ -80,6 +83,7 @@ impl fmt::Display for Error { Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..8b00338 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -160,7 +160,12 @@ impl VerifyData { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); + if response.status().is_redirection() { + let value = response.headers().get("Location").unwrap(); + return Err(Error::Redirection(value.to_str()?.parse()?)) + } else { + return Err(Error::Http(response.status())); + } } let headers = response.headers();