diff --git a/src/compat.rs b/src/compat.rs index 1e8505e..644f856 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -4,54 +4,130 @@ use std::pin::Pin; use std::task::{Context, Poll}; use futures::io::{AsyncRead, AsyncWrite}; -use tungstenite::{Error as WsError, WebSocket}; +use futures::task; +use std::sync::Arc; +use tungstenite::Error as WsError; -pub(crate) trait HasContext { - fn set_context(&mut self, context: (bool, *mut ())); +pub(crate) enum ContextWaker { + Read, + Write, } + #[derive(Debug)] pub struct AllowStd { - pub(crate) inner: S, - pub(crate) context: (bool, *mut ()), + inner: S, + // We have the problem that external read operations (i.e. the Stream impl) + // can trigger both read (AsyncRead) and write (AsyncWrite) operations on + // the underyling stream. At the same time write operations (i.e. the Sink + // impl) can trigger write operations (AsyncWrite) too. + // Both the Stream and the Sink can be used on two different tasks, but it + // is required that AsyncRead and AsyncWrite are only ever used by a single + // task (or better: with a single waker) at a time. + // + // Doing otherwise would cause only the latest waker to be remembered, so + // in our case either the Stream or the Sink impl would potentially wait + // forever to be woken up because only the other one would've been woken + // up. + // + // To solve this we implement a waker proxy that has two slots (one for + // read, one for write) to store wakers. One waker proxy is always passed + // to the AsyncRead, the other to AsyncWrite so that they will only ever + // have to store a single waker, but internally we dispatch any wakeups to + // up to two actual wakers (one from the Sink impl and one from the Stream + // impl). + // + // write_waker_proxy is always used for AsyncWrite, read_waker_proxy for + // AsyncRead. The read_waker slots of both are used for the Stream impl + // (and handshaking), the write_waker slots for the Sink impl. + write_waker_proxy: Arc, + read_waker_proxy: Arc, +} + +// Internal trait used only in the Handshake module for registering +// the waker for the context used during handshaking. We're using the +// read waker slot for this, but any would do. +// +// Don't ever use this from multiple tasks at the same time! +pub(crate) trait SetWaker { + fn set_waker(&self, waker: &task::Waker); } -impl HasContext for AllowStd { - fn set_context(&mut self, context: (bool, *mut ())) { - self.context = context; +impl SetWaker for AllowStd { + fn set_waker(&self, waker: &task::Waker) { + self.set_waker(ContextWaker::Read, waker); } } -pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket>); +impl AllowStd { + pub(crate) fn new(inner: S, waker: &task::Waker) -> Self { + let res = Self { + inner, + write_waker_proxy: Default::default(), + read_waker_proxy: Default::default(), + }; + + // Register the handshake waker as read waker for both proxies, + // see also the SetWaker trait. + res.write_waker_proxy.read_waker.register(waker); + res.read_waker_proxy.read_waker.register(waker); -impl Drop for Guard<'_, S> { - fn drop(&mut self) { - trace!("{}:{} Guard.drop", file!(), line!()); - (self.0).get_mut().context = (true, std::ptr::null_mut()); + res + } + + // Set the read or write waker for our proxies. + // + // Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream + // impl on the WebSocketStream. + // Reading can also cause writes to happen, e.g. in case of Message::Ping handling. + // + // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the + // WebSocketStream. + pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) { + match kind { + ContextWaker::Read => { + self.write_waker_proxy.read_waker.register(waker); + self.read_waker_proxy.read_waker.register(waker); + } + ContextWaker::Write => { + self.write_waker_proxy.write_waker.register(waker); + self.read_waker_proxy.write_waker.register(waker); + } + } } } -// *mut () context is neither Send nor Sync -unsafe impl Send for AllowStd {} -unsafe impl Sync for AllowStd {} +// Proxy Waker that we pass to the internal AsyncRead/Write of the +// stream underlying the websocket. We have two slots here for the +// actual wakers to allow external read operations to trigger both +// reads and writes, and the same for writes. +#[derive(Debug, Default)] +struct WakerProxy { + read_waker: task::AtomicWaker, + write_waker: task::AtomicWaker, +} + +impl task::ArcWake for WakerProxy { + fn wake_by_ref(arc_self: &Arc) { + arc_self.read_waker.wake(); + arc_self.write_waker.wake(); + } +} impl AllowStd where S: Unpin, { - fn with_context(&mut self, f: F) -> Poll> + fn with_context(&mut self, kind: ContextWaker, f: F) -> Poll> where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { trace!("{}:{} AllowStd.with_context", file!(), line!()); - unsafe { - if !self.context.0 { - //was called by start_send without context - return Poll::Pending - } - assert!(!self.context.1.is_null()); - let waker = &mut *(self.context.1 as *mut _); - f(waker, Pin::new(&mut self.inner)) - } + let waker = match kind { + ContextWaker::Read => task::waker_ref(&self.read_waker_proxy), + ContextWaker::Write => task::waker_ref(&self.write_waker_proxy), + }; + let mut context = task::Context::from_waker(&waker); + f(&mut context, Pin::new(&mut self.inner)) } pub(crate) fn get_mut(&mut self) -> &mut S { @@ -69,7 +145,7 @@ where { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { trace!("{}:{} Read.read", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Read, |ctx, stream| { trace!( "{}:{} Read.with_context read -> poll_read", file!(), @@ -89,7 +165,7 @@ where { fn write(&mut self, buf: &[u8]) -> std::io::Result { trace!("{}:{} Write.write", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Write, |ctx, stream| { trace!( "{}:{} Write.with_context write -> poll_write", file!(), @@ -104,7 +180,7 @@ where fn flush(&mut self) -> std::io::Result<()> { trace!("{}:{} Write.flush", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Write, |ctx, stream| { trace!( "{}:{} Write.with_context flush -> poll_flush", file!(), @@ -124,7 +200,7 @@ pub(crate) fn cvt(r: Result) -> Poll> { Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => { trace!("WouldBlock"); Poll::Pending - }, + } Err(e) => Poll::Ready(Err(e)), } } diff --git a/src/handshake.rs b/src/handshake.rs index 7d517ab..ecf6172 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -1,4 +1,4 @@ -use crate::compat::{AllowStd, HasContext}; +use crate::compat::{AllowStd, SetWaker}; use crate::WebSocketStream; use futures::io::{AsyncRead, AsyncWrite}; use log::*; @@ -45,10 +45,7 @@ where .take() .expect("future polled after completion"); trace!("Setting context when skipping handshake"); - let stream = AllowStd { - inner: inner.stream, - context: (true, ctx as *mut _ as *mut ()), - }; + let stream = AllowStd::new(inner.stream, ctx.waker()); Poll::Ready((inner.f)(stream)) } @@ -71,7 +68,7 @@ struct StartedHandshakeFutureInner { async fn handshake(stream: S, f: F) -> Result> where Role: HandshakeRole + Unpin, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, F: FnOnce(AllowStd) -> Result> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { @@ -125,7 +122,7 @@ where impl Future for StartedHandshakeFuture where Role: HandshakeRole, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, F: FnOnce(AllowStd) -> Result> + Unpin, S: Unpin, AllowStd: Read + Write, @@ -135,18 +132,11 @@ where fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { let inner = self.0.take().expect("future polled after completion"); trace!("Setting ctx when starting handshake"); - let stream = AllowStd { - inner: inner.stream, - context: (true, ctx as *mut _ as *mut ()), - }; + let stream = AllowStd::new(inner.stream, ctx.waker()); match (inner.f)(stream) { Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), - Err(Error::Interrupted(mut mid)) => { - let machine = mid.get_mut(); - machine.get_mut().set_context((true, std::ptr::null_mut())); - Poll::Ready(Ok(StartedHandshake::Mid(mid))) - } + Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), } } @@ -155,7 +145,7 @@ where impl Future for MidHandshake where Role: HandshakeRole + Unpin, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, { type Output = Result>; @@ -165,16 +155,12 @@ where let machine = s.get_mut(); trace!("Setting context in handshake"); - machine - .get_mut() - .set_context((true, cx as *mut _ as *mut ())); + machine.get_mut().set_waker(cx.waker()); match s.handshake() { Ok(stream) => Poll::Ready(Ok(stream)), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), - Err(Error::Interrupted(mut mid)) => { - let machine = mid.get_mut(); - machine.get_mut().set_context((true, std::ptr::null_mut())); + Err(Error::Interrupted(mid)) => { *this.0 = Some(mid); Poll::Pending } diff --git a/src/lib.rs b/src/lib.rs index 2cbb5b0..6106612 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub mod stream; use std::io::{Read, Write}; -use compat::{cvt, AllowStd}; +use compat::{cvt, AllowStd, ContextWaker}; use futures::io::{AsyncRead, AsyncWrite}; use futures::{Sink, Stream}; use log::*; @@ -216,19 +216,17 @@ impl WebSocketStream { WebSocketStream { inner: ws } } - fn with_context(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R + fn with_context(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R where S: Unpin, F: FnOnce(&mut WebSocket>) -> R, AllowStd: Read + Write, { trace!("{}:{} WebSocketStream.with_context", file!(), line!()); - self.inner.get_mut().context = match ctx { - None => (false, std::ptr::null_mut()), - Some(cx) => (true, cx as *mut _ as *mut ()), - }; - let mut g = compat::Guard(&mut self.inner); - f(&mut (g.0)) + if let Some((kind, ctx)) = ctx { + self.inner.get_mut().set_waker(kind, &ctx.waker()); + } + f(&mut self.inner) } /// Returns a shared reference to the inner stream. @@ -281,7 +279,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{}:{} Stream.poll_next", file!(), line!()); - match futures::ready!(self.with_context(Some(cx), |s| { + match futures::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| { trace!( "{}:{} Stream.with_context poll_next -> read_message()", file!(), @@ -304,7 +302,7 @@ where type Error = WsError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending())) } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { @@ -325,11 +323,11 @@ where } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match (*self).with_context(Some(cx), |s| s.close(None)) { + match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) { Ok(()) => Poll::Ready(Ok(())), Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(err) => { @@ -358,7 +356,9 @@ where let message = this.message.take().expect("Cannot poll twice"); Poll::Ready( this.stream - .with_context(Some(cx), |s| s.write_message(message)), + .with_context(Some((ContextWaker::Write, cx)), |s| { + s.write_message(message) + }), ) } } @@ -379,7 +379,10 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let message = this.message.take().expect("Cannot poll twice"); - Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message))) + Poll::Ready( + this.stream + .with_context(Some((ContextWaker::Write, cx)), |s| s.close(message)), + ) } }