use crate::compat::{AllowStd, HasContext}; use crate::WebSocketStream; use log::*; use pin_project::pin_project; use std::future::Future; use std::io::{Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; use futures::io::{AsyncRead, AsyncWrite}; use tungstenite::handshake::client::Response; use tungstenite::handshake::server::Callback; use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake}; use tungstenite::{ClientHandshake, ServerHandshake, WebSocket}; pub(crate) async fn without_handshake(stream: S, f: F) -> WebSocketStream where F: FnOnce(AllowStd) -> WebSocket> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream })); let ws = start.await; WebSocketStream::new(ws) } struct SkippedHandshakeFuture(Option>); struct SkippedHandshakeFutureInner { f: F, stream: S, } impl Future for SkippedHandshakeFuture where F: FnOnce(AllowStd) -> WebSocket> + Unpin, S: Unpin, AllowStd: Read + Write, { type Output = WebSocket>; fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { let inner = self .get_mut() .0 .take() .expect("future polled after completion"); trace!("Setting context when skipping handshake"); let stream = AllowStd { inner: inner.stream, context: (true, ctx as *mut _ as *mut ()), }; Poll::Ready((inner.f)(stream)) } } #[pin_project] struct MidHandshake(Option>); enum StartedHandshake { Done(Role::FinalResult), Mid(WsHandshake), } struct StartedHandshakeFuture(Option>); struct StartedHandshakeFutureInner { f: F, stream: S, } async fn handshake(stream: S, f: F) -> Result> where Role: HandshakeRole + Unpin, Role::InternalStream: HasContext, F: FnOnce(AllowStd) -> Result> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); match start.await? { StartedHandshake::Done(r) => Ok(r), StartedHandshake::Mid(s) => { let res: Result> = MidHandshake::(Some(s)).await; res } } } pub(crate) async fn client_handshake( stream: S, f: F, ) -> Result<(WebSocketStream, Response), Error>>> where F: FnOnce( AllowStd, ) -> Result< > as HandshakeRole>::FinalResult, Error>>, > + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let result = handshake(stream, f).await?; let (s, r) = result; Ok((WebSocketStream::new(s), r)) } pub(crate) async fn server_handshake( stream: S, f: F, ) -> Result, Error, C>>> where C: Callback + Unpin, F: FnOnce( AllowStd, ) -> Result< , C> as HandshakeRole>::FinalResult, Error, C>>, > + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let s: WebSocket> = handshake(stream, f).await?; Ok(WebSocketStream::new(s)) } impl Future for StartedHandshakeFuture where Role: HandshakeRole, Role::InternalStream: HasContext, F: FnOnce(AllowStd) -> Result> + Unpin, S: Unpin, AllowStd: Read + Write, { type Output = Result, Error>; fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { let inner = self.0.take().expect("future polled after completion"); trace!("Setting ctx when starting handshake"); let stream = AllowStd { inner: inner.stream, context: (true, ctx as *mut _ as *mut ()), }; match (inner.f)(stream) { Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), Err(Error::Interrupted(mut mid)) => { let machine = mid.get_mut(); machine.get_mut().set_context((true, std::ptr::null_mut())); Poll::Ready(Ok(StartedHandshake::Mid(mid))) } Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), } } } impl Future for MidHandshake where Role: HandshakeRole + Unpin, Role::InternalStream: HasContext, { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut s = this.0.take().expect("future polled after completion"); let machine = s.get_mut(); trace!("Setting context in handshake"); machine.get_mut().set_context((true, cx as *mut _ as *mut ())); match s.handshake() { Ok(stream) => Poll::Ready(Ok(stream)), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), Err(Error::Interrupted(mut mid)) => { let machine = mid.get_mut(); machine.get_mut().set_context((true, std::ptr::null_mut())); *this.0 = Some(mid); Poll::Pending } } } }