From b2a477b30538036fe140514bbcbbddf49fb19c40 Mon Sep 17 00:00:00 2001 From: Ivan Nikulin Date: Wed, 14 Oct 2020 12:58:52 +0100 Subject: [PATCH] GH-46 Return the response in case of non-101 response from server Following redirects inside the lib (https://github.com/snapview/tungstenite-rs/pull/148) has few flows: there is no redirect loop prevention in this case, in case of using the lib in proxy it's impossible to return the upstream response to browser, etc. With this change response is propagated to the lib's user, so it can decide what to do with it in case of redirects: either send it to browser for it to follow redirects or implement redirect following on their side. --- src/error.rs | 3 +++ src/handshake/client.rs | 41 ++++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/error.rs b/src/error.rs index ab24753..dc05a12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,6 +65,8 @@ pub enum Error { Url(Cow<'static, str>), /// HTTP error. Http(http::StatusCode), + /// HTTP response error. + HttpResponse(http::Response<()>), /// HTTP format error. HttpFormat(http::Error), } @@ -83,6 +85,7 @@ impl fmt::Display for Error { Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::HttpResponse(ref res) => write!(f, "HTTP response error: {}", res.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..e9be75e 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,6 +90,12 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { + // If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if result.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::HttpResponse(result)); + } + self.verify_data.verify_response(&result)?; debug!("Client handshake done."); let websocket = @@ -105,16 +111,18 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority() + let authority = uri + .authority() .ok_or_else(|| Error::Url("No host name in the URL".into()))? .as_str(); - let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ + let host = if let Some(idx) = authority.find('@') { + // handle possible name:password@ authority.split_at(idx + 1).1 } else { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())) + return Err(Error::Url("URL contains empty host name".into())); } write!( @@ -138,7 +146,7 @@ fn generate_request(request: Request, key: &str) -> Result> { for (k, v) in request.headers() { let mut k = k.as_str(); - if k == "sec-websocket-protocol" { + if k == "sec-websocket-protocol" { k = "Sec-WebSocket-Protocol"; } writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); @@ -157,14 +165,9 @@ struct VerifyData { impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { - // 1. If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); - } let headers = response.headers(); - // 2. If the response lacks an |Upgrade| header field or the |Upgrade| + // 1. If the response lacks an |Upgrade| header field or the |Upgrade| // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) @@ -178,7 +181,7 @@ impl VerifyData { "No \"Upgrade: websocket\" in server reply".into(), )); } - // 3. If the response lacks a |Connection| header field or the + // 2. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) @@ -192,7 +195,7 @@ impl VerifyData { "No \"Connection: upgrade\" in server reply".into(), )); } - // 4. If the response lacks a |Sec-WebSocket-Accept| header field or + // 3. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) @@ -205,14 +208,14 @@ impl VerifyData { "Key mismatch in Sec-WebSocket-Accept".into(), )); } - // 5. If the response includes a |Sec-WebSocket-Extensions| header + // 4. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) // TODO - // 6. If the response includes a |Sec-WebSocket-Protocol| header field + // 5. If the response includes a |Sec-WebSocket-Protocol| header field // and this header field indicates the use of a subprotocol that was // not present in the client's handshake (the server has indicated a // subprotocol not requested by the client), the client MUST _Fail @@ -266,8 +269,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { use super::super::machine::TryParse; - use crate::client::IntoClientRequest; use super::{generate_key, generate_request, Response}; + use crate::client::IntoClientRequest; #[test] fn random_keys() { @@ -304,7 +307,9 @@ mod tests { #[test] fn request_formatting_with_host() { - let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ @@ -321,7 +326,9 @@ mod tests { #[test] fn request_formatting_with_at() { - let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://user:pass@localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\