parent
59ca2c885e
commit
3821e0952a
@ -1,42 +1,42 @@ |
||||
use futures::{Future, Stream}; |
||||
use futures::StreamExt; |
||||
use log::*; |
||||
use tokio::net::TcpListener; |
||||
use tokio_tungstenite::{accept_async, tungstenite::Error as WsError}; |
||||
use std::net::{SocketAddr, ToSocketAddrs}; |
||||
use tokio::net::{TcpListener, TcpStream}; |
||||
use tokio_tungstenite::accept_async; |
||||
|
||||
fn main() { |
||||
env_logger::init(); |
||||
async fn accept_connection(peer: SocketAddr, stream: TcpStream) { |
||||
let mut ws_stream = accept_async(stream).await.expect("Failed to accept"); |
||||
|
||||
let mut runtime = tokio::runtime::Builder::new().build().unwrap(); |
||||
info!("New WebSocket connection: {}", peer); |
||||
|
||||
let addr = "127.0.0.1:9002".parse().unwrap(); |
||||
let socket = TcpListener::bind(&addr).unwrap(); |
||||
info!("Listening on: {}", addr); |
||||
while let Some(msg) = ws_stream.next().await { |
||||
let msg = msg.expect("Failed to get request"); |
||||
if msg.is_text() || msg.is_binary() { |
||||
ws_stream.send(msg).await.expect("Failed to send response"); |
||||
} |
||||
} |
||||
} |
||||
|
||||
let srv = socket |
||||
.incoming() |
||||
.map_err(Into::into) |
||||
.for_each(move |stream| { |
||||
let peer = stream |
||||
.peer_addr() |
||||
.expect("connected streams should have a peer address"); |
||||
info!("Peer address: {}", peer); |
||||
#[tokio::main] |
||||
async fn main() { |
||||
env_logger::init(); |
||||
|
||||
accept_async(stream).and_then(move |ws_stream| { |
||||
info!("New WebSocket connection: {}", peer); |
||||
let (sink, stream) = ws_stream.split(); |
||||
let job = stream |
||||
.filter(|msg| msg.is_text() || msg.is_binary()) |
||||
.forward(sink) |
||||
.and_then(|(_stream, _sink)| Ok(())) |
||||
.map_err(|err| match err { |
||||
WsError::ConnectionClosed => (), |
||||
err => info!("WS error: {}", err), |
||||
}); |
||||
let addr = "127.0.0.1:9002" |
||||
.to_socket_addrs() |
||||
.expect("Not a valid address") |
||||
.next() |
||||
.expect("Not a socket address"); |
||||
let socket = TcpListener::bind(&addr).await.unwrap(); |
||||
let mut incoming = socket.incoming(); |
||||
info!("Listening on: {}", addr); |
||||
|
||||
tokio::spawn(job); |
||||
Ok(()) |
||||
}) |
||||
}); |
||||
while let Some(stream) = incoming.next().await { |
||||
let stream = stream.expect("Failed to get stream"); |
||||
let peer = stream |
||||
.peer_addr() |
||||
.expect("connected streams should have a peer address"); |
||||
info!("Peer address: {}", peer); |
||||
|
||||
runtime.block_on(srv).unwrap(); |
||||
tokio::spawn(accept_connection(peer, stream)); |
||||
} |
||||
} |
||||
|
@ -0,0 +1,123 @@ |
||||
use log::*; |
||||
use std::io::{Read, Write}; |
||||
use std::pin::Pin; |
||||
use std::task::{Context, Poll}; |
||||
|
||||
use tokio_io::{AsyncRead, AsyncWrite}; |
||||
use tungstenite::{Error as WsError, WebSocket}; |
||||
|
||||
pub(crate) trait HasContext { |
||||
fn set_context(&mut self, context: *mut ()); |
||||
} |
||||
#[derive(Debug)] |
||||
pub struct AllowStd<S> { |
||||
pub(crate) inner: S, |
||||
pub(crate) context: *mut (), |
||||
} |
||||
|
||||
impl<S> HasContext for AllowStd<S> { |
||||
fn set_context(&mut self, context: *mut ()) { |
||||
self.context = context; |
||||
} |
||||
} |
||||
|
||||
pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>); |
||||
|
||||
impl<S> Drop for Guard<'_, S> { |
||||
fn drop(&mut self) { |
||||
trace!("{}:{} Guard.drop", file!(), line!()); |
||||
(self.0).get_mut().context = std::ptr::null_mut(); |
||||
} |
||||
} |
||||
|
||||
// *mut () context is neither Send nor Sync
|
||||
unsafe impl<S: Send> Send for AllowStd<S> {} |
||||
unsafe impl<S: Sync> Sync for AllowStd<S> {} |
||||
|
||||
impl<S> AllowStd<S> |
||||
where |
||||
S: Unpin, |
||||
{ |
||||
fn with_context<F, R>(&mut self, f: F) -> R |
||||
where |
||||
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, |
||||
{ |
||||
trace!("{}:{} AllowStd.with_context", file!(), line!()); |
||||
unsafe { |
||||
assert!(!self.context.is_null()); |
||||
let waker = &mut *(self.context as *mut _); |
||||
f(waker, Pin::new(&mut self.inner)) |
||||
} |
||||
} |
||||
|
||||
pub(crate) fn get_mut(&mut self) -> &mut S { |
||||
&mut self.inner |
||||
} |
||||
|
||||
pub(crate) fn get_ref(&self) -> &S { |
||||
&self.inner |
||||
} |
||||
} |
||||
|
||||
impl<S> Read for AllowStd<S> |
||||
where |
||||
S: AsyncRead + Unpin, |
||||
{ |
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { |
||||
trace!("{}:{} Read.read", file!(), line!()); |
||||
match self.with_context(|ctx, stream| { |
||||
trace!( |
||||
"{}:{} Read.with_context read -> poll_read", |
||||
file!(), |
||||
line!() |
||||
); |
||||
stream.poll_read(ctx, buf) |
||||
}) { |
||||
Poll::Ready(r) => r, |
||||
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<S> Write for AllowStd<S> |
||||
where |
||||
S: AsyncWrite + Unpin, |
||||
{ |
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { |
||||
trace!("{}:{} Write.write", file!(), line!()); |
||||
match self.with_context(|ctx, stream| { |
||||
trace!( |
||||
"{}:{} Write.with_context write -> poll_write", |
||||
file!(), |
||||
line!() |
||||
); |
||||
stream.poll_write(ctx, buf) |
||||
}) { |
||||
Poll::Ready(r) => r, |
||||
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), |
||||
} |
||||
} |
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> { |
||||
trace!("{}:{} Write.flush", file!(), line!()); |
||||
match self.with_context(|ctx, stream| { |
||||
trace!( |
||||
"{}:{} Write.with_context flush -> poll_flush", |
||||
file!(), |
||||
line!() |
||||
); |
||||
stream.poll_flush(ctx) |
||||
}) { |
||||
Poll::Ready(r) => r, |
||||
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> { |
||||
match r { |
||||
Ok(v) => Poll::Ready(Ok(v)), |
||||
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending, |
||||
Err(e) => Poll::Ready(Err(e)), |
||||
} |
||||
} |
@ -0,0 +1,181 @@ |
||||
use crate::compat::{AllowStd, HasContext}; |
||||
use crate::WebSocketStream; |
||||
use log::*; |
||||
use pin_project::pin_project; |
||||
use std::future::Future; |
||||
use std::io::{Read, Write}; |
||||
use std::pin::Pin; |
||||
use std::task::{Context, Poll}; |
||||
use tokio_io::{AsyncRead, AsyncWrite}; |
||||
use tungstenite::handshake::client::Response; |
||||
use tungstenite::handshake::server::Callback; |
||||
use tungstenite::handshake::{HandshakeError as Error, HandshakeRole, MidHandshake as WsHandshake}; |
||||
use tungstenite::{ClientHandshake, ServerHandshake, WebSocket}; |
||||
|
||||
pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S> |
||||
where |
||||
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin, |
||||
S: AsyncRead + AsyncWrite + Unpin, |
||||
{ |
||||
let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream })); |
||||
|
||||
let ws = start.await; |
||||
|
||||
WebSocketStream::new(ws) |
||||
} |
||||
|
||||
struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>); |
||||
struct SkippedHandshakeFutureInner<F, S> { |
||||
f: F, |
||||
stream: S, |
||||
} |
||||
|
||||
impl<F, S> Future for SkippedHandshakeFuture<F, S> |
||||
where |
||||
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin, |
||||
S: Unpin, |
||||
AllowStd<S>: Read + Write, |
||||
{ |
||||
type Output = WebSocket<AllowStd<S>>; |
||||
|
||||
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { |
||||
let inner = self |
||||
.get_mut() |
||||
.0 |
||||
.take() |
||||
.expect("future polled after completion"); |
||||
trace!("Setting context when skipping handshake"); |
||||
let stream = AllowStd { |
||||
inner: inner.stream, |
||||
context: ctx as *mut _ as *mut (), |
||||
}; |
||||
|
||||
Poll::Ready((inner.f)(stream)) |
||||
} |
||||
} |
||||
|
||||
#[pin_project] |
||||
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>); |
||||
|
||||
enum StartedHandshake<Role: HandshakeRole> { |
||||
Done(Role::FinalResult), |
||||
Mid(WsHandshake<Role>), |
||||
} |
||||
|
||||
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>); |
||||
struct StartedHandshakeFutureInner<F, S> { |
||||
f: F, |
||||
stream: S, |
||||
} |
||||
|
||||
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>> |
||||
where |
||||
Role: HandshakeRole + Unpin, |
||||
Role::InternalStream: HasContext, |
||||
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin, |
||||
S: AsyncRead + AsyncWrite + Unpin, |
||||
{ |
||||
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); |
||||
|
||||
match start.await? { |
||||
StartedHandshake::Done(r) => Ok(r), |
||||
StartedHandshake::Mid(s) => { |
||||
let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await; |
||||
res |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub(crate) async fn client_handshake<F, S>( |
||||
stream: S, |
||||
f: F, |
||||
) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>> |
||||
where |
||||
F: FnOnce( |
||||
AllowStd<S>, |
||||
) -> Result< |
||||
<ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult, |
||||
Error<ClientHandshake<AllowStd<S>>>, |
||||
> + Unpin, |
||||
S: AsyncRead + AsyncWrite + Unpin, |
||||
{ |
||||
let result = handshake(stream, f).await?; |
||||
let (s, r) = result; |
||||
Ok((WebSocketStream::new(s), r)) |
||||
} |
||||
|
||||
pub(crate) async fn server_handshake<C, F, S>( |
||||
stream: S, |
||||
f: F, |
||||
) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>> |
||||
where |
||||
C: Callback + Unpin, |
||||
F: FnOnce( |
||||
AllowStd<S>, |
||||
) -> Result< |
||||
<ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult, |
||||
Error<ServerHandshake<AllowStd<S>, C>>, |
||||
> + Unpin, |
||||
S: AsyncRead + AsyncWrite + Unpin, |
||||
{ |
||||
let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?; |
||||
Ok(WebSocketStream::new(s)) |
||||
} |
||||
|
||||
impl<Role, F, S> Future for StartedHandshakeFuture<F, S> |
||||
where |
||||
Role: HandshakeRole, |
||||
Role::InternalStream: HasContext, |
||||
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin, |
||||
S: Unpin, |
||||
AllowStd<S>: Read + Write, |
||||
{ |
||||
type Output = Result<StartedHandshake<Role>, Error<Role>>; |
||||
|
||||
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { |
||||
let inner = self.0.take().expect("future polled after completion"); |
||||
trace!("Setting ctx when starting handshake"); |
||||
let stream = AllowStd { |
||||
inner: inner.stream, |
||||
context: ctx as *mut _ as *mut (), |
||||
}; |
||||
|
||||
match (inner.f)(stream) { |
||||
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), |
||||
Err(Error::Interrupted(mut mid)) => { |
||||
let machine = mid.get_mut(); |
||||
machine.get_mut().set_context(std::ptr::null_mut()); |
||||
Poll::Ready(Ok(StartedHandshake::Mid(mid))) |
||||
} |
||||
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<Role> Future for MidHandshake<Role> |
||||
where |
||||
Role: HandshakeRole + Unpin, |
||||
Role::InternalStream: HasContext, |
||||
{ |
||||
type Output = Result<Role::FinalResult, Error<Role>>; |
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
||||
let this = self.project(); |
||||
let mut s = this.0.take().expect("future polled after completion"); |
||||
|
||||
let machine = s.get_mut(); |
||||
trace!("Setting context in handshake"); |
||||
machine.get_mut().set_context(cx as *mut _ as *mut ()); |
||||
|
||||
match s.handshake() { |
||||
Ok(stream) => Poll::Ready(Ok(stream)), |
||||
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), |
||||
Err(Error::Interrupted(mut mid)) => { |
||||
let machine = mid.get_mut(); |
||||
machine.get_mut().set_context(std::ptr::null_mut()); |
||||
*this.0 = Some(mid); |
||||
Poll::Pending |
||||
} |
||||
} |
||||
} |
||||
} |