From eda57a012ce9006e6165c8b408b87d1bf3b8dbce Mon Sep 17 00:00:00 2001 From: Jeffrey Esquivel S Date: Mon, 12 Nov 2018 14:03:54 -0600 Subject: [PATCH] Implement support to get the peer address --- examples/client.rs | 4 ++++ examples/server.rs | 1 + src/connect.rs | 18 ++++++++++++++++-- src/lib.rs | 16 ++++++++++++++++ src/stream.rs | 16 ++++++++++++++++ 5 files changed, 53 insertions(+), 2 deletions(-) diff --git a/examples/client.rs b/examples/client.rs index 6ceb9da..4d746b5 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -25,6 +25,7 @@ use futures::{Future, Sink, Stream}; use tungstenite::protocol::Message; use tokio_tungstenite::connect_async; +use tokio_tungstenite::stream::PeerAddr; fn main() { // Specify the server address to which the client will be connecting. @@ -61,6 +62,9 @@ fn main() { let client = connect_async(url).and_then(move |(ws_stream, _)| { println!("WebSocket handshake has been successfully completed"); + let addr = ws_stream.peer_addr().expect("connected streams should have a peer address"); + println!("Peer address: {}", addr); + // `sink` is the stream of messages going out. // `stream` is the stream of incoming messages. let (sink, stream) = ws_stream.split(); diff --git a/examples/server.rs b/examples/server.rs index 1a3dfce..f24a71c 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -49,6 +49,7 @@ fn main() { let srv = socket.incoming().for_each(move |stream| { let addr = stream.peer_addr().expect("connected streams should have a peer address"); + println!("Peer address: {}", addr); // We have to clone both of these values, because the `and_then` // function below constructs a new future, `and_then` requires diff --git a/src/connect.rs b/src/connect.rs index 2573ab9..dfb44de 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -3,6 +3,7 @@ extern crate tokio_dns; extern crate tokio_tcp; +use std::net::SocketAddr; use std::io::Result as IoResult; use self::tokio_tcp::TcpStream; @@ -14,7 +15,7 @@ use tungstenite::Error; use tungstenite::client::url_mode; use tungstenite::handshake::client::Response; -use stream::NoDelay; +use stream::{NoDelay, PeerAddr}; use super::{WebSocketStream, Request, client_async}; impl NoDelay for TcpStream { @@ -23,6 +24,12 @@ impl NoDelay for TcpStream { } } +impl PeerAddr for TcpStream { + fn peer_addr(&self) -> IoResult { + self.peer_addr() + } +} + #[cfg(feature="tls")] mod encryption { extern crate native_tls; @@ -31,6 +38,7 @@ mod encryption { use self::native_tls::TlsConnector; use self::tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream}; + use std::net::SocketAddr; use std::io::{Read, Write, Result as IoResult}; use futures::{future, Future}; @@ -39,7 +47,7 @@ mod encryption { use tungstenite::Error; use tungstenite::stream::Mode; - use stream::{NoDelay, Stream as StreamSwitcher}; + use stream::{NoDelay, PeerAddr, Stream as StreamSwitcher}; /// A stream that might be protected with TLS. pub type MaybeTlsStream = StreamSwitcher>; @@ -52,6 +60,12 @@ mod encryption { } } + impl PeerAddr for TlsStream { + fn peer_addr(&self) -> IoResult { + self.get_ref().get_ref().peer_addr() + } + } + pub fn wrap_stream(socket: S, domain: String, mode: Mode) -> Box, Error=Error> + Send> where diff --git a/src/lib.rs b/src/lib.rs index a55d696..d8a6e5e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,12 @@ pub mod stream; use std::io::ErrorKind; +#[cfg(feature="stream")] +use std::{ + net::SocketAddr, + io::Result as IoResult, +}; + use futures::{Poll, Future, Async, AsyncSink, Stream, Sink, StartSend}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -45,6 +51,9 @@ use tungstenite::{ #[cfg(feature="connect")] pub use connect::{connect_async, client_async_tls}; +#[cfg(feature="stream")] +pub use stream::PeerAddr; + #[cfg(all(feature="connect", feature="tls"))] pub use connect::MaybeTlsStream; @@ -194,6 +203,13 @@ impl WebSocketStream { } } +#[cfg(feature="stream")] +impl PeerAddr for WebSocketStream { + fn peer_addr(&self) -> IoResult { + self.inner.get_ref().peer_addr() + } +} + impl Stream for WebSocketStream where T: AsyncRead + AsyncWrite { type Item = Message; type Error = WsError; diff --git a/src/stream.rs b/src/stream.rs index d06cb1e..944972e 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -6,6 +6,7 @@ extern crate bytes; +use std::net::SocketAddr; use std::io::{Read, Write, Result as IoResult, Error as IoError}; use self::bytes::{Buf, BufMut}; @@ -18,6 +19,12 @@ pub trait NoDelay { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>; } +/// Trait to get the remote address from the underlying stream. +pub trait PeerAddr { + /// Returns the remote address that this stream is connected to. + fn peer_addr(&self) -> IoResult; +} + /// Stream, either plain TCP or TLS. pub enum Stream { /// Unencrypted socket stream. @@ -59,6 +66,15 @@ impl NoDelay for Stream { } } +impl PeerAddr for Stream { + fn peer_addr(&self) -> IoResult { + match *self { + Stream::Plain(ref s) => s.peer_addr(), + Stream::Tls(ref s) => s.peer_addr(), + } + } +} + impl AsyncRead for Stream { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { match *self {