Refactor TLS handling

Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
pull/12/head
Alexey Galakhov 8 years ago
parent 450790725d
commit e0eecd28b1
  1. 78
      src/client.rs
  2. 10
      src/error.rs

@ -7,7 +7,52 @@ use std::io::{Read, Write};
use url::Url;
#[cfg(feature="tls")]
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError};
mod encryption {
use std::net::TcpStream;
use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError};
pub use native_tls::TlsStream;
pub use stream::Stream as StreamSwitcher;
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use stream::Mode;
use error::Result;
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder()?.build()?;
connector.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
}
}
}
}
#[cfg(not(feature="tls"))]
mod encryption {
use std::net::TcpStream;
use stream::Mode;
use error::{Error, Result};
pub type AutoStream = TcpStream;
pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
}
pub use self::encryption::AutoStream;
use self::encryption::wrap_stream;
use protocol::WebSocket;
use handshake::HandshakeError;
@ -15,13 +60,6 @@ use handshake::client::{ClientHandshake, Request};
use stream::Mode;
use error::{Error, Result};
#[cfg(feature="tls")]
use stream::Stream as StreamSwitcher;
#[cfg(feature="tls")]
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
#[cfg(not(feature="tls"))]
pub type AutoStream = TcpStream;
/// Connect to the given WebSocket in blocking mode.
///
@ -46,30 +84,6 @@ pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> {
})
}
#[cfg(feature="tls")]
fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder()?.build()?;
connector.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
}
}
}
#[cfg(not(feature="tls"))]
fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr>
{

@ -12,7 +12,9 @@ use std::string;
use httparse;
#[cfg(feature="tls")]
use native_tls;
pub mod tls {
pub use native_tls::Error;
}
pub type Result<T> = result::Result<T, Error>;
@ -25,7 +27,7 @@ pub enum Error {
Io(io::Error),
#[cfg(feature="tls")]
/// TLS error
Tls(native_tls::Error),
Tls(tls::Error),
/// Buffer capacity exhausted
Capacity(Cow<'static, str>),
/// Protocol violation
@ -89,8 +91,8 @@ impl From<string::FromUtf8Error> for Error {
}
#[cfg(feature="tls")]
impl From<native_tls::Error> for Error {
fn from(err: native_tls::Error) -> Self {
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)
}
}

Loading…
Cancel
Save