diff --git a/src/error.rs b/src/error.rs index e224da7..a9c4ceb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,7 +3,7 @@ use std::{io, result, str, string}; use crate::protocol::{frame::coding::Data, Message}; -use http::Response; +use http::{header::HeaderName, Response}; use thiserror::Error; /// Result type of all Tungstenite library calls. @@ -138,7 +138,7 @@ pub enum CapacityError { } /// Indicates the specific type/cause of a protocol error. -#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Error, Debug, PartialEq, Eq, Clone)] pub enum ProtocolError { /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). #[error("Unsupported HTTP method used - only GET is allowed")] @@ -167,6 +167,10 @@ pub enum ProtocolError { /// Custom responses must be unsuccessful. #[error("Custom response must not be successful")] CustomResponseSuccessful, + /// Invalid header is passed. This header is formed by the library automatically + /// and must not be overwritten by the user. + #[error("Not allowed to pass overwrite the standard header {0}")] + InvalidHeader(HeaderName), /// No more data while still performing handshake. #[error("Handshake not finished")] HandshakeIncomplete, diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 07c41c1..36d6262 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -5,7 +5,7 @@ use std::{ marker::PhantomData, }; -use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; +use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; @@ -125,6 +125,14 @@ fn generate_request(request: Request, key: &str) -> Result> { .unwrap(); for (k, v) in request.headers() { + if k == header::CONNECTION + || k == header::UPGRADE + || k == header::SEC_WEBSOCKET_VERSION + || k == header::SEC_WEBSOCKET_KEY + || k == header::HOST + { + return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone()))); + } let mut k = k.as_str(); if k == "sec-websocket-protocol" { k = "Sec-WebSocket-Protocol";