chore: apply `fmt` to the whole project

pull/156/head
Daniel Abramov 4 years ago
parent 2638bd69c7
commit 96d9eb75e5
  1. 13
      examples/autobahn-client.rs
  2. 15
      examples/autobahn-server.rs
  3. 11
      examples/callback-error.rs
  4. 4
      examples/client.rs
  5. 9
      examples/server.rs
  6. 7
      rustfmt.toml
  7. 54
      src/client.rs
  8. 8
      src/error.rs
  9. 92
      src/handshake/client.rs
  10. 3
      src/handshake/headers.rs
  11. 23
      src/handshake/machine.rs
  12. 13
      src/handshake/mod.rs
  13. 60
      src/handshake/server.rs
  14. 14
      src/lib.rs
  15. 22
      src/protocol/frame/coding.rs
  16. 70
      src/protocol/frame/frame.rs
  17. 22
      src/protocol/frame/mod.rs
  18. 26
      src/protocol/message.rs
  19. 83
      src/protocol/mod.rs
  20. 6
      src/server.rs
  21. 6
      src/util.rs
  22. 49
      tests/connection_reset.rs
  23. 10
      tests/no_send_after_close.rs
  24. 14
      tests/receive_after_init_close.rs

@ -14,11 +14,7 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())
@ -26,11 +22,8 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> { fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = Url::parse(&format!( let case_url =
"ws://localhost:9001/runCase?case={}&agent={}", Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
case, AGENT
))
.unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect(case_url)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {

@ -1,9 +1,10 @@
use std::net::{TcpListener, TcpStream}; use std::{
use std::thread::spawn; net::{TcpListener, TcpStream},
thread::spawn,
};
use log::*; use log::*;
use tungstenite::handshake::HandshakeRole; use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::{accept, Error, HandshakeError, Message, Result};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
@ -32,12 +33,14 @@ fn main() {
for stream in server.incoming() { for stream in server.incoming() {
spawn(move || match stream { spawn(move || match stream {
Ok(stream) => if let Err(err) = handle_client(stream) { Ok(stream) => {
if let Err(err) = handle_client(stream) {
match err { match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (), Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e), e => error!("test: {}", e),
} }
}, }
}
Err(e) => error!("Error accepting stream: {}", e), Err(e) => error!("Error accepting stream: {}", e),
}); });
} }

@ -1,9 +1,10 @@
use std::net::TcpListener; use std::{net::TcpListener, thread::spawn};
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::{
use tungstenite::handshake::server::{Request, Response}; accept_hdr,
use tungstenite::http::StatusCode; handshake::server::{Request, Response},
http::StatusCode,
};
fn main() { fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap(); let server = TcpListener::bind("127.0.0.1:3012").unwrap();

@ -14,9 +14,7 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
loop { loop {
let msg = socket.read_message().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);

@ -1,8 +1,9 @@
use std::net::TcpListener; use std::{net::TcpListener, thread::spawn};
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::{
use tungstenite::handshake::server::{Request, Response}; accept_hdr,
handshake::server::{Request, Response},
};
fn main() { fn main() {
env_logger::init(); env_logger::init();

@ -0,0 +1,7 @@
# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check].
# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md
# Break complex but short statements a bit less.
use_small_heuristics = "Max"
merge_imports = true

@ -1,16 +1,20 @@
//! Methods to connect to a WebSocket as a client. //! Methods to connect to a WebSocket as a client.
use std::io::{Read, Write}; use std::{
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; io::{Read, Write},
use std::result::Result as StdResult; net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::{Uri, request::Parts}; use http::{request::Parts, Uri};
use log::*; use log::*;
use url::Url; use url::Url;
use crate::handshake::client::{Request, Response}; use crate::{
use crate::protocol::WebSocketConfig; handshake::client::{Request, Response},
protocol::WebSocketConfig,
};
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
mod encryption { mod encryption {
@ -22,8 +26,7 @@ mod encryption {
/// TCP stream switcher (plain/TLS). /// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>; pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use crate::error::Result; use crate::{error::Result, stream::Mode};
use crate::stream::Mode;
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> { pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode { match mode {
@ -48,8 +51,10 @@ mod encryption {
mod encryption { mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
use crate::error::{Error, Result}; use crate::{
use crate::stream::Mode; error::{Error, Result},
stream::Mode,
};
/// TLS support is nod compiled in, this is just standard `TcpStream`. /// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream; pub type AutoStream = TcpStream;
@ -65,11 +70,12 @@ mod encryption {
use self::encryption::wrap_stream; use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::error::{Error, Result}; use crate::{
use crate::handshake::client::ClientHandshake; error::{Error, Result},
use crate::handshake::HandshakeError; handshake::{client::ClientHandshake, HandshakeError},
use crate::protocol::WebSocket; protocol::WebSocket,
use crate::stream::{Mode, NoDelay}; stream::{Mode, NoDelay},
};
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
/// ///
@ -91,16 +97,14 @@ pub fn connect_with_config<Req: IntoClientRequest>(
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
max_redirects: u8, max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
fn try_client_handshake(
fn try_client_handshake(request: Request, config: Option<WebSocketConfig>) request: Request,
-> Result<(WebSocket<AutoStream>, Response)> config: Option<WebSocketConfig>,
{ ) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri(); let uri = request.uri();
let mode = uri_mode(uri)?; let mode = uri_mode(uri)?;
let host = request let host =
.uri() request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode { let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80, Mode::Plain => 80,
Mode::Tls => 443, Mode::Tls => 443,
@ -164,9 +168,7 @@ pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoSt
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -1,12 +1,6 @@
//! Error handling. //! Error handling.
use std::borrow::Cow; use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string};
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io;
use std::result;
use std::str;
use std::string;
use crate::protocol::Message; use crate::protocol::Message;
use http::Response; use http::Response;

@ -1,17 +1,24 @@
//! Client handshake machine. //! Client handshake machine.
use std::io::{Read, Write}; use std::{
use std::marker::PhantomData; io::{Read, Write},
marker::PhantomData,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, MAX_HEADERS}; use super::{
use super::machine::{HandshakeMachine, StageResult, TryParse}; convert_key,
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; headers::{FromHttparse, MAX_HEADERS},
use crate::error::{Error, Result}; machine::{HandshakeMachine, StageResult, TryParse},
use crate::protocol::{Role, WebSocket, WebSocketConfig}; HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Client request type. /// Client request type.
pub type Request = HttpRequest<()>; pub type Request = HttpRequest<()>;
@ -35,15 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol( return Err(Error::Protocol("Invalid HTTP method, only GET supported".into()));
"Invalid HTTP method, only GET supported".into(),
));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
// Check the URI scheme: only ws or wss are supported // Check the URI scheme: only ws or wss are supported
@ -58,18 +61,11 @@ impl<S: Read + Write> ClientHandshake<S> {
let client = { let client = {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData }
verify_data: VerifyData { accept_key },
config,
_marker: PhantomData,
}
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
Ok(MidHandshake { Ok(MidHandshake { role: client, machine })
role: client,
machine,
})
} }
} }
@ -85,11 +81,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { StageResult::DoneReading { stream, result, tail } => {
stream,
result,
tail,
} => {
let result = self.verify_data.verify_response(result)?; let result = self.verify_data.verify_response(result)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = let websocket =
@ -105,16 +97,16 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri(); let uri = request.uri();
let authority = uri.authority() let authority =
.ok_or_else(|| Error::Url("No host name in the URL".into()))? uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
.as_str(); let host = if let Some(idx) = authority.find('@') {
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ // handle possible name:password@
authority.split_at(idx + 1).1 authority.split_at(idx + 1).1
} else { } else {
authority authority
}; };
if authority.is_empty() { if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into())) return Err(Error::Url("URL contains empty host name".into()));
} }
write!( write!(
@ -128,10 +120,8 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
version = request.version(), version = request.version(),
host = host, host = host,
path = uri path =
.path_and_query() uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
@ -175,9 +165,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
"No \"Upgrade: websocket\" in server reply".into(),
));
} }
// 3. If the response lacks a |Connection| header field or the // 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
@ -189,22 +177,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
"No \"Connection: upgrade\" in server reply".into(),
));
} }
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or // 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !headers if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
.get("Sec-WebSocket-Accept") return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
} }
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
@ -238,9 +218,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -266,9 +244,8 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest; use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -342,9 +319,6 @@ mod tests {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.status(), http::StatusCode::OK); assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!( assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],);
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
} }
} }

@ -41,8 +41,7 @@ impl TryParse for HeaderMap {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, HeaderMap};
use super::HeaderMap;
#[test] #[test]
fn headers() { fn headers() {

@ -2,8 +2,10 @@ use bytes::Buf;
use log::*; use log::*;
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result}; use crate::{
use crate::util::NonBlockingResult; error::{Error, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
/// A generic handshake state machine. /// A generic handshake state machine.
@ -23,10 +25,7 @@ impl<Stream> HandshakeMachine<Stream> {
} }
/// Start writing data to the peer. /// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine { HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
@ -52,8 +51,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.no_block()?; .no_block()?;
match read { match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())), Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok( Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size); buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading { RoundResult::StageFinished(StageResult::DoneReading {
result: obj, result: obj,
@ -65,8 +63,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
}) })
}, }),
),
None => Ok(RoundResult::WouldBlock(HandshakeMachine { None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
@ -112,11 +109,7 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)] #[derive(Debug)]
pub enum StageResult<Obj, Stream> { pub enum StageResult<Obj, Stream> {
/// Reading round finished. /// Reading round finished.
DoneReading { DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
/// Writing round finished. /// Writing round finished.
DoneWriting(Stream), DoneWriting(Stream),
} }

@ -6,9 +6,11 @@ pub mod server;
mod machine; mod machine;
use std::error::Error as ErrorTrait; use std::{
use std::fmt; error::Error as ErrorTrait,
use std::io::{Read, Write}; fmt,
io::{Read, Write},
};
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
@ -39,10 +41,7 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => { RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
machine: m,
..self
}))
} }
RoundResult::Incomplete(m) => m, RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {

@ -1,18 +1,25 @@
//! Server handshake machine. //! Server handshake machine.
use std::io::{self, Read, Write}; use std::{
use std::marker::PhantomData; io::{self, Read, Write},
use std::result::Result as StdResult; marker::PhantomData,
result::Result as StdResult,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, MAX_HEADERS}; use super::{
use super::machine::{HandshakeMachine, StageResult, TryParse}; convert_key,
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; headers::{FromHttparse, MAX_HEADERS},
use crate::error::{Error, Result}; machine::{HandshakeMachine, StageResult, TryParse},
use crate::protocol::{Role, WebSocket, WebSocketConfig}; HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Server request type. /// Server request type.
pub type Request = HttpRequest<()>; pub type Request = HttpRequest<()>;
@ -30,9 +37,7 @@ pub fn create_response(request: &Request) -> Result<Response> {
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
if !request if !request
@ -42,9 +47,7 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into()));
"No \"Connection: upgrade\" in client request".into(),
));
} }
if !request if !request
@ -54,20 +57,11 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into()));
"No \"Upgrade: websocket\" in client request".into(),
));
} }
if !request if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
.headers() return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into()));
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
} }
let key = request let key = request
@ -121,9 +115,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -229,11 +221,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
finish: StageResult<Self::IncomingData, Self::InternalStream>, finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> { ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { StageResult::DoneReading { stream, result, tail } => {
stream,
result,
tail,
} => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol("Junk after client request".into()));
} }
@ -290,9 +278,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, create_response, Request};
use super::create_response;
use super::Request;
#[test] #[test]
fn request_parsing() { fn request_parsing() {

@ -22,10 +22,10 @@ pub mod server;
pub mod stream; pub mod stream;
pub mod util; pub mod util;
pub use crate::client::{client, connect}; pub use crate::{
pub use crate::error::{Error, Result}; client::{client, connect},
pub use crate::handshake::client::ClientHandshake; error::{Error, Result},
pub use crate::handshake::server::ServerHandshake; handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
pub use crate::handshake::HandshakeError; protocol::{Message, WebSocket},
pub use crate::protocol::{Message, WebSocket}; server::{accept, accept_hdr},
pub use crate::server::{accept, accept_hdr}; };

@ -1,7 +1,9 @@
//! Various codes defined in RFC 6455. //! Various codes defined in RFC 6455.
use std::convert::{From, Into}; use std::{
use std::fmt; convert::{From, Into},
fmt,
};
/// WebSocket message opcode as in RFC 6455. /// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -71,9 +73,11 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode { impl Into<u8> for OpCode {
fn into(self) -> u8 { fn into(self) -> u8 {
use self::Control::{Close, Ping, Pong}; use self::{
use self::Data::{Binary, Continue, Text}; Control::{Close, Ping, Pong},
use self::OpCode::*; Data::{Binary, Continue, Text},
OpCode::*,
};
match self { match self {
Data(Continue) => 0, Data(Continue) => 0,
Data(Text) => 1, Data(Text) => 1,
@ -90,9 +94,11 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode { impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode { fn from(byte: u8) -> OpCode {
use self::Control::{Close, Ping, Pong}; use self::{
use self::Data::{Binary, Continue, Text}; Control::{Close, Ping, Pong},
use self::OpCode::*; Data::{Binary, Continue, Text},
OpCode::*,
};
match byte { match byte {
0 => Data(Continue), 0 => Data(Continue),
1 => Data(Text), 1 => Data(Text),

@ -1,14 +1,18 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*; use log::*;
use std::borrow::Cow; use std::{
use std::default::Default; borrow::Cow,
use std::fmt; default::Default,
use std::io::{Cursor, ErrorKind, Read, Write}; fmt,
use std::result::Result as StdResult; io::{Cursor, ErrorKind, Read, Write},
use std::string::{FromUtf8Error, String}; result::Result as StdResult,
string::{FromUtf8Error, String},
use super::coding::{CloseCode, Control, Data, OpCode}; };
use super::mask::{apply_mask, generate_mask};
use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
/// A struct representing the close command. /// A struct representing the close command.
@ -23,10 +27,7 @@ pub struct CloseFrame<'t> {
impl<'t> CloseFrame<'t> { impl<'t> CloseFrame<'t> {
/// Convert into a owned string. /// Convert into a owned string.
pub fn into_owned(self) -> CloseFrame<'static> { pub fn into_owned(self) -> CloseFrame<'static> {
CloseFrame { CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
code: self.code,
reason: self.reason.into_owned().into(),
}
} }
} }
@ -192,14 +193,7 @@ impl FrameHeader {
_ => (), _ => (),
} }
let hdr = FrameHeader { let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
is_final,
rsv1,
rsv2,
rsv3,
opcode,
mask,
};
Ok(Some((hdr, length))) Ok(Some((hdr, length)))
} }
@ -298,10 +292,7 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { Ok(Some(CloseFrame { code, reason: text.into() }))
code,
reason: text.into(),
}))
} }
} }
} }
@ -309,19 +300,9 @@ impl Frame {
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!( debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
matches!(opcode, OpCode::Data(_)),
"Invalid opcode for data frame."
);
Frame { Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
header: FrameHeader {
is_final,
opcode,
..FrameHeader::default()
},
payload: data,
}
} }
/// Create a new Pong control frame. /// Create a new Pong control frame.
@ -360,10 +341,7 @@ impl Frame {
Vec::new() Vec::new()
}; };
Frame { Frame { header: FrameHeader::default(), payload }
header: FrameHeader::default(),
payload,
}
} }
/// Create a frame from given header and data. /// Create a frame from given header and data.
@ -401,10 +379,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
) )
} }
} }
@ -476,10 +451,7 @@ mod tests {
let mut payload = Vec::new(); let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap(); raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload); let frame = Frame::from_payload(header, payload);
assert_eq!( assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
} }
#[test] #[test]

@ -6,8 +6,7 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
pub use self::frame::CloseFrame; pub use self::frame::{CloseFrame, Frame, FrameHeader};
pub use self::frame::{Frame, FrameHeader};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
@ -26,18 +25,12 @@ pub struct FrameSocket<Stream> {
impl<Stream> FrameSocket<Stream> { impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket. /// Create a new frame socket.
pub fn new(stream: Stream) -> Self { pub fn new(stream: Stream) -> Self {
FrameSocket { FrameSocket { stream, codec: FrameCodec::new() }
stream,
codec: FrameCodec::new(),
}
} }
/// Create a new frame socket from partially read data. /// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self { pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket { FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
stream,
codec: FrameCodec::from_partially_read(part),
}
} }
/// Extract a stream from the socket. /// Extract a stream from the socket.
@ -184,9 +177,7 @@ impl FrameCodec {
{ {
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending(stream) self.write_pending(stream)
} }
@ -231,10 +222,7 @@ mod tests {
sock.read_frame(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
assert_eq!( assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();

@ -1,7 +1,9 @@
use std::convert::{AsRef, From, Into}; use std::{
use std::fmt; convert::{AsRef, From, Into},
use std::result::Result as StdResult; fmt,
use std::str; result::Result as StdResult,
str,
};
use super::frame::CloseFrame; use super::frame::CloseFrame;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -19,10 +21,7 @@ mod string_collect {
impl StringCollector { impl StringCollector {
pub fn new() -> Self { pub fn new() -> Self {
StringCollector { StringCollector { data: String::new(), incomplete: None }
data: String::new(),
incomplete: None,
}
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
@ -54,10 +53,7 @@ mod string_collect {
self.data.push_str(text); self.data.push_str(text);
Ok(()) Ok(())
} }
Err(DecodeError::Incomplete { Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
valid_prefix,
incomplete_suffix,
}) => {
self.data.push_str(valid_prefix); self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix); self.incomplete = Some(incomplete_suffix);
Ok(()) Ok(())
@ -127,11 +123,7 @@ impl IncompleteMessage {
// Be careful about integer overflows here. // Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size { if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity( return Err(Error::Capacity(
format!( format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(),
"Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
)); ));
} }

@ -4,19 +4,26 @@ pub mod frame;
mod message; mod message;
pub use self::frame::CloseFrame; pub use self::{frame::CloseFrame, message::Message};
pub use self::message::Message;
use log::*; use log::*;
use std::collections::VecDeque; use std::{
use std::io::{ErrorKind as IoErrorKind, Read, Write}; collections::VecDeque,
use std::mem::replace; io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; use self::{
use self::frame::{Frame, FrameCodec}; frame::{
use self::message::{IncompleteMessage, IncompleteMessageType}; coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
use crate::error::{Error, Result}; Frame, FrameCodec,
use crate::util::NonBlockingResult; },
message::{IncompleteMessage, IncompleteMessageType},
};
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -74,10 +81,7 @@ impl<Stream> WebSocket<Stream> {
/// or together with an existing one. If you need an initial handshake, use /// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket. /// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket { WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
socket: stream,
context: WebSocketContext::new(role, config),
}
} }
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
@ -320,9 +324,7 @@ impl WebSocketContext {
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol( return Err(Error::Protocol("Sending after closing is not allowed".into()));
"Sending after closing is not allowed".into(),
));
} }
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
@ -455,9 +457,7 @@ impl WebSocketContext {
Role::Client => { Role::Client => {
if frame.is_masked() { if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455) // A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol( return Err(Error::Protocol("Received a masked frame from server".into()));
"Received a masked frame from server".into(),
));
} }
} }
} }
@ -474,9 +474,9 @@ impl WebSocketContext {
Err(Error::Protocol("Control frame too big".into())) Err(Error::Protocol("Control frame too big".into()))
} }
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => Err(Error::Protocol( OpCtl::Reserved(i) => {
format!("Unknown control frame type {}", i).into(), Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
)), }
OpCtl::Ping => { OpCtl::Ping => {
let data = frame.into_data(); let data = frame.into_data();
// No ping processing after we sent a close frame. // No ping processing after we sent a close frame.
@ -527,9 +527,9 @@ impl WebSocketContext {
Ok(None) Ok(None)
} }
} }
OpData::Reserved(i) => Err(Error::Protocol( OpData::Reserved(i) => {
format!("Unknown data frame type {}", i).into(), Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
)), }
} }
} }
} // match opcode } // match opcode
@ -539,9 +539,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} }
_ => Err(Error::Protocol( _ => Err(Error::Protocol("Connection reset without closing handshake".into())),
"Connection reset without closing handshake".into(),
)),
} }
} }
} }
@ -602,9 +600,7 @@ impl WebSocketContext {
} }
trace!("Sending frame: {:?}", frame); trace!("Sending frame: {:?}", frame);
self.frame self.frame.write_frame(stream, frame).check_connection_reset(self.state)
.write_frame(stream, frame)
.check_connection_reset(self.state)
} }
} }
@ -669,8 +665,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests { mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use std::io; use std::{io, io::Cursor};
use std::io::Cursor;
struct WriteMoc<Stream>(Stream); struct WriteMoc<Stream>(Stream);
@ -699,14 +694,8 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!( assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
socket.read_message().unwrap(), assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
} }
#[test] #[test]
@ -715,10 +704,7 @@ mod tests {
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x6c, 0x64, 0x21, 0x6c, 0x64, 0x21,
]); ]);
let limit = WebSocketConfig { let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
max_message_size: Some(10),
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!( assert_eq!(
socket.read_message().unwrap_err().to_string(), socket.read_message().unwrap_err().to_string(),
@ -729,10 +715,7 @@ mod tests {
#[test] #[test]
fn size_limiting_binary() { fn size_limiting_binary() {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig { let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
max_message_size: Some(2),
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!( assert_eq!(
socket.read_message().unwrap_err().to_string(), socket.read_message().unwrap_err().to_string(),

@ -2,8 +2,10 @@
pub use crate::handshake::server::ServerHandshake; pub use crate::handshake::server::ServerHandshake;
use crate::handshake::server::{Callback, NoCallback}; use crate::handshake::{
use crate::handshake::HandshakeError; server::{Callback, NoCallback},
HandshakeError,
};
use crate::protocol::{WebSocket, WebSocketConfig}; use crate::protocol::{WebSocket, WebSocketConfig};

@ -1,7 +1,9 @@
//! Helper traits to ease non-blocking handling. //! Helper traits to ease non-blocking handling.
use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::{
use std::result::Result as StdResult; io::{Error as IoError, ErrorKind as IoErrorKind},
result::Result as StdResult,
};
use crate::error::Error; use crate::error::Error;

@ -1,15 +1,17 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket. //! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpStream, TcpListener}; use std::{
use std::process::exit; net::{TcpListener, TcpStream},
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream};
use native_tls::TlsStream; use native_tls::TlsStream;
use url::Url;
use net2::TcpStreamExt; use net2::TcpStreamExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>; type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
@ -26,8 +28,8 @@ where
exit(1); exit(1);
}); });
let server = TcpListener::bind(("127.0.0.1", port)) let server =
.expect("Can't listen, is port already in use?"); TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?");
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap()) let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap())
@ -46,11 +48,10 @@ where
#[test] #[test]
fn test_server_close() { fn test_server_close() {
do_test(3012, do_test(
3012,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = cli_sock.read_message().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
@ -75,16 +76,16 @@ fn test_server_close() {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}); },
);
} }
#[test] #[test]
fn test_evil_server_close() { fn test_evil_server_close() {
do_test(3013, do_test(
3013,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
@ -108,16 +109,16 @@ fn test_evil_server_close() {
// and now just drop the connection without waiting for `ConnectionClosed` // and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap(); srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
drop(srv_sock); drop(srv_sock);
}); },
);
} }
#[test] #[test]
fn test_client_close() { fn test_client_close() {
do_test(3014, do_test(
3014,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = cli_sock.read_message().unwrap(); // receive answer from server let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server"); assert_eq!(message.into_data(), b"From Server");
@ -147,6 +148,6 @@ fn test_client_close() {
Error::ConnectionClosed => {} Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err), _ => panic!("unexpected error: {:?}", err),
} }
}); },
);
} }

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
use std::net::TcpListener; use std::{
use std::process::exit; net::TcpListener,
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
use std::net::TcpListener; use std::{
use std::process::exit; net::TcpListener,
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;
@ -24,9 +26,7 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = client.read_message().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());

Loading…
Cancel
Save