Merge pull request #27 from snapview/websocket-config

Allow the configuration of `WebSocket`
pull/37/head
Alexey Galakhov 6 years ago committed by GitHub
commit b93abcf900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Cargo.toml
  2. 4
      examples/autobahn-client.rs
  3. 57
      src/client.rs
  4. 19
      src/handshake/client.rs
  5. 15
      src/handshake/server.rs
  6. 51
      src/protocol/mod.rs
  7. 37
      src/server.rs

@ -9,7 +9,7 @@ readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.5.4" documentation = "https://docs.rs/tungstenite/0.5.4"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.5.4" version = "0.6.0"
[features] [features]
default = ["tls"] default = ["tls"]

@ -11,7 +11,7 @@ const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap() Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
)?; )?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
@ -20,7 +20,7 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())

@ -7,6 +7,7 @@ use std::io::{Read, Write};
use url::Url; use url::Url;
use handshake::client::Response; use handshake::client::Response;
use protocol::WebSocketConfig;
#[cfg(feature="tls")] #[cfg(feature="tls")]
mod encryption { mod encryption {
@ -67,6 +68,9 @@ use error::{Error, Result};
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
/// ///
/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
/// equal to calling `connect()` function.
///
/// The URL may be either ws:// or wss://. /// The URL may be either ws:// or wss://.
/// To support wss:// URLs, feature "tls" must be turned on. /// To support wss:// URLs, feature "tls" must be turned on.
/// ///
@ -77,21 +81,40 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
-> Result<(WebSocket<AutoStream>, Response)> request: Req,
{ config: Option<WebSocketConfig>
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into(); let request: Request = request.into();
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let addrs = request.url.to_socket_addrs()?; let addrs = request.url.to_socket_addrs()?;
let mut stream = connect_to_some(addrs, &request.url, mode)?; let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client(request, stream) client_with_config(request, stream, config)
.map_err(|e| match e { .map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) })
} }
/// Connect to the given WebSocket in blocking mode.
///
/// The URL may be either ws:// or wss://.
/// To support wss:// URLs, feature "tls" must be turned on.
///
/// This function "just works" for those who wants a simple blocking solution
/// similar to `std::net::TcpStream`. If you want a non-blocking or other
/// custom stream, call `client` instead.
///
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Response)>
{
connect_with_config(request, None)
}
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream> fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr> where A: Iterator<Item=SocketAddr>
{ {
@ -119,18 +142,34 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
} }
} }
/// Do the client handshake over the given stream. /// Do the client handshake over the given stream given a web socket configuration. Passing `None`
/// as configuration is equal to calling `client()` function.
/// ///
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>( pub fn client_with_config<'t, Stream, Req>(
request: Req, request: Req,
stream: Stream stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
{
ClientHandshake::start(stream, request.into(), config).handshake()
}
/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,
{ {
ClientHandshake::start(stream, request.into()).handshake() client_with_config(request, stream, None)
} }

@ -11,7 +11,7 @@ use rand;
use url::Url; use url::Url;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -71,12 +71,17 @@ impl From<Url> for Request<'static> {
#[derive(Debug)] #[derive(Debug)]
pub struct ClientHandshake<S> { pub struct ClientHandshake<S> {
verify_data: VerifyData, verify_data: VerifyData,
config: Option<WebSocketConfig>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<S: Read + Write> ClientHandshake<S> { impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start(stream: S, request: Request) -> MidHandshake<Self> { pub fn start(
stream: S,
request: Request,
config: Option<WebSocketConfig>
) -> MidHandshake<Self> {
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
@ -102,6 +107,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake {
verify_data: VerifyData { accept_key }, verify_data: VerifyData { accept_key },
config,
_marker: PhantomData, _marker: PhantomData,
} }
}; };
@ -125,8 +131,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), let websocket = WebSocket::from_partially_read(
result)) stream,
tail,
Role::Client,
self.config.clone(),
);
ProcessingResult::Done((websocket, result))
} }
}) })
} }

@ -8,7 +8,7 @@ use httparse;
use httparse::Status; use httparse::Status;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -108,6 +108,8 @@ pub struct ServerHandshake<S, C> {
/// to reply to it. The callback returns an optional headers which will be added to the reply /// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user. /// which the server sends to the user.
callback: Option<C>, callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -117,11 +119,11 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
/// the handshake, this callback will be called when the a websocket client connnects to the /// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client /// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers. /// upon join based on the incoming headers.
pub fn start(stream: S, callback: C) -> MidHandshake<Self> { pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), _marker: PhantomData }, role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData },
} }
} }
} }
@ -151,7 +153,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
debug!("Server handshake done."); debug!("Server handshake done.");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket)
} }
}) })
} }

@ -26,6 +26,23 @@ pub enum Role {
Client, Client,
} }
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
}
}
}
/// WebSocket input-output stream. /// WebSocket input-output stream.
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
@ -42,20 +59,26 @@ pub struct WebSocket<Stream> {
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send: a data send queue. /// Send: a data send queue.
send_queue: VecDeque<Frame>, send_queue: VecDeque<Frame>,
max_send_queue: usize,
/// Send: an OOB pong message. /// Send: an OOB pong message.
pong: Option<Frame>, pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
} }
impl<Stream> WebSocket<Stream> { impl<Stream> WebSocket<Stream> {
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket(stream: Stream, role: Role) -> Self { pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role) WebSocket::from_frame_socket(FrameSocket::new(stream), role, config)
} }
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_partially_read(stream: Stream, part: Vec<u8>, role: Role) -> Self { pub fn from_partially_read(
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role, config)
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
@ -68,15 +91,19 @@ impl<Stream> WebSocket<Stream> {
} }
/// Convert a frame socket into a WebSocket. /// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self { fn from_frame_socket(
socket: FrameSocket<Stream>,
role: Role,
config: Option<WebSocketConfig>
) -> Self {
WebSocket { WebSocket {
role: role, role: role,
socket: socket, socket: socket,
state: WebSocketState::Active, state: WebSocketState::Active,
incomplete: None, incomplete: None,
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
max_send_queue: 1,
pong: None, pong: None,
config: config.unwrap_or_else(|| WebSocketConfig::default()),
} }
} }
} }
@ -104,7 +131,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Send a message to stream, if possible. /// Send a message to stream, if possible.
/// ///
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// and Close requests. If the WebSocket's send queue is full, SendQueueFull will be returned /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// ///
/// Note that only the last pong frame is stored to be sent, and only the /// Note that only the last pong frame is stored to be sent, and only the
@ -113,8 +140,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
// Try to make some room for the new message // Try to make some room for the new message
self.write_pending().no_block()?; self.write_pending().no_block()?;
if self.send_queue.len() >= self.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
return Err(Error::SendQueueFull(message)); if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
} }
let frame = match message { let frame = match message {
@ -466,7 +495,7 @@ mod tests {
0x82, 0x03, 0x82, 0x03,
0x01, 0x02, 0x03, 0x01, 0x02, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));

@ -5,10 +5,25 @@ pub use handshake::server::ServerHandshake;
use handshake::HandshakeError; use handshake::HandshakeError;
use handshake::server::{Callback, NoCallback}; use handshake::server::{Callback, NoCallback};
use protocol::WebSocket; use protocol::{WebSocket, WebSocketConfig};
use std::io::{Read, Write}; use std::io::{Read, Write};
/// Accept the given Stream as a WebSocket.
///
/// Uses a configuration provided as an argument. Calling it with `None` will use the default one
/// used by `accept()`.
///
/// This function starts a server WebSocket handshake over the given stream.
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketConfig>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{
accept_hdr_with_config(stream, NoCallback, config)
}
/// Accept the given Stream as a WebSocket. /// Accept the given Stream as a WebSocket.
/// ///
/// This function starts a server WebSocket handshake over the given stream. /// This function starts a server WebSocket handshake over the given stream.
@ -18,7 +33,23 @@ use std::io::{Read, Write};
pub fn accept<S: Read + Write>(stream: S) pub fn accept<S: Read + Write>(stream: S)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{ {
accept_hdr(stream, NoCallback) accept_with_config(stream, None)
}
/// Accept the given Stream as a WebSocket.
///
/// Uses a configuration provided as an argument. Calling it with `None` will use the default one
/// used by `accept_hdr()`.
///
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
stream: S,
callback: C,
config: Option<WebSocketConfig>
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback, config).handshake()
} }
/// Accept the given Stream as a WebSocket. /// Accept the given Stream as a WebSocket.
@ -29,5 +60,5 @@ pub fn accept<S: Read + Write>(stream: S)
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C) pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
{ {
ServerHandshake::start(stream, callback).handshake() accept_hdr_with_config(stream, callback, None)
} }

Loading…
Cancel
Save