|
|
|
@ -61,11 +61,18 @@ impl<S: Read + Write> ClientHandshake<S> { |
|
|
|
|
|
|
|
|
|
let client = { |
|
|
|
|
let accept_key = derive_accept_key(key.as_ref()); |
|
|
|
|
ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData } |
|
|
|
|
ClientHandshake { |
|
|
|
|
verify_data: VerifyData { accept_key }, |
|
|
|
|
config, |
|
|
|
|
_marker: PhantomData, |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
trace!("Client handshake initiated."); |
|
|
|
|
Ok(MidHandshake { role: client, machine }) |
|
|
|
|
// trace!("Client handshake initiated.");
|
|
|
|
|
Ok(MidHandshake { |
|
|
|
|
role: client, |
|
|
|
|
machine, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -81,7 +88,11 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> { |
|
|
|
|
StageResult::DoneWriting(stream) => { |
|
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) |
|
|
|
|
} |
|
|
|
|
StageResult::DoneReading { stream, result, tail } => { |
|
|
|
|
StageResult::DoneReading { |
|
|
|
|
stream, |
|
|
|
|
result, |
|
|
|
|
tail, |
|
|
|
|
} => { |
|
|
|
|
let result = self.verify_data.verify_response(result)?; |
|
|
|
|
debug!("Client handshake done."); |
|
|
|
|
let websocket = |
|
|
|
@ -97,7 +108,10 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
|
let mut req = Vec::new(); |
|
|
|
|
let uri = request.uri(); |
|
|
|
|
|
|
|
|
|
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); |
|
|
|
|
let authority = uri |
|
|
|
|
.authority() |
|
|
|
|
.ok_or(Error::Url(UrlError::NoHostName))? |
|
|
|
|
.as_str(); |
|
|
|
|
let host = if let Some(idx) = authority.find('@') { |
|
|
|
|
// handle possible name:password@
|
|
|
|
|
authority.split_at(idx + 1).1 |
|
|
|
@ -119,7 +133,10 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
|
Sec-WebSocket-Key: {key}\r\n", |
|
|
|
|
version = request.version(), |
|
|
|
|
host = host, |
|
|
|
|
path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), |
|
|
|
|
path = uri |
|
|
|
|
.path_and_query() |
|
|
|
|
.ok_or(Error::Url(UrlError::NoPathOrQuery))? |
|
|
|
|
.as_str(), |
|
|
|
|
key = key |
|
|
|
|
) |
|
|
|
|
.unwrap(); |
|
|
|
@ -140,7 +157,7 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
|
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
|
} |
|
|
|
|
writeln!(req, "\r").unwrap(); |
|
|
|
|
trace!("Request: {:?}", String::from_utf8_lossy(&req)); |
|
|
|
|
// trace!("Request: {:?}", String::from_utf8_lossy(&req));
|
|
|
|
|
Ok(req) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -171,7 +188,9 @@ impl VerifyData { |
|
|
|
|
.map(|h| h.eq_ignore_ascii_case("websocket")) |
|
|
|
|
.unwrap_or(false) |
|
|
|
|
{ |
|
|
|
|
return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); |
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
ProtocolError::MissingUpgradeWebSocketHeader, |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
// 3. If the response lacks a |Connection| header field or the
|
|
|
|
|
// |Connection| header field doesn't contain a token that is an
|
|
|
|
@ -183,14 +202,22 @@ impl VerifyData { |
|
|
|
|
.map(|h| h.eq_ignore_ascii_case("Upgrade")) |
|
|
|
|
.unwrap_or(false) |
|
|
|
|
{ |
|
|
|
|
return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); |
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
ProtocolError::MissingConnectionUpgradeHeader, |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
|
|
|
|
|
// the |Sec-WebSocket-Accept| contains a value other than the
|
|
|
|
|
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
|
|
|
|
|
// Connection_. (RFC 6455)
|
|
|
|
|
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { |
|
|
|
|
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); |
|
|
|
|
if !headers |
|
|
|
|
.get("Sec-WebSocket-Accept") |
|
|
|
|
.map(|h| h == &self.accept_key) |
|
|
|
|
.unwrap_or(false) |
|
|
|
|
{ |
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
ProtocolError::SecWebSocketAcceptKeyMismatch, |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
// 5. If the response includes a |Sec-WebSocket-Extensions| header
|
|
|
|
|
// field and this header field indicates the use of an extension
|
|
|
|
@ -288,7 +315,9 @@ mod tests { |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn request_formatting_with_host() { |
|
|
|
|
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
|
let request = "wss://localhost:9001/getCaseCount" |
|
|
|
|
.into_client_request() |
|
|
|
|
.unwrap(); |
|
|
|
|
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; |
|
|
|
|
let correct = b"\ |
|
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
@ -305,7 +334,9 @@ mod tests { |
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
|
fn request_formatting_with_at() { |
|
|
|
|
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); |
|
|
|
|
let request = "wss://user:pass@localhost:9001/getCaseCount" |
|
|
|
|
.into_client_request() |
|
|
|
|
.unwrap(); |
|
|
|
|
let key = "A70tsIbeMZUbJHh5BWFw6Q=="; |
|
|
|
|
let correct = b"\ |
|
|
|
|
GET /getCaseCount HTTP/1.1\r\n\ |
|
|
|
@ -325,6 +356,9 @@ mod tests { |
|
|
|
|
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; |
|
|
|
|
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); |
|
|
|
|
assert_eq!(resp.status(), http::StatusCode::OK); |
|
|
|
|
assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],); |
|
|
|
|
assert_eq!( |
|
|
|
|
resp.headers().get("Content-Type").unwrap(), |
|
|
|
|
&b"text/html"[..], |
|
|
|
|
); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|