From bd87b429e2381fad414d787cc4cd526abc24623d Mon Sep 17 00:00:00 2001 From: Sunny Date: Mon, 31 Jan 2022 16:52:32 +0800 Subject: [PATCH] Allow user to overwrite the |host| header while generating the request As we discussed in #255, it's possible for us to provide some way to sepcifiy the generated request's |host| header. Now when the incoming request stuct contains a |host| header, we will use it's value as outgoing |host| header's value. If there's no |host| header specified, we still use the authority part of uri as |host| header. This behavior is complied with RFC 6455 and close issue #255 --- src/handshake/client.rs | 48 ++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 36d6262..01fa8a5 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -96,17 +96,22 @@ impl HandshakeRole for ClientHandshake { fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); + let headers = request.headers(); - let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); - let host = if let Some(idx) = authority.find('@') { - // handle possible name:password@ - authority.split_at(idx + 1).1 + let host = if headers.contains_key(header::HOST) { + headers.get(header::HOST).unwrap().to_str()? } else { - authority + let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); + if authority.is_empty() { + return Err(Error::Url(UrlError::EmptyHostName)); + } + if let Some(idx) = authority.find('@') { + // handle possible name:password@ + authority.split_at(idx + 1).1 + } else { + authority + } }; - if authority.is_empty() { - return Err(Error::Url(UrlError::EmptyHostName)); - } write!( req, @@ -124,12 +129,14 @@ fn generate_request(request: Request, key: &str) -> Result> { ) .unwrap(); - for (k, v) in request.headers() { + for (k, v) in headers { + if k == header::HOST { + continue; + } if k == header::CONNECTION || k == header::UPGRADE || k == header::SEC_WEBSOCKET_VERSION || k == header::SEC_WEBSOCKET_KEY - || k == header::HOST { return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); } @@ -320,6 +327,27 @@ mod tests { assert_eq!(&request[..], &correct[..]); } + #[test] + fn request_formatting_with_host_header() { + let mut request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); + request.headers_mut().insert( + http::header::HOST, + "foo.com:8080".parse().unwrap(), + ); + let key = "A70tsIbeMZUbJHh5BWFw6Q=="; + let correct = b"\ + GET /getCaseCount HTTP/1.1\r\n\ + Host: foo.com:8080\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ + \r\n"; + let request = generate_request(request, key).unwrap(); + println!("Request: {}", String::from_utf8_lossy(&request)); + assert_eq!(&request[..], &correct[..]); + } + #[test] fn response_parsing() { const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";