Make the websocket configurable

pull/27/head
Daniel Abramov 6 years ago
parent 1f037abc34
commit 00303fa60c
  1. 8
      examples/autobahn-client.rs
  2. 2
      examples/autobahn-server.rs
  3. 2
      examples/client.rs
  4. 2
      examples/server.rs
  5. 12
      src/client.rs
  6. 19
      src/handshake/client.rs
  7. 15
      src/handshake/server.rs
  8. 51
      src/protocol/mod.rs
  9. 16
      src/server.rs

@ -11,7 +11,8 @@ const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap() Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
None,
)?; )?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
@ -20,7 +21,8 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
None,
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())
@ -31,7 +33,7 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse( let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap(); ).unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect(case_url, None)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -16,7 +16,7 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
} }
fn handle_client(stream: TcpStream) -> Result<()> { fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept(stream, None).map_err(must_not_block)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) |

@ -8,7 +8,7 @@ use tungstenite::{Message, connect};
fn main() { fn main() {
env_logger::init(); env_logger::init();
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap(), None)
.expect("Can't connect"); .expect("Can't connect");
println!("Connected to the server"); println!("Connected to the server");

@ -25,7 +25,7 @@ fn main() {
]; ];
Ok(Some(extra_headers)) Ok(Some(extra_headers))
}; };
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); let mut websocket = accept_hdr(stream.unwrap(), callback, None).unwrap();
loop { loop {
let msg = websocket.read_message().unwrap(); let msg = websocket.read_message().unwrap();

@ -7,6 +7,7 @@ use std::io::{Read, Write};
use url::Url; use url::Url;
use handshake::client::Response; use handshake::client::Response;
use protocol::WebSocketConfig;
#[cfg(feature="tls")] #[cfg(feature="tls")]
mod encryption { mod encryption {
@ -77,7 +78,7 @@ use error::{Error, Result};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) pub fn connect<'t, Req: Into<Request<'t>>>(request: Req, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)> -> Result<(WebSocket<AutoStream>, Response)>
{ {
let request: Request = request.into(); let request: Request = request.into();
@ -85,7 +86,7 @@ pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
let addrs = request.url.to_socket_addrs()?; let addrs = request.url.to_socket_addrs()?;
let mut stream = connect_to_some(addrs, &request.url, mode)?; let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client(request, stream) client(request, stream, config)
.map_err(|e| match e { .map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
@ -126,11 +127,12 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>( pub fn client<'t, Stream, Req>(
request: Req, request: Req,
stream: Stream stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,
{ {
ClientHandshake::start(stream, request.into()).handshake() ClientHandshake::start(stream, request.into(), config).handshake()
} }

@ -11,7 +11,7 @@ use rand;
use url::Url; use url::Url;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -71,12 +71,17 @@ impl From<Url> for Request<'static> {
#[derive(Debug)] #[derive(Debug)]
pub struct ClientHandshake<S> { pub struct ClientHandshake<S> {
verify_data: VerifyData, verify_data: VerifyData,
config: Option<WebSocketConfig>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<S: Read + Write> ClientHandshake<S> { impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start(stream: S, request: Request) -> MidHandshake<Self> { pub fn start(
stream: S,
request: Request,
config: Option<WebSocketConfig>
) -> MidHandshake<Self> {
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
@ -102,6 +107,7 @@ impl<S: Read + Write> ClientHandshake<S> {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake {
verify_data: VerifyData { accept_key }, verify_data: VerifyData { accept_key },
config,
_marker: PhantomData, _marker: PhantomData,
} }
}; };
@ -125,8 +131,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?; self.verify_data.verify_response(&result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
ProcessingResult::Done((WebSocket::from_partially_read(stream, tail, Role::Client), let websocket = WebSocket::from_partially_read(
result)) stream,
tail,
Role::Client,
self.config.clone(),
);
ProcessingResult::Done((websocket, result))
} }
}) })
} }

@ -8,7 +8,7 @@ use httparse;
use httparse::Status; use httparse::Status;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, Role}; use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
@ -108,6 +108,8 @@ pub struct ServerHandshake<S, C> {
/// to reply to it. The callback returns an optional headers which will be added to the reply /// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user. /// which the server sends to the user.
callback: Option<C>, callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -117,11 +119,11 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
/// the handshake, this callback will be called when the a websocket client connnects to the /// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client /// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers. /// upon join based on the incoming headers.
pub fn start(stream: S, callback: C) -> MidHandshake<Self> { pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), _marker: PhantomData }, role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData },
} }
} }
} }
@ -151,7 +153,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
debug!("Server handshake done."); debug!("Server handshake done.");
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server)) let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket)
} }
}) })
} }

@ -26,6 +26,23 @@ pub enum Role {
Client, Client,
} }
/// The configuration for WebSocket connection.
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
pub max_send_queue: Option<usize>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
}
}
}
/// WebSocket input-output stream. /// WebSocket input-output stream.
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
@ -42,20 +59,26 @@ pub struct WebSocket<Stream> {
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send: a data send queue. /// Send: a data send queue.
send_queue: VecDeque<Frame>, send_queue: VecDeque<Frame>,
max_send_queue: usize,
/// Send: an OOB pong message. /// Send: an OOB pong message.
pong: Option<Frame>, pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig,
} }
impl<Stream> WebSocket<Stream> { impl<Stream> WebSocket<Stream> {
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket(stream: Stream, role: Role) -> Self { pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role) WebSocket::from_frame_socket(FrameSocket::new(stream), role, config)
} }
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_partially_read(stream: Stream, part: Vec<u8>, role: Role) -> Self { pub fn from_partially_read(
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role) stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role, config)
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
@ -68,15 +91,19 @@ impl<Stream> WebSocket<Stream> {
} }
/// Convert a frame socket into a WebSocket. /// Convert a frame socket into a WebSocket.
fn from_frame_socket(socket: FrameSocket<Stream>, role: Role) -> Self { fn from_frame_socket(
socket: FrameSocket<Stream>,
role: Role,
config: Option<WebSocketConfig>
) -> Self {
WebSocket { WebSocket {
role: role, role: role,
socket: socket, socket: socket,
state: WebSocketState::Active, state: WebSocketState::Active,
incomplete: None, incomplete: None,
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
max_send_queue: 1,
pong: None, pong: None,
config: config.unwrap_or_else(|| WebSocketConfig::default()),
} }
} }
} }
@ -104,7 +131,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Send a message to stream, if possible. /// Send a message to stream, if possible.
/// ///
/// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping /// WebSocket will buffer a configurable number of messages at a time, except to reply to Ping
/// and Close requests. If the WebSocket's send queue is full, SendQueueFull will be returned /// and Close requests. If the WebSocket's send queue is full, `SendQueueFull` will be returned
/// along with the passed message. Otherwise, the message is queued and Ok(()) is returned. /// along with the passed message. Otherwise, the message is queued and Ok(()) is returned.
/// ///
/// Note that only the last pong frame is stored to be sent, and only the /// Note that only the last pong frame is stored to be sent, and only the
@ -113,8 +140,10 @@ impl<Stream: Read + Write> WebSocket<Stream> {
// Try to make some room for the new message // Try to make some room for the new message
self.write_pending().no_block()?; self.write_pending().no_block()?;
if self.send_queue.len() >= self.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
return Err(Error::SendQueueFull(message)); if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
} }
let frame = match message { let frame = match message {
@ -466,7 +495,7 @@ mod tests {
0x82, 0x03, 0x82, 0x03,
0x01, 0x02, 0x03, 0x01, 0x02, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));

@ -5,7 +5,7 @@ pub use handshake::server::ServerHandshake;
use handshake::HandshakeError; use handshake::HandshakeError;
use handshake::server::{Callback, NoCallback}; use handshake::server::{Callback, NoCallback};
use protocol::WebSocket; use protocol::{WebSocket, WebSocketConfig};
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -15,10 +15,10 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others.
pub fn accept<S: Read + Write>(stream: S) pub fn accept<S: Read + Write>(stream: S, config: Option<WebSocketConfig>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{ {
accept_hdr(stream, NoCallback) accept_hdr(stream, NoCallback, config)
} }
/// Accept the given Stream as a WebSocket. /// Accept the given Stream as a WebSocket.
@ -26,8 +26,10 @@ pub fn accept<S: Read + Write>(stream: S)
/// This function does the same as `accept()` but accepts an extra callback /// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming /// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply. /// requests and is able to add extra headers to the reply.
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C) pub fn accept_hdr<S: Read + Write, C: Callback>(
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> stream: S,
{ callback: C,
ServerHandshake::start(stream, callback).handshake() config: Option<WebSocketConfig>
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback, config).handshake()
} }

Loading…
Cancel
Save