diff --git a/src/lib.rs b/src/lib.rs index f5a49ee..d2a0763 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -217,6 +217,7 @@ where #[derive(Debug)] pub struct WebSocketStream { inner: WebSocket>, + closing: bool, } impl WebSocketStream { @@ -250,7 +251,10 @@ impl WebSocketStream { } pub(crate) fn new(ws: WebSocket>) -> Self { - WebSocketStream { inner: ws } + WebSocketStream { + inner: ws, + closing: false, + } } fn with_context(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R @@ -336,9 +340,7 @@ where fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { match (*self).with_context(None, |s| s.write_message(item)) { Ok(()) => Ok(()), - Err(::tungstenite::Error::Io(ref err)) - if err.kind() == std::io::ErrorKind::WouldBlock => - { + Err(::tungstenite::Error::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => { // the message was accepted and queued // isn't an error. Ok(()) @@ -355,12 +357,34 @@ where } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) { - Ok(()) => Poll::Ready(Ok(())), - Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), - Err(err) => { - debug!("websocket close error: {}", err); - Poll::Ready(Err(err)) + if self.closing { + // After queing it, we call `write_pending` to drive the close handshake to completion. + match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.write_pending()) { + Ok(()) => Poll::Ready(Ok(())), + Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), + Err(::tungstenite::Error::Io(err)) + if err.kind() == std::io::ErrorKind::WouldBlock => + { + trace!("WouldBlock"); + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + } else { + match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) { + Ok(()) => Poll::Ready(Ok(())), + Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), + Err(::tungstenite::Error::Io(err)) + if err.kind() == std::io::ErrorKind::WouldBlock => + { + trace!("WouldBlock"); + self.closing = true; + Poll::Pending + } + Err(err) => { + debug!("websocket close error: {}", err); + Poll::Ready(Err(err)) + } } } }