client: overhaul of the request generation

pull/259/head
Daniel Abramov 3 years ago
parent 1b999136ef
commit d661f57224
  1. 29
      src/client.rs
  2. 5
      src/error.rs
  3. 159
      src/handshake/client.rs

@ -12,7 +12,7 @@ use log::*;
use url::Url;
use crate::{
handshake::client::{Request, Response},
handshake::client::{generate_key, Request, Response},
protocol::WebSocketConfig,
stream::MaybeTlsStream,
};
@ -178,7 +178,11 @@ where
/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
/// simply take your request and pass it as is further without altering any headers or URLs, so
/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
/// the proper `http::Request<()>` for you.
pub trait IntoClientRequest {
/// Convert into a `Request` that can be used for a client connection.
fn into_client_request(self) -> Result<Request>;
@ -210,7 +214,26 @@ impl<'a> IntoClientRequest for &'a Uri {
impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self).body(())?)
let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
let host = authority
.find('@')
.map(|idx| authority.split_at(idx + 1).1)
.unwrap_or_else(|| authority);
if host.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
let req = Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.uri(self)
.body(())?;
Ok(req)
}
}

@ -167,9 +167,8 @@ pub enum ProtocolError {
/// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")]
CustomResponseSuccessful,
/// Invalid header is passed. This header is formed by the library automatically
/// and must not be overwritten by the user.
#[error("Not allowed to pass overwrite the standard header {0}")]
/// Invalid header is passed. Or the header is missing in the request. Or not present at all. Check the request that you pass.
#[error("Missing, duplicated or incorrect header {0}")]
InvalidHeader(HeaderName),
/// No more data while still performing handshake.
#[error("Handshake not finished")]

@ -5,7 +5,9 @@ use std::{
marker::PhantomData,
};
use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
use httparse::Status;
use log::*;
@ -52,12 +54,11 @@ impl<S: Read + Write> ClientHandshake<S> {
// Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?;
let key = generate_key();
// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let machine = {
let req = generate_request(request, &key)?;
HandshakeMachine::start_write(stream, req)
};
let machine = HandshakeMachine::start_write(stream, request);
let client = {
let accept_key = derive_accept_key(key.as_ref());
@ -92,56 +93,73 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
}
/// Generate client request.
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
write!(
req,
"\
GET {path} {version:?}\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
version = request.version(),
host = host,
path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
key = key
"GET {path} {version:?}\r\n",
path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
version = request.version()
)
.unwrap();
for (k, v) in request.headers() {
if k == header::CONNECTION
|| k == header::UPGRADE
|| k == header::SEC_WEBSOCKET_VERSION
|| k == header::SEC_WEBSOCKET_KEY
|| k == header::HOST
{
// Headers that must be present in a correct request.
const KEY_HEADERNAME: &str = "Sec-WebSocket-Key";
const WEBSOCKET_HEADERS: [&str; 5] =
["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME];
// We must extract a WebSocket key from a properly formed request or fail if it's not present.
let key = request
.headers()
.get(KEY_HEADERNAME)
.ok_or_else(|| {
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(KEY_HEADERNAME)))
})?
.to_str()?
.to_owned();
// We must check that all necessary headers for a valid request are present. Note that we have to
// deal with the fact that some apps seem to have a case-sensitive check for headers which is not
// correct and should not considered the correct behavior, but it seems like some apps ignore it.
// `http` by default writes all headers in lower-case which is fine (and does not violate the RFC)
// but some servers seem to be poorely written and ignore RFC.
//
// See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492
let headers = request.headers_mut();
for header in WEBSOCKET_HEADERS {
let value = headers.remove(header).ok_or_else(|| {
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(header)))
})?;
write!(req, "{header}: {value}\r\n", value = value.to_str()?).unwrap();
}
// Now we must ensure that the headers that we've written once are not anymore present in the map.
// If they do, then the request is invalid (some headers are duplicated there for some reason).
let insensitive: Vec<String> =
WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect();
for (k, v) in headers {
let mut name = k.as_str();
// We have already written the necessary headers once (above) and removed them from the map.
// If we encounter them again, then the request is considered invalid and error is returned.
// Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust.
if insensitive.iter().any(|x| x == name) {
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
}
let mut k = k.as_str();
if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol";
// Relates to the issue of some servers treating headers in a case-sensitive way, please see:
// https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem)
if name == "sec-websocket-protocol" {
name = "Sec-WebSocket-Protocol";
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
}
writeln!(req, "\r").unwrap();
trace!("Request: {:?}", String::from_utf8_lossy(&req));
Ok(req)
Ok((req, key))
}
/// Information for handshake verification.
@ -241,7 +259,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
}
/// Generate a random key for the `Sec-WebSocket-Key` header.
fn generate_key() -> String {
pub fn generate_key() -> String {
// a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455)
let r: [u8; 16] = rand::random();
@ -269,54 +287,41 @@ mod tests {
assert!(k2[..22].find('=').is_none());
}
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
fn construct_expected(host: &str, key: &str) -> Vec<u8> {
format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
Sec-WebSocket-Key: {key}\r\n\
\r\n"
)
.into_bytes()
}
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost:9001\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Host: localhost:9001\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
let (request, key) = generate_request(request).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}

Loading…
Cancel
Save