diff --git a/Cargo.toml b/Cargo.toml index c554212..a2118c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,9 @@ all-features = true default = [] native-tls = ["native-tls-crate"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] -rustls-tls-native-roots = ["rustls", "webpki", "rustls-native-certs"] -rustls-tls-webpki-roots = ["rustls", "webpki", "webpki-roots"] +rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] +rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] +__rustls-tls = ["rustls", "webpki"] [dependencies] base64 = "0.13.0" diff --git a/examples/srv_accept_unmasked_frames.rs b/examples/srv_accept_unmasked_frames.rs index 953f9f1..b280fba 100644 --- a/examples/srv_accept_unmasked_frames.rs +++ b/examples/srv_accept_unmasked_frames.rs @@ -1,8 +1,8 @@ use std::{net::TcpListener, thread::spawn}; use tungstenite::{ + accept_hdr_with_config, handshake::server::{Request, Response}, protocol::WebSocketConfig, - server::accept_hdr_with_config, }; fn main() { diff --git a/src/client.rs b/src/client.rs index 04ce540..67a3c41 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,118 +14,9 @@ use url::Url; use crate::{ handshake::client::{Request, Response}, protocol::WebSocketConfig, + stream::MaybeTlsStream, }; -#[cfg(feature = "native-tls")] -mod encryption { - pub use native_tls_crate::TlsStream; - use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; - use std::net::TcpStream; - - pub use crate::stream::Stream as StreamSwitcher; - /// TCP stream switcher (plain/TLS). - pub type AutoStream = StreamSwitcher>; - - use crate::{ - error::{Result, TlsError}, - stream::Mode, - }; - - pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(stream)), - Mode::Tls => { - let connector = TlsConnector::builder().build().map_err(TlsError::Native)?; - connector - .connect(domain, stream) - .map_err(|e| match e { - TlsHandshakeError::Failure(f) => TlsError::Native(f).into(), - TlsHandshakeError::WouldBlock(_) => { - panic!("Bug: TLS handshake not blocked") - } - }) - .map(StreamSwitcher::Tls) - } - } - } -} - -#[cfg(all( - any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"), - not(feature = "native-tls") -))] -mod encryption { - use rustls::ClientConfig; - pub use rustls::{ClientSession, StreamOwned}; - use std::{net::TcpStream, sync::Arc}; - use webpki::DNSNameRef; - - pub use crate::stream::Stream as StreamSwitcher; - /// TCP stream switcher (plain/TLS). - pub type AutoStream = StreamSwitcher>; - - use crate::{ - error::{Result, TlsError}, - stream::Mode, - }; - - pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(stream)), - Mode::Tls => { - let config = { - #[allow(unused_mut)] - let mut config = ClientConfig::new(); - #[cfg(feature = "rustls-tls-native-roots")] - { - config.root_store = - rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?; - } - #[cfg(feature = "rustls-tls-webpki-roots")] - { - config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - } - - Arc::new(config) - }; - - let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; - let client = ClientSession::new(&config, domain); - let stream = StreamOwned::new(client, stream); - - Ok(StreamSwitcher::Tls(stream)) - } - } - } -} - -#[cfg(not(any( - feature = "native-tls", - feature = "rustls-tls-native-roots", - feature = "rustls-tls-webpki-roots" -)))] -mod encryption { - use std::net::TcpStream; - - use crate::{ - error::{Error, Result, UrlError}, - stream::Mode, - }; - - /// TLS support is not compiled in, this is just standard `TcpStream`. - pub type AutoStream = TcpStream; - - pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { - match mode { - Mode::Plain => Ok(stream), - Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), - } - } -} - -use self::encryption::wrap_stream; -pub use self::encryption::AutoStream; - use crate::{ error::{Error, Result, UrlError}, handshake::{client::ClientHandshake, HandshakeError}, @@ -152,11 +43,11 @@ pub fn connect_with_config( request: Req, config: Option, max_redirects: u8, -) -> Result<(WebSocket, Response)> { +) -> Result<(WebSocket>, Response)> { fn try_client_handshake( request: Request, config: Option, - ) -> Result<(WebSocket, Response)> { + ) -> Result<(WebSocket>, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; @@ -165,9 +56,15 @@ pub fn connect_with_config( Mode::Tls => 443, }); let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?; NoDelay::set_nodelay(&mut stream, true)?; - client_with_config(request, stream, config).map_err(|e| match e { + + #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] + let client = client_with_config(request, MaybeTlsStream::Plain(stream), config); + #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] + let client = crate::tls::client_tls_with_config(request, stream, config, None); + + client.map_err(|e| match e { HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), }) @@ -216,18 +113,17 @@ pub fn connect_with_config( /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of /// the `*-tls` features if you don't call `connect` since it's the only function that uses them. -pub fn connect(request: Req) -> Result<(WebSocket, Response)> { +pub fn connect( + request: Req, +) -> Result<(WebSocket>, Response)> { connect_with_config(request, None, 3) } -fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { - let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?; +fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result { for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); - if let Ok(raw_stream) = TcpStream::connect(addr) { - if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { - return Ok(stream); - } + if let Ok(stream) = TcpStream::connect(addr) { + return Ok(stream); } } Err(Error::Url(UrlError::UnableToConnect(uri.to_string()))) diff --git a/src/error.rs b/src/error.rs index d1d176b..510e3f4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,7 +7,7 @@ use http::Response; use thiserror::Error; /// Result type of all Tungstenite library calls. -pub type Result = result::Result; +pub type Result = result::Result; /// Possible WebSocket errors. #[derive(Error, Debug)] @@ -253,11 +253,11 @@ pub enum TlsError { #[error("native-tls error: {0}")] Native(#[from] native_tls_crate::Error), /// Rustls error. - #[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))] + #[cfg(feature = "__rustls-tls")] #[error("rustls error: {0}")] Rustls(#[from] rustls::TLSError), /// DNS name resolution error. - #[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))] + #[cfg(feature = "__rustls-tls")] #[error("Invalid DNS name: {0}")] Dns(#[from] webpki::InvalidDNSNameError), } diff --git a/src/lib.rs b/src/lib.rs index c21958e..c2f29c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,8 +19,10 @@ pub mod client; pub mod error; pub mod handshake; pub mod protocol; -pub mod server; +mod server; pub mod stream; +#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] +mod tls; pub mod util; const READ_BUFFER_CHUNK_SIZE: usize = 4096; @@ -31,5 +33,8 @@ pub use crate::{ error::{Error, Result}, handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, protocol::{Message, WebSocket}, - server::{accept, accept_hdr}, + server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, }; + +#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] +pub use tls::{client_tls, client_tls_with_config, Connector}; diff --git a/src/stream.rs b/src/stream.rs index 5edfe03..b7fe0e4 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,13 +4,16 @@ //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. -use std::io::{Read, Result as IoResult, Write}; +use std::{ + fmt::{self, Debug}, + io::{Read, Result as IoResult, Write}, +}; use std::net::TcpStream; #[cfg(feature = "native-tls")] use native_tls_crate::TlsStream; -#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))] +#[cfg(feature = "__rustls-tls")] use rustls::StreamOwned; /// Stream mode, either plain TCP or TLS. @@ -41,51 +44,95 @@ impl NoDelay for TlsStream { } } -#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))] +#[cfg(feature = "__rustls-tls")] impl NoDelay for StreamOwned { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.sock.set_nodelay(nodelay) } } -/// Stream, either plain TCP or TLS. -#[derive(Debug)] -pub enum Stream { +/// A stream that might be protected with TLS. +#[non_exhaustive] +pub enum MaybeTlsStream { /// Unencrypted socket stream. Plain(S), - /// Encrypted socket stream. - Tls(T), + #[cfg(feature = "native-tls")] + /// Encrypted socket stream using `native-tls`. + NativeTls(native_tls_crate::TlsStream), + #[cfg(feature = "__rustls-tls")] + /// Encrypted socket stream using `rustls`. + Rustls(rustls::StreamOwned), } -impl Read for Stream { +impl Debug for MaybeTlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(), + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(), + #[cfg(feature = "__rustls-tls")] + Self::Rustls(s) => { + struct RustlsStreamDebug<'a, S: Read + Write>( + &'a rustls::StreamOwned, + ); + + impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("StreamOwned") + .field("sess", &self.0.sess) + .field("sock", &self.0.sock) + .finish() + } + } + + f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish() + } + } + } +} + +impl Read for MaybeTlsStream { fn read(&mut self, buf: &mut [u8]) -> IoResult { match *self { - Stream::Plain(ref mut s) => s.read(buf), - Stream::Tls(ref mut s) => s.read(buf), + MaybeTlsStream::Plain(ref mut s) => s.read(buf), + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(ref mut s) => s.read(buf), + #[cfg(feature = "__rustls-tls")] + MaybeTlsStream::Rustls(ref mut s) => s.read(buf), } } } -impl Write for Stream { +impl Write for MaybeTlsStream { fn write(&mut self, buf: &[u8]) -> IoResult { match *self { - Stream::Plain(ref mut s) => s.write(buf), - Stream::Tls(ref mut s) => s.write(buf), + MaybeTlsStream::Plain(ref mut s) => s.write(buf), + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(ref mut s) => s.write(buf), + #[cfg(feature = "__rustls-tls")] + MaybeTlsStream::Rustls(ref mut s) => s.write(buf), } } + fn flush(&mut self) -> IoResult<()> { match *self { - Stream::Plain(ref mut s) => s.flush(), - Stream::Tls(ref mut s) => s.flush(), + MaybeTlsStream::Plain(ref mut s) => s.flush(), + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(ref mut s) => s.flush(), + #[cfg(feature = "__rustls-tls")] + MaybeTlsStream::Rustls(ref mut s) => s.flush(), } } } -impl NoDelay for Stream { +impl NoDelay for MaybeTlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { match *self { - Stream::Plain(ref mut s) => s.set_nodelay(nodelay), - Stream::Tls(ref mut s) => s.set_nodelay(nodelay), + MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay), + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay), + #[cfg(feature = "__rustls-tls")] + MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay), } } } diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..4f07a54 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,219 @@ +//! Connection helper. +use std::io::{Read, Write}; + +use crate::{ + client::{client_with_config, uri_mode, IntoClientRequest}, + error::UrlError, + handshake::client::Response, + protocol::WebSocketConfig, + stream::MaybeTlsStream, + ClientHandshake, Error, HandshakeError, Result, WebSocket, +}; + +/// A connector that can be used when establishing connections, allowing to control whether +/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the +/// `Plain` variant. +#[non_exhaustive] +#[allow(missing_debug_implementations)] +pub enum Connector { + /// Plain (non-TLS) connector. + Plain, + /// `native-tls` TLS connector. + #[cfg(feature = "native-tls")] + NativeTls(native_tls_crate::TlsConnector), + /// `rustls` TLS connector. + #[cfg(feature = "__rustls-tls")] + Rustls(std::sync::Arc), +} + +mod encryption { + #[cfg(feature = "native-tls")] + pub mod native_tls { + use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; + + use std::io::{Read, Write}; + + use crate::{ + error::TlsError, + stream::{MaybeTlsStream, Mode}, + Error, Result, + }; + + pub fn wrap_stream( + socket: S, + domain: &str, + mode: Mode, + tls_connector: Option, + ) -> Result> + where + S: Read + Write, + { + match mode { + Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), + Mode::Tls => { + let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok); + let connector = try_connector.map_err(TlsError::Native)?; + let connected = connector.connect(domain, socket); + match connected { + Err(e) => match e { + TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())), + TlsHandshakeError::WouldBlock(_) => { + panic!("Bug: TLS handshake not blocked") + } + }, + Ok(s) => Ok(MaybeTlsStream::NativeTls(s)), + } + } + } + } + } + + #[cfg(feature = "__rustls-tls")] + pub mod rustls { + use rustls::{ClientConfig, ClientSession, StreamOwned}; + use webpki::DNSNameRef; + + use std::{ + io::{Read, Write}, + sync::Arc, + }; + + use crate::{ + error::TlsError, + stream::{MaybeTlsStream, Mode}, + Result, + }; + + pub fn wrap_stream( + socket: S, + domain: &str, + mode: Mode, + tls_connector: Option>, + ) -> Result> + where + S: Read + Write, + { + match mode { + Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), + Mode::Tls => { + let config = match tls_connector { + Some(config) => config, + None => { + #[allow(unused_mut)] + let mut config = ClientConfig::new(); + #[cfg(feature = "rustls-tls-native-roots")] + { + config.root_store = rustls_native_certs::load_native_certs() + .map_err(|(_, err)| err)?; + } + #[cfg(feature = "rustls-tls-webpki-roots")] + { + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + } + + Arc::new(config) + } + }; + let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; + let client = ClientSession::new(&config, domain); + let stream = StreamOwned::new(client, socket); + + Ok(MaybeTlsStream::Rustls(stream)) + } + } + } + } + + pub mod plain { + use std::io::{Read, Write}; + + use crate::{ + error::UrlError, + stream::{MaybeTlsStream, Mode}, + Error, Result, + }; + + pub fn wrap_stream(socket: S, mode: Mode) -> Result> + where + S: Read + Write, + { + match mode { + Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), + Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), + } + } + } +} + +type TlsHandshakeError = HandshakeError>>; + +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub fn client_tls( + request: R, + stream: S, +) -> Result<(WebSocket>, Response), TlsHandshakeError> +where + R: IntoClientRequest, + S: Read + Write, +{ + client_tls_with_config(request, stream, None, None) +} + +/// The same as [`client_tls()`] but one can specify a websocket configuration, +/// and an optional connector. If no connector is specified, a default one will +/// be created. +/// +/// Please refer to [`client_tls()`] for more details. +pub fn client_tls_with_config( + request: R, + stream: S, + config: Option, + connector: Option, +) -> Result<(WebSocket>, Response), TlsHandshakeError> +where + R: IntoClientRequest, + S: Read + Write, +{ + let request = request.into_client_request()?; + + #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] + let domain = match request.uri().host() { + Some(d) => Ok(d.to_string()), + None => Err(Error::Url(UrlError::NoHostName)), + }?; + + let mode = uri_mode(&request.uri())?; + + let stream = match connector { + Some(conn) => match conn { + #[cfg(feature = "native-tls")] + Connector::NativeTls(conn) => { + self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn)) + } + #[cfg(feature = "__rustls-tls")] + Connector::Rustls(conn) => { + self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn)) + } + Connector::Plain => self::encryption::plain::wrap_stream(stream, mode), + }, + None => { + #[cfg(feature = "native-tls")] + { + self::encryption::native_tls::wrap_stream(stream, &domain, mode, None) + } + #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))] + { + self::encryption::rustls::wrap_stream(stream, &domain, mode, None) + } + #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] + { + self::encryption::plain::wrap_stream(stream, mode) + } + } + }?; + + client_with_config(request, stream, config) +} diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index c14f2ec..232c27e 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,10 +1,6 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closed from the server's point of view and drop the underlying tcp socket. -#![cfg(any( - feature = "native-tls", - feature = "rustls-tls-native-roots", - feature = "rustls-tls-webpki-roots" -))] +#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))] use std::{ net::{TcpListener, TcpStream}, @@ -14,16 +10,10 @@ use std::{ }; use net2::TcpStreamExt; -use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; +use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket}; use url::Url; -#[cfg(feature = "native-tls")] -type Sock = WebSocket>>; -#[cfg(all( - any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"), - not(feature = "native-tls") -))] -type Sock = WebSocket>>; +type Sock = WebSocket>; fn do_test(port: u16, client_task: CT, server_task: ST) where