//! Connection helper. use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; use tungstenite::Error; use futures::io::{AsyncRead, AsyncWrite}; use super::{client_async, Request, WebSocketStream}; #[cfg(feature = "tls")] pub(crate) mod encryption { use async_tls::TlsConnector as AsyncTlsConnector; use async_tls::client::TlsStream; use tungstenite::stream::Mode; use tungstenite::Error; use futures::io::{AsyncRead, AsyncWrite}; use crate::stream::Stream as StreamSwitcher; /// A stream that might be protected with TLS. pub type MaybeTlsStream = StreamSwitcher>; pub type AutoStream = MaybeTlsStream; pub async fn wrap_stream( socket: S, domain: String, mode: Mode, ) -> Result, Error> where S: 'static + AsyncRead + AsyncWrite + Send + Unpin, { match mode { Mode::Plain => Ok(StreamSwitcher::Plain(socket)), Mode::Tls => { let stream = AsyncTlsConnector::new(); let connected = stream.connect(&domain, socket)?.await; match connected { Err(e) => Err(Error::Io(e)), Ok(s) => Ok(StreamSwitcher::Tls(s)), } } } } } #[cfg(feature = "tls")] pub use self::encryption::MaybeTlsStream; #[cfg(not(feature = "tls"))] pub(crate) mod encryption { use futures::{future, Future}; use futures::io::{AsyncRead, AsyncWrite}; use tungstenite::stream::Mode; use tungstenite::Error; pub type AutoStream = S; pub async fn wrap_stream( socket: S, _domain: String, mode: Mode, ) -> Result, Error> where S: 'static + AsyncRead + AsyncWrite + Send + Unpin, { match mode { Mode::Plain => Ok(socket), Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), } } } use self::encryption::{wrap_stream, AutoStream}; /// Get a domain from an URL. #[inline] fn domain(request: &Request) -> Result { match request.url.host_str() { Some(d) => Ok(d.to_string()), None => Err(Error::Url("no host name in the url".into())), } } /// Creates a WebSocket handshake from a request and a stream, /// upgrading the stream to TLS if required. pub async fn client_async_tls( request: R, stream: S, ) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin, AutoStream: Unpin, { let request: Request = request.into(); let domain = domain(&request)?; // Make sure we check domain and mode first. URL must be valid. let mode = url_mode(&request.url)?; let stream = wrap_stream(stream, domain, mode).await?; client_async(request, stream).await } #[cfg(feature = "async_std_runtime")] pub(crate) mod async_std_runtime { use super::*; use async_std::net::TcpStream; /// Connect to a given URL. pub async fn connect_async( request: R, ) -> Result<(WebSocketStream>, Response), Error> where R: Into> + Unpin, { let request: Request = request.into(); let domain = domain(&request)?; let port = request .url .port_or_known_default() .expect("Bug: port unknown"); let try_socket = TcpStream::connect((domain.as_str(), port)).await; let socket = try_socket.map_err(Error::Io)?; client_async_tls(request, socket).await } } #[cfg(feature = "async_std_runtime")] pub use async_std_runtime::connect_async;