diff --git a/src/lib.rs b/src/lib.rs index 4947ac3..f5a49ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -377,12 +377,24 @@ where pub(crate) fn domain( request: &tungstenite::handshake::client::Request, ) -> Result { - match request.uri().host() { - Some(d) => Ok(d.to_string()), - None => Err(tungstenite::Error::Url( - tungstenite::error::UrlError::NoHostName, - )), - } + request + .uri() + .host() + .map(|host| { + // If host is an IPv6 address, it might be surrounded by brackets. These brackets are + // *not* part of a valid IP, so they must be stripped out. + // + // The URI from the request is guaranteed to be valid, so we don't need a separate + // check for the closing bracket. + let host = if host.starts_with('[') { + &host[1..host.len() - 1] + } else { + host + }; + + host.to_owned() + }) + .ok_or_else(|| tungstenite::Error::Url(tungstenite::error::UrlError::NoHostName)) } #[cfg(any( @@ -407,3 +419,29 @@ pub(crate) fn port( tungstenite::error::UrlError::UnsupportedUrlScheme, )) } + +#[cfg(test)] +mod tests { + #[cfg(any( + feature = "async-tls", + feature = "async-std-runtime", + feature = "tokio-runtime", + feature = "gio-runtime" + ))] + #[test] + fn domain_strips_ipv6_brackets() { + use tungstenite::client::IntoClientRequest; + + let request = "ws://[::1]:80".into_client_request().unwrap(); + assert_eq!(crate::domain(&request).unwrap(), "::1"); + } + + #[test] + fn requests_cannot_contain_invalid_uris() { + use tungstenite::client::IntoClientRequest; + + assert!("ws://[".into_client_request().is_err()); + assert!("ws://[blabla/bla".into_client_request().is_err()); + assert!("ws://[::1/bla".into_client_request().is_err()); + } +}