From 3a58069db2848a9946d9ee0c60d876ec2e4f8110 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 26 Jun 2018 11:12:57 +0200 Subject: [PATCH] Create helpers for config-like functions As suggested by @agalakhov --- examples/autobahn-client.rs | 4 +-- examples/autobahn-server.rs | 2 +- examples/client.rs | 2 +- examples/server.rs | 2 +- src/client.rs | 49 ++++++++++++++++++++++++++++++++----- src/server.rs | 35 +++++++++++++++++++++++--- 6 files changed, 79 insertions(+), 15 deletions(-) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index e5488b4..21f1d9b 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -12,7 +12,6 @@ const AGENT: &'static str = "Tungstenite"; fn get_case_count() -> Result { let (mut socket, _) = connect( Url::parse("ws://localhost:9001/getCaseCount").unwrap(), - None, )?; let msg = socket.read_message()?; socket.close(None)?; @@ -22,7 +21,6 @@ fn get_case_count() -> Result { fn update_reports() -> Result<()> { let (mut socket, _) = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), - None, )?; socket.close(None)?; Ok(()) @@ -33,7 +31,7 @@ fn run_test(case: u32) -> Result<()> { let case_url = Url::parse( &format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) ).unwrap(); - let (mut socket, _) = connect(case_url, None)?; + let (mut socket, _) = connect(case_url)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 53b7d7a..10aba75 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -16,7 +16,7 @@ fn must_not_block(err: HandshakeError) -> Error { } fn handle_client(stream: TcpStream) -> Result<()> { - let mut socket = accept(stream, None).map_err(must_not_block)?; + let mut socket = accept(stream).map_err(must_not_block)?; loop { match socket.read_message()? { msg @ Message::Text(_) | diff --git a/examples/client.rs b/examples/client.rs index 1ba7c71..13b7d59 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,7 +8,7 @@ use tungstenite::{Message, connect}; fn main() { env_logger::init(); - let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap(), None) + let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) .expect("Can't connect"); println!("Connected to the server"); diff --git a/examples/server.rs b/examples/server.rs index 9649d6c..86df2fe 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -25,7 +25,7 @@ fn main() { ]; Ok(Some(extra_headers)) }; - let mut websocket = accept_hdr(stream.unwrap(), callback, None).unwrap(); + let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); loop { let msg = websocket.read_message().unwrap(); diff --git a/src/client.rs b/src/client.rs index bb4d7f1..3e643ec 100644 --- a/src/client.rs +++ b/src/client.rs @@ -68,6 +68,9 @@ use error::{Error, Result}; /// Connect to the given WebSocket in blocking mode. /// +/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is +/// equal to calling `connect()` function. +/// /// The URL may be either ws:// or wss://. /// To support wss:// URLs, feature "tls" must be turned on. /// @@ -78,21 +81,40 @@ use error::{Error, Result}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>(request: Req, config: Option) - -> Result<(WebSocket, Response)> -{ +pub fn connect_with_config<'t, Req: Into>>( + request: Req, + config: Option +) -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; let addrs = request.url.to_socket_addrs()?; let mut stream = connect_to_some(addrs, &request.url, mode)?; NoDelay::set_nodelay(&mut stream, true)?; - client(request, stream, config) + client_with_config(request, stream, config) .map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), }) } +/// Connect to the given WebSocket in blocking mode. +/// +/// The URL may be either ws:// or wss://. +/// To support wss:// URLs, feature "tls" must be turned on. +/// +/// This function "just works" for those who wants a simple blocking solution +/// similar to `std::net::TcpStream`. If you want a non-blocking or other +/// custom stream, call `client` instead. +/// +/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, +/// use `client` instead. There is no need to enable the "tls" feature if you don't call +/// `connect` since it's the only function that uses native_tls. +pub fn connect<'t, Req: Into>>(request: Req) + -> Result<(WebSocket, Response)> +{ + connect_with_config(request, None) +} + fn connect_to_some(addrs: A, url: &Url, mode: Mode) -> Result where A: Iterator { @@ -120,12 +142,13 @@ pub fn url_mode(url: &Url) -> Result { } } -/// Do the client handshake over the given stream. +/// Do the client handshake over the given stream given a web socket configuration. Passing `None` +/// as configuration is equal to calling `client()` function. /// /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>( +pub fn client_with_config<'t, Stream, Req>( request: Req, stream: Stream, config: Option, @@ -136,3 +159,17 @@ where { ClientHandshake::start(stream, request.into(), config).handshake() } + +/// Do the client handshake over the given stream. +/// +/// Use this function if you need a nonblocking handshake support or if you +/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. +/// Any stream supporting `Read + Write` will do. +pub fn client<'t, Stream, Req>(request: Req, stream: Stream) + -> StdResult<(WebSocket, Response), HandshakeError>> +where + Stream: Read + Write, + Req: Into>, +{ + client_with_config(request, stream, None) +} diff --git a/src/server.rs b/src/server.rs index d737565..ba93508 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,27 +9,56 @@ use protocol::{WebSocket, WebSocketConfig}; use std::io::{Read, Write}; +/// Accept the given Stream as a WebSocket. +/// +/// Uses a configuration provided as an argument. Calling it with `None` will use the default one +/// used by `accept()`. +/// +/// This function starts a server WebSocket handshake over the given stream. +/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` +/// for the stream here. Any `Read + Write` streams are supported, including +/// those from `Mio` and others. +pub fn accept_with_config(stream: S, config: Option) + -> Result, HandshakeError>> +{ + accept_hdr_with_config(stream, NoCallback, config) +} + /// Accept the given Stream as a WebSocket. /// /// This function starts a server WebSocket handshake over the given stream. /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// for the stream here. Any `Read + Write` streams are supported, including /// those from `Mio` and others. -pub fn accept(stream: S, config: Option) +pub fn accept(stream: S) -> Result, HandshakeError>> { - accept_hdr(stream, NoCallback, config) + accept_with_config(stream, None) } /// Accept the given Stream as a WebSocket. /// +/// Uses a configuration provided as an argument. Calling it with `None` will use the default one +/// used by `accept_hdr()`. +/// /// This function does the same as `accept()` but accepts an extra callback /// for header processing. The callback receives headers of the incoming /// requests and is able to add extra headers to the reply. -pub fn accept_hdr( +pub fn accept_hdr_with_config( stream: S, callback: C, config: Option ) -> Result, HandshakeError>> { ServerHandshake::start(stream, callback, config).handshake() } + +/// Accept the given Stream as a WebSocket. +/// +/// This function does the same as `accept()` but accepts an extra callback +/// for header processing. The callback receives headers of the incoming +/// requests and is able to add extra headers to the reply. +pub fn accept_hdr(stream: S, callback: C) + -> Result, HandshakeError>> +{ + accept_hdr_with_config(stream, callback, None) +}