|
|
@ -5,7 +5,9 @@ use std::{ |
|
|
|
marker::PhantomData, |
|
|
|
marker::PhantomData, |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; |
|
|
|
use http::{ |
|
|
|
|
|
|
|
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode, |
|
|
|
|
|
|
|
}; |
|
|
|
use httparse::Status; |
|
|
|
use httparse::Status; |
|
|
|
use log::*; |
|
|
|
use log::*; |
|
|
|
|
|
|
|
|
|
|
@ -52,12 +54,11 @@ impl<S: Read + Write> ClientHandshake<S> { |
|
|
|
// Check the URI scheme: only ws or wss are supported
|
|
|
|
// Check the URI scheme: only ws or wss are supported
|
|
|
|
let _ = crate::client::uri_mode(request.uri())?; |
|
|
|
let _ = crate::client::uri_mode(request.uri())?; |
|
|
|
|
|
|
|
|
|
|
|
let key = generate_key(); |
|
|
|
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
|
|
|
|
|
|
|
// Also extract the key from it (it must be present in a correct request).
|
|
|
|
|
|
|
|
let (request, key) = generate_request(request)?; |
|
|
|
|
|
|
|
|
|
|
|
let machine = { |
|
|
|
let machine = HandshakeMachine::start_write(stream, request); |
|
|
|
let req = generate_request(request, &key)?; |
|
|
|
|
|
|
|
HandshakeMachine::start_write(stream, req) |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let client = { |
|
|
|
let client = { |
|
|
|
let accept_key = derive_accept_key(key.as_ref()); |
|
|
|
let accept_key = derive_accept_key(key.as_ref()); |
|
|
@ -92,56 +93,73 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// Generate client request.
|
|
|
|
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
|
|
|
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> { |
|
|
|
let mut req = Vec::new(); |
|
|
|
let mut req = Vec::new(); |
|
|
|
let uri = request.uri(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); |
|
|
|
|
|
|
|
let host = if let Some(idx) = authority.find('@') { |
|
|
|
|
|
|
|
// handle possible name:password@
|
|
|
|
|
|
|
|
authority.split_at(idx + 1).1 |
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
authority |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
if authority.is_empty() { |
|
|
|
|
|
|
|
return Err(Error::Url(UrlError::EmptyHostName)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write!( |
|
|
|
write!( |
|
|
|
req, |
|
|
|
req, |
|
|
|
"\ |
|
|
|
"GET {path} {version:?}\r\n", |
|
|
|
GET {path} {version:?}\r\n\ |
|
|
|
path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), |
|
|
|
Host: {host}\r\n\ |
|
|
|
version = request.version() |
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
|
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Version: 13\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Key: {key}\r\n", |
|
|
|
|
|
|
|
version = request.version(), |
|
|
|
|
|
|
|
host = host, |
|
|
|
|
|
|
|
path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), |
|
|
|
|
|
|
|
key = key |
|
|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
.unwrap(); |
|
|
|
.unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
for (k, v) in request.headers() { |
|
|
|
// Headers that must be present in a correct request.
|
|
|
|
if k == header::CONNECTION |
|
|
|
const KEY_HEADERNAME: &str = "Sec-WebSocket-Key"; |
|
|
|
|| k == header::UPGRADE |
|
|
|
const WEBSOCKET_HEADERS: [&str; 5] = |
|
|
|
|| k == header::SEC_WEBSOCKET_VERSION |
|
|
|
["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME]; |
|
|
|
|| k == header::SEC_WEBSOCKET_KEY |
|
|
|
|
|
|
|
|| k == header::HOST |
|
|
|
// We must extract a WebSocket key from a properly formed request or fail if it's not present.
|
|
|
|
{ |
|
|
|
let key = request |
|
|
|
|
|
|
|
.headers() |
|
|
|
|
|
|
|
.get(KEY_HEADERNAME) |
|
|
|
|
|
|
|
.ok_or_else(|| { |
|
|
|
|
|
|
|
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(KEY_HEADERNAME))) |
|
|
|
|
|
|
|
})? |
|
|
|
|
|
|
|
.to_str()? |
|
|
|
|
|
|
|
.to_owned(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// We must check that all necessary headers for a valid request are present. Note that we have to
|
|
|
|
|
|
|
|
// deal with the fact that some apps seem to have a case-sensitive check for headers which is not
|
|
|
|
|
|
|
|
// correct and should not considered the correct behavior, but it seems like some apps ignore it.
|
|
|
|
|
|
|
|
// `http` by default writes all headers in lower-case which is fine (and does not violate the RFC)
|
|
|
|
|
|
|
|
// but some servers seem to be poorely written and ignore RFC.
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492
|
|
|
|
|
|
|
|
let headers = request.headers_mut(); |
|
|
|
|
|
|
|
for header in WEBSOCKET_HEADERS { |
|
|
|
|
|
|
|
let value = headers.remove(header).ok_or_else(|| { |
|
|
|
|
|
|
|
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(header))) |
|
|
|
|
|
|
|
})?; |
|
|
|
|
|
|
|
write!(req, "{header}: {value}\r\n", value = value.to_str()?).unwrap(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Now we must ensure that the headers that we've written once are not anymore present in the map.
|
|
|
|
|
|
|
|
// If they do, then the request is invalid (some headers are duplicated there for some reason).
|
|
|
|
|
|
|
|
let insensitive: Vec<String> = |
|
|
|
|
|
|
|
WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect(); |
|
|
|
|
|
|
|
for (k, v) in headers { |
|
|
|
|
|
|
|
let mut name = k.as_str(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// We have already written the necessary headers once (above) and removed them from the map.
|
|
|
|
|
|
|
|
// If we encounter them again, then the request is considered invalid and error is returned.
|
|
|
|
|
|
|
|
// Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust.
|
|
|
|
|
|
|
|
if insensitive.iter().any(|x| x == name) { |
|
|
|
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); |
|
|
|
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); |
|
|
|
} |
|
|
|
} |
|
|
|
let mut k = k.as_str(); |
|
|
|
|
|
|
|
if k == "sec-websocket-protocol" { |
|
|
|
// Relates to the issue of some servers treating headers in a case-sensitive way, please see:
|
|
|
|
k = "Sec-WebSocket-Protocol"; |
|
|
|
// https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem)
|
|
|
|
|
|
|
|
if name == "sec-websocket-protocol" { |
|
|
|
|
|
|
|
name = "Sec-WebSocket-Protocol"; |
|
|
|
} |
|
|
|
} |
|
|
|
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
writeln!(req, "\r").unwrap(); |
|
|
|
writeln!(req, "\r").unwrap(); |
|
|
|
trace!("Request: {:?}", String::from_utf8_lossy(&req)); |
|
|
|
trace!("Request: {:?}", String::from_utf8_lossy(&req)); |
|
|
|
Ok(req) |
|
|
|
Ok((req, key)) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// Information for handshake verification.
|
|
|
|
/// Information for handshake verification.
|
|
|
@ -241,7 +259,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// Generate a random key for the `Sec-WebSocket-Key` header.
|
|
|
|
/// Generate a random key for the `Sec-WebSocket-Key` header.
|
|
|
|
fn generate_key() -> String { |
|
|
|
pub fn generate_key() -> String { |
|
|
|
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
|
|
|
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
|
|
|
// when decoded, is 16 bytes in length (RFC 6455)
|
|
|
|
// when decoded, is 16 bytes in length (RFC 6455)
|
|
|
|
let r: [u8; 16] = rand::random(); |
|
|
|
let r: [u8; 16] = rand::random(); |
|
|
@ -269,54 +287,41 @@ mod tests { |
|
|
|
assert!(k2[..22].find('=').is_none()); |
|
|
|
assert!(k2[..22].find('=').is_none()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
fn construct_expected(host: &str, key: &str) -> Vec<u8> { |
|
|
|
fn request_formatting() { |
|
|
|
format!( |
|
|
|
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); |
|
|
|
"\ |
|
|
|
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; |
|
|
|
|
|
|
|
let correct = b"\ |
|
|
|
|
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
Host: localhost\r\n\ |
|
|
|
Host: {host}\r\n\ |
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
Sec-WebSocket-Version: 13\r\n\ |
|
|
|
Sec-WebSocket-Version: 13\r\n\ |
|
|
|
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ |
|
|
|
Sec-WebSocket-Key: {key}\r\n\ |
|
|
|
\r\n"; |
|
|
|
\r\n" |
|
|
|
let request = generate_request(request, key).unwrap(); |
|
|
|
) |
|
|
|
println!("Request: {}", String::from_utf8_lossy(&request)); |
|
|
|
.into_bytes() |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
|
|
|
fn request_formatting() { |
|
|
|
|
|
|
|
let request = "ws://localhost/getCaseCount".into_client_request().unwrap(); |
|
|
|
|
|
|
|
let (request, key) = generate_request(request).unwrap(); |
|
|
|
|
|
|
|
let correct = construct_expected("localhost", &key); |
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
|
fn request_formatting_with_host() { |
|
|
|
fn request_formatting_with_host() { |
|
|
|
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; |
|
|
|
let (request, key) = generate_request(request).unwrap(); |
|
|
|
let correct = b"\ |
|
|
|
let correct = construct_expected("localhost:9001", &key); |
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
|
|
|
|
Host: localhost:9001\r\n\ |
|
|
|
|
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
|
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Version: 13\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ |
|
|
|
|
|
|
|
\r\n"; |
|
|
|
|
|
|
|
let request = generate_request(request, key).unwrap(); |
|
|
|
|
|
|
|
println!("Request: {}", String::from_utf8_lossy(&request)); |
|
|
|
|
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
|
fn request_formatting_with_at() { |
|
|
|
fn request_formatting_with_at() { |
|
|
|
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; |
|
|
|
let (request, key) = generate_request(request).unwrap(); |
|
|
|
let correct = b"\ |
|
|
|
let correct = construct_expected("localhost:9001", &key); |
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
|
|
|
|
Host: localhost:9001\r\n\ |
|
|
|
|
|
|
|
Connection: Upgrade\r\n\ |
|
|
|
|
|
|
|
Upgrade: websocket\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Version: 13\r\n\ |
|
|
|
|
|
|
|
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ |
|
|
|
|
|
|
|
\r\n"; |
|
|
|
|
|
|
|
let request = generate_request(request, key).unwrap(); |
|
|
|
|
|
|
|
println!("Request: {}", String::from_utf8_lossy(&request)); |
|
|
|
|
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
assert_eq!(&request[..], &correct[..]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|