//! Connection helper. use std::io::Result as IoResult; use std::net::SocketAddr; use tokio_tcp::TcpStream; use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; use tungstenite::Error; use super::{client_async, Request, WebSocketStream}; use crate::stream::{NoDelay, PeerAddr}; impl NoDelay for TcpStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { TcpStream::set_nodelay(self, nodelay) } } impl PeerAddr for TcpStream { fn peer_addr(&self) -> IoResult { self.peer_addr() } } #[cfg(feature = "tls")] mod encryption { use native_tls::TlsConnector; use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream}; use std::io::{Read, Result as IoResult, Write}; use std::net::SocketAddr; use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::stream::Mode; use tungstenite::Error; use crate::stream::{NoDelay, PeerAddr, Stream as StreamSwitcher}; /// A stream that might be protected with TLS. pub type MaybeTlsStream = StreamSwitcher>; pub type AutoStream = MaybeTlsStream; impl NoDelay for TlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.get_mut().get_mut().set_nodelay(nodelay) } } impl PeerAddr for TlsStream { fn peer_addr(&self) -> IoResult { self.get_ref().get_ref().peer_addr() } } pub fn wrap_stream( socket: S, domain: String, mode: Mode, ) -> Box, Error = Error> + Send> where S: 'static + AsyncRead + AsyncWrite + Send, { match mode { Mode::Plain => Box::new(future::ok(StreamSwitcher::Plain(socket))), Mode::Tls => Box::new( future::result(TlsConnector::new()) .map(TokioTlsConnector::from) .and_then(move |connector| connector.connect(&domain, socket)) .map(StreamSwitcher::Tls) .map_err(Error::Tls), ), } } } #[cfg(feature = "tls")] pub use self::encryption::MaybeTlsStream; #[cfg(not(feature = "tls"))] mod encryption { use futures::{future, Future}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::stream::Mode; use tungstenite::Error; pub type AutoStream = S; pub fn wrap_stream( socket: S, _domain: String, mode: Mode, ) -> Box, Error = Error> + Send> where S: 'static + AsyncRead + AsyncWrite + Send, { match mode { Mode::Plain => Box::new(future::ok(socket)), Mode::Tls => Box::new(future::err(Error::Url( "TLS support not compiled in.".into(), ))), } } } use self::encryption::{wrap_stream, AutoStream}; /// Get a domain from an URL. #[inline] fn domain(request: &Request) -> Result { match request.url.host_str() { Some(d) => Ok(d.to_string()), None => Err(Error::Url("no host name in the url".into())), } } /// Creates a WebSocket handshake from a request and a stream, /// upgrading the stream to TLS if required. pub fn client_async_tls( request: R, stream: S, ) -> Box>, Response), Error = Error> + Send> where R: Into>, S: 'static + AsyncRead + AsyncWrite + NoDelay + Send, { let request: Request = request.into(); let domain = match domain(&request) { Ok(domain) => domain, Err(err) => return Box::new(future::err(err)), }; // Make sure we check domain and mode first. URL must be valid. let mode = match url_mode(&request.url) { Ok(m) => m, Err(e) => return Box::new(future::err(e)), }; Box::new( wrap_stream(stream, domain, mode) .and_then(|mut stream| { NoDelay::set_nodelay(&mut stream, true) .map(move |()| stream) .map_err(|e| e.into()) }) .and_then(move |stream| client_async(request, stream)), ) } /// Connect to a given URL. pub fn connect_async( request: R, ) -> Box>, Response), Error = Error> + Send> where R: Into>, { let request: Request = request.into(); let domain = match domain(&request) { Ok(domain) => domain, Err(err) => return Box::new(future::err(err)), }; let port = request .url .port_or_known_default() .expect("Bug: port unknown"); Box::new( tokio_dns::TcpStream::connect((domain.as_str(), port)) .map_err(|e| e.into()) .and_then(move |socket| client_async_tls(request, socket)), ) }