From 4d3be03d98fac702d02da0f262b1021e0f6d3269 Mon Sep 17 00:00:00 2001 From: liaozhou Date: Fri, 15 Oct 2021 15:50:51 +0800 Subject: [PATCH] support overriding DNS items for connect --- src/client.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 67a3c41..167a516 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,7 @@ //! Methods to connect to a WebSocket as a client. use std::{ + collections::HashMap, io::{Read, Write}, net::{SocketAddr, TcpStream, ToSocketAddrs}, result::Result as StdResult, @@ -43,10 +44,35 @@ pub fn connect_with_config( request: Req, config: Option, max_redirects: u8, +) -> Result<(WebSocket>, Response)> { + connect_with_config_overrides_dns(request, config, HashMap::new(), max_redirects) +} + +/// Connect to the given WebSocket in blocking mode, overrides DNS items. +/// +/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is +/// equal to calling `connect()` function. +/// +/// The URL may be either ws:// or wss://. +/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on. +/// +/// This function "just works" for those who wants a simple blocking solution +/// similar to `std::net::TcpStream`. If you want a non-blocking or other +/// custom stream, call `client` instead. +/// +/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If +/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of +/// the `*-tls` features if you don't call `connect` since it's the only function that uses them. +pub fn connect_with_config_overrides_dns( + request: Req, + config: Option, + dns_overrides: HashMap, + max_redirects: u8, ) -> Result<(WebSocket>, Response)> { fn try_client_handshake( request: Request, config: Option, + dns_overrides: &HashMap, ) -> Result<(WebSocket>, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; @@ -55,7 +81,11 @@ pub fn connect_with_config( Mode::Plain => 80, Mode::Tls => 443, }); - let addrs = (host, port).to_socket_addrs()?; + let addrs = if dns_overrides.contains_key(host) { + vec![dns_overrides.get(host).unwrap().clone()].into_iter() + } else { + (host, port).to_socket_addrs()? + }; let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?; NoDelay::set_nodelay(&mut stream, true)?; @@ -83,7 +113,7 @@ pub fn connect_with_config( for attempt in 0..(max_redirects + 1) { let request = create_request(&parts, &uri); - match try_client_handshake(request, config) { + match try_client_handshake(request, config, &dns_overrides) { Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => { if let Some(location) = res.headers().get("Location") { uri = location.to_str()?.parse::()?; @@ -119,6 +149,25 @@ pub fn connect( connect_with_config(request, None, 3) } +/// Connect to the given WebSocket in blocking mode, overrides DNS items. +/// +/// The URL may be either ws:// or wss://. +/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on. +/// +/// This function "just works" for those who wants a simple blocking solution +/// similar to `std::net::TcpStream`. If you want a non-blocking or other +/// custom stream, call `client` instead. +/// +/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If +/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of +/// the `*-tls` features if you don't call `connect` since it's the only function that uses them. +pub fn connect_overrides_dns( + request: Req, + dns_overrides: HashMap, +) -> Result<(WebSocket>, Response)> { + connect_with_config_overrides_dns(request, None, dns_overrides, 3) +} + fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result { for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr);