diff --git a/src/connect.rs b/src/connect.rs index 5fccf2e..3136d82 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -10,6 +10,8 @@ use self::tokio_core::reactor::Remote; use self::tokio_dns::tcp_connect; use futures::{future, Future}; +use tokio_io::{AsyncRead, AsyncWrite}; + use tungstenite::Error; use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; @@ -92,27 +94,37 @@ mod encryption { use self::encryption::{AutoStream, wrap_stream}; -/// Connect to a given URL. -pub fn connect_async(request: R, handle: Remote) - -> Box>, Response), Error=Error>> +/// Get a domain from an URL. +#[inline] +fn domain(request: &Request) -> Result { + match request.url.host_str() { + Some(d) => Ok(d.to_string()), + None => Err(Error::Url("no host name in the url".into())), + } +} + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub fn client_async_tls(request: R, stream: S) + -> Box>, Response), Error=Error>> where - R: Into> + R: Into>, + S: 'static + AsyncRead + AsyncWrite + NoDelay, { let request: Request = request.into(); + let domain = match domain(&request) { + Ok(domain) => domain, + Err(err) => return Box::new(future::err(err)), + }; + // Make sure we check domain and mode first. URL must be valid. let mode = match url_mode(&request.url) { Ok(m) => m, Err(e) => return Box::new(future::err(e.into())), }; - let domain = match request.url.host_str() { - Some(d) => d.to_string(), - None => return Box::new(future::err(Error::Url("No host name in the URL".into()))), - }; - let port = request.url.port_or_known_default().expect("Bug: port unknown"); - Box::new(tcp_connect((domain.as_str(), port), handle).map_err(|e| e.into()) - .and_then(move |socket| wrap_stream(socket, domain, mode)) + Box::new(wrap_stream(stream, domain, mode) .and_then(|mut stream| { NoDelay::set_nodelay(&mut stream, true) .map(move |()| stream) @@ -120,3 +132,21 @@ where }) .and_then(move |stream| client_async(request, stream))) } + +/// Connect to a given URL. +pub fn connect_async(request: R, handle: Remote) + -> Box>, Response), Error=Error>> +where + R: Into> +{ + let request: Request = request.into(); + + let domain = match domain(&request) { + Ok(domain) => domain, + Err(err) => return Box::new(future::err(err)), + }; + let port = request.url.port_or_known_default().expect("Bug: port unknown"); + + Box::new(tcp_connect((domain.as_str(), port), handle).map_err(|e| e.into()) + .and_then(move |socket| client_async_tls(request, socket))) +} diff --git a/src/lib.rs b/src/lib.rs index 3402c05..b7fb89a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ use tungstenite::error::Error as WsError; use tungstenite::server; #[cfg(feature="connect")] -pub use connect::connect_async; +pub use connect::{connect_async, client_async_tls}; /// Creates a WebSocket handshake from a request and a stream. /// For convenience, the user may call this with a url string, a URL,