Add proxy waiters to allow the Stream to trigger AsyncWrite operations

As a side effect also gets rid of unsafe code and raw pointers.

We have the problem that external read operations (i.e. the Stream impl)
can trigger both read (AsyncRead) and write (AsyncWrite) operations on
the underyling stream. At the same time write operations (i.e. the Sink
impl) can trigger write operations (AsyncWrite) too.
Both the Stream and the Sink can be used on two different tasks, but it
is required that AsyncRead and AsyncWrite are only ever used by a single
task (or better: with a single waker) at a time.

Doing otherwise would cause only the latest waker to be remembered, so
in our case either the Stream or the Sink impl would potentially wait
forever to be woken up because only the other one would've been woken
up.

To solve this we implement a waker proxy that has two slots (one for
read, one for write) to store wakers. One waker proxy is always passed
to the AsyncRead, the other to AsyncWrite so that they will only ever
have to store a single waker, but internally we dispatch any wakeups to
up to two actual wakers (one from the Sink impl and one from the Stream
impl).
pull/3/head
Sebastian Dröge 5 years ago
parent dc9c1b3d5f
commit 9f3a1ee30b
  1. 136
      src/compat.rs
  2. 32
      src/handshake.rs
  3. 31
      src/lib.rs

@ -4,54 +4,130 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket}; use futures::task;
use std::sync::Arc;
use tungstenite::Error as WsError;
pub(crate) trait HasContext { pub(crate) enum ContextWaker {
fn set_context(&mut self, context: (bool, *mut ())); Read,
Write,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct AllowStd<S> { pub struct AllowStd<S> {
pub(crate) inner: S, inner: S,
pub(crate) context: (bool, *mut ()), // We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
// the underyling stream. At the same time write operations (i.e. the Sink
// impl) can trigger write operations (AsyncWrite) too.
// Both the Stream and the Sink can be used on two different tasks, but it
// is required that AsyncRead and AsyncWrite are only ever used by a single
// task (or better: with a single waker) at a time.
//
// Doing otherwise would cause only the latest waker to be remembered, so
// in our case either the Stream or the Sink impl would potentially wait
// forever to be woken up because only the other one would've been woken
// up.
//
// To solve this we implement a waker proxy that has two slots (one for
// read, one for write) to store wakers. One waker proxy is always passed
// to the AsyncRead, the other to AsyncWrite so that they will only ever
// have to store a single waker, but internally we dispatch any wakeups to
// up to two actual wakers (one from the Sink impl and one from the Stream
// impl).
//
// write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
// AsyncRead. The read_waker slots of both are used for the Stream impl
// (and handshaking), the write_waker slots for the Sink impl.
write_waker_proxy: Arc<WakerProxy>,
read_waker_proxy: Arc<WakerProxy>,
}
// Internal trait used only in the Handshake module for registering
// the waker for the context used during handshaking. We're using the
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
} }
impl<S> HasContext for AllowStd<S> { impl<S> SetWaker for AllowStd<S> {
fn set_context(&mut self, context: (bool, *mut ())) { fn set_waker(&self, waker: &task::Waker) {
self.context = context; self.set_waker(ContextWaker::Read, waker);
} }
} }
pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>); impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
read_waker_proxy: Default::default(),
};
// Register the handshake waker as read waker for both proxies,
// see also the SetWaker trait.
res.write_waker_proxy.read_waker.register(waker);
res.read_waker_proxy.read_waker.register(waker);
impl<S> Drop for Guard<'_, S> { res
fn drop(&mut self) { }
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = (true, std::ptr::null_mut()); // Set the read or write waker for our proxies.
//
// Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
// impl on the WebSocketStream.
// Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
self.read_waker_proxy.read_waker.register(waker);
}
ContextWaker::Write => {
self.write_waker_proxy.write_waker.register(waker);
self.read_waker_proxy.write_waker.register(waker);
}
}
} }
} }
// *mut () context is neither Send nor Sync // Proxy Waker that we pass to the internal AsyncRead/Write of the
unsafe impl<S: Send> Send for AllowStd<S> {} // stream underlying the websocket. We have two slots here for the
unsafe impl<S: Sync> Sync for AllowStd<S> {} // actual wakers to allow external read operations to trigger both
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
}
impl task::ArcWake for WakerProxy {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.read_waker.wake();
arc_self.write_waker.wake();
}
}
impl<S> AllowStd<S> impl<S> AllowStd<S>
where where
S: Unpin, S: Unpin,
{ {
fn with_context<F, R>(&mut self, f: F) -> Poll<std::io::Result<R>> fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
where where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>, F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{ {
trace!("{}:{} AllowStd.with_context", file!(), line!()); trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe { let waker = match kind {
if !self.context.0 { ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
//was called by start_send without context ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
return Poll::Pending };
} let mut context = task::Context::from_waker(&waker);
assert!(!self.context.1.is_null()); f(&mut context, Pin::new(&mut self.inner))
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
} }
pub(crate) fn get_mut(&mut self) -> &mut S { pub(crate) fn get_mut(&mut self) -> &mut S {
@ -69,7 +145,7 @@ where
{ {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!()); trace!("{}:{} Read.read", file!(), line!());
match self.with_context(|ctx, stream| { match self.with_context(ContextWaker::Read, |ctx, stream| {
trace!( trace!(
"{}:{} Read.with_context read -> poll_read", "{}:{} Read.with_context read -> poll_read",
file!(), file!(),
@ -89,7 +165,7 @@ where
{ {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!()); trace!("{}:{} Write.write", file!(), line!());
match self.with_context(|ctx, stream| { match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!( trace!(
"{}:{} Write.with_context write -> poll_write", "{}:{} Write.with_context write -> poll_write",
file!(), file!(),
@ -104,7 +180,7 @@ where
fn flush(&mut self) -> std::io::Result<()> { fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!()); trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(|ctx, stream| { match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!( trace!(
"{}:{} Write.with_context flush -> poll_flush", "{}:{} Write.with_context flush -> poll_flush",
file!(), file!(),
@ -124,7 +200,7 @@ pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => { Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock"); trace!("WouldBlock");
Poll::Pending Poll::Pending
}, }
Err(e) => Poll::Ready(Err(e)), Err(e) => Poll::Ready(Err(e)),
} }
} }

@ -1,4 +1,4 @@
use crate::compat::{AllowStd, HasContext}; use crate::compat::{AllowStd, SetWaker};
use crate::WebSocketStream; use crate::WebSocketStream;
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use log::*; use log::*;
@ -45,10 +45,7 @@ where
.take() .take()
.expect("future polled after completion"); .expect("future polled after completion");
trace!("Setting context when skipping handshake"); trace!("Setting context when skipping handshake");
let stream = AllowStd { let stream = AllowStd::new(inner.stream, ctx.waker());
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
Poll::Ready((inner.f)(stream)) Poll::Ready((inner.f)(stream))
} }
@ -71,7 +68,7 @@ struct StartedHandshakeFutureInner<F, S> {
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>> async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where where
Role: HandshakeRole + Unpin, Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext, Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin, F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
@ -125,7 +122,7 @@ where
impl<Role, F, S> Future for StartedHandshakeFuture<F, S> impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where where
Role: HandshakeRole, Role: HandshakeRole,
Role::InternalStream: HasContext, Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin, F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin, S: Unpin,
AllowStd<S>: Read + Write, AllowStd<S>: Read + Write,
@ -135,18 +132,11 @@ where
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion"); let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake"); trace!("Setting ctx when starting handshake");
let stream = AllowStd { let stream = AllowStd::new(inner.stream, ctx.waker());
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
match (inner.f)(stream) { match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => { Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
} }
} }
@ -155,7 +145,7 @@ where
impl<Role> Future for MidHandshake<Role> impl<Role> Future for MidHandshake<Role>
where where
Role: HandshakeRole + Unpin, Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext, Role::InternalStream: SetWaker,
{ {
type Output = Result<Role::FinalResult, Error<Role>>; type Output = Result<Role::FinalResult, Error<Role>>;
@ -165,16 +155,12 @@ where
let machine = s.get_mut(); let machine = s.get_mut();
trace!("Setting context in handshake"); trace!("Setting context in handshake");
machine machine.get_mut().set_waker(cx.waker());
.get_mut()
.set_context((true, cx as *mut _ as *mut ()));
match s.handshake() { match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)), Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => { Err(Error::Interrupted(mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
*this.0 = Some(mid); *this.0 = Some(mid);
Poll::Pending Poll::Pending
} }

@ -27,7 +27,7 @@ pub mod stream;
use std::io::{Read, Write}; use std::io::{Read, Write};
use compat::{cvt, AllowStd}; use compat::{cvt, AllowStd, ContextWaker};
use futures::io::{AsyncRead, AsyncWrite}; use futures::io::{AsyncRead, AsyncWrite};
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use log::*; use log::*;
@ -216,19 +216,17 @@ impl<S> WebSocketStream<S> {
WebSocketStream { inner: ws } WebSocketStream { inner: ws }
} }
fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
where where
S: Unpin, S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R, F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write, AllowStd<S>: Read + Write,
{ {
trace!("{}:{} WebSocketStream.with_context", file!(), line!()); trace!("{}:{} WebSocketStream.with_context", file!(), line!());
self.inner.get_mut().context = match ctx { if let Some((kind, ctx)) = ctx {
None => (false, std::ptr::null_mut()), self.inner.get_mut().set_waker(kind, &ctx.waker());
Some(cx) => (true, cx as *mut _ as *mut ()), }
}; f(&mut self.inner)
let mut g = compat::Guard(&mut self.inner);
f(&mut (g.0))
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
@ -281,7 +279,7 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("{}:{} Stream.poll_next", file!(), line!()); trace!("{}:{} Stream.poll_next", file!(), line!());
match futures::ready!(self.with_context(Some(cx), |s| { match futures::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
trace!( trace!(
"{}:{} Stream.with_context poll_next -> read_message()", "{}:{} Stream.with_context poll_next -> read_message()",
file!(), file!(),
@ -304,7 +302,7 @@ where
type Error = WsError; type Error = WsError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending())) (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
} }
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
@ -325,11 +323,11 @@ where
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending())) (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match (*self).with_context(Some(cx), |s| s.close(None)) { match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())), Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => { Err(err) => {
@ -358,7 +356,9 @@ where
let message = this.message.take().expect("Cannot poll twice"); let message = this.message.take().expect("Cannot poll twice");
Poll::Ready( Poll::Ready(
this.stream this.stream
.with_context(Some(cx), |s| s.write_message(message)), .with_context(Some((ContextWaker::Write, cx)), |s| {
s.write_message(message)
}),
) )
} }
} }
@ -379,7 +379,10 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
let message = this.message.take().expect("Cannot poll twice"); let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message))) Poll::Ready(
this.stream
.with_context(Some((ContextWaker::Write, cx)), |s| s.close(message)),
)
} }
} }

Loading…
Cancel
Save