clean up the redirect logic a bit

pull/148/head
Daniel Abramov 4 years ago
parent 60f7b0f024
commit 521f1a0767
  1. 55
      src/client.rs
  2. 4
      src/protocol/mod.rs

@ -4,7 +4,7 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use http::Uri;
use http::{Uri, request::Parts};
use log::*;
use url::Url;
@ -89,10 +89,12 @@ use crate::stream::{Mode, NoDelay};
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> {
let mut request: Request = request.into_client_request()?;
fn inner(request: Request, config: Option<WebSocketConfig>) -> Result<(WebSocket<AutoStream>, Response)> {
fn try_client_handshake(request: Request, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
@ -112,38 +114,33 @@ pub fn connect_with_config<Req: IntoClientRequest>(
})
}
let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0);
let mut redirects = 0;
fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version.clone());
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
loop {
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code
// Have to manually clone Method because there is one field that contains a Box,
// but in the case of normal request methods it is Copy
let request2 = Request::builder()
.method(request.method().clone())
.version(request.version());
for attempt in 0..(max_redirects + 1) {
let request = create_request(&parts, &uri);
match inner(request, config) {
Ok(r) => return Ok(r),
Err(e) => match e {
Error::Http(res) => {
if res.status().is_redirection() {
let uri = res.headers().get("Location").ok_or(Error::NoLocation)?;
match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
let location = res.headers().get("Location").ok_or(Error::NoLocation)?;
uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri);
request = request2.uri(uri.to_str()?.parse::<Uri>()?).body(()).unwrap();
redirects += 1;
if redirects > max_redirects {
return Err(Error::Http(res));
}
} else {
return Err(Error::Http(res));
}
}
_ => return Err(e),
continue;
}
other => return other,
}
}
unreachable!("Bug in a redirect handling logic")
}
/// Connect to the given WebSocket in blocking mode.
@ -159,7 +156,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
connect_with_config(request, None, 0)
}
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {

@ -43,9 +43,6 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// The max number of redirects the client should follow before aborting the connection.
/// The default value is 3. `None` here means that the client will not attempt to follow redirects.
pub max_redirects: Option<u8>,
}
impl Default for WebSocketConfig {
@ -54,7 +51,6 @@ impl Default for WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
max_redirects: Some(3)
}
}
}

Loading…
Cancel
Save