clean up the redirect logic a bit

pull/148/head
Daniel Abramov 4 years ago
parent 60f7b0f024
commit 521f1a0767
  1. 63
      src/client.rs
  2. 4
      src/protocol/mod.rs

@ -4,7 +4,7 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::Uri; use http::{Uri, request::Parts};
use log::*; use log::*;
use url::Url; use url::Url;
@ -89,10 +89,12 @@ use crate::stream::{Mode, NoDelay};
pub fn connect_with_config<Req: IntoClientRequest>( pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let mut request: Request = request.into_client_request()?;
fn inner(request: Request, config: Option<WebSocketConfig>) -> Result<(WebSocket<AutoStream>, Response)> { fn try_client_handshake(request: Request, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request let host = request
@ -106,44 +108,39 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let addrs = (host, port).to_socket_addrs()?; let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e{ client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) })
} }
let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut redirects = 0; let mut builder = Request::builder()
.uri(uri.clone())
loop { .method(parts.method.clone())
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code .version(parts.version.clone());
// Have to manually clone Method because there is one field that contains a Box, *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
// but in the case of normal request methods it is Copy builder.body(()).expect("Failed to create `Request`")
let request2 = Request::builder() }
.method(request.method().clone())
.version(request.version()); let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
match inner(request, config) {
Ok(r) => return Ok(r), for attempt in 0..(max_redirects + 1) {
Err(e) => match e { let request = create_request(&parts, &uri);
Error::Http(res) => {
if res.status().is_redirection() { match try_client_handshake(request, config) {
let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
debug!("Redirecting to {:?}", uri); let location = res.headers().get("Location").ok_or(Error::NoLocation)?;
request = request2.uri(uri.to_str()?.parse::<Uri>()?).body(()).unwrap(); uri = location.to_str()?.parse::<Uri>()?;
redirects += 1; debug!("Redirecting to {:?}", uri);
if redirects > max_redirects { continue;
return Err(Error::Http(res));
}
} else {
return Err(Error::Http(res));
}
}
_ => return Err(e),
} }
other => return other,
} }
} }
unreachable!("Bug in a redirect handling logic")
} }
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
@ -159,7 +156,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> { pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None) connect_with_config(request, None, 0)
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {

@ -43,9 +43,6 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating /// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user. /// by a malicious user.
pub max_frame_size: Option<usize>, pub max_frame_size: Option<usize>,
/// The max number of redirects the client should follow before aborting the connection.
/// The default value is 3. `None` here means that the client will not attempt to follow redirects.
pub max_redirects: Option<u8>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
@ -54,7 +51,6 @@ impl Default for WebSocketConfig {
max_send_queue: None, max_send_queue: None,
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
max_redirects: Some(3)
} }
} }
} }

Loading…
Cancel
Save