diff --git a/Cargo.toml b/Cargo.toml index d4ebd5b..11b7b3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a log = "0.4" futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } futures-io = { version = "0.3", default-features = false, features = ["std"] } -pin-project = "1" +pin-project-lite = "0.2" [dependencies.tungstenite] version = "0.11.0" diff --git a/src/tokio.rs b/src/tokio.rs index bab4700..19fe904 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -103,7 +103,7 @@ where R: IntoClientRequest + Unpin, S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { - crate::client_async_with_config(request, TokioAdapter(stream), config).await + crate::client_async_with_config(request, TokioAdapter::new(stream), config).await } /// Accepts a new WebSocket connection with the provided stream. @@ -163,7 +163,7 @@ where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, C: Callback + Unpin, { - crate::accept_hdr_async_with_config(TokioAdapter(stream), callback, config).await + crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await } /// Type alias for the stream type of the `client_async()` functions. @@ -379,15 +379,30 @@ where client_async_tls_with_connector_and_config(request, socket, connector, config).await } -use pin_project::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; -/// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide -/// the variants from the `futures` crate and the other way around. -#[pin_project] -#[derive(Debug, Clone)] -pub struct TokioAdapter(#[pin] pub T); +pin_project_lite::pin_project! { + /// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide + /// the variants from the `futures` crate and the other way around. + #[derive(Debug, Clone)] + pub struct TokioAdapter { + #[pin] + inner: T, + } +} + +impl TokioAdapter { + /// Creates a new `TokioAdapter` wrapping the provided value. + pub fn new(inner: T) -> Self { + Self { inner } + } + + /// Consumes this `TokioAdapter`, returning the underlying value. + pub fn into_inner(self) -> T { + self.inner + } +} impl AsyncRead for TokioAdapter { fn poll_read( @@ -396,7 +411,7 @@ impl AsyncRead for TokioAdapter { buf: &mut [u8], ) -> Poll> { let mut buf = tokio::io::ReadBuf::new(buf); - match self.project().0.poll_read(cx, &mut buf)? { + match self.project().inner.poll_read(cx, &mut buf)? { Poll::Pending => Poll::Pending, Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())), } @@ -409,15 +424,15 @@ impl AsyncWrite for TokioAdapter { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.project().0.poll_write(cx, buf) + self.project().inner.poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_shutdown(cx) + self.project().inner.poll_shutdown(cx) } } @@ -428,7 +443,7 @@ impl tokio::io::AsyncRead for TokioAdapter { buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { let slice = buf.initialize_unfilled(); - let n = match self.project().0.poll_read(cx, slice)? { + let n = match self.project().inner.poll_read(cx, slice)? { Poll::Pending => return Poll::Pending, Poll::Ready(n) => n, }; @@ -443,17 +458,17 @@ impl tokio::io::AsyncWrite for TokioAdapter { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.project().0.poll_write(cx, buf) + self.project().inner.poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.project().0.poll_close(cx) + self.project().inner.poll_close(cx) } } diff --git a/src/tokio/dummy_tls.rs b/src/tokio/dummy_tls.rs index 090a769..1dd2a38 100644 --- a/src/tokio/dummy_tls.rs +++ b/src/tokio/dummy_tls.rs @@ -21,7 +21,7 @@ where S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { match mode { - Mode::Plain => Ok(TokioAdapter(socket)), + Mode::Plain => Ok(TokioAdapter::new(socket)), Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), } } diff --git a/src/tokio/native_tls.rs b/src/tokio/native_tls.rs index 8c45118..ae27281 100644 --- a/src/tokio/native_tls.rs +++ b/src/tokio/native_tls.rs @@ -28,7 +28,7 @@ where S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter::new(socket))), Mode::Tls => { let stream = { let connector = if let Some(connector) = connector { @@ -42,7 +42,7 @@ where .await .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + Ok(StreamSwitcher::Tls(TokioAdapter::new(stream))) } } } diff --git a/src/tokio/openssl.rs b/src/tokio/openssl.rs index 9c3d461..f878f8f 100644 --- a/src/tokio/openssl.rs +++ b/src/tokio/openssl.rs @@ -34,7 +34,7 @@ where + Sync, { match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter::new(socket))), Mode::Tls => { let stream = { let connector = if let Some(connector) = connector { @@ -54,7 +54,7 @@ where TlsStream::new(ssl, socket) .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + Ok(StreamSwitcher::Tls(TokioAdapter::new(stream))) } } } diff --git a/src/tokio/rustls.rs b/src/tokio/rustls.rs index 9c7d885..8b1a9cf 100644 --- a/src/tokio/rustls.rs +++ b/src/tokio/rustls.rs @@ -29,7 +29,7 @@ where S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { match mode { - Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter(socket))), + Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter::new(socket))), Mode::Tls => { let stream = { let connector = if let Some(connector) = connector { @@ -48,7 +48,7 @@ where .await .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))? }; - Ok(StreamSwitcher::Tls(TokioAdapter(stream))) + Ok(StreamSwitcher::Tls(TokioAdapter::new(stream))) } } }