client: overhaul of the request generation

pull/259/head
Daniel Abramov 3 years ago
parent 1b999136ef
commit d661f57224
  1. 29
      src/client.rs
  2. 5
      src/error.rs
  3. 159
      src/handshake/client.rs

@ -12,7 +12,7 @@ use log::*;
use url::Url; use url::Url;
use crate::{ use crate::{
handshake::client::{Request, Response}, handshake::client::{generate_key, Request, Response},
protocol::WebSocketConfig, protocol::WebSocketConfig,
stream::MaybeTlsStream, stream::MaybeTlsStream,
}; };
@ -178,7 +178,11 @@ where
/// Trait for converting various types into HTTP requests used for a client connection. /// Trait for converting various types into HTTP requests used for a client connection.
/// ///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and /// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`. /// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
/// simply take your request and pass it as is further without altering any headers or URLs, so
/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
/// the proper `http::Request<()>` for you.
pub trait IntoClientRequest { pub trait IntoClientRequest {
/// Convert into a `Request` that can be used for a client connection. /// Convert into a `Request` that can be used for a client connection.
fn into_client_request(self) -> Result<Request>; fn into_client_request(self) -> Result<Request>;
@ -210,7 +214,26 @@ impl<'a> IntoClientRequest for &'a Uri {
impl IntoClientRequest for Uri { impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request> { fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self).body(())?) let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
let host = authority
.find('@')
.map(|idx| authority.split_at(idx + 1).1)
.unwrap_or_else(|| authority);
if host.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
let req = Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.uri(self)
.body(())?;
Ok(req)
} }
} }

@ -167,9 +167,8 @@ pub enum ProtocolError {
/// Custom responses must be unsuccessful. /// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")] #[error("Custom response must not be successful")]
CustomResponseSuccessful, CustomResponseSuccessful,
/// Invalid header is passed. This header is formed by the library automatically /// Invalid header is passed. Or the header is missing in the request. Or not present at all. Check the request that you pass.
/// and must not be overwritten by the user. #[error("Missing, duplicated or incorrect header {0}")]
#[error("Not allowed to pass overwrite the standard header {0}")]
InvalidHeader(HeaderName), InvalidHeader(HeaderName),
/// No more data while still performing handshake. /// No more data while still performing handshake.
#[error("Handshake not finished")] #[error("Handshake not finished")]

@ -5,7 +5,9 @@ use std::{
marker::PhantomData, marker::PhantomData,
}; };
use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
@ -52,12 +54,11 @@ impl<S: Read + Write> ClientHandshake<S> {
// Check the URI scheme: only ws or wss are supported // Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?; let _ = crate::client::uri_mode(request.uri())?;
let key = generate_key(); // Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let machine = { let machine = HandshakeMachine::start_write(stream, request);
let req = generate_request(request, &key)?;
HandshakeMachine::start_write(stream, req)
};
let client = { let client = {
let accept_key = derive_accept_key(key.as_ref()); let accept_key = derive_accept_key(key.as_ref());
@ -92,56 +93,73 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
} }
} }
/// Generate client request. /// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
write!( write!(
req, req,
"\ "GET {path} {version:?}\r\n",
GET {path} {version:?}\r\n\ path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
Host: {host}\r\n\ version = request.version()
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
version = request.version(),
host = host,
path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
key = key
) )
.unwrap(); .unwrap();
for (k, v) in request.headers() { // Headers that must be present in a correct request.
if k == header::CONNECTION const KEY_HEADERNAME: &str = "Sec-WebSocket-Key";
|| k == header::UPGRADE const WEBSOCKET_HEADERS: [&str; 5] =
|| k == header::SEC_WEBSOCKET_VERSION ["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME];
|| k == header::SEC_WEBSOCKET_KEY
|| k == header::HOST // We must extract a WebSocket key from a properly formed request or fail if it's not present.
{ let key = request
.headers()
.get(KEY_HEADERNAME)
.ok_or_else(|| {
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(KEY_HEADERNAME)))
})?
.to_str()?
.to_owned();
// We must check that all necessary headers for a valid request are present. Note that we have to
// deal with the fact that some apps seem to have a case-sensitive check for headers which is not
// correct and should not considered the correct behavior, but it seems like some apps ignore it.
// `http` by default writes all headers in lower-case which is fine (and does not violate the RFC)
// but some servers seem to be poorely written and ignore RFC.
//
// See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492
let headers = request.headers_mut();
for header in WEBSOCKET_HEADERS {
let value = headers.remove(header).ok_or_else(|| {
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(header)))
})?;
write!(req, "{header}: {value}\r\n", value = value.to_str()?).unwrap();
}
// Now we must ensure that the headers that we've written once are not anymore present in the map.
// If they do, then the request is invalid (some headers are duplicated there for some reason).
let insensitive: Vec<String> =
WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect();
for (k, v) in headers {
let mut name = k.as_str();
// We have already written the necessary headers once (above) and removed them from the map.
// If we encounter them again, then the request is considered invalid and error is returned.
// Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust.
if insensitive.iter().any(|x| x == name) {
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
} }
let mut k = k.as_str();
if k == "sec-websocket-protocol" { // Relates to the issue of some servers treating headers in a case-sensitive way, please see:
k = "Sec-WebSocket-Protocol"; // https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem)
if name == "sec-websocket-protocol" {
name = "Sec-WebSocket-Protocol";
} }
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
} }
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req)); trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok(req) Ok((req, key))
} }
/// Information for handshake verification. /// Information for handshake verification.
@ -241,7 +259,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
} }
/// Generate a random key for the `Sec-WebSocket-Key` header. /// Generate a random key for the `Sec-WebSocket-Key` header.
fn generate_key() -> String { pub fn generate_key() -> String {
// a base64-encoded (see Section 4 of [RFC4648]) value that, // a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455) // when decoded, is 16 bytes in length (RFC 6455)
let r: [u8; 16] = rand::random(); let r: [u8; 16] = rand::random();
@ -269,54 +287,41 @@ mod tests {
assert!(k2[..22].find('=').is_none()); assert!(k2[..22].find('=').is_none());
} }
#[test] fn construct_expected(host: &str, key: &str) -> Vec<u8> {
fn request_formatting() { format!(
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); "\
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\ GET /getCaseCount HTTP/1.1\r\n\
Host: localhost\r\n\ Host: {host}\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ Sec-WebSocket-Key: {key}\r\n\
\r\n"; \r\n"
let request = generate_request(request, key).unwrap(); )
println!("Request: {}", String::from_utf8_lossy(&request)); .into_bytes()
}
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[test] #[test]
fn request_formatting_with_host() { fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let (request, key) = generate_request(request).unwrap();
let correct = b"\ let correct = construct_expected("localhost:9001", &key);
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost:9001\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }
#[test] #[test]
fn request_formatting_with_at() { fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let (request, key) = generate_request(request).unwrap();
let correct = b"\ let correct = construct_expected("localhost:9001", &key);
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost:9001\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]); assert_eq!(&request[..], &correct[..]);
} }

Loading…
Cancel
Save