Make the websocket configurable

pull/27/head
Daniel Abramov 6 years ago
parent 1f037abc34
commit 00303fa60c
  1. 8
      examples/autobahn-client.rs
  2. 2
      examples/autobahn-server.rs
  3. 2
      examples/client.rs
  4. 2
      examples/server.rs
  5. 12
      src/client.rs
  6. 19
      src/handshake/client.rs
  7. 15
      src/handshake/server.rs
  8. 51
      src/protocol/mod.rs
  9. 16
      src/server.rs

@ -11,7 +11,8 @@ const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
None,
)?;
let msg = socket.read_message()?;
socket.close(None)?;
@ -20,7 +21,8 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
None,
)?;
socket.close(None)?;
Ok(())
@ -31,7 +33,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let (mut socket, _) = connect(case_url)?;
let (mut socket, _) = connect(case_url, None)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -16,7 +16,7 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut socket = accept(stream, None).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |

@ -8,7 +8,7 @@ use tungstenite::{Message, connect};
fn main() {
env_logger::init();
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap(), None)
.expect("Can't connect");
println!("Connected to the server");

@ -25,7 +25,7 @@ fn main() {
];
Ok(Some(extra_headers))
};
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();
let mut websocket = accept_hdr(stream.unwrap(), callback, None).unwrap();
loop {
let msg = websocket.read_message().unwrap();

@ -7,6 +7,7 @@ use std::io::{Read, Write};
use url::Url;
use handshake::client::Response;
use protocol::WebSocketConfig;
#[cfg(feature="tls")]
mod encryption {
@ -77,7 +78,7 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
let request: Request = request.into();
@ -85,7 +86,7 @@ pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
let addrs = request.url.to_socket_addrs()?;
let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client(request, stream)
client(request, stream, config)
.map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
@ -126,11 +127,12 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
request: Req,
stream: Stream
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
{
ClientHandshake::start(stream, request.into()).handshake()
ClientHandshake::start(stream, request.into(), config).handshake()
}

@ -11,7 +11,7 @@ use rand;
use url::Url;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -71,12 +71,17 @@ impl From<Url> for Request<'static> {
#[derive(Debug)]
pub struct ClientHandshake<S> {
verify_data: VerifyData,
config: Option<WebSocketConfig>,
_marker: PhantomData<S>,
}
impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake.
pub fn start(stream: S, request: Request) -> MidHandshake<Self> {
pub fn start(
stream: S,
request: Request,
config: Option<WebSocketConfig>
) -> MidHandshake<Self> {
let key = generate_key();
let machine = {
@ -102,6 +107,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData { accept_key },
config,
_marker: PhantomData,
}
};
@ -125,8 +131,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client),
result))
let websocket = WebSocket::from_partially_read(
stream,
tail,
Role::Client,
self.config.clone(),
);
ProcessingResult::Done((websocket, result))
}
})
}

@ -8,7 +8,7 @@ use httparse;
use httparse::Status;
use error::{Error, Result};
use protocol::{WebSocket, Role};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -108,6 +108,8 @@ pub struct ServerHandshake<S, C> {
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -117,11 +119,11 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: C) -> MidHandshake<Self> {
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), _marker: PhantomData },
role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData },
}
}
}
@ -151,7 +153,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
}
StageResult::DoneWriting(stream) => {
debug!("Server handshake done.");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket)
}
})
}

@ -26,6 +26,23 @@ pub enum Role {
Client,
}
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
}
}
}
/// WebSocket input-output stream.
///
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
@ -42,20 +59,26 @@ pub struct WebSocket<Stream> {
incomplete: Option<IncompleteMessage>,
/// Send: a data send queue.
send_queue: VecDeque<Frame>,
max_send_queue: usize,
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
}
impl<Stream> WebSocket<Stream> {
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket(stream: Stream, role: Role) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role)
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role, config)
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_partially_read(stream: Stream, part: Vec<u8>, role: Role) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role)
pub fn from_partially_read(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role, config)
}
/// Returns a shared reference to the inner stream.
@ -68,15 +91,19 @@ impl<Stream> WebSocket<Stream> {
}
/// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self {
fn from_frame_socket(
socket: FrameSocket<Stream>,
role: Role,
config: Option<WebSocketConfig>
) -> Self {
WebSocket {
role: role,
socket: socket,
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
max_send_queue: 1,
pong: None,
config: config.unwrap_or_else(|| WebSocketConfig::default()),
}
}
}
@ -104,7 +131,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Send a message to stream, if possible.
///
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// and Close requests. If the WebSocket's send queue is full, SendQueueFull will be returned
/// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
///
/// Note that only the last pong frame is stored to be sent, and only the
@ -113,8 +140,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
// Try to make some room for the new message
self.write_pending().no_block()?;
if self.send_queue.len() >= self.max_send_queue {
return Err(Error::SendQueueFull(message));
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message {
@ -466,7 +495,7 @@ mod tests {
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));

@ -5,7 +5,7 @@ pub use handshake::server::ServerHandshake;
use handshake::HandshakeError;
use handshake::server::{Callback, NoCallback};
use protocol::WebSocket;
use protocol::{WebSocket, WebSocketConfig};
use std::io::{Read, Write};
@ -15,10 +15,10 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(stream: S)
pub fn accept<S: Read + Write>(stream: S, config: Option<WebSocketConfig>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{
accept_hdr(stream, NoCallback)
accept_hdr(stream, NoCallback, config)
}
/// Accept the given Stream as a WebSocket.
@ -26,8 +26,10 @@ pub fn accept<S: Read + Write>(stream: S)
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
{
ServerHandshake::start(stream, callback).handshake()
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
config: Option<WebSocketConfig>
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback, config).handshake()
}

Loading…
Cancel
Save