diff --git a/src/client.rs b/src/client.rs index 9f8516c..49ed656 100644 --- a/src/client.rs +++ b/src/client.rs @@ -90,40 +90,60 @@ pub fn connect_with_config( request: Req, config: Option, ) -> Result<(WebSocket, Response)> { - let request: Request = request.into_client_request()?; - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - let uri = request.uri(); - let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = uri.port_u16().unwrap_or(match mode { - Mode::Plain => 80, - Mode::Tls => 443, - }); - let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; - NoDelay::set_nodelay(&mut stream, true)?; - match client_with_config(request, stream, config).map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) { - Ok(r) => Ok(r), - Err(e) => match e { - Error::Redirection(uri) => { - debug!("Redirecting to {}", uri); - let request = request2.uri(uri).body(()).unwrap(); - connect_with_config(request, config) + let mut request: Request = request.into_client_request()?; + + fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + let uri = request.uri(); + let mode = uri_mode(uri)?; + let host = request + .uri() + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + NoDelay::set_nodelay(&mut stream, true)?; + client_with_config(request, stream, config).map_err(|e| match e{ + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) + } + + let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); + let mut redirects = 0; + + loop { + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); + + match inner(request, config) { + Ok(r) => return Ok(r), + Err(e) => match e { + Error::Http(res) => { + if res.status().is_redirection() { + let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; + debug!("Redirecting to {:?}", uri); + request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); + redirects += 1; + if redirects > max_redirects { + return Err(Error::Http(res)); + } + } else { + return Err(Error::Http(res)); + } + } + _ => return Err(e), } - _ => Err(e), } } + } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 9df1a45..01edcb0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::Uri; +use http::{Response, StatusCode}; #[cfg(feature = "tls")] pub mod tls { @@ -61,10 +61,12 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), + /// HTTP error (status only). + HttpStatus(StatusCode), /// HTTP error. - Http(http::StatusCode), - /// HTTP 3xx redirection response - Redirection(Uri), + Http(Response<()>), + /// No Location header in 3xx response + NoLocation, /// HTTP format error. HttpFormat(http::Error), } @@ -82,8 +84,9 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP error: {}", code), - Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), + Error::NoLocation => write!(f, "No Location header specified"), + Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), + Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8b00338..e2ca308 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,6 +90,11 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { + // If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if result.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(result)); + } self.verify_data.verify_response(&result)?; debug!("Client handshake done."); let websocket = @@ -157,16 +162,6 @@ struct VerifyData { impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { - // 1. If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - if response.status().is_redirection() { - let value = response.headers().get("Location").unwrap(); - return Err(Error::Redirection(value.to_str()?.parse()?)) - } else { - return Err(Error::Http(response.status())); - } - } let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 9412a6f..406dc24 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -274,7 +274,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(StatusCode::from_u16(err)?)); + return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8137393..505dddd 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,6 +43,9 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, + /// The max number of redirects the client should follow before aborting the connection. + /// The default value is 3. `None` here means that the client will not attempt to follow redirects. + pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -51,6 +54,7 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), + max_redirects: Some(3) } } }