From e9aaf9b1e987688e96398cdec93aff2d9d50b054 Mon Sep 17 00:00:00 2001 From: Dominik Nakamura Date: Sat, 2 Jan 2021 16:06:25 +0900 Subject: [PATCH] Add support for rustls as TLS backend --- Cargo.toml | 21 ++++++++++++++++++--- src/client.rs | 37 +++++++++++++++++++++++++++++++++++-- src/error.rs | 39 +++++++++++++++++++++++++++++++++++---- src/lib.rs | 3 +++ src/stream.rs | 13 +++++++++++-- tests/connection_reset.rs | 7 +++++-- 6 files changed, 107 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 173292e..63a07c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,13 @@ repository = "https://github.com/snapview/tungstenite-rs" version = "0.11.1" edition = "2018" +[package.metadata.docs.rs] +features = ["native-tls"] + [features] -default = ["tls"] -tls = ["native-tls"] -tls-vendored = ["native-tls", "native-tls/vendored"] +default = [] +native-tls-vendored = ["native-tls", "native-tls/vendored"] +rustls-tls = ["rustls", "webpki", "webpki-roots"] [dependencies] base64 = "0.13.0" @@ -34,6 +37,18 @@ utf-8 = "0.7.5" optional = true version = "0.2.3" +[dependencies.rustls] +optional = true +version = "0.19.0" + +[dependencies.webpki] +optional = true +version = "0.21.4" + +[dependencies.webpki-roots] +optional = true +version = "0.21.0" + [dev-dependencies] env_logger = "0.8.1" net2 = "0.2.33" diff --git a/src/client.rs b/src/client.rs index 1741fa2..1515f17 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,7 +16,7 @@ use crate::{ protocol::WebSocketConfig, }; -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] mod encryption { pub use native_tls::TlsStream; use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector}; @@ -47,7 +47,40 @@ mod encryption { } } -#[cfg(not(feature = "tls"))] +#[cfg(feature = "rustls-tls")] +mod encryption { + use rustls::ClientConfig; + pub use rustls::{ClientSession, StreamOwned}; + use std::{net::TcpStream, sync::Arc}; + use webpki::DNSNameRef; + + pub use crate::stream::Stream as StreamSwitcher; + /// TCP stream switcher (plain/TLS). + pub type AutoStream = StreamSwitcher>; + + use crate::{error::Result, stream::Mode}; + + pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(stream)), + Mode::Tls => { + let config = { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + + Arc::new(config) + }; + let domain = DNSNameRef::try_from_ascii_str(domain)?; + let client = ClientSession::new(&config, domain); + let stream = StreamOwned::new(client, stream); + + Ok(StreamSwitcher::Tls(stream)) + } + } + } +} + +#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] mod encryption { use std::net::TcpStream; diff --git a/src/error.rs b/src/error.rs index c2becc7..7f14229 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,12 +5,19 @@ use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string} use crate::protocol::Message; use http::Response; -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] pub mod tls { //! TLS error wrapper module, feature-gated. pub use native_tls::Error; } +#[cfg(feature = "rustls-tls")] +pub mod tls { + //! TLS error wrapper module, feature-gated. + pub use rustls::TLSError as Error; + pub use webpki::InvalidDNSNameError as DnsError; +} + /// Result type of all Tungstenite library calls. pub type Result = result::Result; @@ -40,9 +47,15 @@ pub enum Error { /// Input-output error. Apart from WouldBlock, these are generally errors with the /// underlying connection and you should probably consider them fatal. Io(io::Error), - #[cfg(feature = "tls")] + #[cfg(feature = "native-tls")] /// TLS error Tls(tls::Error), + #[cfg(feature = "rustls-tls")] + /// TLS error + Tls(tls::Error), + #[cfg(feature = "rustls-tls")] + /// DNS name resolution error. + Dns(tls::DnsError), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). @@ -67,8 +80,12 @@ impl fmt::Display for Error { Error::ConnectionClosed => write!(f, "Connection closed normally"), Error::AlreadyClosed => write!(f, "Trying to work with closed connection"), Error::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(feature = "tls")] + #[cfg(feature = "native-tls")] Error::Tls(ref err) => write!(f, "TLS error: {}", err), + #[cfg(feature = "rustls-tls")] + Error::Tls(ref err) => write!(f, "TLS error: {}", err), + #[cfg(feature = "rustls-tls")] + Error::Dns(ref err) => write!(f, "Invalid DNS name: {}", err), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), Error::SendQueueFull(_) => write!(f, "Send queue is full"), @@ -136,13 +153,27 @@ impl From for Error { } } -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] +impl From for Error { + fn from(err: tls::Error) -> Self { + Error::Tls(err) + } +} + +#[cfg(feature = "rustls-tls")] impl From for Error { fn from(err: tls::Error) -> Self { Error::Tls(err) } } +#[cfg(feature = "rustls-tls")] +impl From for Error { + fn from(err: tls::DnsError) -> Self { + Error::Dns(err) + } +} + impl From for Error { fn from(err: httparse::Error) -> Self { match err { diff --git a/src/lib.rs b/src/lib.rs index 82f7822..4ec74fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,3 +29,6 @@ pub use crate::{ protocol::{Message, WebSocket}, server::{accept, accept_hdr}, }; + +#[cfg(all(feature = "native-tls", feature = "rustls-tls"))] +compile_error!("either \"native-tls\" or \"rustls-tls\" can be enabled, but not both."); diff --git a/src/stream.rs b/src/stream.rs index 96d26d2..d52edae 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -8,8 +8,10 @@ use std::io::{Read, Result as IoResult, Write}; use std::net::TcpStream; -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] use native_tls::TlsStream; +#[cfg(feature = "rustls-tls")] +use rustls::StreamOwned as TlsStream; /// Stream mode, either plain TCP or TLS. #[derive(Clone, Copy, Debug)] @@ -32,13 +34,20 @@ impl NoDelay for TcpStream { } } -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] impl NoDelay for TlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.get_mut().set_nodelay(nodelay) } } +#[cfg(feature = "rustls-tls")] +impl NoDelay for TlsStream { + fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { + self.sock.set_nodelay(nodelay) + } +} + /// Stream, either plain TCP or TLS. #[derive(Debug)] pub enum Stream { diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 7e3e33f..17d7729 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,5 +1,6 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection //! is closedd from the server's point of view and drop the underlying tcp socket. +#![cfg(any(feature = "native-tls", feature = "rustls-tls"))] use std::{ net::{TcpListener, TcpStream}, @@ -8,12 +9,14 @@ use std::{ time::Duration, }; -use native_tls::TlsStream; use net2::TcpStreamExt; use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; use url::Url; -type Sock = WebSocket>>; +#[cfg(feature = "native-tls")] +type Sock = WebSocket>>; +#[cfg(feature = "rustls-tls")] +type Sock = WebSocket>>; fn do_test(port: u16, client_task: CT, server_task: ST) where