From d661f57224c4dfd1c46fe5cde6761637ae2f5c4d Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 3 Feb 2022 19:33:55 +0100 Subject: [PATCH] client: overhaul of the request generation --- src/client.rs | 29 +++++++- src/error.rs | 5 +- src/handshake/client.rs | 159 +++++++++++++++++++++------------------- 3 files changed, 110 insertions(+), 83 deletions(-) diff --git a/src/client.rs b/src/client.rs index 12dfe9b..2bc522a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,7 +12,7 @@ use log::*; use url::Url; use crate::{ - handshake::client::{Request, Response}, + handshake::client::{generate_key, Request, Response}, protocol::WebSocketConfig, stream::MaybeTlsStream, }; @@ -178,7 +178,11 @@ where /// Trait for converting various types into HTTP requests used for a client connection. /// /// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and -/// `http::Request<()>`. +/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will +/// simply take your request and pass it as is further without altering any headers or URLs, so +/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass +/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating +/// the proper `http::Request<()>` for you. pub trait IntoClientRequest { /// Convert into a `Request` that can be used for a client connection. fn into_client_request(self) -> Result; @@ -210,7 +214,26 @@ impl<'a> IntoClientRequest for &'a Uri { impl IntoClientRequest for Uri { fn into_client_request(self) -> Result { - Ok(Request::get(self).body(())?) + let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); + let host = authority + .find('@') + .map(|idx| authority.split_at(idx + 1).1) + .unwrap_or_else(|| authority); + + if host.is_empty() { + return Err(Error::Url(UrlError::EmptyHostName)); + } + + let req = Request::builder() + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()) + .uri(self) + .body(())?; + Ok(req) } } diff --git a/src/error.rs b/src/error.rs index a9c4ceb..c025080 100644 --- a/src/error.rs +++ b/src/error.rs @@ -167,9 +167,8 @@ pub enum ProtocolError { /// Custom responses must be unsuccessful. #[error("Custom response must not be successful")] CustomResponseSuccessful, - /// Invalid header is passed. This header is formed by the library automatically - /// and must not be overwritten by the user. - #[error("Not allowed to pass overwrite the standard header {0}")] + /// Invalid header is passed. Or the header is missing in the request. Or not present at all. Check the request that you pass. + #[error("Missing, duplicated or incorrect header {0}")] InvalidHeader(HeaderName), /// No more data while still performing handshake. #[error("Handshake not finished")] diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 36d6262..3bee72c 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -5,7 +5,9 @@ use std::{ marker::PhantomData, }; -use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; +use http::{ + header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, +}; use httparse::Status; use log::*; @@ -52,12 +54,11 @@ impl ClientHandshake { // Check the URI scheme: only ws or wss are supported let _ = crate::client::uri_mode(request.uri())?; - let key = generate_key(); + // Convert and verify the `http::Request` and turn it into the request as per RFC. + // Also extract the key from it (it must be present in a correct request). + let (request, key) = generate_request(request)?; - let machine = { - let req = generate_request(request, &key)?; - HandshakeMachine::start_write(stream, req) - }; + let machine = HandshakeMachine::start_write(stream, request); let client = { let accept_key = derive_accept_key(key.as_ref()); @@ -92,56 +93,73 @@ impl HandshakeRole for ClientHandshake { } } -/// Generate client request. -fn generate_request(request: Request, key: &str) -> Result> { +/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it. +fn generate_request(mut request: Request) -> Result<(Vec, String)> { let mut req = Vec::new(); - let uri = request.uri(); - - let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); - let host = if let Some(idx) = authority.find('@') { - // handle possible name:password@ - authority.split_at(idx + 1).1 - } else { - authority - }; - if authority.is_empty() { - return Err(Error::Url(UrlError::EmptyHostName)); - } - write!( req, - "\ - GET {path} {version:?}\r\n\ - Host: {host}\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n", - version = request.version(), - host = host, - path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), - key = key + "GET {path} {version:?}\r\n", + path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), + version = request.version() ) .unwrap(); - for (k, v) in request.headers() { - if k == header::CONNECTION - || k == header::UPGRADE - || k == header::SEC_WEBSOCKET_VERSION - || k == header::SEC_WEBSOCKET_KEY - || k == header::HOST - { + // Headers that must be present in a correct request. + const KEY_HEADERNAME: &str = "Sec-WebSocket-Key"; + const WEBSOCKET_HEADERS: [&str; 5] = + ["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME]; + + // We must extract a WebSocket key from a properly formed request or fail if it's not present. + let key = request + .headers() + .get(KEY_HEADERNAME) + .ok_or_else(|| { + Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(KEY_HEADERNAME))) + })? + .to_str()? + .to_owned(); + + // We must check that all necessary headers for a valid request are present. Note that we have to + // deal with the fact that some apps seem to have a case-sensitive check for headers which is not + // correct and should not considered the correct behavior, but it seems like some apps ignore it. + // `http` by default writes all headers in lower-case which is fine (and does not violate the RFC) + // but some servers seem to be poorely written and ignore RFC. + // + // See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492 + let headers = request.headers_mut(); + for header in WEBSOCKET_HEADERS { + let value = headers.remove(header).ok_or_else(|| { + Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(header))) + })?; + write!(req, "{header}: {value}\r\n", value = value.to_str()?).unwrap(); + } + + // Now we must ensure that the headers that we've written once are not anymore present in the map. + // If they do, then the request is invalid (some headers are duplicated there for some reason). + let insensitive: Vec = + WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect(); + for (k, v) in headers { + let mut name = k.as_str(); + + // We have already written the necessary headers once (above) and removed them from the map. + // If we encounter them again, then the request is considered invalid and error is returned. + // Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust. + if insensitive.iter().any(|x| x == name) { return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); } - let mut k = k.as_str(); - if k == "sec-websocket-protocol" { - k = "Sec-WebSocket-Protocol"; + + // Relates to the issue of some servers treating headers in a case-sensitive way, please see: + // https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem) + if name == "sec-websocket-protocol" { + name = "Sec-WebSocket-Protocol"; } - writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); + + writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap(); } + writeln!(req, "\r").unwrap(); trace!("Request: {:?}", String::from_utf8_lossy(&req)); - Ok(req) + Ok((req, key)) } /// Information for handshake verification. @@ -241,7 +259,7 @@ impl<'h, 'b: 'h> FromHttparse> for Response { } /// Generate a random key for the `Sec-WebSocket-Key` header. -fn generate_key() -> String { +pub fn generate_key() -> String { // a base64-encoded (see Section 4 of [RFC4648]) value that, // when decoded, is 16 bytes in length (RFC 6455) let r: [u8; 16] = rand::random(); @@ -269,54 +287,41 @@ mod tests { assert!(k2[..22].find('=').is_none()); } - #[test] - fn request_formatting() { - let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); - let key = "A70tsIbeMZUbJHh5BWFw6Q=="; - let correct = b"\ + fn construct_expected(host: &str, key: &str) -> Vec { + format!( + "\ GET /getCaseCount HTTP/1.1\r\n\ - Host: localhost\r\n\ + Host: {host}\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ - \r\n"; - let request = generate_request(request, key).unwrap(); - println!("Request: {}", String::from_utf8_lossy(&request)); + Sec-WebSocket-Key: {key}\r\n\ + \r\n" + ) + .into_bytes() + } + + #[test] + fn request_formatting() { + let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); + let (request, key) = generate_request(request).unwrap(); + let correct = construct_expected("localhost", &key); assert_eq!(&request[..], &correct[..]); } #[test] fn request_formatting_with_host() { let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); - let key = "A70tsIbeMZUbJHh5BWFw6Q=="; - let correct = b"\ - GET /getCaseCount HTTP/1.1\r\n\ - Host: localhost:9001\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ - \r\n"; - let request = generate_request(request, key).unwrap(); - println!("Request: {}", String::from_utf8_lossy(&request)); + let (request, key) = generate_request(request).unwrap(); + let correct = construct_expected("localhost:9001", &key); assert_eq!(&request[..], &correct[..]); } #[test] fn request_formatting_with_at() { let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); - let key = "A70tsIbeMZUbJHh5BWFw6Q=="; - let correct = b"\ - GET /getCaseCount HTTP/1.1\r\n\ - Host: localhost:9001\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ - \r\n"; - let request = generate_request(request, key).unwrap(); - println!("Request: {}", String::from_utf8_lossy(&request)); + let (request, key) = generate_request(request).unwrap(); + let correct = construct_expected("localhost:9001", &key); assert_eq!(&request[..], &correct[..]); }