Add support for rustls as TLS backend

pull/166/head
Dominik Nakamura 4 years ago
parent 208061ba28
commit e9aaf9b1e9
No known key found for this signature in database
GPG Key ID: E4C6A749B2491910
  1. 21
      Cargo.toml
  2. 37
      src/client.rs
  3. 39
      src/error.rs
  4. 3
      src/lib.rs
  5. 13
      src/stream.rs
  6. 7
      tests/connection_reset.rs

@ -12,10 +12,13 @@ repository = "https://github.com/snapview/tungstenite-rs"
version = "0.11.1"
edition = "2018"
[package.metadata.docs.rs]
features = ["native-tls"]
[features]
default = ["tls"]
tls = ["native-tls"]
tls-vendored = ["native-tls", "native-tls/vendored"]
default = []
native-tls-vendored = ["native-tls", "native-tls/vendored"]
rustls-tls = ["rustls", "webpki", "webpki-roots"]
[dependencies]
base64 = "0.13.0"
@ -34,6 +37,18 @@ utf-8 = "0.7.5"
optional = true
version = "0.2.3"
[dependencies.rustls]
optional = true
version = "0.19.0"
[dependencies.webpki]
optional = true
version = "0.21.4"
[dependencies.webpki-roots]
optional = true
version = "0.21.0"
[dev-dependencies]
env_logger = "0.8.1"
net2 = "0.2.33"

@ -16,7 +16,7 @@ use crate::{
protocol::WebSocketConfig,
};
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
mod encryption {
pub use native_tls::TlsStream;
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector};
@ -47,7 +47,40 @@ mod encryption {
}
}
#[cfg(not(feature = "tls"))]
#[cfg(feature = "rustls-tls")]
mod encryption {
use rustls::ClientConfig;
pub use rustls::{ClientSession, StreamOwned};
use std::{net::TcpStream, sync::Arc};
use webpki::DNSNameRef;
pub use crate::stream::Stream as StreamSwitcher;
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, StreamOwned<ClientSession, TcpStream>>;
use crate::{error::Result, stream::Mode};
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let config = {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
Arc::new(config)
};
let domain = DNSNameRef::try_from_ascii_str(domain)?;
let client = ClientSession::new(&config, domain);
let stream = StreamOwned::new(client, stream);
Ok(StreamSwitcher::Tls(stream))
}
}
}
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
mod encryption {
use std::net::TcpStream;

@ -5,12 +5,19 @@ use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}
use crate::protocol::Message;
use http::Response;
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
pub mod tls {
//! TLS error wrapper module, feature-gated.
pub use native_tls::Error;
}
#[cfg(feature = "rustls-tls")]
pub mod tls {
//! TLS error wrapper module, feature-gated.
pub use rustls::TLSError as Error;
pub use webpki::InvalidDNSNameError as DnsError;
}
/// Result type of all Tungstenite library calls.
pub type Result<T> = result::Result<T, Error>;
@ -40,9 +47,15 @@ pub enum Error {
/// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal.
Io(io::Error),
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
/// TLS error
Tls(tls::Error),
#[cfg(feature = "rustls-tls")]
/// TLS error
Tls(tls::Error),
#[cfg(feature = "rustls-tls")]
/// DNS name resolution error.
Dns(tls::DnsError),
/// - When reading: buffer capacity exhausted.
/// - When writing: your message is bigger than the configured max message size
/// (64MB by default).
@ -67,8 +80,12 @@ impl fmt::Display for Error {
Error::ConnectionClosed => write!(f, "Connection closed normally"),
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
#[cfg(feature = "rustls-tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
#[cfg(feature = "rustls-tls")]
Error::Dns(ref err) => write!(f, "Invalid DNS name: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
@ -136,13 +153,27 @@ impl From<http::Error> for Error {
}
}
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)
}
}
#[cfg(feature = "rustls-tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)
}
}
#[cfg(feature = "rustls-tls")]
impl From<tls::DnsError> for Error {
fn from(err: tls::DnsError) -> Self {
Error::Dns(err)
}
}
impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self {
match err {

@ -29,3 +29,6 @@ pub use crate::{
protocol::{Message, WebSocket},
server::{accept, accept_hdr},
};
#[cfg(all(feature = "native-tls", feature = "rustls-tls"))]
compile_error!("either \"native-tls\" or \"rustls-tls\" can be enabled, but not both.");

@ -8,8 +8,10 @@ use std::io::{Read, Result as IoResult, Write};
use std::net::TcpStream;
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
use native_tls::TlsStream;
#[cfg(feature = "rustls-tls")]
use rustls::StreamOwned as TlsStream;
/// Stream mode, either plain TCP or TLS.
#[derive(Clone, Copy, Debug)]
@ -32,13 +34,20 @@ impl NoDelay for TcpStream {
}
}
#[cfg(feature = "tls")]
#[cfg(feature = "native-tls")]
impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().set_nodelay(nodelay)
}
}
#[cfg(feature = "rustls-tls")]
impl<S: rustls::Session, T: Read + Write + NoDelay> NoDelay for TlsStream<S, T> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.sock.set_nodelay(nodelay)
}
}
/// Stream, either plain TCP or TLS.
#[derive(Debug)]
pub enum Stream<S, T> {

@ -1,5 +1,6 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
#![cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use std::{
net::{TcpListener, TcpStream},
@ -8,12 +9,14 @@ use std::{
time::Duration,
};
use native_tls::TlsStream;
use net2::TcpStreamExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
#[cfg(feature = "native-tls")]
type Sock = WebSocket<Stream<TcpStream, native_tls::TlsStream<TcpStream>>>;
#[cfg(feature = "rustls-tls")]
type Sock = WebSocket<Stream<TcpStream, rustls::StreamOwned<rustls::ClientSession, TcpStream>>>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where

Loading…
Cancel
Save