From 6bce14fa26f8b73363a2baafe55ce28614a66469 Mon Sep 17 00:00:00 2001 From: Redrield Date: Thu, 1 Oct 2020 18:27:51 -0400 Subject: [PATCH 1/5] Add facilities to allow clients to follow HTTP 3xx redirects * The connect() function defined in this crate will automatically follow redirecting responses. * Adds Error::Redirection, which is a special case of Error::Http that extracts the redirection target from the response headers, and stores it in the error object. Client implementations that build upon tungstenite can use this to implement redirecting. * A catch-all solution for redirects is not possible due to the abstraction transforming socket types to Read + Write, implementations that use the client_* methods need to handle redirections themselves. --- src/client.rs | 20 ++++++++++++++++++-- src/error.rs | 4 ++++ src/handshake/client.rs | 7 ++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/client.rs b/src/client.rs index cba6109..9f8516c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -91,6 +91,12 @@ pub fn connect_with_config( config: Option, ) -> Result<(WebSocket, Response)> { let request: Request = request.into_client_request()?; + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -104,10 +110,20 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { + match client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) + }) { + Ok(r) => Ok(r), + Err(e) => match e { + Error::Redirection(uri) => { + debug!("Redirecting to {}", uri); + let request = request2.uri(uri).body(()).unwrap(); + connect_with_config(request, config) + } + _ => Err(e), + } + } } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 5c96dc7..9df1a45 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; +use http::Uri; #[cfg(feature = "tls")] pub mod tls { @@ -62,6 +63,8 @@ pub enum Error { Url(Cow<'static, str>), /// HTTP error. Http(http::StatusCode), + /// HTTP 3xx redirection response + Redirection(Uri), /// HTTP format error. HttpFormat(http::Error), } @@ -80,6 +83,7 @@ impl fmt::Display for Error { Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..8b00338 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -160,7 +160,12 @@ impl VerifyData { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); + if response.status().is_redirection() { + let value = response.headers().get("Location").unwrap(); + return Err(Error::Redirection(value.to_str()?.parse()?)) + } else { + return Err(Error::Http(response.status())); + } } let headers = response.headers(); From 60f7b0f02453435605007e03ce07ef18e438a827 Mon Sep 17 00:00:00 2001 From: Redrield Date: Mon, 26 Oct 2020 13:15:21 -0400 Subject: [PATCH 2/5] Fix some code-review issues * Replace Redirection error with a general Http error that owns the response * Make the default client connect function iterative instead of recursive * Add a limit to the amount of redirects a client will attempt to perform --- src/client.rs | 82 +++++++++++++++++++++++++---------------- src/error.rs | 15 +++++--- src/handshake/client.rs | 15 +++----- src/handshake/server.rs | 2 +- src/protocol/mod.rs | 4 ++ 5 files changed, 70 insertions(+), 48 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9f8516c..49ed656 100644 --- a/src/client.rs +++ b/src/client.rs @@ -90,40 +90,60 @@ pub fn connect_with_config( request: Req, config: Option, ) -> Result<(WebSocket, Response)> { - let request: Request = request.into_client_request()?; - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - let uri = request.uri(); - let mode = uri_mode(uri)?; - let host = request - .uri() - .host() - .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = uri.port_u16().unwrap_or(match mode { - Mode::Plain => 80, - Mode::Tls => 443, - }); - let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; - NoDelay::set_nodelay(&mut stream, true)?; - match client_with_config(request, stream, config).map_err(|e| match e { - HandshakeError::Failure(f) => f, - HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), - }) { - Ok(r) => Ok(r), - Err(e) => match e { - Error::Redirection(uri) => { - debug!("Redirecting to {}", uri); - let request = request2.uri(uri).body(()).unwrap(); - connect_with_config(request, config) + let mut request: Request = request.into_client_request()?; + + fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + let uri = request.uri(); + let mode = uri_mode(uri)?; + let host = request + .uri() + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + NoDelay::set_nodelay(&mut stream, true)?; + client_with_config(request, stream, config).map_err(|e| match e{ + HandshakeError::Failure(f) => f, + HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), + }) + } + + let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); + let mut redirects = 0; + + loop { + // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code + // Have to manually clone Method because there is one field that contains a Box, + // but in the case of normal request methods it is Copy + let request2 = Request::builder() + .method(request.method().clone()) + .version(request.version()); + + match inner(request, config) { + Ok(r) => return Ok(r), + Err(e) => match e { + Error::Http(res) => { + if res.status().is_redirection() { + let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; + debug!("Redirecting to {:?}", uri); + request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); + redirects += 1; + if redirects > max_redirects { + return Err(Error::Http(res)); + } + } else { + return Err(Error::Http(res)); + } + } + _ => return Err(e), } - _ => Err(e), } } + } /// Connect to the given WebSocket in blocking mode. diff --git a/src/error.rs b/src/error.rs index 9df1a45..01edcb0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::Uri; +use http::{Response, StatusCode}; #[cfg(feature = "tls")] pub mod tls { @@ -61,10 +61,12 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), + /// HTTP error (status only). + HttpStatus(StatusCode), /// HTTP error. - Http(http::StatusCode), - /// HTTP 3xx redirection response - Redirection(Uri), + Http(Response<()>), + /// No Location header in 3xx response + NoLocation, /// HTTP format error. HttpFormat(http::Error), } @@ -82,8 +84,9 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP error: {}", code), - Error::Redirection(ref uri) => write!(f, "HTTP redirection to: {}", uri), + Error::NoLocation => write!(f, "No Location header specified"), + Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), + Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8b00338..e2ca308 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,6 +90,11 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { + // If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if result.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(result)); + } self.verify_data.verify_response(&result)?; debug!("Client handshake done."); let websocket = @@ -157,16 +162,6 @@ struct VerifyData { impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { - // 1. If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - if response.status().is_redirection() { - let value = response.headers().get("Location").unwrap(); - return Err(Error::Redirection(value.to_str()?.parse()?)) - } else { - return Err(Error::Http(response.status())); - } - } let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 9412a6f..406dc24 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -274,7 +274,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(StatusCode::from_u16(err)?)); + return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8137393..505dddd 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,6 +43,9 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, + /// The max number of redirects the client should follow before aborting the connection. + /// The default value is 3. `None` here means that the client will not attempt to follow redirects. + pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -51,6 +54,7 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), + max_redirects: Some(3) } } } From 521f1a0767339a7e455cdc92580aa477900a50d7 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 16 Nov 2020 19:25:08 +0100 Subject: [PATCH 3/5] clean up the redirect logic a bit --- src/client.rs | 63 +++++++++++++++++++++------------------------ src/protocol/mod.rs | 4 --- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/src/client.rs b/src/client.rs index 49ed656..b28cf3c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; -use http::Uri; +use http::{Uri, request::Parts}; use log::*; use url::Url; @@ -89,10 +89,12 @@ use crate::stream::{Mode, NoDelay}; pub fn connect_with_config( request: Req, config: Option, + max_redirects: u8, ) -> Result<(WebSocket, Response)> { - let mut request: Request = request.into_client_request()?; - fn inner(request: Request, config: Option) -> Result<(WebSocket, Response)> { + fn try_client_handshake(request: Request, config: Option) + -> Result<(WebSocket, Response)> + { let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -106,44 +108,39 @@ pub fn connect_with_config( let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e{ + client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), }) } - let max_redirects = config.as_ref().and_then(|c| c.max_redirects).unwrap_or(0); - let mut redirects = 0; - - loop { - // Copy all the fields from the initial reqeust **except** the URI. This will be used in the event of a redirection code - // Have to manually clone Method because there is one field that contains a Box, - // but in the case of normal request methods it is Copy - let request2 = Request::builder() - .method(request.method().clone()) - .version(request.version()); - - match inner(request, config) { - Ok(r) => return Ok(r), - Err(e) => match e { - Error::Http(res) => { - if res.status().is_redirection() { - let uri = res.headers().get("Location").ok_or(Error::NoLocation)?; - debug!("Redirecting to {:?}", uri); - request = request2.uri(uri.to_str()?.parse::()?).body(()).unwrap(); - redirects += 1; - if redirects > max_redirects { - return Err(Error::Http(res)); - } - } else { - return Err(Error::Http(res)); - } - } - _ => return Err(e), + fn create_request(parts: &Parts, uri: &Uri) -> Request { + let mut builder = Request::builder() + .uri(uri.clone()) + .method(parts.method.clone()) + .version(parts.version.clone()); + *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone(); + builder.body(()).expect("Failed to create `Request`") + } + + let (parts, _) = request.into_client_request()?.into_parts(); + let mut uri = parts.uri.clone(); + + for attempt in 0..(max_redirects + 1) { + let request = create_request(&parts, &uri); + + match try_client_handshake(request, config) { + Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { + let location = res.headers().get("Location").ok_or(Error::NoLocation)?; + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; } + other => return other, } } + unreachable!("Bug in a redirect handling logic") } /// Connect to the given WebSocket in blocking mode. @@ -159,7 +156,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) + connect_with_config(request, None, 0) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 505dddd..8137393 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -43,9 +43,6 @@ pub struct WebSocketConfig { /// be reasonably big for all normal use-cases but small enough to prevent memory eating /// by a malicious user. pub max_frame_size: Option, - /// The max number of redirects the client should follow before aborting the connection. - /// The default value is 3. `None` here means that the client will not attempt to follow redirects. - pub max_redirects: Option, } impl Default for WebSocketConfig { @@ -54,7 +51,6 @@ impl Default for WebSocketConfig { max_send_queue: None, max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), - max_redirects: Some(3) } } } From a8e06d2b39c6c1f9091ce7af2adf2510575ab9d3 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 16 Nov 2020 19:49:17 +0100 Subject: [PATCH 4/5] clean up http error handling --- src/client.rs | 12 ++++++++---- src/error.rs | 10 ++-------- src/handshake/client.rs | 17 +++++++++-------- src/handshake/server.rs | 13 ++++++++----- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/client.rs b/src/client.rs index b28cf3c..0b70af3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -131,10 +131,14 @@ pub fn connect_with_config( match try_client_handshake(request, config) { Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { - let location = res.headers().get("Location").ok_or(Error::NoLocation)?; - uri = location.to_str()?.parse::()?; - debug!("Redirecting to {:?}", uri); - continue; + if let Some(location) = res.headers().get("Location") { + uri = location.to_str()?.parse::()?; + debug!("Redirecting to {:?}", uri); + continue; + } else { + warn!("No `Location` found in redirect"); + return Err(Error::Http(res)); + } } other => return other, } diff --git a/src/error.rs b/src/error.rs index 01edcb0..b2657cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use std::str; use std::string; use crate::protocol::Message; -use http::{Response, StatusCode}; +use http::Response; #[cfg(feature = "tls")] pub mod tls { @@ -61,12 +61,8 @@ pub enum Error { Utf8, /// Invalid URL. Url(Cow<'static, str>), - /// HTTP error (status only). - HttpStatus(StatusCode), /// HTTP error. - Http(Response<()>), - /// No Location header in 3xx response - NoLocation, + Http(Response>), /// HTTP format error. HttpFormat(http::Error), } @@ -84,8 +80,6 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::NoLocation => write!(f, "No Location header specified"), - Error::HttpStatus(ref status) => write!(f, "HTTP error code: {}", status), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e2ca308..bb159d7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,12 +90,7 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { - // If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if result.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(result)); - } - self.verify_data.verify_response(&result)?; + let result = self.verify_data.verify_response(result)?; debug!("Client handshake done."); let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, self.config); @@ -161,7 +156,13 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: Response) -> Result { + // 1. If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.map(|_| None))); + } + let headers = response.headers(); // 2. If the response lacks an |Upgrade| header field or the |Upgrade| @@ -219,7 +220,7 @@ impl VerifyData { // the WebSocket Connection_. (RFC 6455) // TODO - Ok(()) + Ok(response) } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 406dc24..15f6b14 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -195,7 +195,7 @@ pub struct ServerHandshake { /// WebSocket configuration. config: Option, /// Error code/flag. If set, an error will be returned after sending response to the client. - error_code: Option, + error_response: Option, /// Internal stream type. _marker: PhantomData, } @@ -212,7 +212,7 @@ impl ServerHandshake { role: ServerHandshake { callback: Some(callback), config, - error_code: None, + error_response: None, _marker: PhantomData, }, } @@ -259,22 +259,25 @@ impl HandshakeRole for ServerHandshake { )); } - self.error_code = Some(resp.status().as_u16()); + self.error_response = Some(resp); + let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } StageResult::DoneWriting(stream) => { - if let Some(err) = self.error_code.take() { + if let Some(err) = self.error_response.take() { debug!("Server handshake failed."); - return Err(Error::HttpStatus(StatusCode::from_u16(err)?)); + return Err(Error::Http(err)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); From 09f5d34899416d950f9f51611c1757b189cd1ebd Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 17 Nov 2020 11:17:56 +0100 Subject: [PATCH 5/5] use 3 redirects as default for `connect` --- src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 0b70af3..1b20980 100644 --- a/src/client.rs +++ b/src/client.rs @@ -160,7 +160,7 @@ pub fn connect_with_config( /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None, 0) + connect_with_config(request, None, 3) } fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result {