|
|
@ -1,17 +1,24 @@ |
|
|
|
//! Client handshake machine.
|
|
|
|
//! Client handshake machine.
|
|
|
|
|
|
|
|
|
|
|
|
use std::io::{Read, Write}; |
|
|
|
use std::{ |
|
|
|
use std::marker::PhantomData; |
|
|
|
io::{Read, Write}, |
|
|
|
|
|
|
|
marker::PhantomData, |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; |
|
|
|
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; |
|
|
|
use httparse::Status; |
|
|
|
use httparse::Status; |
|
|
|
use log::*; |
|
|
|
use log::*; |
|
|
|
|
|
|
|
|
|
|
|
use super::headers::{FromHttparse, MAX_HEADERS}; |
|
|
|
use super::{ |
|
|
|
use super::machine::{HandshakeMachine, StageResult, TryParse}; |
|
|
|
convert_key, |
|
|
|
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; |
|
|
|
headers::{FromHttparse, MAX_HEADERS}, |
|
|
|
use crate::error::{Error, Result}; |
|
|
|
machine::{HandshakeMachine, StageResult, TryParse}, |
|
|
|
use crate::protocol::{Role, WebSocket, WebSocketConfig}; |
|
|
|
HandshakeRole, MidHandshake, ProcessingResult, |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
use crate::{ |
|
|
|
|
|
|
|
error::{Error, Result}, |
|
|
|
|
|
|
|
protocol::{Role, WebSocket, WebSocketConfig}, |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
/// Client request type.
|
|
|
|
/// Client request type.
|
|
|
|
pub type Request = HttpRequest<()>; |
|
|
|
pub type Request = HttpRequest<()>; |
|
|
@ -35,15 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> { |
|
|
|
config: Option<WebSocketConfig>, |
|
|
|
config: Option<WebSocketConfig>, |
|
|
|
) -> Result<MidHandshake<Self>> { |
|
|
|
) -> Result<MidHandshake<Self>> { |
|
|
|
if request.method() != http::Method::GET { |
|
|
|
if request.method() != http::Method::GET { |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); |
|
|
|
"Invalid HTTP method, only GET supported".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if request.version() < http::Version::HTTP_11 { |
|
|
|
if request.version() < http::Version::HTTP_11 { |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); |
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Check the URI scheme: only ws or wss are supported
|
|
|
|
// Check the URI scheme: only ws or wss are supported
|
|
|
@ -58,18 +61,11 @@ impl<S: Read + Write> ClientHandshake<S> { |
|
|
|
|
|
|
|
|
|
|
|
let client = { |
|
|
|
let client = { |
|
|
|
let accept_key = convert_key(key.as_ref()).unwrap(); |
|
|
|
let accept_key = convert_key(key.as_ref()).unwrap(); |
|
|
|
ClientHandshake { |
|
|
|
ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData } |
|
|
|
verify_data: VerifyData { accept_key }, |
|
|
|
|
|
|
|
config, |
|
|
|
|
|
|
|
_marker: PhantomData, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
trace!("Client handshake initiated."); |
|
|
|
trace!("Client handshake initiated."); |
|
|
|
Ok(MidHandshake { |
|
|
|
Ok(MidHandshake { role: client, machine }) |
|
|
|
role: client, |
|
|
|
|
|
|
|
machine, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -85,11 +81,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> { |
|
|
|
StageResult::DoneWriting(stream) => { |
|
|
|
StageResult::DoneWriting(stream) => { |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) |
|
|
|
} |
|
|
|
} |
|
|
|
StageResult::DoneReading { |
|
|
|
StageResult::DoneReading { stream, result, tail } => { |
|
|
|
stream, |
|
|
|
|
|
|
|
result, |
|
|
|
|
|
|
|
tail, |
|
|
|
|
|
|
|
} => { |
|
|
|
|
|
|
|
let result = self.verify_data.verify_response(result)?; |
|
|
|
let result = self.verify_data.verify_response(result)?; |
|
|
|
debug!("Client handshake done."); |
|
|
|
debug!("Client handshake done."); |
|
|
|
let websocket = |
|
|
|
let websocket = |
|
|
@ -105,16 +97,16 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
let mut req = Vec::new(); |
|
|
|
let mut req = Vec::new(); |
|
|
|
let uri = request.uri(); |
|
|
|
let uri = request.uri(); |
|
|
|
|
|
|
|
|
|
|
|
let authority = uri.authority() |
|
|
|
let authority = |
|
|
|
.ok_or_else(|| Error::Url("No host name in the URL".into()))? |
|
|
|
uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); |
|
|
|
.as_str(); |
|
|
|
let host = if let Some(idx) = authority.find('@') { |
|
|
|
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
|
|
|
|
// handle possible name:password@
|
|
|
|
authority.split_at(idx + 1).1 |
|
|
|
authority.split_at(idx + 1).1 |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
authority |
|
|
|
authority |
|
|
|
}; |
|
|
|
}; |
|
|
|
if authority.is_empty() { |
|
|
|
if authority.is_empty() { |
|
|
|
return Err(Error::Url("URL contains empty host name".into())) |
|
|
|
return Err(Error::Url("URL contains empty host name".into())); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
write!( |
|
|
|
write!( |
|
|
@ -128,17 +120,15 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> { |
|
|
|
Sec-WebSocket-Key: {key}\r\n", |
|
|
|
Sec-WebSocket-Key: {key}\r\n", |
|
|
|
version = request.version(), |
|
|
|
version = request.version(), |
|
|
|
host = host, |
|
|
|
host = host, |
|
|
|
path = uri |
|
|
|
path = |
|
|
|
.path_and_query() |
|
|
|
uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), |
|
|
|
.ok_or_else(|| Error::Url("No path/query in URL".into()))? |
|
|
|
|
|
|
|
.as_str(), |
|
|
|
|
|
|
|
key = key |
|
|
|
key = key |
|
|
|
) |
|
|
|
) |
|
|
|
.unwrap(); |
|
|
|
.unwrap(); |
|
|
|
|
|
|
|
|
|
|
|
for (k, v) in request.headers() { |
|
|
|
for (k, v) in request.headers() { |
|
|
|
let mut k = k.as_str(); |
|
|
|
let mut k = k.as_str(); |
|
|
|
if k == "sec-websocket-protocol" { |
|
|
|
if k == "sec-websocket-protocol" { |
|
|
|
k = "Sec-WebSocket-Protocol"; |
|
|
|
k = "Sec-WebSocket-Protocol"; |
|
|
|
} |
|
|
|
} |
|
|
|
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
|
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); |
|
|
@ -175,9 +165,7 @@ impl VerifyData { |
|
|
|
.map(|h| h.eq_ignore_ascii_case("websocket")) |
|
|
|
.map(|h| h.eq_ignore_ascii_case("websocket")) |
|
|
|
.unwrap_or(false) |
|
|
|
.unwrap_or(false) |
|
|
|
{ |
|
|
|
{ |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); |
|
|
|
"No \"Upgrade: websocket\" in server reply".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
// 3. If the response lacks a |Connection| header field or the
|
|
|
|
// 3. If the response lacks a |Connection| header field or the
|
|
|
|
// |Connection| header field doesn't contain a token that is an
|
|
|
|
// |Connection| header field doesn't contain a token that is an
|
|
|
@ -189,22 +177,14 @@ impl VerifyData { |
|
|
|
.map(|h| h.eq_ignore_ascii_case("Upgrade")) |
|
|
|
.map(|h| h.eq_ignore_ascii_case("Upgrade")) |
|
|
|
.unwrap_or(false) |
|
|
|
.unwrap_or(false) |
|
|
|
{ |
|
|
|
{ |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); |
|
|
|
"No \"Connection: upgrade\" in server reply".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
|
|
|
|
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
|
|
|
|
// the |Sec-WebSocket-Accept| contains a value other than the
|
|
|
|
// the |Sec-WebSocket-Accept| contains a value other than the
|
|
|
|
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
|
|
|
|
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
|
|
|
|
// Connection_. (RFC 6455)
|
|
|
|
// Connection_. (RFC 6455)
|
|
|
|
if !headers |
|
|
|
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { |
|
|
|
.get("Sec-WebSocket-Accept") |
|
|
|
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); |
|
|
|
.map(|h| h == &self.accept_key) |
|
|
|
|
|
|
|
.unwrap_or(false) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
return Err(Error::Protocol( |
|
|
|
|
|
|
|
"Key mismatch in Sec-WebSocket-Accept".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
// 5. If the response includes a |Sec-WebSocket-Extensions| header
|
|
|
|
// 5. If the response includes a |Sec-WebSocket-Extensions| header
|
|
|
|
// field and this header field indicates the use of an extension
|
|
|
|
// field and this header field indicates the use of an extension
|
|
|
@ -238,9 +218,7 @@ impl TryParse for Response { |
|
|
|
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { |
|
|
|
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { |
|
|
|
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { |
|
|
|
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { |
|
|
|
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
|
|
|
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { |
|
|
|
return Err(Error::Protocol( |
|
|
|
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); |
|
|
|
"HTTP version should be 1.1 or higher".into(), |
|
|
|
|
|
|
|
)); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let headers = HeaderMap::from_httparse(raw.headers)?; |
|
|
|
let headers = HeaderMap::from_httparse(raw.headers)?; |
|
|
@ -266,9 +244,8 @@ fn generate_key() -> String { |
|
|
|
|
|
|
|
|
|
|
|
#[cfg(test)] |
|
|
|
#[cfg(test)] |
|
|
|
mod tests { |
|
|
|
mod tests { |
|
|
|
use super::super::machine::TryParse; |
|
|
|
use super::{super::machine::TryParse, generate_key, generate_request, Response}; |
|
|
|
use crate::client::IntoClientRequest; |
|
|
|
use crate::client::IntoClientRequest; |
|
|
|
use super::{generate_key, generate_request, Response}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test] |
|
|
|
#[test] |
|
|
|
fn random_keys() { |
|
|
|
fn random_keys() { |
|
|
@ -342,9 +319,6 @@ mod tests { |
|
|
|
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; |
|
|
|
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; |
|
|
|
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); |
|
|
|
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); |
|
|
|
assert_eq!(resp.status(), http::StatusCode::OK); |
|
|
|
assert_eq!(resp.status(), http::StatusCode::OK); |
|
|
|
assert_eq!( |
|
|
|
assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],); |
|
|
|
resp.headers().get("Content-Type").unwrap(), |
|
|
|
|
|
|
|
&b"text/html"[..], |
|
|
|
|
|
|
|
); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|