Add proxy waiters to allow the Stream to trigger AsyncWrite operations

As a side effect also gets rid of unsafe code and raw pointers.

We have the problem that external read operations (i.e. the Stream impl)
can trigger both read (AsyncRead) and write (AsyncWrite) operations on
the underyling stream. At the same time write operations (i.e. the Sink
impl) can trigger write operations (AsyncWrite) too.
Both the Stream and the Sink can be used on two different tasks, but it
is required that AsyncRead and AsyncWrite are only ever used by a single
task (or better: with a single waker) at a time.

Doing otherwise would cause only the latest waker to be remembered, so
in our case either the Stream or the Sink impl would potentially wait
forever to be woken up because only the other one would've been woken
up.

To solve this we implement a waker proxy that has two slots (one for
read, one for write) to store wakers. One waker proxy is always passed
to the AsyncRead, the other to AsyncWrite so that they will only ever
have to store a single waker, but internally we dispatch any wakeups to
up to two actual wakers (one from the Sink impl and one from the Stream
impl).
pull/3/head
Sebastian Dröge 5 years ago
parent dc9c1b3d5f
commit 9f3a1ee30b
  1. 136
      src/compat.rs
  2. 32
      src/handshake.rs
  3. 31
      src/lib.rs

@ -4,54 +4,130 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use futures::io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket};
use futures::task;
use std::sync::Arc;
use tungstenite::Error as WsError;
pub(crate) trait HasContext {
fn set_context(&mut self, context: (bool, *mut ()));
pub(crate) enum ContextWaker {
Read,
Write,
}
#[derive(Debug)]
pub struct AllowStd<S> {
pub(crate) inner: S,
pub(crate) context: (bool, *mut ()),
inner: S,
// We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
// the underyling stream. At the same time write operations (i.e. the Sink
// impl) can trigger write operations (AsyncWrite) too.
// Both the Stream and the Sink can be used on two different tasks, but it
// is required that AsyncRead and AsyncWrite are only ever used by a single
// task (or better: with a single waker) at a time.
//
// Doing otherwise would cause only the latest waker to be remembered, so
// in our case either the Stream or the Sink impl would potentially wait
// forever to be woken up because only the other one would've been woken
// up.
//
// To solve this we implement a waker proxy that has two slots (one for
// read, one for write) to store wakers. One waker proxy is always passed
// to the AsyncRead, the other to AsyncWrite so that they will only ever
// have to store a single waker, but internally we dispatch any wakeups to
// up to two actual wakers (one from the Sink impl and one from the Stream
// impl).
//
// write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
// AsyncRead. The read_waker slots of both are used for the Stream impl
// (and handshaking), the write_waker slots for the Sink impl.
write_waker_proxy: Arc<WakerProxy>,
read_waker_proxy: Arc<WakerProxy>,
}
// Internal trait used only in the Handshake module for registering
// the waker for the context used during handshaking. We're using the
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
}
impl<S> HasContext for AllowStd<S> {
fn set_context(&mut self, context: (bool, *mut ())) {
self.context = context;
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}
pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>);
impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
read_waker_proxy: Default::default(),
};
// Register the handshake waker as read waker for both proxies,
// see also the SetWaker trait.
res.write_waker_proxy.read_waker.register(waker);
res.read_waker_proxy.read_waker.register(waker);
impl<S> Drop for Guard<'_, S> {
fn drop(&mut self) {
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = (true, std::ptr::null_mut());
res
}
// Set the read or write waker for our proxies.
//
// Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
// impl on the WebSocketStream.
// Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
self.read_waker_proxy.read_waker.register(waker);
}
ContextWaker::Write => {
self.write_waker_proxy.write_waker.register(waker);
self.read_waker_proxy.write_waker.register(waker);
}
}
}
}
// *mut () context is neither Send nor Sync
unsafe impl<S: Send> Send for AllowStd<S> {}
unsafe impl<S: Sync> Sync for AllowStd<S> {}
// Proxy Waker that we pass to the internal AsyncRead/Write of the
// stream underlying the websocket. We have two slots here for the
// actual wakers to allow external read operations to trigger both
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
}
impl task::ArcWake for WakerProxy {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.read_waker.wake();
arc_self.write_waker.wake();
}
}
impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> Poll<std::io::Result<R>>
fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe {
if !self.context.0 {
//was called by start_send without context
return Poll::Pending
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
let waker = match kind {
ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
};
let mut context = task::Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}
pub(crate) fn get_mut(&mut self) -> &mut S {
@ -69,7 +145,7 @@ where
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Read, |ctx, stream| {
trace!(
"{}:{} Read.with_context read -> poll_read",
file!(),
@ -89,7 +165,7 @@ where
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!(
"{}:{} Write.with_context write -> poll_write",
file!(),
@ -104,7 +180,7 @@ where
fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!(
"{}:{} Write.with_context flush -> poll_flush",
file!(),
@ -124,7 +200,7 @@ pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
Poll::Pending
},
}
Err(e) => Poll::Ready(Err(e)),
}
}

@ -1,4 +1,4 @@
use crate::compat::{AllowStd, HasContext};
use crate::compat::{AllowStd, SetWaker};
use crate::WebSocketStream;
use futures::io::{AsyncRead, AsyncWrite};
use log::*;
@ -45,10 +45,7 @@ where
.take()
.expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
let stream = AllowStd::new(inner.stream, ctx.waker());
Poll::Ready((inner.f)(stream))
}
@ -71,7 +68,7 @@ struct StartedHandshakeFutureInner<F, S> {
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
@ -125,7 +122,7 @@ where
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
@ -135,18 +132,11 @@ where
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
let stream = AllowStd::new(inner.stream, ctx.waker());
match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
@ -155,7 +145,7 @@ where
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
{
type Output = Result<Role::FinalResult, Error<Role>>;
@ -165,16 +155,12 @@ where
let machine = s.get_mut();
trace!("Setting context in handshake");
machine
.get_mut()
.set_context((true, cx as *mut _ as *mut ()));
machine.get_mut().set_waker(cx.waker());
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Err(Error::Interrupted(mid)) => {
*this.0 = Some(mid);
Poll::Pending
}

@ -27,7 +27,7 @@ pub mod stream;
use std::io::{Read, Write};
use compat::{cvt, AllowStd};
use compat::{cvt, AllowStd, ContextWaker};
use futures::io::{AsyncRead, AsyncWrite};
use futures::{Sink, Stream};
use log::*;
@ -216,19 +216,17 @@ impl<S> WebSocketStream<S> {
WebSocketStream { inner: ws }
}
fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R
fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
where
S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
self.inner.get_mut().context = match ctx {
None => (false, std::ptr::null_mut()),
Some(cx) => (true, cx as *mut _ as *mut ()),
};
let mut g = compat::Guard(&mut self.inner);
f(&mut (g.0))
if let Some((kind, ctx)) = ctx {
self.inner.get_mut().set_waker(kind, &ctx.waker());
}
f(&mut self.inner)
}
/// Returns a shared reference to the inner stream.
@ -281,7 +279,7 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("{}:{} Stream.poll_next", file!(), line!());
match futures::ready!(self.with_context(Some(cx), |s| {
match futures::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
trace!(
"{}:{} Stream.with_context poll_next -> read_message()",
file!(),
@ -304,7 +302,7 @@ where
type Error = WsError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending()))
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
@ -325,11 +323,11 @@ where
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending()))
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match (*self).with_context(Some(cx), |s| s.close(None)) {
match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
@ -358,7 +356,9 @@ where
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(
this.stream
.with_context(Some(cx), |s| s.write_message(message)),
.with_context(Some((ContextWaker::Write, cx)), |s| {
s.write_message(message)
}),
)
}
}
@ -379,7 +379,10 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message)))
Poll::Ready(
this.stream
.with_context(Some((ContextWaker::Write, cx)), |s| s.close(message)),
)
}
}

Loading…
Cancel
Save