You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
async-tungstenite/src/handshake.rs

182 lines
5.6 KiB

use crate::compat::{AllowStd, HasContext};
use crate::WebSocketStream;
use log::*;
use pin_project::pin_project;
use std::future::Future;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::handshake::client::Response;
use tungstenite::handshake::server::Callback;
use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake};
use tungstenite::{ClientHandshake, ServerHandshake, WebSocket};
pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
let ws = start.await;
WebSocketStream::new(ws)
}
struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
struct SkippedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
impl<F, S> Future for SkippedHandshakeFuture<F, S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = WebSocket<AllowStd<S>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self
.get_mut()
.0
.take()
.expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
Poll::Ready((inner.f)(stream))
}
}
#[pin_project]
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
enum StartedHandshake<Role: HandshakeRole> {
Done(Role::FinalResult),
Mid(WsHandshake<Role>),
}
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
struct StartedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
match start.await? {
StartedHandshake::Done(r) => Ok(r),
StartedHandshake::Mid(s) => {
let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
res
}
}
}
pub(crate) async fn client_handshake<F, S>(
stream: S,
f: F,
) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
where
F: FnOnce(
AllowStd<S>,
) -> Result<
<ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
Error<ClientHandshake<AllowStd<S>>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let result = handshake(stream, f).await?;
let (s, r) = result;
Ok((WebSocketStream::new(s), r))
}
pub(crate) async fn server_handshake<C, F, S>(
stream: S,
f: F,
) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
where
C: Callback + Unpin,
F: FnOnce(
AllowStd<S>,
) -> Result<
<ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
Error<ServerHandshake<AllowStd<S>, C>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
Ok(WebSocketStream::new(s))
}
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: HasContext,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = Result<StartedHandshake<Role>, Error<Role>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
}
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
{
type Output = Result<Role::FinalResult, Error<Role>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut s = this.0.take().expect("future polled after completion");
let machine = s.get_mut();
trace!("Setting context in handshake");
machine.get_mut().set_context((true, cx as *mut _ as *mut ()));
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
*this.0 = Some(mid);
Poll::Pending
}
}
}
}