diff --git a/Cargo.toml b/Cargo.toml index ec72c6e..29cddea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"] license = "MIT" readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" -documentation = "https://docs.rs/tungstenite/0.1.1" +documentation = "https://docs.rs/tungstenite/0.2.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.1.2" +version = "0.2.0" [features] default = ["tls"] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 6b72b60..e883ff4 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -14,7 +14,7 @@ fn get_case_count() -> Result { Url::parse("ws://localhost:9001/getCaseCount").unwrap() )?; let msg = socket.read_message()?; - socket.close()?; + socket.close(None)?; Ok(msg.into_text()?.parse::().unwrap()) } @@ -22,7 +22,7 @@ fn update_reports() -> Result<()> { let mut socket = connect( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap() )?; - socket.close()?; + socket.close(None)?; Ok(()) } diff --git a/examples/client.rs b/examples/client.rs index d8a176f..1c97aba 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -16,6 +16,6 @@ fn main() { let msg = socket.read_message().expect("Error reading message"); println!("Received: {}", msg); } - // socket.close(); + // socket.close(None); } diff --git a/src/client.rs b/src/client.rs index c0166e6..83e0062 100644 --- a/src/client.rs +++ b/src/client.rs @@ -119,6 +119,6 @@ pub fn url_mode(url: &Url) -> Result { pub fn client(url: Url, stream: Stream) -> StdResult, HandshakeError> { - let request = Request { url: url }; + let request = Request { url: url, extra_headers: None }; ClientHandshake::start(stream, request).handshake() } diff --git a/src/error.rs b/src/error.rs index e885697..3942d6f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,8 @@ use std::string; use httparse; +use protocol::frame::CloseFrame; + #[cfg(feature="tls")] pub mod tls { pub use native_tls::Error; @@ -22,7 +24,7 @@ pub type Result = result::Result; #[derive(Debug)] pub enum Error { /// WebSocket connection closed (normally) - ConnectionClosed, + ConnectionClosed(Option>), /// Input-output error Io(io::Error), #[cfg(feature="tls")] @@ -43,7 +45,13 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Error::ConnectionClosed => write!(f, "Connection closed"), + Error::ConnectionClosed(ref frame) => { + if let Some(ref cf) = *frame { + write!(f, "Connection closed: {}", cf) + } else { + write!(f, "Connection closed (empty close frame)") + } + } Error::Io(ref err) => write!(f, "IO error: {}", err), #[cfg(feature="tls")] Error::Tls(ref err) => write!(f, "TLS error: {}", err), @@ -59,7 +67,7 @@ impl fmt::Display for Error { impl ErrorTrait for Error { fn description(&self) -> &str { match *self { - Error::ConnectionClosed => "", + Error::ConnectionClosed(_) => "A close handshake is performed", Error::Io(ref err) => err.description(), #[cfg(feature="tls")] Error::Tls(ref err) => err.description(), diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 9beb5eb..157f309 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -13,12 +13,12 @@ use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::machine::{HandshakeMachine, StageResult, TryParse}; /// Client request. -pub struct Request { +pub struct Request<'t> { pub url: Url, - // TODO extra headers + pub extra_headers: Option<&'t [(&'t str, &'t str)]>, } -impl Request { +impl<'t> Request<'t> { /// The GET part of the request. fn get_path(&self) -> String { if let Some(query) = self.url.query() { @@ -56,9 +56,14 @@ impl ClientHandshake { Connection: upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ - Sec-WebSocket-Key: {key}\r\n\ - \r\n", host = request.get_host(), path = request.get_path(), key = key) - .unwrap(); + Sec-WebSocket-Key: {key}\r\n", + host = request.get_host(), path = request.get_path(), key = key).unwrap(); + if let Some(eh) = request.extra_headers { + for &(k, v) in eh { + write!(req, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(req, "\r\n").unwrap(); HandshakeMachine::start_write(stream, req) }; diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index 3476c68..97f9fd0 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -193,9 +193,16 @@ impl CloseCode { } } -impl Into for CloseCode { +impl fmt::Display for CloseCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let code: u16 = self.into(); + write!(f, "{}", code) + } +} + +impl<'t> Into for &'t CloseCode { fn into(self) -> u16 { - match self { + match *self { Normal => 1000, Away => 1001, Protocol => 1002, @@ -218,6 +225,12 @@ impl Into for CloseCode { } } +impl Into for CloseCode { + fn into(self) -> u16 { + (&self).into() + } +} + impl From for CloseCode { fn from(code: u16) -> CloseCode { match code { diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index d41cde8..3bdc539 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::borrow::Cow; use std::mem::transmute; use std::io::{Cursor, Read, Write, ErrorKind}; use std::default::Default; @@ -43,6 +44,31 @@ fn generate_mask() -> [u8; 4] { rand::random() } +/// A struct representing the close command. +#[derive(Debug, Clone)] +pub struct CloseFrame<'t> { + /// The reason as a code. + pub code: CloseCode, + /// The reason as text string. + pub reason: Cow<'t, str>, +} + +impl<'t> CloseFrame<'t> { + /// Convert into a owned string. + pub fn into_owned(self) -> CloseFrame<'static> { + CloseFrame { + code: self.code, + reason: self.reason.into_owned().into(), + } + } +} + +impl<'t> fmt::Display for CloseFrame<'t> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} ({})", self.reason, self.code) + } +} + /// A struct representing a WebSocket frame. #[derive(Debug, Clone)] pub struct Frame { @@ -215,7 +241,7 @@ impl Frame { /// Consume the frame into a closing frame. #[inline] - pub fn into_close(self) -> Result> { + pub fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), 1 => Err(Error::Protocol("Invalid close sequence".into())), @@ -224,7 +250,7 @@ impl Frame { let code = NetworkEndian::read_u16(&data[0..2]).into(); data.drain(0..2); let text = String::from_utf8(data)?; - Ok(Some((code, text))) + Ok(Some(CloseFrame { code: code, reason: text.into() })) } } } @@ -267,8 +293,8 @@ impl Frame { /// Create a new Close control frame. #[inline] - pub fn close(msg: Option<(CloseCode, &str)>) -> Frame { - let payload = if let Some((code, reason)) = msg { + pub fn close(msg: Option) -> Frame { + let payload = if let Some(CloseFrame { code, reason }) = msg { let raw: [u8; 2] = unsafe { let u: u16 = code.into(); transmute(u.to_be()) diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 14b937b..8973db7 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,6 +5,7 @@ pub mod coding; mod frame; pub use self::frame::Frame; +pub use self::frame::CloseFrame; use std::io::{Read, Write}; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index b03d954..b4ac9c2 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -5,9 +5,10 @@ pub mod frame; mod message; pub use self::message::Message; +pub use self::frame::CloseFrame; use std::collections::VecDeque; -use std::io::{Read, Write}; +use std::io::{Read, Write, ErrorKind as IoErrorKind}; use std::mem::replace; use error::{Error, Result}; @@ -89,7 +90,8 @@ impl WebSocket { self.write_pending().no_block()?; // If we get here, either write blocks or we have nothing to write. // Thus if read blocks, just let it return WouldBlock. - if let Some(message) = self.read_message_frame()? { + let res = self.read_message_frame(); + if let Some(message) = self.translate_close(res)? { trace!("Received message {}", message); return Ok(message) } @@ -123,10 +125,10 @@ impl WebSocket { /// /// This function guarantees that the close frame will be queued. /// There is no need to call it again, just like write_message(). - pub fn close(&mut self) -> Result<()> { + pub fn close(&mut self, code: Option) -> Result<()> { if let WebSocketState::Active = self.state { self.state = WebSocketState::ClosedByUs; - let frame = Frame::close(None); + let frame = Frame::close(code); self.send_queue.push_back(frame); } else { // Already closed, nothing to do. @@ -137,7 +139,10 @@ impl WebSocket { /// Flush the pending send queue. pub fn write_pending(&mut self) -> Result<()> { // First, make sure we have no pending frame sending. - self.socket.write_pending()?; + { + let res = self.socket.write_pending(); + self.translate_close(res)?; + } // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in // response, unless it already received a Close frame. It SHOULD @@ -151,8 +156,8 @@ impl WebSocket { } // If we're closing and there is nothing to send anymore, we should close the connection. - match self.state { - WebSocketState::ClosedByPeer if self.send_queue.is_empty() => { + if self.send_queue.is_empty() { + if let WebSocketState::ClosedByPeer(ref mut frame) = self.state { // The underlying TCP connection, in most normal cases, SHOULD be closed // first by the server, so that it holds the TIME_WAIT state and not the // client (as this would prevent it from re-opening the connection for 2 @@ -161,10 +166,13 @@ impl WebSocket { // a new SYN with a higher seq number). (RFC 6455) match self.role { Role::Client => Ok(()), - Role::Server => Err(Error::ConnectionClosed), + Role::Server => Err(Error::ConnectionClosed(replace(frame, None))), } + } else { + Ok(()) } - _ => Ok(()), + } else { + Ok(()) } } @@ -290,36 +298,47 @@ impl WebSocket { } /// Received a close frame. - fn do_close(&mut self, close: Option<(CloseCode, String)>) -> Result<()> { + fn do_close(&mut self, close: Option) -> Result<()> { + debug!("Received close frame: {:?}", close); match self.state { WebSocketState::Active => { - self.state = WebSocketState::ClosedByPeer; - let reply = if let Some((code, _)) = close { + let close_code = close.as_ref().map(|f| f.code); + self.state = WebSocketState::ClosedByPeer(close.map(CloseFrame::into_owned)); + let reply = if let Some(code) = close_code { if code.is_allowed() { - Frame::close(Some((CloseCode::Normal, ""))) + Frame::close(Some(CloseFrame { + code: CloseCode::Normal, + reason: "".into(), + })) } else { - Frame::close(Some((CloseCode::Protocol, "Protocol violation"))) + Frame::close(Some(CloseFrame { + code: CloseCode::Protocol, + reason: "Protocol violation".into() + })) } } else { Frame::close(None) }; + debug!("Replying to close with {:?}", reply); self.send_queue.push_back(reply); Ok(()) } - WebSocketState::ClosedByPeer => { + WebSocketState::ClosedByPeer(_) | WebSocketState::CloseAcknowledged(_) => { // It is already closed, just ignore. Ok(()) } WebSocketState::ClosedByUs => { // We received a reply. + let close = close.map(CloseFrame::into_owned); match self.role { Role::Client => { // Client waits for the server to close the connection. + self.state = WebSocketState::CloseAcknowledged(close); Ok(()) } Role::Server => { // Server closes the connection. - Err(Error::ConnectionClosed) + Err(Error::ConnectionClosed(close)) } } } @@ -358,16 +377,42 @@ impl WebSocket { frame.set_mask(); } } - self.socket.write_frame(frame) + let res = self.socket.write_frame(frame); + self.translate_close(res) + } + + /// Translate a "Connection reset by peer" into ConnectionClosed as needed. + fn translate_close(&mut self, res: Result) -> Result { + match res { + Err(Error::Io(err)) => Err({ + if err.kind() == IoErrorKind::ConnectionReset { + match self.state { + WebSocketState::ClosedByPeer(ref mut frame) => + Error::ConnectionClosed(replace(frame, None)), + WebSocketState::CloseAcknowledged(ref mut frame) => + Error::ConnectionClosed(replace(frame, None)), + _ => Error::Io(err), + } + } else { + Error::Io(err) + } + }), + x => x, + } } } /// The current connection state. enum WebSocketState { + /// The connection is active. Active, + /// We initiated a close handshake. ClosedByUs, - ClosedByPeer, + /// The peer initiated a close handshake. + ClosedByPeer(Option>), + /// The peer replied to our close handshake. + CloseAcknowledged(Option>), } impl WebSocketState {