Merge pull request #148 from Redrield/feature/follow-3xx

Add facilities to allow clients to follow HTTP 3xx redirects
pull/156/head
Daniel Abramov 4 years ago committed by GitHub
commit 2638bd69c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 43
      src/client.rs
  2. 5
      src/error.rs
  3. 9
      src/handshake/client.rs
  4. 13
      src/handshake/server.rs

@ -4,7 +4,7 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::Uri; use http::{Uri, request::Parts};
use log::*; use log::*;
use url::Url; use url::Url;
@ -89,8 +89,12 @@ use crate::stream::{Mode, NoDelay};
pub fn connect_with_config<Req: IntoClientRequest>( pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into_client_request()?;
fn try_client_handshake(request: Request, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request let host = request
@ -108,6 +112,39 @@ pub fn connect_with_config<Req: IntoClientRequest>(
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) })
}
fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version.clone());
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
for attempt in 0..(max_redirects + 1) {
let request = create_request(&parts, &uri);
match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
if let Some(location) = res.headers().get("Location") {
uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri);
continue;
} else {
warn!("No `Location` found in redirect");
return Err(Error::Http(res));
}
}
other => return other,
}
}
unreachable!("Bug in a redirect handling logic")
} }
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
@ -123,7 +160,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> { pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None) connect_with_config(request, None, 3)
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {

@ -9,6 +9,7 @@ use std::str;
use std::string; use std::string;
use crate::protocol::Message; use crate::protocol::Message;
use http::Response;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
@ -61,7 +62,7 @@ pub enum Error {
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error. /// HTTP error.
Http(http::StatusCode), Http(Response<Option<String>>),
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), HttpFormat(http::Error),
} }
@ -79,7 +80,7 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
} }
} }

@ -90,7 +90,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result, result,
tail, tail,
} => { } => {
self.verify_data.verify_response(&result)?; let result = self.verify_data.verify_response(result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config); WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
@ -156,12 +156,13 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> { pub fn verify_response(&self, response: Response) -> Result<Response> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status())); return Err(Error::Http(response.map(|_| None)));
} }
let headers = response.headers(); let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
@ -219,7 +220,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok(()) Ok(response)
} }
} }

@ -195,7 +195,7 @@ pub struct ServerHandshake<S, C> {
/// WebSocket configuration. /// WebSocket configuration.
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>, error_response: Option<ErrorResponse>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -212,7 +212,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
role: ServerHandshake { role: ServerHandshake {
callback: Some(callback), callback: Some(callback),
config, config,
error_code: None, error_response: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
@ -259,22 +259,25 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
)); ));
} }
self.error_code = Some(resp.status().as_u16()); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap();
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &resp)?; write_response(&mut output, &resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend_from_slice(body.as_bytes());
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
} }
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_response.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?)); return Err(Error::Http(err));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);

Loading…
Cancel
Save