fix: return err if try to overwrite standard hedaer

pull/227/head
yifei 3 years ago
parent bc1b88b820
commit 9f563561a4
  1. 8
      src/error.rs
  2. 10
      src/handshake/client.rs

@ -3,7 +3,7 @@
use std::{io, result, str, string}; use std::{io, result, str, string};
use crate::protocol::{frame::coding::Data, Message}; use crate::protocol::{frame::coding::Data, Message};
use http::Response; use http::{header::HeaderName, Response};
use thiserror::Error; use thiserror::Error;
/// Result type of all Tungstenite library calls. /// Result type of all Tungstenite library calls.
@ -138,7 +138,7 @@ pub enum CapacityError {
} }
/// Indicates the specific type/cause of a protocol error. /// Indicates the specific type/cause of a protocol error.
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] #[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum ProtocolError { pub enum ProtocolError {
/// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used).
#[error("Unsupported HTTP method used - only GET is allowed")] #[error("Unsupported HTTP method used - only GET is allowed")]
@ -167,6 +167,10 @@ pub enum ProtocolError {
/// Custom responses must be unsuccessful. /// Custom responses must be unsuccessful.
#[error("Custom response must not be successful")] #[error("Custom response must not be successful")]
CustomResponseSuccessful, CustomResponseSuccessful,
/// Invalid header is passed. This header is formed by the library automatically
/// and must not be overwritten by the user.
#[error("Not allowed to pass overwrite the standard header {0}")]
InvalidHeader(HeaderName),
/// No more data while still performing handshake. /// No more data while still performing handshake.
#[error("Handshake not finished")] #[error("Handshake not finished")]
HandshakeIncomplete, HandshakeIncomplete,

@ -5,7 +5,7 @@ use std::{
marker::PhantomData, marker::PhantomData,
}; };
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
@ -125,6 +125,14 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
.unwrap(); .unwrap();
for (k, v) in request.headers() { for (k, v) in request.headers() {
if k == header::CONNECTION
|| k == header::UPGRADE
|| k == header::SEC_WEBSOCKET_VERSION
|| k == header::SEC_WEBSOCKET_KEY
|| k == header::HOST
{
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
}
let mut k = k.as_str(); let mut k = k.as_str();
if k == "sec-websocket-protocol" { if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol"; k = "Sec-WebSocket-Protocol";

Loading…
Cancel
Save