diff --git a/src/client.rs b/src/client.rs index 49ed656..b28cf3c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; -use http::Uri; +use http::{Uri, request::Parts}; use log::*; use url::Url; @@ -89,10 +89,12 @@ use crate::stream::{Mode, NoDelay}; pub fn connect_with_config( request: Req, config: Option, + max_redirects: u8, ) -> Result<(WebSocket, Response)> { - let mut request: Request = request.into_client_request()?; - fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + fn try_client_handshake(request: Request, config: Option) + -> Result<(WebSocket, Response)> + { let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -106,44 +108,39 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e{ + client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), }) } - let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); - let mut redirects = 0; - - loop { - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - - match inner(request, config) { - Ok(r) => return Ok(r), - Err(e) => match e { - Error::Http(res) => { - if res.status().is_redirection() { - let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; - debug!("Redirecting to {:?}", uri); - request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); - redirects += 1; - if redirects > max_redirects { - return Err(Error::Http(res)); - } - } else { - return Err(Error::Http(res)); - } - } - _ => return Err(e), + fn create_request(parts: &Parts, uri: &Uri) -> Request { + let mut builder = Request::builder() + .uri(uri.clone()) + .method(parts.method.clone()) + .version(parts.version.clone()); + *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); + builder.body(()).expect("Failed to create `Request`") + } + + let (parts, _) = request.into_client_request()?.into_parts(); + let mut uri = parts.uri.clone(); + + for attempt in 0..(max_redirects + 1) { + let request = create_request(&parts, &uri); + + match try_client_handshake(request, config) { + Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { + let location = res.headers().get("Location").ok_or(Error::NoLocation)?; + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; } + other => return other, } } + unreachable!("Bug in a redirect handling logic") } /// Connect to the given WebSocket in blocking mode. @@ -159,7 +156,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) + connect_with_config(request, None, 0) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 505dddd..8137393 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,9 +43,6 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, - /// The max number of redirects the client should follow before aborting the connection. - /// The default value is 3. `None` here means that the client will not attempt to follow redirects. - pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -54,7 +51,6 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), - max_redirects: Some(3) } } }