From 9f3a1ee30ba047ecd222b60d6c0a1ee17a706b4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Mon, 25 Nov 2019 21:44:15 +0100 Subject: [PATCH] Add proxy waiters to allow the Stream to trigger AsyncWrite operations As a side effect also gets rid of unsafe code and raw pointers. We have the problem that external read operations (i.e. the Stream impl) can trigger both read (AsyncRead) and write (AsyncWrite) operations on the underyling stream. At the same time write operations (i.e. the Sink impl) can trigger write operations (AsyncWrite) too. Both the Stream and the Sink can be used on two different tasks, but it is required that AsyncRead and AsyncWrite are only ever used by a single task (or better: with a single waker) at a time. Doing otherwise would cause only the latest waker to be remembered, so in our case either the Stream or the Sink impl would potentially wait forever to be woken up because only the other one would've been woken up. To solve this we implement a waker proxy that has two slots (one for read, one for write) to store wakers. One waker proxy is always passed to the AsyncRead, the other to AsyncWrite so that they will only ever have to store a single waker, but internally we dispatch any wakeups to up to two actual wakers (one from the Sink impl and one from the Stream impl). --- src/compat.rs | 136 ++++++++++++++++++++++++++++++++++++----------- src/handshake.rs | 32 ++++------- src/lib.rs | 31 ++++++----- 3 files changed, 132 insertions(+), 67 deletions(-) diff --git a/src/compat.rs b/src/compat.rs index 1e8505e..644f856 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -4,54 +4,130 @@ use std::pin::Pin; use std::task::{Context, Poll}; use futures::io::{AsyncRead, AsyncWrite}; -use tungstenite::{Error as WsError, WebSocket}; +use futures::task; +use std::sync::Arc; +use tungstenite::Error as WsError; -pub(crate) trait HasContext { - fn set_context(&mut self, context: (bool, *mut ())); +pub(crate) enum ContextWaker { + Read, + Write, } + #[derive(Debug)] pub struct AllowStd { - pub(crate) inner: S, - pub(crate) context: (bool, *mut ()), + inner: S, + // We have the problem that external read operations (i.e. the Stream impl) + // can trigger both read (AsyncRead) and write (AsyncWrite) operations on + // the underyling stream. At the same time write operations (i.e. the Sink + // impl) can trigger write operations (AsyncWrite) too. + // Both the Stream and the Sink can be used on two different tasks, but it + // is required that AsyncRead and AsyncWrite are only ever used by a single + // task (or better: with a single waker) at a time. + // + // Doing otherwise would cause only the latest waker to be remembered, so + // in our case either the Stream or the Sink impl would potentially wait + // forever to be woken up because only the other one would've been woken + // up. + // + // To solve this we implement a waker proxy that has two slots (one for + // read, one for write) to store wakers. One waker proxy is always passed + // to the AsyncRead, the other to AsyncWrite so that they will only ever + // have to store a single waker, but internally we dispatch any wakeups to + // up to two actual wakers (one from the Sink impl and one from the Stream + // impl). + // + // write_waker_proxy is always used for AsyncWrite, read_waker_proxy for + // AsyncRead. The read_waker slots of both are used for the Stream impl + // (and handshaking), the write_waker slots for the Sink impl. + write_waker_proxy: Arc, + read_waker_proxy: Arc, +} + +// Internal trait used only in the Handshake module for registering +// the waker for the context used during handshaking. We're using the +// read waker slot for this, but any would do. +// +// Don't ever use this from multiple tasks at the same time! +pub(crate) trait SetWaker { + fn set_waker(&self, waker: &task::Waker); } -impl HasContext for AllowStd { - fn set_context(&mut self, context: (bool, *mut ())) { - self.context = context; +impl SetWaker for AllowStd { + fn set_waker(&self, waker: &task::Waker) { + self.set_waker(ContextWaker::Read, waker); } } -pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket>); +impl AllowStd { + pub(crate) fn new(inner: S, waker: &task::Waker) -> Self { + let res = Self { + inner, + write_waker_proxy: Default::default(), + read_waker_proxy: Default::default(), + }; + + // Register the handshake waker as read waker for both proxies, + // see also the SetWaker trait. + res.write_waker_proxy.read_waker.register(waker); + res.read_waker_proxy.read_waker.register(waker); -impl Drop for Guard<'_, S> { - fn drop(&mut self) { - trace!("{}:{} Guard.drop", file!(), line!()); - (self.0).get_mut().context = (true, std::ptr::null_mut()); + res + } + + // Set the read or write waker for our proxies. + // + // Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream + // impl on the WebSocketStream. + // Reading can also cause writes to happen, e.g. in case of Message::Ping handling. + // + // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the + // WebSocketStream. + pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) { + match kind { + ContextWaker::Read => { + self.write_waker_proxy.read_waker.register(waker); + self.read_waker_proxy.read_waker.register(waker); + } + ContextWaker::Write => { + self.write_waker_proxy.write_waker.register(waker); + self.read_waker_proxy.write_waker.register(waker); + } + } } } -// *mut () context is neither Send nor Sync -unsafe impl Send for AllowStd {} -unsafe impl Sync for AllowStd {} +// Proxy Waker that we pass to the internal AsyncRead/Write of the +// stream underlying the websocket. We have two slots here for the +// actual wakers to allow external read operations to trigger both +// reads and writes, and the same for writes. +#[derive(Debug, Default)] +struct WakerProxy { + read_waker: task::AtomicWaker, + write_waker: task::AtomicWaker, +} + +impl task::ArcWake for WakerProxy { + fn wake_by_ref(arc_self: &Arc) { + arc_self.read_waker.wake(); + arc_self.write_waker.wake(); + } +} impl AllowStd where S: Unpin, { - fn with_context(&mut self, f: F) -> Poll> + fn with_context(&mut self, kind: ContextWaker, f: F) -> Poll> where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { trace!("{}:{} AllowStd.with_context", file!(), line!()); - unsafe { - if !self.context.0 { - //was called by start_send without context - return Poll::Pending - } - assert!(!self.context.1.is_null()); - let waker = &mut *(self.context.1 as *mut _); - f(waker, Pin::new(&mut self.inner)) - } + let waker = match kind { + ContextWaker::Read => task::waker_ref(&self.read_waker_proxy), + ContextWaker::Write => task::waker_ref(&self.write_waker_proxy), + }; + let mut context = task::Context::from_waker(&waker); + f(&mut context, Pin::new(&mut self.inner)) } pub(crate) fn get_mut(&mut self) -> &mut S { @@ -69,7 +145,7 @@ where { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { trace!("{}:{} Read.read", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Read, |ctx, stream| { trace!( "{}:{} Read.with_context read -> poll_read", file!(), @@ -89,7 +165,7 @@ where { fn write(&mut self, buf: &[u8]) -> std::io::Result { trace!("{}:{} Write.write", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Write, |ctx, stream| { trace!( "{}:{} Write.with_context write -> poll_write", file!(), @@ -104,7 +180,7 @@ where fn flush(&mut self) -> std::io::Result<()> { trace!("{}:{} Write.flush", file!(), line!()); - match self.with_context(|ctx, stream| { + match self.with_context(ContextWaker::Write, |ctx, stream| { trace!( "{}:{} Write.with_context flush -> poll_flush", file!(), @@ -124,7 +200,7 @@ pub(crate) fn cvt(r: Result) -> Poll> { Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => { trace!("WouldBlock"); Poll::Pending - }, + } Err(e) => Poll::Ready(Err(e)), } } diff --git a/src/handshake.rs b/src/handshake.rs index 7d517ab..ecf6172 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -1,4 +1,4 @@ -use crate::compat::{AllowStd, HasContext}; +use crate::compat::{AllowStd, SetWaker}; use crate::WebSocketStream; use futures::io::{AsyncRead, AsyncWrite}; use log::*; @@ -45,10 +45,7 @@ where .take() .expect("future polled after completion"); trace!("Setting context when skipping handshake"); - let stream = AllowStd { - inner: inner.stream, - context: (true, ctx as *mut _ as *mut ()), - }; + let stream = AllowStd::new(inner.stream, ctx.waker()); Poll::Ready((inner.f)(stream)) } @@ -71,7 +68,7 @@ struct StartedHandshakeFutureInner { async fn handshake(stream: S, f: F) -> Result> where Role: HandshakeRole + Unpin, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, F: FnOnce(AllowStd) -> Result> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { @@ -125,7 +122,7 @@ where impl Future for StartedHandshakeFuture where Role: HandshakeRole, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, F: FnOnce(AllowStd) -> Result> + Unpin, S: Unpin, AllowStd: Read + Write, @@ -135,18 +132,11 @@ where fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { let inner = self.0.take().expect("future polled after completion"); trace!("Setting ctx when starting handshake"); - let stream = AllowStd { - inner: inner.stream, - context: (true, ctx as *mut _ as *mut ()), - }; + let stream = AllowStd::new(inner.stream, ctx.waker()); match (inner.f)(stream) { Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), - Err(Error::Interrupted(mut mid)) => { - let machine = mid.get_mut(); - machine.get_mut().set_context((true, std::ptr::null_mut())); - Poll::Ready(Ok(StartedHandshake::Mid(mid))) - } + Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), } } @@ -155,7 +145,7 @@ where impl Future for MidHandshake where Role: HandshakeRole + Unpin, - Role::InternalStream: HasContext, + Role::InternalStream: SetWaker, { type Output = Result>; @@ -165,16 +155,12 @@ where let machine = s.get_mut(); trace!("Setting context in handshake"); - machine - .get_mut() - .set_context((true, cx as *mut _ as *mut ())); + machine.get_mut().set_waker(cx.waker()); match s.handshake() { Ok(stream) => Poll::Ready(Ok(stream)), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), - Err(Error::Interrupted(mut mid)) => { - let machine = mid.get_mut(); - machine.get_mut().set_context((true, std::ptr::null_mut())); + Err(Error::Interrupted(mid)) => { *this.0 = Some(mid); Poll::Pending } diff --git a/src/lib.rs b/src/lib.rs index 2cbb5b0..6106612 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub mod stream; use std::io::{Read, Write}; -use compat::{cvt, AllowStd}; +use compat::{cvt, AllowStd, ContextWaker}; use futures::io::{AsyncRead, AsyncWrite}; use futures::{Sink, Stream}; use log::*; @@ -216,19 +216,17 @@ impl WebSocketStream { WebSocketStream { inner: ws } } - fn with_context(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R + fn with_context(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R where S: Unpin, F: FnOnce(&mut WebSocket>) -> R, AllowStd: Read + Write, { trace!("{}:{} WebSocketStream.with_context", file!(), line!()); - self.inner.get_mut().context = match ctx { - None => (false, std::ptr::null_mut()), - Some(cx) => (true, cx as *mut _ as *mut ()), - }; - let mut g = compat::Guard(&mut self.inner); - f(&mut (g.0)) + if let Some((kind, ctx)) = ctx { + self.inner.get_mut().set_waker(kind, &ctx.waker()); + } + f(&mut self.inner) } /// Returns a shared reference to the inner stream. @@ -281,7 +279,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{}:{} Stream.poll_next", file!(), line!()); - match futures::ready!(self.with_context(Some(cx), |s| { + match futures::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| { trace!( "{}:{} Stream.with_context poll_next -> read_message()", file!(), @@ -304,7 +302,7 @@ where type Error = WsError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending())) } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { @@ -325,11 +323,11 @@ where } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match (*self).with_context(Some(cx), |s| s.close(None)) { + match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) { Ok(()) => Poll::Ready(Ok(())), Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(err) => { @@ -358,7 +356,9 @@ where let message = this.message.take().expect("Cannot poll twice"); Poll::Ready( this.stream - .with_context(Some(cx), |s| s.write_message(message)), + .with_context(Some((ContextWaker::Write, cx)), |s| { + s.write_message(message) + }), ) } } @@ -379,7 +379,10 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let message = this.message.take().expect("Cannot poll twice"); - Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message))) + Poll::Ready( + this.stream + .with_context(Some((ContextWaker::Write, cx)), |s| s.close(message)), + ) } }