From b01d17e317df327ff0fe79fde106214b3b8a6826 Mon Sep 17 00:00:00 2001 From: Brian Schwind Date: Fri, 25 Jan 2019 12:57:12 +0900 Subject: [PATCH] Add support for HTTP proxies with optional Basic Authentication --- src/client.rs | 97 ++++++++++++++++++++++++++++++++++++----- src/handshake/client.rs | 46 +++++++++++++++++++ 2 files changed, 132 insertions(+), 11 deletions(-) diff --git a/src/client.rs b/src/client.rs index 4fc2b46..b5b2ea0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,15 @@ //! Methods to connect to an WebSocket as a client. -use std::net::{TcpStream, SocketAddr, ToSocketAddrs}; +use std::net::{TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; use std::io::{Read, Write}; use url::Url; -use handshake::client::Response; +use httparse; +use httparse::Status; + +use handshake::client::{Response, ProxyAuth}; use protocol::WebSocketConfig; #[cfg(feature="tls")] @@ -87,8 +90,8 @@ pub fn connect_with_config<'t, Req: Into>>( ) -> Result<(WebSocket, Response)> { let request: Request = request.into(); let mode = url_mode(&request.url)?; - let addrs = request.url.to_socket_addrs()?; - let mut stream = connect_to_some(addrs, &request.url, mode)?; + // let addrs = request.url.to_socket_addrs()?; + let mut stream = connect_to_some(&request, mode)?; NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config) .map_err(|e| match e { @@ -115,19 +118,91 @@ pub fn connect<'t, Req: Into>>(request: Req) connect_with_config(request, None) } -fn connect_to_some(addrs: A, url: &Url, mode: Mode) -> Result - where A: Iterator -{ - let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?; +fn connect_to_some(req: &Request, mode: Mode) -> Result { + let addrs = if let Some(ref p) = req.proxy { + p.url.to_socket_addrs()? + } else { + req.url.to_socket_addrs()? + }; + + let domain = req.url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + for addr in addrs { - debug!("Trying to contact {} at {}...", url, addr); - if let Ok(raw_stream) = TcpStream::connect(addr) { + debug!("Trying to contact {} at {}...", req.url, addr); + + if let Ok(mut raw_stream) = TcpStream::connect(addr) { + if let Some(ref proxy) = req.proxy { + let port = req.url.port_or_known_default().unwrap_or_else(|| { + match mode { + Mode::Plain => 80, + Mode::Tls => 443 + } + }); + + // Connect to a proxy here + let mut buf = format!("\ + CONNECT {host}:{port} HTTP/1.1\r\n\ + Host: {host}:{port}\r\n\ + Proxy-Connection: Keep-Alive\r\n", + host = domain, + port = port, + ).into_bytes(); + + // Add any headers required for authentication + if let Some(ref auth) = proxy.auth { + match auth { + ProxyAuth::Basic(basic_auth) => { + buf.extend_from_slice(&format!("Proxy-Authorization: {}\r\n", basic_auth.to_header_value()).as_bytes()); + } + } + } + + // Add the trailing empty line + buf.extend_from_slice(b"\r\n"); + + let _ = raw_stream.write(&buf)?; + + let mut read_buf = [0; 256]; + let mut response_vec = Vec::with_capacity(1024); + + loop { + let n = raw_stream.read(&mut read_buf)?; + + response_vec.extend_from_slice(&read_buf[..n]); + + let mut hbuffer = [httparse::EMPTY_HEADER; super::handshake::headers::MAX_HEADERS]; + let mut res = httparse::Response::new(&mut hbuffer); + + match res.parse(&response_vec)? { + Status::Partial => { + // We only received part of the proxy HTTP response, + // continue reading from the socket to get the rest + } + Status::Complete(_size) => { + match res.code { + Some(200) => { + // We're connected to the proxy and good to go + break; + } + Some(code) => { + return Err(Error::Http(code)) + } + None => { + // https://support.cloudflare.com/hc/en-us/articles/200171936-Error-520-Web-server-is-returning-an-unknown-error + return Err(Error::Http(520)) + } + } + } + } + } + } + if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { return Ok(stream) } } } - Err(Error::Url(format!("Unable to connect to {}", url).into())) + Err(Error::Url(format!("Unable to connect to {}", req.url).into())) } /// Get the mode of the given URL. diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 014af64..c96c8ef 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,6 +16,44 @@ use super::headers::{Headers, FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; +/// HTTP Basic Authentication Credentials +#[derive(Debug)] +pub struct BasicAuth { + /// Basic Auth username + pub username: String, + + /// Optional Basic Auth password + pub password: Option +} + +impl BasicAuth { + /// Converts the credential information into a format suitable for use + /// in HTTP headers. + /// Example: Username: user, Password: pass -> "Basic dXNlcjpwYXNz" + pub fn to_header_value(&self) -> String { + // TODO - clean this up + let creds = format!("{}:{}", self.username, self.password.as_ref().unwrap_or(&"".to_string())); + format!("Basic {}", base64::encode(&creds)) + } +} + +/// The different types of Proxy Authentication +#[derive(Debug)] +pub enum ProxyAuth { + /// Sets the `Proxy-Authorization` header using Basic Auth + Basic(BasicAuth) +} + +/// A configuration for using an HTTP proxy +#[derive(Debug)] +pub struct Proxy { + /// The URL of the proxy server to connect to + pub url: Url, + + /// Optional proxy authentication configuration + pub auth: Option, +} + /// Client request. #[derive(Debug)] pub struct Request<'t> { @@ -23,6 +61,8 @@ pub struct Request<'t> { pub url: Url, /// Extra HTTP headers to append to the request. pub extra_headers: Option, Cow<'t, str>)>>, + /// Proxy Configuration + pub proxy: Option, } impl<'t> Request<'t> { @@ -56,6 +96,11 @@ impl<'t> Request<'t> { headers.push((name, value)); self.extra_headers = Some(headers); } + + /// Sets an HTTP proxy to use when connecting to the WebSocket server + pub fn set_proxy(&mut self, proxy: Proxy) { + self.proxy = Some(proxy); + } } impl From for Request<'static> { @@ -63,6 +108,7 @@ impl From for Request<'static> { Request { url: value, extra_headers: None, + proxy: None, } } }