@ -4,54 +4,130 @@ use std::pin::Pin;
use std ::task ::{ Context , Poll } ;
use std ::task ::{ Context , Poll } ;
use futures ::io ::{ AsyncRead , AsyncWrite } ;
use futures ::io ::{ AsyncRead , AsyncWrite } ;
use tungstenite ::{ Error as WsError , WebSocket } ;
use futures ::task ;
use std ::sync ::Arc ;
use tungstenite ::Error as WsError ;
pub ( crate ) trait HasContext {
pub ( crate ) enum ContextWaker {
fn set_context ( & mut self , context : ( bool , * mut ( ) ) ) ;
Read ,
Write ,
}
}
#[ derive(Debug) ]
#[ derive(Debug) ]
pub struct AllowStd < S > {
pub struct AllowStd < S > {
pub ( crate ) inner : S ,
inner : S ,
pub ( crate ) context : ( bool , * mut ( ) ) ,
// We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
// the underyling stream. At the same time write operations (i.e. the Sink
// impl) can trigger write operations (AsyncWrite) too.
// Both the Stream and the Sink can be used on two different tasks, but it
// is required that AsyncRead and AsyncWrite are only ever used by a single
// task (or better: with a single waker) at a time.
//
// Doing otherwise would cause only the latest waker to be remembered, so
// in our case either the Stream or the Sink impl would potentially wait
// forever to be woken up because only the other one would've been woken
// up.
//
// To solve this we implement a waker proxy that has two slots (one for
// read, one for write) to store wakers. One waker proxy is always passed
// to the AsyncRead, the other to AsyncWrite so that they will only ever
// have to store a single waker, but internally we dispatch any wakeups to
// up to two actual wakers (one from the Sink impl and one from the Stream
// impl).
//
// write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
// AsyncRead. The read_waker slots of both are used for the Stream impl
// (and handshaking), the write_waker slots for the Sink impl.
write_waker_proxy : Arc < WakerProxy > ,
read_waker_proxy : Arc < WakerProxy > ,
}
// Internal trait used only in the Handshake module for registering
// the waker for the context used during handshaking. We're using the
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
pub ( crate ) trait SetWaker {
fn set_waker ( & self , waker : & task ::Waker ) ;
}
}
impl < S > HasContext for AllowStd < S > {
impl < S > SetWaker for AllowStd < S > {
fn set_context ( & mut self , context : ( bool , * mut ( ) ) ) {
fn set_waker ( & self , waker : & task ::Waker ) {
self . context = context ;
self . set_waker ( ContextWaker ::Read , waker ) ;
}
}
}
}
pub ( crate ) struct Guard < ' a , S > ( pub ( crate ) & ' a mut WebSocket < AllowStd < S > > ) ;
impl < S > AllowStd < S > {
pub ( crate ) fn new ( inner : S , waker : & task ::Waker ) -> Self {
let res = Self {
inner ,
write_waker_proxy : Default ::default ( ) ,
read_waker_proxy : Default ::default ( ) ,
} ;
// Register the handshake waker as read waker for both proxies,
// see also the SetWaker trait.
res . write_waker_proxy . read_waker . register ( waker ) ;
res . read_waker_proxy . read_waker . register ( waker ) ;
impl < S > Drop for Guard < ' _ , S > {
res
fn drop ( & mut self ) {
}
trace ! ( "{}:{} Guard.drop" , file! ( ) , line! ( ) ) ;
( self . 0 ) . get_mut ( ) . context = ( true , std ::ptr ::null_mut ( ) ) ;
// Set the read or write waker for our proxies.
//
// Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
// impl on the WebSocketStream.
// Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub ( crate ) fn set_waker ( & self , kind : ContextWaker , waker : & task ::Waker ) {
match kind {
ContextWaker ::Read = > {
self . write_waker_proxy . read_waker . register ( waker ) ;
self . read_waker_proxy . read_waker . register ( waker ) ;
}
ContextWaker ::Write = > {
self . write_waker_proxy . write_waker . register ( waker ) ;
self . read_waker_proxy . write_waker . register ( waker ) ;
}
}
}
}
}
}
// Proxy Waker that we pass to the internal AsyncRead/Write of the
// stream underlying the websocket. We have two slots here for the
// actual wakers to allow external read operations to trigger both
// reads and writes, and the same for writes.
#[ derive(Debug, Default) ]
struct WakerProxy {
read_waker : task ::AtomicWaker ,
write_waker : task ::AtomicWaker ,
}
// *mut () context is neither Send nor Sync
impl task ::ArcWake for WakerProxy {
unsafe impl < S : Send > Send for AllowStd < S > { }
fn wake_by_ref ( arc_self : & Arc < Self > ) {
unsafe impl < S : Sync > Sync for AllowStd < S > { }
arc_self . read_waker . wake ( ) ;
arc_self . write_waker . wake ( ) ;
}
}
impl < S > AllowStd < S >
impl < S > AllowStd < S >
where
where
S : Unpin ,
S : Unpin ,
{
{
fn with_context < F , R > ( & mut self , f : F ) -> Poll < std ::io ::Result < R > >
fn with_context < F , R > ( & mut self , kind : ContextWaker , f : F ) -> Poll < std ::io ::Result < R > >
where
where
F : FnOnce ( & mut Context < ' _ > , Pin < & mut S > ) -> Poll < std ::io ::Result < R > > ,
F : FnOnce ( & mut Context < ' _ > , Pin < & mut S > ) -> Poll < std ::io ::Result < R > > ,
{
{
trace ! ( "{}:{} AllowStd.with_context" , file! ( ) , line! ( ) ) ;
trace ! ( "{}:{} AllowStd.with_context" , file! ( ) , line! ( ) ) ;
unsafe {
let waker = match kind {
if ! self . context . 0 {
ContextWaker ::Read = > task ::waker_ref ( & self . read_waker_proxy ) ,
//was called by start_send without context
ContextWaker ::Write = > task ::waker_ref ( & self . write_waker_proxy ) ,
return Poll ::Pending
} ;
}
let mut context = task ::Context ::from_waker ( & waker ) ;
assert! ( ! self . context . 1. is_null ( ) ) ;
f ( & mut context , Pin ::new ( & mut self . inner ) )
let waker = & mut * ( self . context . 1 as * mut _ ) ;
f ( waker , Pin ::new ( & mut self . inner ) )
}
}
}
pub ( crate ) fn get_mut ( & mut self ) -> & mut S {
pub ( crate ) fn get_mut ( & mut self ) -> & mut S {
@ -69,7 +145,7 @@ where
{
{
fn read ( & mut self , buf : & mut [ u8 ] ) -> std ::io ::Result < usize > {
fn read ( & mut self , buf : & mut [ u8 ] ) -> std ::io ::Result < usize > {
trace ! ( "{}:{} Read.read" , file! ( ) , line! ( ) ) ;
trace ! ( "{}:{} Read.read" , file! ( ) , line! ( ) ) ;
match self . with_context ( | ctx , stream | {
match self . with_context ( ContextWaker ::Read , | ctx , stream | {
trace ! (
trace ! (
"{}:{} Read.with_context read -> poll_read" ,
"{}:{} Read.with_context read -> poll_read" ,
file! ( ) ,
file! ( ) ,
@ -89,7 +165,7 @@ where
{
{
fn write ( & mut self , buf : & [ u8 ] ) -> std ::io ::Result < usize > {
fn write ( & mut self , buf : & [ u8 ] ) -> std ::io ::Result < usize > {
trace ! ( "{}:{} Write.write" , file! ( ) , line! ( ) ) ;
trace ! ( "{}:{} Write.write" , file! ( ) , line! ( ) ) ;
match self . with_context ( | ctx , stream | {
match self . with_context ( ContextWaker ::Write , | ctx , stream | {
trace ! (
trace ! (
"{}:{} Write.with_context write -> poll_write" ,
"{}:{} Write.with_context write -> poll_write" ,
file! ( ) ,
file! ( ) ,
@ -104,7 +180,7 @@ where
fn flush ( & mut self ) -> std ::io ::Result < ( ) > {
fn flush ( & mut self ) -> std ::io ::Result < ( ) > {
trace ! ( "{}:{} Write.flush" , file! ( ) , line! ( ) ) ;
trace ! ( "{}:{} Write.flush" , file! ( ) , line! ( ) ) ;
match self . with_context ( | ctx , stream | {
match self . with_context ( ContextWaker ::Write , | ctx , stream | {
trace ! (
trace ! (
"{}:{} Write.with_context flush -> poll_flush" ,
"{}:{} Write.with_context flush -> poll_flush" ,
file! ( ) ,
file! ( ) ,
@ -124,7 +200,7 @@ pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
Err ( WsError ::Io ( ref e ) ) if e . kind ( ) = = std ::io ::ErrorKind ::WouldBlock = > {
Err ( WsError ::Io ( ref e ) ) if e . kind ( ) = = std ::io ::ErrorKind ::WouldBlock = > {
trace ! ( "WouldBlock" ) ;
trace ! ( "WouldBlock" ) ;
Poll ::Pending
Poll ::Pending
} ,
}
Err ( e ) = > Poll ::Ready ( Err ( e ) ) ,
Err ( e ) = > Poll ::Ready ( Err ( e ) ) ,
}
}
}
}