use log::*; use std::io::{Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::{Error as WsError, WebSocket}; pub(crate) trait HasContext { fn set_context(&mut self, context: (bool, *mut ())); } #[derive(Debug)] pub struct AllowStd { pub(crate) inner: S, pub(crate) context: (bool, *mut ()), } impl HasContext for AllowStd { fn set_context(&mut self, context: (bool, *mut ())) { self.context = context; } } pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket>); impl Drop for Guard<'_, S> { fn drop(&mut self) { trace!("{}:{} Guard.drop", file!(), line!()); (self.0).get_mut().context = (true, std::ptr::null_mut()); } } // *mut () context is neither Send nor Sync unsafe impl Send for AllowStd {} unsafe impl Sync for AllowStd {} impl AllowStd where S: Unpin, { fn with_context(&mut self, f: F) -> Poll> where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { trace!("{}:{} AllowStd.with_context", file!(), line!()); unsafe { if !self.context.0 { //was called by start_send without context return Poll::Pending } assert!(!self.context.1.is_null()); let waker = &mut *(self.context.1 as *mut _); f(waker, Pin::new(&mut self.inner)) } } pub(crate) fn get_mut(&mut self) -> &mut S { &mut self.inner } pub(crate) fn get_ref(&self) -> &S { &self.inner } } impl Read for AllowStd where S: AsyncRead + Unpin, { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { trace!("{}:{} Read.read", file!(), line!()); match self.with_context(|ctx, stream| { trace!( "{}:{} Read.with_context read -> poll_read", file!(), line!() ); stream.poll_read(ctx, buf) }) { Poll::Ready(r) => r, Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), } } } impl Write for AllowStd where S: AsyncWrite + Unpin, { fn write(&mut self, buf: &[u8]) -> std::io::Result { trace!("{}:{} Write.write", file!(), line!()); match self.with_context(|ctx, stream| { trace!( "{}:{} Write.with_context write -> poll_write", file!(), line!() ); stream.poll_write(ctx, buf) }) { Poll::Ready(r) => r, Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), } } fn flush(&mut self) -> std::io::Result<()> { trace!("{}:{} Write.flush", file!(), line!()); match self.with_context(|ctx, stream| { trace!( "{}:{} Write.with_context flush -> poll_flush", file!(), line!() ); stream.poll_flush(ctx) }) { Poll::Ready(r) => r, Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), } } } pub(crate) fn cvt(r: Result) -> Poll> { match r { Ok(v) => Poll::Ready(Ok(v)), Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } }