|
|
|
@ -30,10 +30,15 @@ pub(crate) mod encryption { |
|
|
|
|
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>; |
|
|
|
|
|
|
|
|
|
pub type AutoStream<S> = MaybeTlsStream<S>; |
|
|
|
|
#[cfg(feature = "tls")] |
|
|
|
|
pub type Connector = async_tls::TlsConnector; |
|
|
|
|
#[cfg(feature = "native-tls")] |
|
|
|
|
pub type Connector = real_native_tls::TlsConnector; |
|
|
|
|
|
|
|
|
|
pub async fn wrap_stream<S>( |
|
|
|
|
socket: S, |
|
|
|
|
domain: String, |
|
|
|
|
connector: Option<Connector>, |
|
|
|
|
mode: Mode, |
|
|
|
|
) -> Result<AutoStream<S>, Error> |
|
|
|
|
where |
|
|
|
@ -44,13 +49,17 @@ pub(crate) mod encryption { |
|
|
|
|
Mode::Tls => { |
|
|
|
|
#[cfg(feature = "tls")] |
|
|
|
|
let stream = { |
|
|
|
|
let connector = AsyncTlsConnector::new(); |
|
|
|
|
let connector = connector.unwrap_or_else(|| AsyncTlsConnector::new()); |
|
|
|
|
connector.connect(&domain, socket)?.await? |
|
|
|
|
}; |
|
|
|
|
#[cfg(feature = "native-tls")] |
|
|
|
|
let stream = { |
|
|
|
|
let builder = real_native_tls::TlsConnector::builder(); |
|
|
|
|
let connector = builder.build()?; |
|
|
|
|
let connector = if let Some(connector) = connector { |
|
|
|
|
connector |
|
|
|
|
} else { |
|
|
|
|
let builder = real_native_tls::TlsConnector::builder(); |
|
|
|
|
builder.build()? |
|
|
|
|
}; |
|
|
|
|
let connector = AsyncTlsConnector::from(connector); |
|
|
|
|
connector.connect(&domain, socket).await? |
|
|
|
|
}; |
|
|
|
@ -59,23 +68,23 @@ pub(crate) mod encryption { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[cfg(feature = "tls-base")] |
|
|
|
|
pub use self::encryption::MaybeTlsStream; |
|
|
|
|
|
|
|
|
|
#[cfg(not(feature = "tls-base"))] |
|
|
|
|
pub(crate) mod encryption { |
|
|
|
|
use futures::io::{AsyncRead, AsyncWrite}; |
|
|
|
|
use futures::{future, Future}; |
|
|
|
|
|
|
|
|
|
use tungstenite::stream::Mode; |
|
|
|
|
use tungstenite::Error; |
|
|
|
|
|
|
|
|
|
pub type AutoStream<S> = S; |
|
|
|
|
pub type Connector = (); |
|
|
|
|
|
|
|
|
|
pub async fn wrap_stream<S>( |
|
|
|
|
pub(crate) async fn wrap_stream<S>( |
|
|
|
|
socket: S, |
|
|
|
|
_domain: String, |
|
|
|
|
_connector: Option<()>, |
|
|
|
|
mode: Mode, |
|
|
|
|
) -> Result<AutoStream<S>, Error> |
|
|
|
|
where |
|
|
|
@ -88,7 +97,7 @@ pub(crate) mod encryption { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
use self::encryption::{wrap_stream, AutoStream}; |
|
|
|
|
use self::encryption::AutoStream; |
|
|
|
|
|
|
|
|
|
/// Get a domain from an URL.
|
|
|
|
|
#[inline] |
|
|
|
@ -105,6 +114,22 @@ pub async fn client_async_tls<R, S>( |
|
|
|
|
request: R, |
|
|
|
|
stream: S, |
|
|
|
|
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> |
|
|
|
|
where |
|
|
|
|
R: Into<Request<'static>> + Unpin, |
|
|
|
|
S: 'static + AsyncRead + AsyncWrite + Send + Unpin, |
|
|
|
|
AutoStream<S>: Unpin, |
|
|
|
|
{ |
|
|
|
|
client_async_tls_with_connector(request, stream, None).await |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/// Creates a WebSocket handshake from a request and a stream,
|
|
|
|
|
/// upgrading the stream to TLS if required and using the given
|
|
|
|
|
/// connector.
|
|
|
|
|
pub async fn client_async_tls_with_connector<R, S>( |
|
|
|
|
request: R, |
|
|
|
|
stream: S, |
|
|
|
|
connector: Option<self::encryption::Connector>, |
|
|
|
|
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> |
|
|
|
|
where |
|
|
|
|
R: Into<Request<'static>> + Unpin, |
|
|
|
|
S: 'static + AsyncRead + AsyncWrite + Send + Unpin, |
|
|
|
@ -117,7 +142,7 @@ where |
|
|
|
|
// Make sure we check domain and mode first. URL must be valid.
|
|
|
|
|
let mode = url_mode(&request.url)?; |
|
|
|
|
|
|
|
|
|
let stream = wrap_stream(stream, domain, mode).await?; |
|
|
|
|
let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?; |
|
|
|
|
client_async(request, stream).await |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -145,7 +170,34 @@ pub(crate) mod async_std_runtime { |
|
|
|
|
let socket = try_socket.map_err(Error::Io)?; |
|
|
|
|
client_async_tls(request, socket).await |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[cfg(any(feature = "tls", feature = "native-tls"))] |
|
|
|
|
/// Connect to a given URL using the provided TLS connector.
|
|
|
|
|
pub async fn connect_async_with_tls_connector<R>( |
|
|
|
|
request: R, |
|
|
|
|
connector: Option<super::encryption::Connector>, |
|
|
|
|
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error> |
|
|
|
|
where |
|
|
|
|
R: Into<Request<'static>> + Unpin, |
|
|
|
|
{ |
|
|
|
|
let request: Request = request.into(); |
|
|
|
|
|
|
|
|
|
let domain = domain(&request)?; |
|
|
|
|
let port = request |
|
|
|
|
.url |
|
|
|
|
.port_or_known_default() |
|
|
|
|
.expect("Bug: port unknown"); |
|
|
|
|
|
|
|
|
|
let try_socket = TcpStream::connect((domain.as_str(), port)).await; |
|
|
|
|
let socket = try_socket.map_err(Error::Io)?; |
|
|
|
|
client_async_tls_with_connector(request, socket, connector).await |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#[cfg(feature = "async_std_runtime")] |
|
|
|
|
pub use async_std_runtime::connect_async; |
|
|
|
|
#[cfg(all(
|
|
|
|
|
feature = "async_std_runtime", |
|
|
|
|
any(feature = "tls", feature = "native-tls") |
|
|
|
|
))] |
|
|
|
|
pub use async_std_runtime::connect_async_with_tls_connector; |
|
|
|
|