Fix some code-review issues

* Replace Redirection error with a general Http error that owns the
response
* Make the default client connect function iterative instead of
recursive
* Add a limit to the amount of redirects a client will attempt to
perform
pull/148/head
Redrield 4 years ago committed by Daniel Abramov
parent 6bce14fa26
commit 60f7b0f024
  1. 82
      src/client.rs
  2. 15
      src/error.rs
  3. 15
      src/handshake/client.rs
  4. 2
      src/handshake/server.rs
  5. 4
      src/protocol/mod.rs

@ -90,40 +90,60 @@ pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into_client_request()?;
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code
// Have to manually clone Method because there is one field that contains a Box,
// but in the case of normal request methods it is Copy
let request2 = Request::builder()
.method(request.method().clone())
.version(request.version());
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
match client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) {
Ok(r) => Ok(r),
Err(e) => match e {
Error::Redirection(uri) => {
debug!("Redirecting to {}", uri);
let request = request2.uri(uri).body(()).unwrap();
connect_with_config(request, config)
let mut request: Request = request.into_client_request()?;
fn inner(request: Request, config: Option<WebSocketConfig>) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e{
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0);
let mut redirects = 0;
loop {
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code
// Have to manually clone Method because there is one field that contains a Box,
// but in the case of normal request methods it is Copy
let request2 = Request::builder()
.method(request.method().clone())
.version(request.version());
match inner(request, config) {
Ok(r) => return Ok(r),
Err(e) => match e {
Error::Http(res) => {
if res.status().is_redirection() {
let uri = res.headers().get("Location").ok_or(Error::NoLocation)?;
debug!("Redirecting to {:?}", uri);
request = request2.uri(uri.to_str()?.parse::<Uri>()?).body(()).unwrap();
redirects += 1;
if redirects > max_redirects {
return Err(Error::Http(res));
}
} else {
return Err(Error::Http(res));
}
}
_ => return Err(e),
}
_ => Err(e),
}
}
}
/// Connect to the given WebSocket in blocking mode.

@ -9,7 +9,7 @@ use std::str;
use std::string;
use crate::protocol::Message;
use http::Uri;
use http::{Response, StatusCode};
#[cfg(feature = "tls")]
pub mod tls {
@ -61,10 +61,12 @@ pub enum Error {
Utf8,
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error (status only).
HttpStatus(StatusCode),
/// HTTP error.
Http(http::StatusCode),
/// HTTP 3xx redirection response
Redirection(Uri),
Http(Response<()>),
/// No Location header in 3xx response
NoLocation,
/// HTTP format error.
HttpFormat(http::Error),
}
@ -82,8 +84,9 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri),
Error::NoLocation => write!(f, "No Location header specified"),
Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}

@ -90,6 +90,11 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result,
tail,
} => {
// If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if result.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(result));
}
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
let websocket =
@ -157,16 +162,6 @@ struct VerifyData {
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
if response.status().is_redirection() {
let value = response.headers().get("Location").unwrap();
return Err(Error::Redirection(value.to_str()?.parse()?))
} else {
return Err(Error::Http(response.status()));
}
}
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|

@ -274,7 +274,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() {
debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?));
return Err(Error::HttpStatus(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);

@ -43,6 +43,9 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// The max number of redirects the client should follow before aborting the connection.
/// The default value is 3. `None` here means that the client will not attempt to follow redirects.
pub max_redirects: Option<u8>,
}
impl Default for WebSocketConfig {
@ -51,6 +54,7 @@ impl Default for WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
max_redirects: Some(3)
}
}
}

Loading…
Cancel
Save