diff --git a/Cargo.toml b/Cargo.toml index 4a9901f..3581338 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.1.2" +version = "0.2.0" [features] default = ["tls"] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 6b72b60..e883ff4 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -14,7 +14,7 @@ fn get_case_count() -> Result { Url::parse("ws://localhost:9001/getCaseCount").unwrap() )?; let msg = socket.read_message()?; - socket.close()?; + socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) } @@ -22,7 +22,7 @@ fn update_reports() -> Result<()> { let mut socket = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() )?; - socket.close()?; + socket.close(None)?; Ok(()) } diff --git a/examples/client.rs b/examples/client.rs index d8a176f..1c97aba 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -16,6 +16,6 @@ fn main() { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); } - // socket.close(); + // socket.close(None); } diff --git a/src/error.rs b/src/error.rs index e885697..3942d6f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,8 @@ use std::string; use httparse; +use protocol::frame::CloseFrame; + #[cfg(feature="tls")] pub mod tls { pub use native_tls::Error; @@ -22,7 +24,7 @@ pub type Result = result::Result; #[derive(Debug)] pub enum Error { /// WebSocket connection closed (normally) - ConnectionClosed, + ConnectionClosed(Option>), /// Input-output error Io(io::Error), #[cfg(feature="tls")] @@ -43,7 +45,13 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Error::ConnectionClosed => write!(f, "Connection closed"), + Error::ConnectionClosed(ref frame) => { + if let Some(ref cf) = *frame { + write!(f, "Connection closed: {}", cf) + } else { + write!(f, "Connection closed (empty close frame)") + } + } Error::Io(ref err) => write!(f, "IO error: {}", err), #[cfg(feature="tls")] Error::Tls(ref err) => write!(f, "TLS error: {}", err), @@ -59,7 +67,7 @@ impl fmt::Display for Error { impl ErrorTrait for Error { fn description(&self) -> &str { match *self { - Error::ConnectionClosed => "", + Error::ConnectionClosed(_) => "A close handshake is performed", Error::Io(ref err) => err.description(), #[cfg(feature="tls")] Error::Tls(ref err) => err.description(), diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 74a3cac..4709bf7 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -8,7 +8,7 @@ pub use self::message::Message; pub use self::frame::CloseFrame; use std::collections::VecDeque; -use std::io::{Read, Write}; +use std::io::{Read, Write, ErrorKind as IoErrorKind}; use std::mem::replace; use error::{Error, Result}; @@ -90,7 +90,8 @@ impl WebSocket { self.write_pending().no_block()?; // If we get here, either write blocks or we have nothing to write. // Thus if read blocks, just let it return WouldBlock. - if let Some(message) = self.read_message_frame()? { + let res = self.read_message_frame(); + if let Some(message) = self.translate_close(res)? { trace!("Received message {}", message); return Ok(message) } @@ -124,11 +125,11 @@ impl WebSocket { /// /// This function guarantees that the close frame will be queued. /// There is no need to call it again, just like write_message(). - pub fn close(&mut self) -> Result<()> { + pub fn close(&mut self, code: Option) -> Result<()> { match self.state { WebSocketState::Active => { self.state = WebSocketState::ClosedByUs; - let frame = Frame::close(None); + let frame = Frame::close(code); self.send_queue.push_back(frame); } _ => { @@ -141,7 +142,10 @@ impl WebSocket { /// Flush the pending send queue. pub fn write_pending(&mut self) -> Result<()> { // First, make sure we have no pending frame sending. - self.socket.write_pending()?; + { + let res = self.socket.write_pending(); + self.translate_close(res)?; + } // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // response, unless it already received a Close frame. It SHOULD @@ -155,8 +159,8 @@ impl WebSocket { } // If we're closing and there is nothing to send anymore, we should close the connection. - match self.state { - WebSocketState::ClosedByPeer if self.send_queue.is_empty() => { + if self.send_queue.is_empty() { + if let WebSocketState::ClosedByPeer(ref mut frame) = self.state { // The underlying TCP connection, in most normal cases, SHOULD be closed // first by the server, so that it holds the TIME_WAIT state and not the // client (as this would prevent it from re-opening the connection for 2 @@ -165,10 +169,13 @@ impl WebSocket { // a new SYN with a higher seq number). (RFC 6455) match self.role { Role::Client => Ok(()), - Role::Server => Err(Error::ConnectionClosed), + Role::Server => Err(Error::ConnectionClosed(replace(frame, None))), } + } else { + Ok(()) } - _ => Ok(()), + } else { + Ok(()) } } @@ -295,10 +302,12 @@ impl WebSocket { /// Received a close frame. fn do_close(&mut self, close: Option) -> Result<()> { + debug!("Received close frame: {:?}", close); match self.state { WebSocketState::Active => { - self.state = WebSocketState::ClosedByPeer; - let reply = if let Some(CloseFrame { code, .. }) = close { + let close_code = close.as_ref().map(|f| f.code); + self.state = WebSocketState::ClosedByPeer(close.map(CloseFrame::into_owned)); + let reply = if let Some(code) = close_code { if code.is_allowed() { Frame::close(Some(CloseFrame { code: CloseCode::Normal, @@ -313,23 +322,26 @@ impl WebSocket { } else { Frame::close(None) }; + debug!("Replying to close with {:?}", reply); self.send_queue.push_back(reply); Ok(()) } - WebSocketState::ClosedByPeer => { + WebSocketState::ClosedByPeer(_) | WebSocketState::CloseAcknowledged(_) => { // It is already closed, just ignore. Ok(()) } WebSocketState::ClosedByUs => { // We received a reply. + let close = close.map(CloseFrame::into_owned); match self.role { Role::Client => { // Client waits for the server to close the connection. + self.state = WebSocketState::CloseAcknowledged(close); Ok(()) } Role::Server => { // Server closes the connection. - Err(Error::ConnectionClosed) + Err(Error::ConnectionClosed(close)) } } } @@ -368,16 +380,42 @@ impl WebSocket { frame.set_mask(); } } - self.socket.write_frame(frame) + let res = self.socket.write_frame(frame); + self.translate_close(res) + } + + /// Translate a "Connection reset by peer" into ConnectionClosed as needed. + fn translate_close(&mut self, res: Result) -> Result { + match res { + Err(Error::Io(err)) => Err({ + if err.kind() == IoErrorKind::ConnectionReset { + match self.state { + WebSocketState::ClosedByPeer(ref mut frame) => + Error::ConnectionClosed(replace(frame, None)), + WebSocketState::CloseAcknowledged(ref mut frame) => + Error::ConnectionClosed(replace(frame, None)), + _ => Error::Io(err), + } + } else { + Error::Io(err) + } + }), + x => x, + } } } /// The current connection state. enum WebSocketState { + /// The connection is active. Active, + /// We initiated a close handshake. ClosedByUs, - ClosedByPeer, + /// The peer initiated a close handshake. + ClosedByPeer(Option>), + /// The peer replied to our close handshake. + CloseAcknowledged(Option>), } impl WebSocketState {