diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 36d6262..01fa8a5 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -96,17 +96,22 @@ impl HandshakeRole for ClientHandshake { fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); + let headers = request.headers(); - let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); - let host = if let Some(idx) = authority.find('@') { - // handle possible name:password@ - authority.split_at(idx + 1).1 + let host = if headers.contains_key(header::HOST) { + headers.get(header::HOST).unwrap().to_str()? } else { - authority + let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); + if authority.is_empty() { + return Err(Error::Url(UrlError::EmptyHostName)); + } + if let Some(idx) = authority.find('@') { + // handle possible name:password@ + authority.split_at(idx + 1).1 + } else { + authority + } }; - if authority.is_empty() { - return Err(Error::Url(UrlError::EmptyHostName)); - } write!( req, @@ -124,12 +129,14 @@ fn generate_request(request: Request, key: &str) -> Result> { ) .unwrap(); - for (k, v) in request.headers() { + for (k, v) in headers { + if k == header::HOST { + continue; + } if k == header::CONNECTION || k == header::UPGRADE || k == header::SEC_WEBSOCKET_VERSION || k == header::SEC_WEBSOCKET_KEY - || k == header::HOST { return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); } @@ -320,6 +327,27 @@ mod tests { assert_eq!(&request[..], &correct[..]); } + #[test] + fn request_formatting_with_host_header() { + let mut request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); + request.headers_mut().insert( + http::header::HOST, + "foo.com:8080".parse().unwrap(), + ); + let key = "A70tsIbeMZUbJHh5BWFw6Q=="; + let correct = b"\ + GET /getCaseCount HTTP/1.1\r\n\ + Host: foo.com:8080\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ + \r\n"; + let request = generate_request(request, key).unwrap(); + println!("Request: {}", String::from_utf8_lossy(&request)); + assert_eq!(&request[..], &correct[..]); + } + #[test] fn response_parsing() { const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";