diff --git a/Cargo.toml b/Cargo.toml index 6569d79..ac93fdc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ edition = "2018" default = ["tls"] tls = ["native-tls"] tls-vendored = ["native-tls", "native-tls/vendored"] -deflate = ["flate2"] +deflate = [] [dependencies] base64 = "0.12.0" @@ -30,12 +30,7 @@ rand = "0.7.2" sha-1 = "0.9" url = "2.1.0" utf-8 = "0.7.5" - -[dependencies.flate2] -optional = true -default-features = false -version = "1.0" -features = ["zlib"] +flate2 = { version = "1.0", features = ["zlib"], default-features = false } [dependencies.native-tls] optional = true diff --git a/README.md b/README.md index a0351db..5374cfe 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Features Tungstenite provides a complete implementation of the WebSocket specification. TLS is supported on all platforms using native-tls. -There is no support for permessage-deflate at the moment. It's planned. +Permessage-deflate. Testing ------- diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 1b4ee94..523056d 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -3,17 +3,13 @@ use url::Url; use tungstenite::client::connect_with_config; use tungstenite::extensions::deflate::DeflateExt; -use tungstenite::extensions::uncompressed::PlainTextExt; use tungstenite::protocol::WebSocketConfig; -use tungstenite::{connect, Error, Message, Result, WebSocket}; +use tungstenite::{connect, Error, Message, Result}; const AGENT: &str = "Tungstenite"; fn get_case_count() -> Result { - let (mut socket, _): (WebSocket<_, PlainTextExt>, _) = connect_with_config( - Url::parse("ws://localhost:9001/getCaseCount").unwrap(), - None, - )?; + let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let msg = socket.read_message()?; socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) diff --git a/fuzz/fuzz_targets/read_message_client.rs b/fuzz/fuzz_targets/read_message_client.rs index affdb3e..1c0708b 100644 --- a/fuzz/fuzz_targets/read_message_client.rs +++ b/fuzz/fuzz_targets/read_message_client.rs @@ -1,12 +1,11 @@ #![no_main] -#[macro_use] -extern crate libfuzzer_sys; +#[macro_use] extern crate libfuzzer_sys; extern crate tungstenite; use std::io; use std::io::Cursor; -use tungstenite::protocol::Role; use tungstenite::WebSocket; +use tungstenite::protocol::Role; //use std::result::Result; // FIXME: copypasted from tungstenite's protocol/mod.rs diff --git a/scripts/autobahn-client.sh b/scripts/autobahn-client.sh old mode 100644 new mode 100755 diff --git a/scripts/autobahn-server.sh b/scripts/autobahn-server.sh old mode 100644 new mode 100755 diff --git a/src/client.rs b/src/client.rs index 25416cb..3911342 100644 --- a/src/client.rs +++ b/src/client.rs @@ -66,7 +66,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::error::{Error, Result}; -use crate::extensions::uncompressed::PlainTextExt; +use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::handshake::client::ClientHandshake; use crate::handshake::HandshakeError; @@ -88,12 +88,13 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config( +pub fn connect_with_config( request: Req, - config: Option>, -) -> Result<(WebSocket, Response)> + config: Option>, +) -> Result<(WebSocket, Response)> where - E: WebSocketExtension, + Req: IntoClientRequest, + Ext: WebSocketExtension, { let request: Request = request.into_client_request()?; let uri = request.uri(); @@ -129,7 +130,7 @@ where /// `connect` since it's the only function that uses native_tls. pub fn connect( request: Req, -) -> Result<(WebSocket, Response)> { +) -> Result<(WebSocket, Response)> { connect_with_config(request, None) } @@ -166,15 +167,15 @@ pub fn uri_mode(uri: &Uri) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client_with_config( +pub fn client_with_config( request: Req, stream: Stream, - config: Option>, -) -> StdResult<(WebSocket, Response), HandshakeError>> + config: Option>, +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: IntoClientRequest, - E: WebSocketExtension, + Ext: WebSocketExtension, { ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake() } @@ -188,8 +189,8 @@ pub fn client( request: Req, stream: Stream, ) -> StdResult< - (WebSocket, Response), - HandshakeError>, + (WebSocket, Response), + HandshakeError>, > where Stream: Read + Write, diff --git a/src/error.rs b/src/error.rs index 9915787..547a931 100644 --- a/src/error.rs +++ b/src/error.rs @@ -68,7 +68,7 @@ pub enum Error { /// HTTP format error. HttpFormat(http::Error), /// An error from a WebSocket extension. - ExtensionError(Box), + ExtensionError(Cow<'static, str>), } impl fmt::Display for Error { diff --git a/src/extensions/deflate.rs b/src/extensions/deflate.rs index b172e78..fe1f7d9 100644 --- a/src/extensions/deflate.rs +++ b/src/extensions/deflate.rs @@ -2,7 +2,7 @@ use std::fmt::{Display, Formatter}; -use crate::extensions::uncompressed::PlainTextExt; +use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::Frame; @@ -15,10 +15,20 @@ use flate2::{ }; use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS}; use http::{HeaderValue, Request, Response}; +use std::borrow::Cow; use std::mem::replace; use std::slice; -const EXT_NAME: &str = "permessage-deflate"; +/// The WebSocket Extension Identifier as per the IANA registry. +const EXT_IDENT: &str = "permessage-deflate"; + +/// The minimum size of the LZ77 sliding window size. +const LZ77_MIN_WINDOW_SIZE: u8 = 9; + +/// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter +/// indicates that the client can receive messages compressed using an LZ77 sliding window of up to +/// 32,768 bytes. RFC 7692 7.1.2.1. +const LZ77_MAX_WINDOW_SIZE: u8 = 15; /// A permessage-deflate configuration. #[derive(Clone, Copy, Debug)] @@ -33,9 +43,14 @@ pub struct DeflateConfig { max_window_bits: u8, /// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1. request_no_context_takeover: bool, + /// Whether to accept `no_context_takeover`. accept_no_context_takeover: bool, + // Whether the compressor should be reset after usage. compress_reset: bool, + // Whether the decompressor should be reset after usage. decompress_reset: bool, + /// The active compression level. The integer here is typically on a scale of 0-9 where 0 means + /// "no compression" and 9 means "take as long as you'd like". compression_level: Compression, } @@ -91,7 +106,7 @@ impl DeflateConfig { /// Sets the LZ77 sliding window size. pub fn set_max_window_bits(&mut self, max_window_bits: u8) { - assert!((9u8..=15u8).contains(&max_window_bits)); + assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits)); self.max_window_bits = max_window_bits; } @@ -110,7 +125,7 @@ impl Default for DeflateConfig { fn default() -> Self { DeflateConfig { max_message_size: Some(MAX_MESSAGE_SIZE), - max_window_bits: 15, + max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, accept_no_context_takeover: true, compress_reset: false, @@ -135,7 +150,7 @@ impl Default for DeflateConfigBuilder { fn default() -> Self { DeflateConfigBuilder { max_message_size: Some(MAX_MESSAGE_SIZE), - max_window_bits: 15, + max_window_bits: LZ77_MAX_WINDOW_SIZE, request_no_context_takeover: false, accept_no_context_takeover: true, fragments_grow: true, @@ -154,7 +169,7 @@ impl DeflateConfigBuilder { /// Sets the LZ77 sliding window size. Panics if the provided size is not in `9..=15`. pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder { assert!( - (9u8..=15u8).contains(&max_window_bits), + (LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits), "max window bits must be in range 9..=15" ); self.max_window_bits = max_window_bits; @@ -207,26 +222,7 @@ pub struct DeflateExt { /// The deflate compressor. deflator: Deflator, /// If this deflate extension is not used, messages will be forwarded to this extension. - uncompressed_extension: PlainTextExt, -} - -impl Clone for DeflateExt { - fn clone(&self) -> Self { - DeflateExt { - enabled: self.enabled, - config: self.config, - fragments: vec![], - inflator: Inflator::new(), - deflator: Deflator::new(self.config.compression_level()), - uncompressed_extension: PlainTextExt::new(self.config.max_message_size()), - } - } -} - -impl Default for DeflateExt { - fn default() -> Self { - DeflateExt::new(Default::default()) - } + uncompressed_extension: UncompressedExt, } impl DeflateExt { @@ -238,7 +234,7 @@ impl DeflateExt { fragments: vec![], inflator: Inflator::new(), deflator: Deflator::new(Compression::fast()), - uncompressed_extension: PlainTextExt::new(config.max_message_size()), + uncompressed_extension: UncompressedExt::new(config.max_message_size()), } } @@ -262,10 +258,10 @@ impl DeflateExt { match window_bits_str.trim().parse() { Ok(mut window_bits) => { if window_bits == 8 { - window_bits = 9; + window_bits = LZ77_MIN_WINDOW_SIZE; } - if window_bits >= 9 && window_bits <= 15 { + if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE { if window_bits != self.config.max_window_bits() { Ok(Some(window_bits)) } else { @@ -284,7 +280,7 @@ impl DeflateExt { fn decline(&mut self, res: &mut Response) { self.enabled = false; - res.headers_mut().remove(EXT_NAME); + res.headers_mut().remove(EXT_IDENT); } } @@ -319,7 +315,7 @@ impl std::error::Error for DeflateExtensionError {} impl From for crate::Error { fn from(e: DeflateExtensionError) -> Self { - crate::Error::ExtensionError(Box::new(e)) + crate::Error::ExtensionError(Cow::from(e.to_string())) } } @@ -329,6 +325,12 @@ impl From for DeflateExtensionError { } } +impl Default for DeflateExt { + fn default() -> Self { + DeflateExt::new(Default::default()) + } +} + impl WebSocketExtension for DeflateExt { type Error = DeflateExtensionError; @@ -344,14 +346,14 @@ impl WebSocketExtension for DeflateExt { } fn on_make_request(&mut self, mut request: Request) -> Request { - let mut header_value = String::from(EXT_NAME); + let mut header_value = String::from(EXT_IDENT); let DeflateConfig { max_window_bits, request_no_context_takeover, .. } = self.config; - if max_window_bits < 15 { + if max_window_bits < LZ77_MAX_WINDOW_SIZE { header_value.push_str(&format!( "; client_max_window_bits={}; server_max_window_bits={}", max_window_bits, max_window_bits @@ -486,7 +488,7 @@ impl WebSocketExtension for DeflateExt { } if !response_str.contains("client_max_window_bits") - && self.config.max_window_bits() < 15 + && self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE { continue; } @@ -671,7 +673,7 @@ impl WebSocketExtension for DeflateExt { if self.enabled && (!self.fragments.is_empty() || frame.header().rsv1) { if !frame.header().is_final { self.fragments.push(frame); - return Ok(None); + Ok(None) } else { let message = if let OpCode::Data(Data::Continue) = frame.header().opcode { self.fragments.push(frame); diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index d8207de..0e56bea 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -13,7 +13,7 @@ pub mod uncompressed; /// A trait for defining WebSocket extensions. Extensions may be stacked by nesting them inside /// one another. -pub trait WebSocketExtension: Default + Clone { +pub trait WebSocketExtension { /// An error type that the extension produces. type Error: Into; diff --git a/src/extensions/uncompressed.rs b/src/extensions/uncompressed.rs index c4f4643..ec5919b 100644 --- a/src/extensions/uncompressed.rs +++ b/src/extensions/uncompressed.rs @@ -2,47 +2,31 @@ use crate::extensions::WebSocketExtension; use crate::protocol::frame::coding::{Data, OpCode}; use crate::protocol::frame::Frame; use crate::protocol::message::{IncompleteMessage, IncompleteMessageType}; -use crate::protocol::MAX_MESSAGE_SIZE; use crate::{Error, Message}; /// An uncompressed message handler for a WebSocket. #[derive(Debug)] -pub struct PlainTextExt { +pub struct UncompressedExt { incomplete: Option, max_message_size: Option, } -impl PlainTextExt { - /// Builds a new `PlainTextExt` that will permit a maximum message size of `max_message_size` +impl UncompressedExt { + /// Builds a new `UncompressedExt` that will permit a maximum message size of `max_message_size` /// or will be unbounded if `None`. - pub fn new(max_message_size: Option) -> PlainTextExt { - PlainTextExt { + pub fn new(max_message_size: Option) -> UncompressedExt { + UncompressedExt { incomplete: None, max_message_size, } } } -impl Clone for PlainTextExt { - fn clone(&self) -> Self { - Self::default() - } -} - -impl Default for PlainTextExt { - fn default() -> Self { - PlainTextExt { - incomplete: None, - max_message_size: Some(MAX_MESSAGE_SIZE), - } - } -} - -impl WebSocketExtension for PlainTextExt { +impl WebSocketExtension for UncompressedExt { type Error = Error; fn new(max_message_size: Option) -> Self { - PlainTextExt { + UncompressedExt { incomplete: None, max_message_size, } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 9e1dffd..d8c84d7 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -22,24 +22,25 @@ pub type Response = HttpResponse<()>; /// Client handshake role. #[derive(Debug)] -pub struct ClientHandshake +pub struct ClientHandshake where - E: WebSocketExtension, + Extension: WebSocketExtension, { verify_data: VerifyData, - config: Option>, + config: Option>>, _marker: PhantomData, } -impl ClientHandshake +impl ClientHandshake where - E: WebSocketExtension, + Stream: Read + Write, + Ext: WebSocketExtension, { /// Initiate a client handshake. pub fn start( - stream: S, + stream: Stream, request: Request, - mut config: Option>, + mut config: Option>, ) -> Result> { if request.method() != http::Method::GET { return Err(Error::Protocol( @@ -67,7 +68,7 @@ where let accept_key = convert_key(key.as_ref()).unwrap(); ClientHandshake { verify_data: VerifyData { accept_key }, - config, + config: Some(config), _marker: PhantomData, } }; @@ -80,13 +81,14 @@ where } } -impl HandshakeRole for ClientHandshake +impl HandshakeRole for ClientHandshake where - E: WebSocketExtension, + Stream: Read + Write, + Ext: WebSocketExtension, { type IncomingData = Response; - type InternalStream = S; - type FinalResult = (WebSocket, Response); + type InternalStream = Stream; + type FinalResult = (WebSocket, Response); fn stage_finished( &mut self, @@ -101,11 +103,11 @@ where result, tail, } => { - self.verify_data - .verify_response(&result, &mut self.config)?; + let mut config = self.config.take().unwrap(); + + self.verify_data.verify_response(&result, &mut config)?; debug!("Client handshake done."); - let websocket = - WebSocket::from_partially_read(stream, tail, Role::Client, self.config.clone()); + let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config); ProcessingResult::Done((websocket, result)) } }) @@ -113,13 +115,13 @@ where } /// Generate client request. -fn generate_request( +fn generate_request( request: Request, key: &str, - config: &mut Option>, + config: &mut Option>, ) -> Result> where - E: WebSocketExtension, + Ext: WebSocketExtension, { let request = match config { Some(ref mut config) => config.encoder.on_make_request(request), @@ -181,13 +183,13 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response( + pub fn verify_response( &self, response: &Response, - config: &mut Option>, + config: &mut Option>, ) -> Result<()> where - E: WebSocketExtension, + Ext: WebSocketExtension, { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) @@ -306,7 +308,7 @@ mod tests { use super::super::machine::TryParse; use super::{generate_key, generate_request, Response}; use crate::client::IntoClientRequest; - use crate::extensions::uncompressed::PlainTextExt; + use crate::extensions::uncompressed::UncompressedExt; #[test] fn random_keys() { @@ -337,7 +339,8 @@ mod tests { Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; let request = - generate_request::(request, key, &mut Some(Default::default())).unwrap(); + generate_request::(request, key, &mut Some(Default::default())) + .unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } @@ -357,7 +360,8 @@ mod tests { Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; let request = - generate_request::(request, key, &mut Some(Default::default())).unwrap(); + generate_request::(request, key, &mut Some(Default::default())) + .unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } @@ -377,7 +381,8 @@ mod tests { Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\ \r\n"; let request = - generate_request::(request, key, &mut Some(Default::default())).unwrap(); + generate_request::(request, key, &mut Some(Default::default())) + .unwrap(); println!("Request: {}", String::from_utf8_lossy(&request)); assert_eq!(&request[..], &correct[..]); } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 21b4de4..c04755a 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -93,10 +93,9 @@ pub fn create_response(request: &Request) -> Result { fn write_response(w: &mut dyn io::Write, response: &HttpResponse) -> Result<()> { writeln!( w, - "{version:?} {status} {reason}\r", + "{version:?} {status}\r", version = response.version(), - status = response.status(), - reason = response.status().canonical_reason().unwrap_or(""), + status = response.status() )?; for (k, v) in response.headers() { @@ -192,37 +191,43 @@ impl Callback for NoCallback { /// Server handshake role. #[allow(missing_copy_implementations)] #[derive(Debug)] -pub struct ServerHandshake +pub struct ServerHandshake where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// Callback which is called whenever the server read the request from the client and is ready /// to reply to it. The callback returns an optional headers which will be added to the reply /// which the server sends to the user. callback: Option, /// WebSocket configuration. - config: Option>, + config: Option>>, /// Error code/flag. If set, an error will be returned after sending response to the client. error_code: Option, /// Internal stream type. _marker: PhantomData, } -impl ServerHandshake +impl ServerHandshake where - E: WebSocketExtension, + S: Read + Write, + C: Callback, + Ext: WebSocketExtension, { /// Start server handshake. `callback` specifies a custom callback which the user can pass to /// the handshake, this callback will be called when the a websocket client connnects to the /// server, you can specify the callback if you want to add additional header to the client /// upon join based on the incoming headers. - pub fn start(stream: S, callback: C, config: Option>) -> MidHandshake { + pub fn start( + stream: S, + callback: C, + config: Option>, + ) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), role: ServerHandshake { callback: Some(callback), - config, + config: Some(config), error_code: None, _marker: PhantomData, }, @@ -230,13 +235,15 @@ where } } -impl HandshakeRole for ServerHandshake +impl HandshakeRole for ServerHandshake where - E: WebSocketExtension, + S: Read + Write, + C: Callback, + Ext: WebSocketExtension, { type IncomingData = Request; type InternalStream = S; - type FinalResult = WebSocket; + type FinalResult = WebSocket; fn stage_finished( &mut self, @@ -254,7 +261,7 @@ where let mut response = create_response(&request)?; - if let Some(ref mut config) = self.config { + if let Some(ref mut config) = self.config.as_mut().unwrap() { if let Err(e) = config.encoder.on_receive_request(&request, &mut response) { return Err(e.into()); } @@ -298,8 +305,11 @@ where return Err(Error::Http(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); - let websocket = - WebSocket::from_raw_socket(stream, Role::Server, self.config.clone()); + let websocket = WebSocket::from_raw_socket( + stream, + Role::Server, + self.config.take().unwrap(), + ); ProcessingResult::Done(websocket) } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index e95207f..28c3a43 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -16,7 +16,7 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use self::frame::{Frame, FrameCodec}; use self::message::IncompleteMessage; use crate::error::{Error, Result}; -use crate::extensions::uncompressed::PlainTextExt; +use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use crate::util::NonBlockingResult; @@ -33,7 +33,7 @@ pub enum Role { /// The configuration for WebSocket connection. #[derive(Debug, Copy, Clone)] -pub struct WebSocketConfig +pub struct WebSocketConfig where E: WebSocketExtension, { @@ -83,26 +83,30 @@ where /// This is THE structure you want to create to be able to speak the WebSocket protocol. /// It may be created by calling `connect`, `accept` or `client` functions. #[derive(Debug)] -pub struct WebSocket +pub struct WebSocket where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// The underlying socket. socket: Stream, /// The context for managing a WebSocket. - context: WebSocketContext, + context: WebSocketContext, } -impl WebSocket +impl WebSocket where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// Convert a raw socket into a WebSocket without performing a handshake. /// /// Call this function if you're using Tungstenite as a part of a web framework /// or together with an existing one. If you need an initial handshake, use /// `connect()` or `accept()` functions of the crate to construct a websocket. - pub fn from_raw_socket(stream: Stream, role: Role, config: Option>) -> Self { + pub fn from_raw_socket( + stream: Stream, + role: Role, + config: Option>, + ) -> Self { WebSocket { socket: stream, context: WebSocketContext::new(role, config), @@ -118,7 +122,7 @@ where stream: Stream, part: Vec, role: Role, - config: Option>, + config: Option>, ) -> Self { WebSocket { socket: stream, @@ -137,12 +141,12 @@ where } /// Change the configuration. - pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { + pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { self.context.set_config(set_func) } /// Read the configuration. - pub fn get_config(&self) -> &WebSocketConfig { + pub fn get_config(&self) -> &WebSocketConfig { self.context.get_config() } @@ -162,9 +166,10 @@ where } } -impl WebSocket +impl WebSocket where - E: WebSocketExtension, + Stream: Read + Write, + Ext: WebSocketExtension, { /// Read a message from stream, if possible. /// @@ -248,9 +253,9 @@ where /// A context for managing WebSocket stream. #[derive(Debug)] -pub struct WebSocketContext +pub struct WebSocketContext where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// Server or client? role: Role, @@ -265,15 +270,15 @@ where /// Send: an OOB pong message. pong: Option, /// The configuration for the websocket session. - config: WebSocketConfig, + config: WebSocketConfig, } -impl WebSocketContext +impl WebSocketContext where - E: WebSocketExtension, + Ext: WebSocketExtension, { /// Create a WebSocket context that manages a post-handshake stream. - pub fn new(role: Role, config: Option>) -> Self { + pub fn new(role: Role, config: Option>) -> Self { let config = config.unwrap_or_else(Default::default); WebSocketContext { @@ -291,7 +296,7 @@ where pub fn from_partially_read( part: Vec, role: Role, - config: Option>, + config: Option>, ) -> Self { WebSocketContext { frame: FrameCodec::from_partially_read(part), @@ -300,12 +305,12 @@ where } /// Change the configuration. - pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { + pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { set_func(&mut self.config) } /// Read the configuration. - pub fn get_config(&self) -> &WebSocketConfig { + pub fn get_config(&self) -> &WebSocketConfig { &self.config } @@ -672,7 +677,7 @@ impl CheckConnectionReset for Result { mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; - use crate::extensions::uncompressed::PlainTextExt; + use crate::extensions::uncompressed::UncompressedExt; use std::io; use std::io::Cursor; @@ -700,7 +705,7 @@ mod tests { 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02, 0x03, ]); - let mut socket: WebSocket<_, PlainTextExt> = + let mut socket: WebSocket<_, UncompressedExt> = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); @@ -723,7 +728,7 @@ mod tests { let limit = WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: PlainTextExt::new(Some(10)), + encoder: UncompressedExt::new(Some(10)), }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( @@ -738,7 +743,7 @@ mod tests { let limit = WebSocketConfig { max_send_queue: None, max_frame_size: Some(16 << 20), - encoder: PlainTextExt::new(Some(2)), + encoder: UncompressedExt::new(Some(2)), }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); assert_eq!( diff --git a/src/server.rs b/src/server.rs index 66130de..99e3757 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,7 +7,7 @@ use crate::handshake::HandshakeError; use crate::protocol::{WebSocket, WebSocketConfig}; -use crate::extensions::uncompressed::PlainTextExt; +use crate::extensions::uncompressed::UncompressedExt; use crate::extensions::WebSocketExtension; use std::io::{Read, Write}; @@ -20,12 +20,13 @@ use std::io::{Read, Write}; /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept_with_config( - stream: S, - config: Option>, -) -> Result, HandshakeError>> +pub fn accept_with_config( + stream: Stream, + config: Option>, +) -> Result, HandshakeError>> where - E: WebSocketExtension, + Stream: Read + Write, + Ext: WebSocketExtension, { accept_hdr_with_config(stream, NoCallback, config) } @@ -38,8 +39,10 @@ where /// those from `Mio` and others. pub fn accept( stream: S, -) -> Result, HandshakeError>> -{ +) -> Result< + WebSocket, + HandshakeError>, +> { accept_with_config(stream, None) } @@ -51,13 +54,15 @@ pub fn accept( /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr_with_config( +pub fn accept_hdr_with_config( stream: S, callback: C, - config: Option>, -) -> Result, HandshakeError>> + config: Option>, +) -> Result, HandshakeError>> where - E: WebSocketExtension, + S: Read + Write, + C: Callback, + Ext: WebSocketExtension, { ServerHandshake::start(stream, callback, config).handshake() } @@ -70,6 +75,6 @@ where pub fn accept_hdr( stream: S, callback: C, -) -> Result, HandshakeError>> { +) -> Result, HandshakeError>> { accept_hdr_with_config(stream, callback, None) } diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 1565d2e..87396ff 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -8,16 +8,16 @@ use std::time::Duration; use native_tls::TlsStream; use net2::TcpStreamExt; -use tungstenite::extensions::uncompressed::PlainTextExt; +use tungstenite::extensions::uncompressed::UncompressedExt; use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; use url::Url; -type Sock = WebSocket>, E>; +type Sock = WebSocket>, Ext>; fn do_test(port: u16, client_task: CT, server_task: ST) where - CT: FnOnce(Sock) + Send + 'static, - ST: FnOnce(WebSocket), + CT: FnOnce(Sock) + Send + 'static, + ST: FnOnce(WebSocket), { env_logger::try_init().ok();