diff --git a/src/handshake/client.rs b/src/handshake/client.rs index f3c92f7..44b4cd1 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -15,11 +15,11 @@ use super::{ machine::{HandshakeMachine, StageResult, TryParse}, HandshakeRole, MidHandshake, ProcessingResult, }; +use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers}; use crate::{ error::{Error, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; -use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers}; /// Client request type. pub type Request = HttpRequest<()>; @@ -62,7 +62,11 @@ impl ClientHandshake { let client = { let accept_key = convert_key(key.as_ref()).unwrap(); - ClientHandshake { verify_data: VerifyData { accept_key }, config: Some(config), _marker: PhantomData } + ClientHandshake { + verify_data: VerifyData { accept_key }, + config: Some(config), + _marker: PhantomData, + } }; trace!("Client handshake initiated."); @@ -83,10 +87,12 @@ impl HandshakeRole for ClientHandshake { ProcessingResult::Continue(HandshakeMachine::start_read(stream)) } StageResult::DoneReading { stream, result, tail } => { - let result = self.verify_data.verify_response(result)?; + let mut config = self.config.take().unwrap(); + let result = self.verify_data.verify_response(result, &mut config)?; + debug!("Client handshake done."); - let websocket = - WebSocket::from_partially_read(stream, tail, Role::Client, self.config); + + let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config); ProcessingResult::Done((websocket, result)) } }) @@ -154,7 +160,7 @@ struct VerifyData { impl VerifyData { pub fn verify_response( &self, - response: &Response, + response: Response, config: &mut Option, ) -> Result { // 1. If the status code received from the server is not 101, the @@ -202,7 +208,7 @@ impl VerifyData { // that was not present in the client's handshake (the server has // indicated an extension not requested by the client), the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - verify_compression_resp_headers(response, config)?; + verify_compression_resp_headers(&response, config)?; // 6. If the response includes a |Sec-WebSocket-Protocol| header field // and this header field indicates the use of a subprotocol that was @@ -293,9 +299,7 @@ mod tests { #[test] fn request_formatting_with_host() { - let request = "wss://localhost:9001/getCaseCount" - .into_client_request() - .unwrap(); + let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ @@ -312,9 +316,7 @@ mod tests { #[test] fn request_formatting_with_at() { - let request = "wss://user:pass@localhost:9001/getCaseCount" - .into_client_request() - .unwrap(); + let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap(); let key = "A70tsIbeMZUbJHh5BWFw6Q=="; let correct = b"\ GET /getCaseCount HTTP/1.1\r\n\ diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 35fb73d..dba60cb 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -15,10 +15,10 @@ use super::{ headers::{FromHttparse, MAX_HEADERS}, machine::{HandshakeMachine, StageResult, TryParse}, HandshakeRole, MidHandshake, ProcessingResult, - extensions::verify_compression_req_headers }; use crate::{ error::{Error, Result}, + extensions::compression::verify_compression_req_headers, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -228,7 +228,7 @@ impl HandshakeRole for ServerHandshake { } let mut response = create_response(&result)?; - verify_compression_req_headers(&request, &mut response, &mut self.config)?; + verify_compression_req_headers(&result, &mut response, &mut self.config)?; let callback_result = if let Some(callback) = self.callback.take() { callback.on_request(&result, response) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 3aca767..b79be78 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -65,11 +65,7 @@ pub struct ExtensionHeaders { impl Default for ExtensionHeaders { fn default() -> Self { - ExtensionHeaders { - rsv1: false, - rsv2: false, - rsv3: false, - } + ExtensionHeaders { rsv1: false, rsv2: false, rsv3: false } } } @@ -210,12 +206,7 @@ impl FrameHeader { } let ext_headers = ExtensionHeaders { rsv1, rsv2, rsv3 }; - let hdr = FrameHeader { - is_final, - ext_headers, - opcode, - mask, - }; + let hdr = FrameHeader { is_final, ext_headers, opcode, mask }; Ok(Some((hdr, length))) } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 0b4b5ac..5dedb1f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -2,7 +2,7 @@ pub mod frame; -mod message; +pub(crate) mod message; pub use self::{frame::CloseFrame, message::Message}; @@ -18,11 +18,12 @@ use self::{ coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}, Frame, FrameCodec, }, - message::{IncompleteMessage, IncompleteMessageType}, - extensions::{WebSocketExtension, compression::{CompressionSwitcher, WsCompression}}; + message::IncompleteMessage, }; use crate::{ error::{Error, Result}, + extensions::compression::{CompressionSwitcher, WsCompression}, + extensions::WebSocketExtension, util::NonBlockingResult, }; @@ -44,6 +45,10 @@ pub struct WebSocketConfig { /// means here that the size of the queue is unlimited. The default value is the unlimited /// queue. pub max_send_queue: Option, + /// The maximum size of a message. `None` means no size limit. The default value is 64 MiB + /// which should be reasonably big for all normal use-cases but small enough to prevent + /// memory eating by a malicious user. + pub max_message_size: Option, /// The maximum size of a single message frame. `None` means no size limit. The limit is for /// frame payload NOT including the frame header. The default value is 16 MiB which should /// be reasonably big for all normal use-cases but small enough to prevent memory eating @@ -57,6 +62,7 @@ impl Default for WebSocketConfig { fn default() -> Self { WebSocketConfig { max_send_queue: None, + max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), compression: WsCompression::None(Some(MAX_MESSAGE_SIZE)), } @@ -501,9 +507,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol( - "Connection reset without closing handshake".into(), - )), + _ => Err(Error::Protocol("Connection reset without closing handshake".into())), } } } @@ -629,6 +633,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; + use crate::extensions::compression::WsCompression; use std::{io, io::Cursor}; struct WriteMoc(Stream); @@ -668,7 +673,11 @@ mod tests { 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, ]); - let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; + let limit = WebSocketConfig { + max_message_size: Some(10), + compression: WsCompression::None(Some(10)), + ..WebSocketConfig::default() + }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(), @@ -679,7 +688,11 @@ mod tests { #[test] fn size_limiting_binary() { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); - let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; + let limit = WebSocketConfig { + max_message_size: Some(2), + compression: WsCompression::None(Some(2)), + ..WebSocketConfig::default() + }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( socket.read_message().unwrap_err().to_string(),