Fix some code-review issues

* Replace Redirection error with a general Http error that owns the
response
* Make the default client connect function iterative instead of
recursive
* Add a limit to the amount of redirects a client will attempt to
perform
pull/148/head
Redrield 4 years ago committed by Daniel Abramov
parent 6bce14fa26
commit 60f7b0f024
  1. 50
      src/client.rs
  2. 15
      src/error.rs
  3. 15
      src/handshake/client.rs
  4. 2
      src/handshake/server.rs
  5. 4
      src/protocol/mod.rs

@ -90,13 +90,9 @@ pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into_client_request()?; let mut request: Request = request.into_client_request()?;
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code
// Have to manually clone Method because there is one field that contains a Box, fn inner(request: Request, config: Option<WebSocketConfig>) -> Result<(WebSocket<AutoStream>, Response)> {
// but in the case of normal request methods it is Copy
let request2 = Request::builder()
.method(request.method().clone())
.version(request.version());
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request let host = request
@ -110,22 +106,46 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let addrs = (host, port).to_socket_addrs()?; let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
match client_with_config(request, stream, config).map_err(|e| match e { client_with_config(request, stream, config).map_err(|e| match e{
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) { })
Ok(r) => Ok(r), }
let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0);
let mut redirects = 0;
loop {
// Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code
// Have to manually clone Method because there is one field that contains a Box,
// but in the case of normal request methods it is Copy
let request2 = Request::builder()
.method(request.method().clone())
.version(request.version());
match inner(request, config) {
Ok(r) => return Ok(r),
Err(e) => match e { Err(e) => match e {
Error::Redirection(uri) => { Error::Http(res) => {
debug!("Redirecting to {}", uri); if res.status().is_redirection() {
let request = request2.uri(uri).body(()).unwrap(); let uri = res.headers().get("Location").ok_or(Error::NoLocation)?;
connect_with_config(request, config) debug!("Redirecting to {:?}", uri);
request = request2.uri(uri.to_str()?.parse::<Uri>()?).body(()).unwrap();
redirects += 1;
if redirects > max_redirects {
return Err(Error::Http(res));
}
} else {
return Err(Error::Http(res));
}
} }
_ => Err(e), _ => return Err(e),
} }
} }
} }
}
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
/// ///
/// The URL may be either ws:// or wss://. /// The URL may be either ws:// or wss://.

@ -9,7 +9,7 @@ use std::str;
use std::string; use std::string;
use crate::protocol::Message; use crate::protocol::Message;
use http::Uri; use http::{Response, StatusCode};
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
@ -61,10 +61,12 @@ pub enum Error {
Utf8, Utf8,
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error (status only).
HttpStatus(StatusCode),
/// HTTP error. /// HTTP error.
Http(http::StatusCode), Http(Response<()>),
/// HTTP 3xx redirection response /// No Location header in 3xx response
Redirection(Uri), NoLocation,
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), HttpFormat(http::Error),
} }
@ -82,8 +84,9 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code), Error::NoLocation => write!(f, "No Location header specified"),
Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
} }
} }

@ -90,6 +90,11 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
result, result,
tail, tail,
} => { } => {
// If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if result.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(result));
}
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket =
@ -157,16 +162,6 @@ struct VerifyData {
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> { pub fn verify_response(&self, response: &Response) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
if response.status().is_redirection() {
let value = response.headers().get("Location").unwrap();
return Err(Error::Redirection(value.to_str()?.parse()?))
} else {
return Err(Error::Http(response.status()));
}
}
let headers = response.headers(); let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|

@ -274,7 +274,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_code.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?)); return Err(Error::HttpStatus(StatusCode::from_u16(err)?));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);

@ -43,6 +43,9 @@ pub struct WebSocketConfig {
/// be reasonably big for all normal use-cases but small enough to prevent memory eating /// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user. /// by a malicious user.
pub max_frame_size: Option<usize>, pub max_frame_size: Option<usize>,
/// The max number of redirects the client should follow before aborting the connection.
/// The default value is 3. `None` here means that the client will not attempt to follow redirects.
pub max_redirects: Option<u8>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
@ -51,6 +54,7 @@ impl Default for WebSocketConfig {
max_send_queue: None, max_send_queue: None,
max_message_size: Some(64 << 20), max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20), max_frame_size: Some(16 << 20),
max_redirects: Some(3)
} }
} }
} }

Loading…
Cancel
Save