Fix `poll_close` returning WouldBlock error kind

pull/104/head
Benoît CORTIER 3 years ago committed by Sebastian Dröge
parent 4cc0ffd5f9
commit 4dd8888a9d
  1. 32
      src/lib.rs

@ -217,6 +217,7 @@ where
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocketStream<S> { pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>, inner: WebSocket<AllowStd<S>>,
closing: bool,
} }
impl<S> WebSocketStream<S> { impl<S> WebSocketStream<S> {
@ -250,7 +251,10 @@ impl<S> WebSocketStream<S> {
} }
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self { pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
WebSocketStream { inner: ws } WebSocketStream {
inner: ws,
closing: false,
}
} }
fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
@ -336,9 +340,7 @@ where
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write_message(item)) { match (*self).with_context(None, |s| s.write_message(item)) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(::tungstenite::Error::Io(ref err)) Err(::tungstenite::Error::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
if err.kind() == std::io::ErrorKind::WouldBlock =>
{
// the message was accepted and queued // the message was accepted and queued
// isn't an error. // isn't an error.
Ok(()) Ok(())
@ -355,15 +357,37 @@ where
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.closing {
// After queing it, we call `write_pending` to drive the close handshake to completion.
match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.write_pending()) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::Io(err))
if err.kind() == std::io::ErrorKind::WouldBlock =>
{
trace!("WouldBlock");
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
} else {
match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) { match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())), Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::Io(err))
if err.kind() == std::io::ErrorKind::WouldBlock =>
{
trace!("WouldBlock");
self.closing = true;
Poll::Pending
}
Err(err) => { Err(err) => {
debug!("websocket close error: {}", err); debug!("websocket close error: {}", err);
Poll::Ready(Err(err)) Poll::Ready(Err(err))
} }
} }
} }
}
} }
#[cfg(any( #[cfg(any(

Loading…
Cancel
Save