From 770b8f3582d8744ed56149e083683f3c1c2a70f2 Mon Sep 17 00:00:00 2001 From: "error.d" Date: Mon, 27 Apr 2020 17:39:58 +0800 Subject: [PATCH] connect support timout --- src/client.rs | 27 ++++++++++++++++++++------- tests/connection_reset.rs | 2 +- tests/no_send_after_close.rs | 2 +- tests/receive_after_init_close.rs | 2 +- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/client.rs b/src/client.rs index d9d7151..011c86a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,6 +9,8 @@ use log::*; use url::Url; +use std::time::Duration; + use crate::handshake::client::{Request, Response}; use crate::protocol::WebSocketConfig; @@ -88,6 +90,7 @@ use crate::stream::{Mode, NoDelay}; /// `connect` since it's the only function that uses native_tls. pub fn connect_with_config( request: Req, + timeout: Option, config: Option, ) -> Result<(WebSocket, Response)> { let request: Request = request.into_client_request()?; @@ -102,7 +105,7 @@ pub fn connect_with_config( Mode::Tls => 443, }); let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode, timeout)?; NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, @@ -122,19 +125,29 @@ pub fn connect_with_config( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect(request: Req) -> Result<(WebSocket, Response)> { - connect_with_config(request, None) +pub fn connect(request: Req, timeout: Option) -> Result<(WebSocket, Response)> { + connect_with_config(request, timeout, None) } -fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { +fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode, timeout: Option) -> Result { let domain = uri .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); - if let Ok(raw_stream) = TcpStream::connect(addr) { - if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { - return Ok(stream); + let raw_stream = if let Some(timeout) = timeout { + TcpStream::connect_timeout(addr, timeout) + } else { + TcpStream::connect(addr) + }; + if let Err(err) = raw_stream { + debug!("connect {} at {} error: {:?}", uri, addr, err); + } else { + let stream = wrap_stream(raw_stream.unwrap(), domain, mode); + if let Err(err) = stream { + debug!("warp_stream error: {:?}", err); + } else { + return Ok(stream.unwrap()); } } } diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index d95ee81..06cfdd2 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -30,7 +30,7 @@ where .expect("Can't listen, is port already in use?"); let client_thread = spawn(move || { - let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap()) + let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap(), None) .expect("Can't connect to port"); client_task(client); diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index d8e20e5..d41ae53 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -22,7 +22,7 @@ fn test_no_send_after_close() { let server = TcpListener::bind("127.0.0.1:3013").unwrap(); let client_thread = spawn(move || { - let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); + let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap(), None).unwrap(); let message = client.read_message().unwrap(); // receive close from server assert!(message.is_close()); diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 352020e..90b345d 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -22,7 +22,7 @@ fn test_receive_after_init_close() { let server = TcpListener::bind("127.0.0.1:3013").unwrap(); let client_thread = spawn(move || { - let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); + let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap(), None).unwrap(); client .write_message(Message::Text("Hello WebSocket".into()))