diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index e053f5e..e5488b4 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -11,7 +11,8 @@ const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { let (mut socket, _) = connect( - Url::parse("ws://localhost:9001/getCaseCount").unwrap() + Url::parse("ws://localhost:9001/getCaseCount").unwrap(), + None, )?; let msg = socket.read_message()?; socket.close(None)?; @@ -20,7 +21,8 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( - Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() + Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), + None, )?; socket.close(None)?; Ok(()) @@ -31,7 +33,7 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let (mut socket, _) = connect(case_url)?; + let (mut socket, _) = connect(case_url, None)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 10aba75..53b7d7a 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -16,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream).map_err(must_not_block)?; + let mut socket = accept(stream, None).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/client.rs b/examples/client.rs index 13b7d59..1ba7c71 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,7 +8,7 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init(); - let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) + let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap(), None) .expect("Can't connect"); println!("Connected to the server"); diff --git a/examples/server.rs b/examples/server.rs index 86df2fe..9649d6c 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -25,7 +25,7 @@ fn main() { ]; Ok(Some(extra_headers)) }; - let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); + let mut websocket = accept_hdr(stream.unwrap(), callback, None).unwrap(); loop { let msg = websocket.read_message().unwrap(); diff --git a/src/client.rs b/src/client.rs index 2326ac7..bb4d7f1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::io::{Read, Write}; use url::Url; use handshake::client::Response; +use protocol::WebSocketConfig; #[cfg(feature="tls")] mod encryption { @@ -77,7 +78,7 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req) +pub fn connect<'t, Req: Into>>(request: Req, config: Option) -> Result<(WebSocket, Response)> { let request: Request = request.into(); @@ -85,7 +86,7 @@ pub fn connect<'t, Req: Into>>(request: Req) let addrs = request.url.to_socket_addrs()?; let mut stream = connect_to_some(addrs, &request.url, mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client(request, stream) + client(request, stream, config) .map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), @@ -126,11 +127,12 @@ pub fn url_mode(url: &Url) -> Result { /// Any stream supporting `Read + Write` will do. pub fn client<'t, Stream, Req>( request: Req, - stream: Stream - ) -> StdResult<(WebSocket, Response), HandshakeError>> + stream: Stream, + config: Option, +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: Into>, { - ClientHandshake::start(stream, request.into()).handshake() + ClientHandshake::start(stream, request.into(), config).handshake() } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 00767c4..15494f9 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -11,7 +11,7 @@ use rand; use url::Url; use error::{Error, Result}; -use protocol::{WebSocket, Role}; +use protocol::{WebSocket, WebSocketConfig, Role}; use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; @@ -71,12 +71,17 @@ impl From for Request<'static> { #[derive(Debug)] pub struct ClientHandshake { verify_data: VerifyData, + config: Option, _marker: PhantomData, } impl ClientHandshake { /// Initiate a client handshake. - pub fn start(stream: S, request: Request) -> MidHandshake { + pub fn start( + stream: S, + request: Request, + config: Option + ) -> MidHandshake { let key = generate_key(); let machine = { @@ -102,6 +107,7 @@ impl ClientHandshake { let accept_key = convert_key(key.as_ref()).unwrap(); ClientHandshake { verify_data: VerifyData { accept_key }, + config, _marker: PhantomData, } }; @@ -125,8 +131,13 @@ impl HandshakeRole for ClientHandshake { StageResult::DoneReading { stream, result, tail, } => { self.verify_data.verify_response(&result)?; debug!("Client handshake done."); - ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), - result)) + let websocket = WebSocket::from_partially_read( + stream, + tail, + Role::Client, + self.config.clone(), + ); + ProcessingResult::Done((websocket, result)) } }) } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index e88cd38..11a8ffb 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -8,7 +8,7 @@ use httparse; use httparse::Status; use error::{Error, Result}; -use protocol::{WebSocket, Role}; +use protocol::{WebSocket, WebSocketConfig, Role}; use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; @@ -108,6 +108,8 @@ pub struct ServerHandshake { /// to reply to it. The callback returns an optional headers which will be added to the reply /// which the server sends to the user. callback: Option, + /// WebSocket configuration. + config: Option, /// Internal stream type. _marker: PhantomData, } @@ -117,11 +119,11 @@ impl ServerHandshake { /// the handshake, this callback will be called when the a websocket client connnects to the /// server, you can specify the callback if you want to add additional header to the client /// upon join based on the incoming headers. - pub fn start(stream: S, callback: C) -> MidHandshake { + pub fn start(stream: S, callback: C, config: Option) -> MidHandshake { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { callback: Some(callback), _marker: PhantomData }, + role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData }, } } } @@ -151,7 +153,12 @@ impl HandshakeRole for ServerHandshake { } StageResult::DoneWriting(stream) => { debug!("Server handshake done."); - ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) + let websocket = WebSocket::from_raw_socket( + stream, + Role::Server, + self.config.clone(), + ); + ProcessingResult::Done(websocket) } }) } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index a7ea9b2..adfe000 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -26,6 +26,23 @@ pub enum Role { Client, } +/// The configuration for WebSocket connection. +#[derive(Debug, Clone, Copy)] +pub struct WebSocketConfig { + /// The size of the send queue. You can use it to turn on/off the backpressure features. `None` + /// means here that the size of the queue is unlimited. The default value is the unlimited + /// queue. + pub max_send_queue: Option, +} + +impl Default for WebSocketConfig { + fn default() -> Self { + WebSocketConfig { + max_send_queue: None, + } + } +} + /// WebSocket input-output stream. /// /// This is THE structure you want to create to be able to speak the WebSocket protocol. @@ -42,20 +59,26 @@ pub struct WebSocket { incomplete: Option, /// Send: a data send queue. send_queue: VecDeque, - max_send_queue: usize, /// Send: an OOB pong message. pong: Option, + /// The configuration for the websocket session. + config: WebSocketConfig, } impl WebSocket { /// Convert a raw socket into a WebSocket without performing a handshake. - pub fn from_raw_socket(stream: Stream, role: Role) -> Self { - WebSocket::from_frame_socket(FrameSocket::new(stream), role) + pub fn from_raw_socket(stream: Stream, role: Role, config: Option) -> Self { + WebSocket::from_frame_socket(FrameSocket::new(stream), role, config) } /// Convert a raw socket into a WebSocket without performing a handshake. - pub fn from_partially_read(stream: Stream, part: Vec, role: Role) -> Self { - WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) + pub fn from_partially_read( + stream: Stream, + part: Vec, + role: Role, + config: Option, + ) -> Self { + WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role, config) } /// Returns a shared reference to the inner stream. @@ -68,15 +91,19 @@ impl WebSocket { } /// Convert a frame socket into a WebSocket. - fn from_frame_socket(socket: FrameSocket, role: Role) -> Self { + fn from_frame_socket( + socket: FrameSocket, + role: Role, + config: Option + ) -> Self { WebSocket { role: role, socket: socket, state: WebSocketState::Active, incomplete: None, send_queue: VecDeque::new(), - max_send_queue: 1, pong: None, + config: config.unwrap_or_else(|| WebSocketConfig::default()), } } } @@ -104,7 +131,7 @@ impl WebSocket { /// Send a message to stream, if possible. /// /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping - /// and Close requests. If the WebSocket's send queue is full, SendQueueFull will be returned + /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. /// /// Note that only the last pong frame is stored to be sent, and only the @@ -113,8 +140,10 @@ impl WebSocket { // Try to make some room for the new message self.write_pending().no_block()?; - if self.send_queue.len() >= self.max_send_queue { - return Err(Error::SendQueueFull(message)); + if let Some(max_send_queue) = self.config.max_send_queue { + if self.send_queue.len() >= max_send_queue { + return Err(Error::SendQueueFull(message)); + } } let frame = match message { @@ -466,7 +495,7 @@ mod tests { 0x82, 0x03, 0x01, 0x02, 0x03, ]); - let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); + let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); diff --git a/src/server.rs b/src/server.rs index 68e026f..d737565 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,7 @@ pub use handshake::server::ServerHandshake; use handshake::HandshakeError; use handshake::server::{Callback, NoCallback}; -use protocol::WebSocket; +use protocol::{WebSocket, WebSocketConfig}; use std::io::{Read, Write}; @@ -15,10 +15,10 @@ use std::io::{Read, Write}; /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept(stream: S) +pub fn accept(stream: S, config: Option) -> Result, HandshakeError>> { - accept_hdr(stream, NoCallback) + accept_hdr(stream, NoCallback, config) } /// Accept the given Stream as a WebSocket. @@ -26,8 +26,10 @@ pub fn accept(stream: S) /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr(stream: S, callback: C) - -> Result, HandshakeError>> -{ - ServerHandshake::start(stream, callback).handshake() +pub fn accept_hdr( + stream: S, + callback: C, + config: Option +) -> Result, HandshakeError>> { + ServerHandshake::start(stream, callback, config).handshake() }