From e0eecd28b12d9c2413add3e6b9fae0a543083424 Mon Sep 17 00:00:00 2001 From: Alexey Galakhov Date: Wed, 22 Mar 2017 19:43:20 +0100 Subject: [PATCH] Refactor TLS handling Signed-off-by: Alexey Galakhov --- src/client.rs | 78 ++++++++++++++++++++++++++++++--------------------- src/error.rs | 10 ++++--- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/src/client.rs b/src/client.rs index d7b3f1f..5300ab2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,7 +7,52 @@ use std::io::{Read, Write}; use url::Url; #[cfg(feature="tls")] -use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError}; +mod encryption { + use std::net::TcpStream; + use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError}; + pub use native_tls::TlsStream; + + pub use stream::Stream as StreamSwitcher; + pub type AutoStream = StreamSwitcher>; + + use stream::Mode; + use error::Result; + + pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(stream)), + Mode::Tls => { + let connector = TlsConnector::builder()?.build()?; + connector.connect(domain, stream) + .map_err(|e| match e { + TlsHandshakeError::Failure(f) => f.into(), + TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), + }) + .map(|s| StreamSwitcher::Tls(s)) + } + } + } +} + +#[cfg(not(feature="tls"))] +mod encryption { + use std::net::TcpStream; + + use stream::Mode; + use error::{Error, Result}; + + pub type AutoStream = TcpStream; + + pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(stream), + Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + } + } +} + +pub use self::encryption::AutoStream; +use self::encryption::wrap_stream; use protocol::WebSocket; use handshake::HandshakeError; @@ -15,13 +60,6 @@ use handshake::client::{ClientHandshake, Request}; use stream::Mode; use error::{Error, Result}; -#[cfg(feature="tls")] -use stream::Stream as StreamSwitcher; - -#[cfg(feature="tls")] -pub type AutoStream = StreamSwitcher>; -#[cfg(not(feature="tls"))] -pub type AutoStream = TcpStream; /// Connect to the given WebSocket in blocking mode. /// @@ -46,30 +84,6 @@ pub fn connect(url: Url) -> Result> { }) } -#[cfg(feature="tls")] -fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { - match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(stream)), - Mode::Tls => { - let connector = TlsConnector::builder()?.build()?; - connector.connect(domain, stream) - .map_err(|e| match e { - TlsHandshakeError::Failure(f) => f.into(), - TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), - }) - .map(|s| StreamSwitcher::Tls(s)) - } - } -} - -#[cfg(not(feature="tls"))] -fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { - match mode { - Mode::Plain => Ok(stream), - Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), - } -} - fn connect_to_some(addrs: A, url: &Url, mode: Mode) -> Result where A: Iterator { diff --git a/src/error.rs b/src/error.rs index 1a9c7f5..e885697 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,7 +12,9 @@ use std::string; use httparse; #[cfg(feature="tls")] -use native_tls; +pub mod tls { + pub use native_tls::Error; +} pub type Result = result::Result; @@ -25,7 +27,7 @@ pub enum Error { Io(io::Error), #[cfg(feature="tls")] /// TLS error - Tls(native_tls::Error), + Tls(tls::Error), /// Buffer capacity exhausted Capacity(Cow<'static, str>), /// Protocol violation @@ -89,8 +91,8 @@ impl From for Error { } #[cfg(feature="tls")] -impl From for Error { - fn from(err: native_tls::Error) -> Self { +impl From for Error { + fn from(err: tls::Error) -> Self { Error::Tls(err) } }