diff --git a/Cargo.toml b/Cargo.toml index 9820127..8a733cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.6.1" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.6.1" +version = "0.7.0" [features] default = ["tls"] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 21f1d9b..3e7b732 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -39,7 +39,8 @@ fn run_test(case: u32) -> Result<()> { socket.write_message(msg)?; } Message::Ping(_) | - Message::Pong(_) => {} + Message::Pong(_) | + Message::Close(_) => {} } } } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 10aba75..10957f5 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -24,7 +24,8 @@ fn handle_client(stream: TcpStream) -> Result<()> { socket.write_message(msg)?; } Message::Ping(_) | - Message::Pong(_) => {} + Message::Pong(_) | + Message::Close(_) => {} } } } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index c5bf9c2..d016c38 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -11,7 +11,7 @@ use super::coding::{OpCode, Control, Data, CloseCode}; use super::mask::{generate_mask, apply_mask}; /// A struct representing the close command. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct CloseFrame<'t> { /// The reason as a code. pub code: CloseCode, diff --git a/src/protocol/message.rs b/src/protocol/message.rs index a4cbc82..c98d22f 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -4,6 +4,7 @@ use std::result::Result as StdResult; use std::str; use error::{Result, Error}; +use super::frame::CloseFrame; mod string_collect { @@ -179,6 +180,8 @@ pub enum Message { /// /// The payload here must have a length less than 125 bytes Pong(Vec), + /// A close message with the optional close frame. + Close(Option>), } impl Message { @@ -229,6 +232,14 @@ impl Message { } } + /// Indicates whether a message ia s close message. + pub fn is_close(&self) -> bool { + match *self { + Message::Close(_) => true, + _ => false, + } + } + /// Get the length of the WebSocket message. pub fn len(&self) -> usize { match *self { @@ -236,6 +247,7 @@ impl Message { Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => data.len(), + Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), } } @@ -252,6 +264,8 @@ impl Message { Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, + Message::Close(None) => Vec::new(), + Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } } @@ -263,6 +277,8 @@ impl Message { Message::Ping(data) | Message::Pong(data) => Ok(try!( String::from_utf8(data).map_err(|err| err.utf8_error()))), + Message::Close(None) => Ok(String::new()), + Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), } } @@ -274,6 +290,8 @@ impl Message { Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => Ok(try!(str::from_utf8(data))), + Message::Close(None) => Ok(""), + Message::Close(Some(ref frame)) => Ok(&frame.reason), } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8b3de04..d82e261 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -8,7 +8,7 @@ pub use self::message::Message; pub use self::frame::CloseFrame; use std::collections::VecDeque; -use std::io::{Read, Write, ErrorKind as IoErrorKind}; +use std::io::{Read, Write}; use std::mem::replace; use error::{Error, Result}; @@ -146,8 +146,7 @@ impl WebSocket { self.write_pending().no_block()?; // If we get here, either write blocks or we have nothing to write. // Thus if read blocks, just let it return WouldBlock. - let res = self.read_message_frame(); - if let Some(message) = self.translate_close(res)? { + if let Some(message) = self.read_message_frame()? { trace!("Received message {}", message); return Ok(message) } @@ -188,34 +187,19 @@ impl WebSocket { self.pong = Some(Frame::pong(data)); return self.write_pending() } + Message::Close(code) => { + return self.close(code) + } }; self.send_queue.push_back(frame); self.write_pending() } - /// Close the connection. - /// - /// This function guarantees that the close frame will be queued. - /// There is no need to call it again. - pub fn close(&mut self, code: Option) -> Result<()> { - if let WebSocketState::Active = self.state { - self.state = WebSocketState::ClosedByUs; - let frame = Frame::close(code); - self.send_queue.push_back(frame); - } else { - // Already closed, nothing to do. - } - self.write_pending() - } - /// Flush the pending send queue. pub fn write_pending(&mut self) -> Result<()> { // First, make sure we have no pending frame sending. - { - let res = self.socket.write_pending(); - self.translate_close(res)?; - } + self.socket.write_pending()?; // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // response, unless it already received a Close frame. It SHOULD @@ -247,6 +231,22 @@ impl WebSocket { Ok(()) } } + + /// Close the connection. + /// + /// This function guarantees that the close frame will be queued. + /// There is no need to call it again. Calling this function is + /// the same as calling `write(Message::Close(..))`. + pub fn close(&mut self, code: Option) -> Result<()> { + if let WebSocketState::Active = self.state { + self.state = WebSocketState::ClosedByUs; + let frame = Frame::close(code); + self.send_queue.push_back(frame); + } else { + // Already closed, nothing to do. + } + self.write_pending() + } } impl WebSocket { @@ -299,7 +299,7 @@ impl WebSocket { Err(Error::Protocol("Control frame too big".into())) } OpCtl::Close => { - self.do_close(frame.into_close()?).map(|_| None) + Ok(self.do_close(frame.into_close()?).map(Message::Close)) } OpCtl::Reserved(i) => { Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) @@ -373,7 +373,7 @@ impl WebSocket { } else { match replace(&mut self.state, WebSocketState::Terminated) { WebSocketState::CloseAcknowledged(close) | WebSocketState::ClosedByPeer(close) => { - Err(Error::ConnectionClosed(close)) + Ok(Some(Message::Close(close))) } _ => { Err(Error::Protocol("Connection reset without closing handshake".into())) @@ -382,13 +382,14 @@ impl WebSocket { } } - /// Received a close frame. - fn do_close(&mut self, close: Option) -> Result<()> { + /// Received a close frame. Tells if we need to return a close frame to the user. + fn do_close(&mut self, close: Option) -> Option>> { debug!("Received close frame: {:?}", close); match self.state { WebSocketState::Active => { let close_code = close.as_ref().map(|f| f.code); - self.state = WebSocketState::ClosedByPeer(close.map(CloseFrame::into_owned)); + let close = close.map(CloseFrame::into_owned); + self.state = WebSocketState::ClosedByPeer(close.clone()); let reply = if let Some(code) = close_code { if code.is_allowed() { Frame::close(Some(CloseFrame { @@ -406,11 +407,12 @@ impl WebSocket { }; debug!("Replying to close with {:?}", reply); self.send_queue.push_back(reply); - Ok(()) + + Some(close) } WebSocketState::ClosedByPeer(_) | WebSocketState::CloseAcknowledged(_) => { // It is already closed, just ignore. - Ok(()) + None } WebSocketState::ClosedByUs => { // We received a reply. @@ -419,11 +421,11 @@ impl WebSocket { Role::Client => { // Client waits for the server to close the connection. self.state = WebSocketState::CloseAcknowledged(close); - Ok(()) + None } Role::Server => { // Server closes the connection. - Err(Error::ConnectionClosed(close)) + Some(close) } } } @@ -442,30 +444,8 @@ impl WebSocket { frame.set_random_mask(); } } - let res = self.socket.write_frame(frame); - self.translate_close(res) + self.socket.write_frame(frame) } - - /// Translate a "Connection reset by peer" into ConnectionClosed as needed. - fn translate_close(&mut self, res: Result) -> Result { - match res { - Err(Error::Io(err)) => Err({ - if err.kind() == IoErrorKind::ConnectionReset { - match self.state { - WebSocketState::ClosedByPeer(ref mut frame) => - Error::ConnectionClosed(frame.take()), - WebSocketState::CloseAcknowledged(ref mut frame) => - Error::ConnectionClosed(frame.take()), - _ => Error::Io(err), - } - } else { - Error::Io(err) - } - }), - x => x, - } - } - } /// The current connection state.