Merge pull request #148 from Redrield/feature/follow-3xx

Add facilities to allow clients to follow HTTP 3xx redirects
pull/156/head
Daniel Abramov 4 years ago committed by GitHub
commit 2638bd69c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 77
      src/client.rs
  2. 5
      src/error.rs
  3. 9
      src/handshake/client.rs
  4. 13
      src/handshake/server.rs

@ -4,7 +4,7 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use http::Uri;
use http::{Uri, request::Parts};
use log::*;
use url::Url;
@ -89,25 +89,62 @@ use crate::stream::{Mode, NoDelay};
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
fn try_client_handshake(request: Request, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version.clone());
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
for attempt in 0..(max_redirects + 1) {
let request = create_request(&parts, &uri);
match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
if let Some(location) = res.headers().get("Location") {
uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri);
continue;
} else {
warn!("No `Location` found in redirect");
return Err(Error::Http(res));
}
}
other => return other,
}
}
unreachable!("Bug in a redirect handling logic")
}
/// Connect to the given WebSocket in blocking mode.
@ -123,7 +160,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
connect_with_config(request, None, 3)
}
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {

@ -9,6 +9,7 @@ use std::str;
use std::string;
use crate::protocol::Message;
use http::Response;
#[cfg(feature = "tls")]
pub mod tls {
@ -61,7 +62,7 @@ pub enum Error {
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(http::StatusCode),
Http(Response<Option<String>>),
/// HTTP format error.
HttpFormat(http::Error),
}
@ -79,7 +80,7 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}

@ -90,7 +90,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result,
tail,
} => {
self.verify_data.verify_response(&result)?;
let result = self.verify_data.verify_response(result)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
@ -156,12 +156,13 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
pub fn verify_response(&self, response: Response) -> Result<Response> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status()));
return Err(Error::Http(response.map(|_| None)));
}
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
@ -219,7 +220,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(())
Ok(response)
}
}

@ -195,7 +195,7 @@ pub struct ServerHandshake<S, C> {
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
error_response: Option<ErrorResponse>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -212,7 +212,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
role: ServerHandshake {
callback: Some(callback),
config,
error_code: None,
error_response: None,
_marker: PhantomData,
},
}
@ -259,22 +259,25 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
));
}
self.error_code = Some(resp.status().as_u16());
self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap();
let mut output = vec![];
write_response(&mut output, &resp)?;
if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes());
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
}
}
}
StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() {
if let Some(err) = self.error_response.take() {
debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?));
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);

Loading…
Cancel
Save