Add API for connecting and creating clients with a custom TLS connector

Fixes https://github.com/sdroege/async-tungstenite/issues/2
pull/5/head
Sebastian Dröge 5 years ago
parent 26099f2754
commit b4d5a9e84e
  1. 66
      src/connect.rs
  2. 7
      src/lib.rs

@ -30,10 +30,15 @@ pub(crate) mod encryption {
pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>; pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
pub type AutoStream<S> = MaybeTlsStream<S>; pub type AutoStream<S> = MaybeTlsStream<S>;
#[cfg(feature = "tls")]
pub type Connector = async_tls::TlsConnector;
#[cfg(feature = "native-tls")]
pub type Connector = real_native_tls::TlsConnector;
pub async fn wrap_stream<S>( pub async fn wrap_stream<S>(
socket: S, socket: S,
domain: String, domain: String,
connector: Option<Connector>,
mode: Mode, mode: Mode,
) -> Result<AutoStream<S>, Error> ) -> Result<AutoStream<S>, Error>
where where
@ -44,13 +49,17 @@ pub(crate) mod encryption {
Mode::Tls => { Mode::Tls => {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
let stream = { let stream = {
let connector = AsyncTlsConnector::new(); let connector = connector.unwrap_or_else(|| AsyncTlsConnector::new());
connector.connect(&domain, socket)?.await? connector.connect(&domain, socket)?.await?
}; };
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
let stream = { let stream = {
let connector = if let Some(connector) = connector {
connector
} else {
let builder = real_native_tls::TlsConnector::builder(); let builder = real_native_tls::TlsConnector::builder();
let connector = builder.build()?; builder.build()?
};
let connector = AsyncTlsConnector::from(connector); let connector = AsyncTlsConnector::from(connector);
connector.connect(&domain, socket).await? connector.connect(&domain, socket).await?
}; };
@ -59,23 +68,23 @@ pub(crate) mod encryption {
} }
} }
} }
#[cfg(feature = "tls-base")] #[cfg(feature = "tls-base")]
pub use self::encryption::MaybeTlsStream; pub use self::encryption::MaybeTlsStream;
#[cfg(not(feature = "tls-base"))] #[cfg(not(feature = "tls-base"))]
pub(crate) mod encryption { pub(crate) mod encryption {
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use futures::{future, Future};
use tungstenite::stream::Mode; use tungstenite::stream::Mode;
use tungstenite::Error; use tungstenite::Error;
pub type AutoStream<S> = S; pub type AutoStream<S> = S;
pub type Connector = ();
pub async fn wrap_stream<S>( pub(crate) async fn wrap_stream<S>(
socket: S, socket: S,
_domain: String, _domain: String,
_connector: Option<()>,
mode: Mode, mode: Mode,
) -> Result<AutoStream<S>, Error> ) -> Result<AutoStream<S>, Error>
where where
@ -88,7 +97,7 @@ pub(crate) mod encryption {
} }
} }
use self::encryption::{wrap_stream, AutoStream}; use self::encryption::AutoStream;
/// Get a domain from an URL. /// Get a domain from an URL.
#[inline] #[inline]
@ -105,6 +114,22 @@ pub async fn client_async_tls<R, S>(
request: R, request: R,
stream: S, stream: S,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
AutoStream<S>: Unpin,
{
client_async_tls_with_connector(request, stream, None).await
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required and using the given
/// connector.
pub async fn client_async_tls_with_connector<R, S>(
request: R,
stream: S,
connector: Option<self::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
where where
R: Into<Request<'static>> + Unpin, R: Into<Request<'static>> + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin, S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
@ -117,7 +142,7 @@ where
// Make sure we check domain and mode first. URL must be valid. // Make sure we check domain and mode first. URL must be valid.
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let stream = wrap_stream(stream, domain, mode).await?; let stream = self::encryption::wrap_stream(stream, domain, connector, mode).await?;
client_async(request, stream).await client_async(request, stream).await
} }
@ -145,7 +170,34 @@ pub(crate) mod async_std_runtime {
let socket = try_socket.map_err(Error::Io)?; let socket = try_socket.map_err(Error::Io)?;
client_async_tls(request, socket).await client_async_tls(request, socket).await
} }
#[cfg(any(feature = "tls", feature = "native-tls"))]
/// Connect to a given URL using the provided TLS connector.
pub async fn connect_async_with_tls_connector<R>(
request: R,
connector: Option<super::encryption::Connector>,
) -> Result<(WebSocketStream<AutoStream<TcpStream>>, Response), Error>
where
R: Into<Request<'static>> + Unpin,
{
let request: Request = request.into();
let domain = domain(&request)?;
let port = request
.url
.port_or_known_default()
.expect("Bug: port unknown");
let try_socket = TcpStream::connect((domain.as_str(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
client_async_tls_with_connector(request, socket, connector).await
}
} }
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "async_std_runtime")]
pub use async_std_runtime::connect_async; pub use async_std_runtime::connect_async;
#[cfg(all(
feature = "async_std_runtime",
any(feature = "tls", feature = "native-tls")
))]
pub use async_std_runtime::connect_async_with_tls_connector;

@ -46,8 +46,15 @@ use tungstenite::{
#[cfg(feature = "connect")] #[cfg(feature = "connect")]
pub use connect::client_async_tls; pub use connect::client_async_tls;
#[cfg(all(feature = "connect", any(feature = "tls", feature = "native-tls")))]
pub use connect::client_async_tls_with_connector;
#[cfg(feature = "async_std_runtime")] #[cfg(feature = "async_std_runtime")]
pub use connect::connect_async; pub use connect::connect_async;
#[cfg(all(
feature = "async_std_runtime",
any(feature = "tls", feature = "native-tls")
))]
pub use connect::connect_async_with_tls_connector;
#[cfg(all(feature = "connect", feature = "tls-base"))] #[cfg(all(feature = "connect", feature = "tls-base"))]
pub use connect::MaybeTlsStream; pub use connect::MaybeTlsStream;

Loading…
Cancel
Save