Merge remote-tracking branch 'upstream/master' into deflate-merge

# Conflicts:
#	examples/autobahn-client.rs
#	examples/autobahn-server.rs
#	src/client.rs
#	src/handshake/client.rs
#	src/handshake/server.rs
#	src/lib.rs
#	src/protocol/frame/frame.rs
#	src/protocol/frame/mod.rs
#	src/protocol/mod.rs
#	tests/connection_reset.rs
pull/144/head
SirCipher 4 years ago
commit 655db5b2eb
  1. 4
      Cargo.toml
  2. 6
      examples/autobahn-client.rs
  3. 7
      examples/autobahn-server.rs
  4. 11
      examples/callback-error.rs
  5. 4
      examples/client.rs
  6. 9
      examples/server.rs
  7. 7
      rustfmt.toml
  8. 120
      src/client.rs
  9. 16
      src/error.rs
  10. 120
      src/handshake/client.rs
  11. 5
      src/handshake/headers.rs
  12. 45
      src/handshake/machine.rs
  13. 14
      src/handshake/mod.rs
  14. 96
      src/handshake/server.rs
  15. 17
      src/lib.rs
  16. 31
      src/protocol/frame/coding.rs
  17. 64
      src/protocol/frame/frame.rs
  18. 22
      src/protocol/frame/mod.rs
  19. 53
      src/protocol/message.rs
  20. 203
      src/protocol/mod.rs
  21. 6
      src/server.rs
  22. 6
      src/util.rs
  23. 31
      tests/connection_reset.rs
  24. 10
      tests/no_send_after_close.rs
  25. 14
      tests/receive_after_init_close.rs

@ -19,7 +19,7 @@ tls-vendored = ["native-tls", "native-tls/vendored"]
deflate = ["flate2"]
[dependencies]
base64 = "0.12.0"
base64 = "0.13.0"
byteorder = "1.3.2"
bytes = "0.5"
http = "0.2"
@ -42,7 +42,7 @@ optional = true
version = "0.2.3"
[dev-dependencies]
env_logger = "0.7.1"
env_logger = "0.8.1"
net2 = "0.2.33"
[[example]]

@ -18,11 +18,7 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
)?;
socket.close(None)?;
Ok(())

@ -1,7 +1,10 @@
use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use std::{
net::{TcpListener, TcpStream},
thread::spawn,
};
use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::extensions::compression::deflate::DeflateConfigBuilder;
use tungstenite::extensions::compression::WsCompression;
use tungstenite::handshake::HandshakeRole;

@ -1,9 +1,10 @@
use std::net::TcpListener;
use std::thread::spawn;
use std::{net::TcpListener, thread::spawn};
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::StatusCode;
use tungstenite::{
accept_hdr,
handshake::server::{Request, Response},
http::StatusCode,
};
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();

@ -14,9 +14,7 @@ fn main() {
println!("* {}", header);
}
socket
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop {
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);

@ -1,8 +1,9 @@
use std::net::TcpListener;
use std::thread::spawn;
use std::{net::TcpListener, thread::spawn};
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, Response};
use tungstenite::{
accept_hdr,
handshake::server::{Request, Response},
};
fn main() {
env_logger::init();

@ -0,0 +1,7 @@
# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check].
# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md
# Break complex but short statements a bit less.
use_small_heuristics = "Max"
merge_imports = true

@ -1,16 +1,20 @@
//! Methods to connect to a WebSocket as a client.
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use std::{
io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::Uri;
use http::{request::Parts, Uri};
use log::*;
use url::Url;
use crate::handshake::client::{Request, Response};
use crate::protocol::WebSocketConfig;
use crate::{
handshake::client::{Request, Response},
protocol::WebSocketConfig,
};
#[cfg(feature = "tls")]
mod encryption {
@ -22,8 +26,7 @@ mod encryption {
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use crate::error::Result;
use crate::stream::Mode;
use crate::{error::Result, stream::Mode};
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
@ -48,8 +51,10 @@ mod encryption {
mod encryption {
use std::net::TcpStream;
use crate::error::{Error, Result};
use crate::stream::Mode;
use crate::{
error::{Error, Result},
stream::Mode,
};
/// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream;
@ -65,11 +70,12 @@ mod encryption {
use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
use crate::{
error::{Error, Result},
handshake::{client::ClientHandshake, HandshakeError},
protocol::WebSocket,
stream::{Mode, NoDelay},
};
/// Connect to the given WebSocket in blocking mode.
///
@ -86,31 +92,63 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)>
where
Req: IntoClientRequest,
{
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> {
fn try_client_handshake(
request: Request,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri();
let mode = uri_mode(uri)?;
let host =
request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version);
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
for attempt in 0..(max_redirects + 1) {
let request = create_request(&parts, &uri);
match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
if let Some(location) = res.headers().get("Location") {
uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri);
continue;
} else {
warn!("No `Location` found in redirect");
return Err(Error::Http(res));
}
}
other => return other,
}
}
unreachable!("Bug in a redirect handling logic")
}
/// Connect to the given WebSocket in blocking mode.
@ -126,13 +164,11 @@ where
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
connect_with_config(request, None, 3)
}
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -1,17 +1,9 @@
//! Error handling.
use std::borrow::Cow;
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io;
use std::result;
use std::str;
use std::string;
use http;
use httparse;
use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string};
use crate::protocol::Message;
use http::Response;
#[cfg(feature = "tls")]
pub mod tls {
@ -64,7 +56,7 @@ pub enum Error {
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(http::StatusCode),
Http(Response<Option<String>>),
/// HTTP format error.
HttpFormat(http::Error),
/// An error from a WebSocket extension.
@ -84,7 +76,7 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
Error::ExtensionError(ref e) => write!(f, "Extension error: {}", e),
}

@ -1,18 +1,25 @@
//! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::{
io::{Read, Write},
marker::PhantomData,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status;
use log::*;
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use super::{
convert_key,
headers::{FromHttparse, MAX_HEADERS},
machine::{HandshakeMachine, StageResult, TryParse},
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type.
pub type Request = HttpRequest<()>;
@ -28,26 +35,19 @@ pub struct ClientHandshake<S> {
_marker: PhantomData<S>,
}
impl<Stream> ClientHandshake<Stream>
where
Stream: Read + Write,
{
impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake.
pub fn start(
stream: Stream,
stream: S,
request: Request,
mut config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
"Invalid HTTP method, only GET supported".into(),
));
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into()));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
// Check the URI scheme: only ws or wss are supported
@ -62,29 +62,18 @@ where
let client = {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData { accept_key },
config: Some(config),
_marker: PhantomData,
}
ClientHandshake { verify_data: VerifyData { accept_key }, config: Some(config), _marker: PhantomData }
};
trace!("Client handshake initiated.");
Ok(MidHandshake {
role: client,
machine,
})
Ok(MidHandshake { role: client, machine })
}
}
impl<Stream> HandshakeRole for ClientHandshake<Stream>
where
Stream: Read + Write,
{
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
type InternalStream = Stream;
type FinalResult = (WebSocket<Stream>, Response);
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -93,16 +82,11 @@ where
StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading {
stream,
result,
tail,
} => {
let mut config = self.config.take().unwrap();
self.verify_data.verify_response(&result, &mut config)?;
StageResult::DoneReading { stream, result, tail } => {
let result = self.verify_data.verify_response(result)?;
debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config);
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
ProcessingResult::Done((websocket, result))
}
})
@ -119,10 +103,8 @@ fn generate_request(
let mut req = Vec::new();
let uri = request.uri();
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let authority =
uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
@ -144,10 +126,8 @@ fn generate_request(
Sec-WebSocket-Key: {key}\r\n",
version = request.version(),
host = host,
path = uri
.path_and_query()
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
path =
uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
key = key
)
.unwrap();
@ -176,12 +156,13 @@ impl VerifyData {
&self,
response: &Response,
config: &mut Option<WebSocketConfig>,
) -> Result<()> {
) -> Result<Response> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status()));
return Err(Error::Http(response.map(|_| None)));
}
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
@ -194,9 +175,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
));
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
@ -208,22 +187,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
));
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !headers
.get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
@ -231,7 +202,6 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
verify_compression_resp_headers(response, config)?;
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
@ -241,7 +211,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455)
// TODO
Ok(())
Ok(response)
}
}
@ -259,9 +229,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -287,8 +255,7 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::{generate_key, generate_request, Response};
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
#[test]
@ -367,9 +334,6 @@ mod tests {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!(
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],);
}
}

@ -1,8 +1,6 @@
//! HTTP Request and response header handling.
use http;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse;
use httparse::Status;
use super::machine::TryParse;
@ -43,8 +41,7 @@ impl TryParse for HeaderMap {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::HeaderMap;
use super::{super::machine::TryParse, HeaderMap};
#[test]
fn headers() {

@ -2,8 +2,10 @@ use bytes::Buf;
use log::*;
use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ};
/// A generic handshake state machine.
@ -23,10 +25,7 @@ impl<Stream> HandshakeMachine<Stream> {
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine {
stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
@ -52,21 +51,19 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.no_block()?;
match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok(
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
},
),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}),
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
@ -112,11 +109,7 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)]
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading {
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
/// Writing round finished.
DoneWriting(Stream),
}

@ -6,11 +6,12 @@ pub mod server;
mod machine;
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io::{Read, Write};
use std::{
error::Error as ErrorTrait,
fmt,
io::{Read, Write},
};
use base64;
use sha1::{Digest, Sha1};
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
@ -40,10 +41,7 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop {
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake {
machine: m,
..self
}))
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {

@ -1,19 +1,26 @@
//! Server handshake machine.
use std::io::{self, Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;
use std::{
io::{self, Read, Write},
marker::PhantomData,
result::Result as StdResult,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status;
use log::*;
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::compression::verify_compression_req_headers;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
use super::{
convert_key,
headers::{FromHttparse, MAX_HEADERS},
machine::{HandshakeMachine, StageResult, TryParse},
HandshakeRole, MidHandshake, ProcessingResult,
extensions::verify_compression_req_headers
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Server request type.
pub type Request = HttpRequest<()>;
@ -31,24 +38,17 @@ pub fn create_response(request: &Request) -> Result<Response> {
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
if !request
.headers()
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("Upgrade"))
})
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in client request".into(),
));
return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into()));
}
if !request
@ -58,20 +58,11 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in client request".into(),
));
return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into()));
}
if !request
.headers()
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into()));
}
let key = request
@ -125,9 +116,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -199,16 +188,12 @@ pub struct ServerHandshake<S, C> {
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
error_response: Option<ErrorResponse>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl<S, C> ServerHandshake<S, C>
where
S: Read + Write,
C: Callback,
{
impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
@ -220,18 +205,14 @@ where
role: ServerHandshake {
callback: Some(callback),
config,
error_code: None,
error_response: None,
_marker: PhantomData,
},
}
}
}
impl<S, C> HandshakeRole for ServerHandshake<S, C>
where
S: Read + Write,
C: Callback,
{
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S>;
@ -241,20 +222,16 @@ where
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneReading {
stream,
result: request,
tail,
} => {
StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()));
}
let mut response = create_response(&request)?;
let mut response = create_response(&result)?;
verify_compression_req_headers(&request, &mut response, &mut self.config)?;
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&request, response)
callback.on_request(&result, response)
} else {
Ok(response)
};
@ -273,22 +250,25 @@ where
));
}
self.error_code = Some(resp.status().as_u16());
self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap();
let mut output = vec![];
write_response(&mut output, &resp)?;
if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes());
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
}
}
}
StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() {
if let Some(err) = self.error_response.take() {
debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?));
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
@ -301,9 +281,7 @@ where
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::create_response;
use super::Request;
use super::{super::machine::TryParse, create_response, Request};
#[test]
fn request_parsing() {

@ -16,18 +16,17 @@ pub use http;
pub mod client;
pub mod error;
pub mod extensions;
pub mod handshake;
pub mod protocol;
pub mod server;
pub mod stream;
pub mod util;
pub mod extensions;
pub use crate::client::{client, connect};
pub use crate::error::{Error, Result};
pub use crate::handshake::client::ClientHandshake;
pub use crate::handshake::server::ServerHandshake;
pub use crate::handshake::HandshakeError;
pub use crate::protocol::{Message, WebSocket};
pub use crate::server::{accept, accept_hdr};
pub use crate::{
client::{client, connect},
error::{Error, Result},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
protocol::{Message, WebSocket},
server::{accept, accept_hdr},
};

@ -1,7 +1,9 @@
//! Various codes defined in RFC 6455.
use std::convert::{From, Into};
use std::fmt;
use std::{
convert::{From, Into},
fmt,
};
/// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -71,9 +73,11 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode {
fn into(self) -> u8 {
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
use self::{
Control::{Close, Ping, Pong},
Data::{Binary, Continue, Text},
OpCode::*,
};
match self {
Data(Continue) => 0,
Data(Text) => 1,
@ -90,9 +94,11 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
use self::{
Control::{Close, Ping, Pong},
Data::{Binary, Continue, Text},
OpCode::*,
};
match byte {
0 => Data(Continue),
1 => Data(Text),
@ -184,14 +190,7 @@ pub enum CloseCode {
impl CloseCode {
/// Check if this CloseCode is allowed.
pub fn is_allowed(self) -> bool {
match self {
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
!matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls)
}
}

@ -1,14 +1,18 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*;
use std::borrow::Cow;
use std::default::Default;
use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult;
use std::string::{FromUtf8Error, String};
use super::coding::{CloseCode, Control, Data, OpCode};
use super::mask::{apply_mask, generate_mask};
use std::{
borrow::Cow,
default::Default,
fmt,
io::{Cursor, ErrorKind, Read, Write},
result::Result as StdResult,
string::{FromUtf8Error, String},
};
use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, Result};
/// A struct representing the close command.
@ -23,10 +27,7 @@ pub struct CloseFrame<'t> {
impl<'t> CloseFrame<'t> {
/// Convert into a owned string.
pub fn into_owned(self) -> CloseFrame<'static> {
CloseFrame {
code: self.code,
reason: self.reason.into_owned().into(),
}
CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
}
}
@ -313,10 +314,7 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some(CloseFrame {
code,
reason: text.into(),
}))
Ok(Some(CloseFrame { code, reason: text.into() }))
}
}
}
@ -324,22 +322,9 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(
match opcode {
OpCode::Data(_) => true,
_ => false,
},
"Invalid opcode for data frame."
);
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader {
is_final,
opcode,
..FrameHeader::default()
},
payload: data,
}
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
}
/// Create a new Pong control frame.
@ -378,10 +363,7 @@ impl Frame {
Vec::new()
};
Frame {
header: FrameHeader::default(),
payload,
}
Frame { header: FrameHeader::default(), payload }
}
/// Create a frame from given header and data.
@ -425,10 +407,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()
)
}
}
@ -500,10 +479,7 @@ mod tests {
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
}
#[test]

@ -6,8 +6,7 @@ pub mod coding;
mod frame;
mod mask;
pub use self::frame::CloseFrame;
pub use self::frame::{ExtensionHeaders, Frame, FrameHeader};
pub use self::frame::{CloseFrame, ExtensionHeaders, Frame, FrameHeader};
use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ};
@ -26,18 +25,12 @@ pub struct FrameSocket<Stream> {
impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket.
pub fn new(stream: Stream) -> Self {
FrameSocket {
stream,
codec: FrameCodec::new(),
}
FrameSocket { stream, codec: FrameCodec::new() }
}
/// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket {
stream,
codec: FrameCodec::from_partially_read(part),
}
FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
}
/// Extract a stream from the socket.
@ -184,9 +177,7 @@ impl FrameCodec {
{
trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
self.write_pending(stream)
}
@ -231,10 +222,7 @@ mod tests {
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner();

@ -1,14 +1,14 @@
use std::convert::{AsRef, From, Into};
use std::fmt;
use std::result::Result as StdResult;
use std::str;
use std::{
convert::{AsRef, From, Into},
fmt,
result::Result as StdResult,
str,
};
use super::frame::CloseFrame;
use crate::error::{Error, Result};
mod string_collect {
use utf8;
use utf8::DecodeError;
use crate::error::{Error, Result};
@ -21,10 +21,7 @@ mod string_collect {
impl StringCollector {
pub fn new() -> Self {
StringCollector {
data: String::new(),
incomplete: None,
}
StringCollector { data: String::new(), incomplete: None }
}
pub fn len(&self) -> usize {
@ -56,10 +53,7 @@ mod string_collect {
self.data.push_str(text);
Ok(())
}
Err(DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => {
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix);
Ok(())
@ -129,11 +123,7 @@ impl IncompleteMessage {
// Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity(
format!(
"Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(),
));
}
@ -203,42 +193,27 @@ impl Message {
/// Indicates whether a message is a text message.
pub fn is_text(&self) -> bool {
match *self {
Message::Text(_) => true,
_ => false,
}
matches!(*self, Message::Text(_))
}
/// Indicates whether a message is a binary message.
pub fn is_binary(&self) -> bool {
match *self {
Message::Binary(_) => true,
_ => false,
}
matches!(*self, Message::Binary(_))
}
/// Indicates whether a message is a ping message.
pub fn is_ping(&self) -> bool {
match *self {
Message::Ping(_) => true,
_ => false,
}
matches!(*self, Message::Ping(_))
}
/// Indicates whether a message is a pong message.
pub fn is_pong(&self) -> bool {
match *self {
Message::Pong(_) => true,
_ => false,
}
matches!(*self, Message::Pong(_))
}
/// Indicates whether a message ia s close message.
pub fn is_close(&self) -> bool {
match *self {
Message::Close(_) => true,
_ => false,
}
matches!(*self, Message::Close(_))
}
/// Get the length of the WebSocket message.

@ -2,24 +2,29 @@
pub mod frame;
pub(crate) mod message;
mod message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::collections::VecDeque;
use std::io::{ErrorKind as IoErrorKind, Read, Write};
use std::mem::replace;
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::IncompleteMessage;
use crate::error::{Error, Result};
use crate::extensions::compression::{CompressionSwitcher, WsCompression};
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::Data;
use crate::util::NonBlockingResult;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{
frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
Frame, FrameCodec,
},
message::{IncompleteMessage, IncompleteMessageType},
extensions::{WebSocketExtension, compression::{CompressionSwitcher, WsCompression}};
};
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
@ -33,7 +38,7 @@ pub enum Role {
}
/// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
@ -77,10 +82,7 @@ impl<Stream> WebSocket<Stream> {
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
}
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
@ -136,10 +138,7 @@ impl<Stream> WebSocket<Stream> {
}
}
impl<Stream> WebSocket<Stream>
where
Stream: Read + Write,
{
impl<Stream: Read + Write> WebSocket<Stream> {
/// Read a message from stream, if possible.
///
/// This will queue responses to ping and close messages to be sent. It will call
@ -333,9 +332,7 @@ impl WebSocketContext {
// Do not write after sending a close frame.
if !self.state.is_active() {
return Err(Error::Protocol(
"Sending after closing is not allowed".into(),
));
return Err(Error::Protocol("Sending after closing is not allowed".into()));
}
if let Some(max_send_queue) = self.config.max_send_queue {
@ -457,9 +454,7 @@ impl WebSocketContext {
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol(
"Received a masked frame from server".into(),
));
return Err(Error::Protocol("Received a masked frame from server".into()));
}
}
}
@ -476,9 +471,9 @@ impl WebSocketContext {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => Err(Error::Protocol(
format!("Unknown control frame type {}", i).into(),
)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Ping => {
let data = frame.into_data();
// No ping processing after we sent a close frame.
@ -568,43 +563,8 @@ impl WebSocketContext {
}
}
if frame.header().is_final {
frame = self.decoder.on_send_frame(frame)?;
}
let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value);
if frame.payload().len() > max_frame_size {
let mut chunks = frame.payload().chunks(max_frame_size).peekable();
let data_frame = Frame::message(
Vec::from(chunks.next().unwrap()),
frame.header().opcode,
false,
);
self.frame
.write_frame(stream, data_frame)
.check_connection_reset(self.state)?;
while let Some(chunk) = chunks.next() {
let frame = Frame::message(
Vec::from(chunk),
OpCode::Data(Data::Continue),
chunks.peek().is_none(),
);
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)?;
}
Ok(())
} else {
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)
}
trace!("Sending frame: {:?}", frame);
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
}
@ -626,20 +586,14 @@ enum WebSocketState {
impl WebSocketState {
/// Tell if we're allowed to process normal messages.
fn is_active(self) -> bool {
match self {
WebSocketState::Active => true,
_ => false,
}
matches!(self, WebSocketState::Active)
}
/// Tell if we should process incoming data. Note that if we send a close frame
/// but the remote hasn't confirmed, they might have sent data before they receive our
/// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
fn can_read(self) -> bool {
match self {
WebSocketState::Active | WebSocketState::ClosedByUs => true,
_ => false,
}
matches!(self, WebSocketState::Active | WebSocketState::ClosedByUs)
}
/// Check if the state is active, return error if not.
@ -675,11 +629,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::compression::WsCompression;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use std::io;
use std::io::Cursor;
use std::{io, io::Cursor};
struct WriteMoc<Stream>(Stream);
@ -708,14 +658,8 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(
socket.read_message().unwrap(),
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
}
#[test]
@ -724,11 +668,7 @@ mod tests {
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
compression: WsCompression::None(Some(10)),
};
let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),
@ -739,80 +679,11 @@ mod tests {
#[test]
fn size_limiting_binary() {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
compression: WsCompression::None(Some(2)),
};
let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2"
);
}
#[test]
fn fragmented_tx() {
let max_message_size = 2;
let input_str = "hello unit test";
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(2),
compression: WsCompression::None(Some(max_message_size)),
};
let mut socket =
WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit));
socket.write_message(Message::text(input_str)).unwrap();
socket.socket.set_position(0);
let WebSocket {
mut socket,
mut context,
} = socket;
let vec = input_str.chars().collect::<Vec<_>>();
let mut iter = vec
.chunks(max_message_size)
.map(|c| c.iter().collect::<String>())
.into_iter()
.peekable();
let frame_eq = |expected: Frame, actual: Frame| {
assert_eq!(expected.payload(), actual.payload());
assert_eq!(expected.header().opcode, actual.header().opcode);
assert_eq!(
expected.header().ext_headers.rsv1,
actual.header().ext_headers.rsv1
);
};
let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false);
frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
while let Some(chars) = iter.next() {
let expected = Frame::message(
chars.into(),
OpCode::Data(Data::Continue),
iter.peek().is_none(),
);
frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
}
}
}

@ -2,8 +2,10 @@
pub use crate::handshake::server::ServerHandshake;
use crate::handshake::server::{Callback, NoCallback};
use crate::handshake::HandshakeError;
use crate::handshake::{
server::{Callback, NoCallback},
HandshakeError,
};
use crate::protocol::{WebSocket, WebSocketConfig};

@ -1,7 +1,9 @@
//! Helper traits to ease non-blocking handling.
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::result::Result as StdResult;
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
result::Result as StdResult,
};
use crate::error::Error;

@ -1,10 +1,12 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpListener, TcpStream};
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::{TcpListener, TcpStream},
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use native_tls::TlsStream;
use net2::TcpStreamExt;
@ -49,9 +51,7 @@ fn test_server_close() {
do_test(
3012,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close());
@ -85,9 +85,7 @@ fn test_evil_server_close() {
do_test(
3013,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1));
@ -109,10 +107,7 @@ fn test_evil_server_close() {
let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock
.get_mut()
.set_linger(Some(Duration::from_secs(0)))
.unwrap();
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
drop(srv_sock);
},
);
@ -123,9 +118,7 @@ fn test_client_close() {
do_test(
3014,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server");
@ -145,9 +138,7 @@ fn test_client_close() {
let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock
.write_message(Message::Text("From Server".into()))
.unwrap();
srv_sock.write_message(Message::Text("From Server".into())).unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close());

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
use std::net::TcpListener;
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message};
use url::Url;

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
use std::net::TcpListener;
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message};
use url::Url;
@ -24,9 +26,7 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close());

Loading…
Cancel
Save