diff --git a/src/error.rs b/src/error.rs index ab24753..dc05a12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,6 +65,8 @@ pub enum Error { Url(Cow<'static, str>), /// HTTP error. Http(http::StatusCode), + /// HTTP response error. + HttpResponse(http::Response<()>), /// HTTP format error. HttpFormat(http::Error), } @@ -83,6 +85,7 @@ impl fmt::Display for Error { Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::HttpResponse(ref res) => write!(f, "HTTP response error: {}", res.status()), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 745da90..e9be75e 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -90,6 +90,12 @@ impl HandshakeRole for ClientHandshake { result, tail, } => { + // If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) + if result.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::HttpResponse(result)); + } + self.verify_data.verify_response(&result)?; debug!("Client handshake done."); let websocket = @@ -105,16 +111,18 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = uri.authority() + let authority = uri + .authority() .ok_or_else(|| Error::Url("No host name in the URL".into()))? .as_str(); - let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ + let host = if let Some(idx) = authority.find('@') { + // handle possible name:password@ authority.split_at(idx + 1).1 } else { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())) + return Err(Error::Url("URL contains empty host name".into())); } write!( @@ -138,7 +146,7 @@ fn generate_request(request: Request, key: &str) -> Result> { for (k, v) in request.headers() { let mut k = k.as_str(); - if k == "sec-websocket-protocol" { + if k == "sec-websocket-protocol" { k = "Sec-WebSocket-Protocol"; } writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); @@ -157,14 +165,9 @@ struct VerifyData { impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { - // 1. If the status code received from the server is not 101, the - // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::Http(response.status())); - } let headers = response.headers(); - // 2. If the response lacks an |Upgrade| header field or the |Upgrade| + // 1. If the response lacks an |Upgrade| header field or the |Upgrade| // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) @@ -178,7 +181,7 @@ impl VerifyData { "No \"Upgrade: websocket\" in server reply".into(), )); } - // 3. If the response lacks a |Connection| header field or the + // 2. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) @@ -192,7 +195,7 @@ impl VerifyData { "No \"Connection: upgrade\" in server reply".into(), )); } - // 4. If the response lacks a |Sec-WebSocket-Accept| header field or + // 3. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) @@ -205,14 +208,14 @@ impl VerifyData { "Key mismatch in Sec-WebSocket-Accept".into(), )); } - // 5. If the response includes a |Sec-WebSocket-Extensions| header + // 4. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) // TODO - // 6. If the response includes a |Sec-WebSocket-Protocol| header field + // 5. If the response includes a |Sec-WebSocket-Protocol| header field // and this header field indicates the use of a subprotocol that was // not present in the client's handshake (the server has indicated a // subprotocol not requested by the client), the client MUST _Fail @@ -266,8 +269,8 @@ fn generate_key() -> String { #[cfg(test)] mod tests { use super::super::machine::TryParse; - use crate::client::IntoClientRequest; use super::{generate_key, generate_request, Response}; + use crate::client::IntoClientRequest; #[test] fn random_keys() { @@ -304,7 +307,9 @@ mod tests { #[test] fn request_formatting_with_host() { - let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ @@ -321,7 +326,9 @@ mod tests { #[test] fn request_formatting_with_at() { - let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); + let request = "wss://user:pass@localhost:9001/getCaseCount" + .into_client_request() + .unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\