Allow user to overwrite the |host| header while generating the request

As we discussed in #255, it's possible for us to provide
some way to sepcifiy the generated request's |host| header.

Now when the incoming request stuct contains a |host| header,
we will use it's value as outgoing |host| header's value.

If there's no |host| header specified, we still use the authority
part of uri as |host| header.

This behavior is complied with RFC 6455 and close issue #255
pull/258/head
Sunny 3 years ago
parent cd79500d25
commit bd87b429e2
  1. 48
      src/handshake/client.rs

@ -96,17 +96,22 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();
let headers = request.headers();
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
let host = if headers.contains_key(header::HOST) {
headers.get(header::HOST).unwrap().to_str()?
} else {
authority
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
if authority.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
}
};
if authority.is_empty() {
return Err(Error::Url(UrlError::EmptyHostName));
}
write!(
req,
@ -124,12 +129,14 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
)
.unwrap();
for (k, v) in request.headers() {
for (k, v) in headers {
if k == header::HOST {
continue;
}
if k == header::CONNECTION
|| k == header::UPGRADE
|| k == header::SEC_WEBSOCKET_VERSION
|| k == header::SEC_WEBSOCKET_KEY
|| k == header::HOST
{
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
}
@ -320,6 +327,27 @@ mod tests {
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn request_formatting_with_host_header() {
let mut request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
request.headers_mut().insert(
http::header::HOST,
"foo.com:8080".parse().unwrap(),
);
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Host: foo.com:8080\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request = generate_request(request, key).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
#[test]
fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";

Loading…
Cancel
Save