Merge pull request #156 from snapview/clean-ups

Regular maintenance
pull/161/head
Daniel Abramov 4 years ago committed by GitHub
commit a7daafdfc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      examples/autobahn-client.rs
  2. 21
      examples/autobahn-server.rs
  3. 11
      examples/callback-error.rs
  4. 4
      examples/client.rs
  5. 9
      examples/server.rs
  6. 7
      rustfmt.toml
  7. 56
      src/client.rs
  8. 8
      src/error.rs
  9. 94
      src/handshake/client.rs
  10. 3
      src/handshake/headers.rs
  11. 45
      src/handshake/machine.rs
  12. 13
      src/handshake/mod.rs
  13. 60
      src/handshake/server.rs
  14. 14
      src/lib.rs
  15. 31
      src/protocol/frame/coding.rs
  16. 70
      src/protocol/frame/frame.rs
  17. 22
      src/protocol/frame/mod.rs
  18. 26
      src/protocol/message.rs
  19. 85
      src/protocol/mod.rs
  20. 6
      src/server.rs
  21. 6
      src/util.rs
  22. 49
      tests/connection_reset.rs
  23. 10
      tests/no_send_after_close.rs
  24. 14
      tests/receive_after_init_close.rs

@ -14,11 +14,7 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
)?;
socket.close(None)?;
Ok(())
@ -26,11 +22,8 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url = Url::parse(&format!(
"ws://localhost:9001/runCase?case={}&agent={}",
case, AGENT
))
.unwrap();
let case_url =
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
let (mut socket, _) = connect(case_url)?;
loop {
match socket.read_message()? {

@ -1,9 +1,10 @@
use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use std::{
net::{TcpListener, TcpStream},
thread::spawn,
};
use log::*;
use tungstenite::handshake::HandshakeRole;
use tungstenite::{accept, Error, HandshakeError, Message, Result};
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -32,12 +33,14 @@ fn main() {
for stream in server.incoming() {
spawn(move || match stream {
Ok(stream) => if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
Ok(stream) => {
if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
}
}
},
}
Err(e) => error!("Error accepting stream: {}", e),
});
}

@ -1,9 +1,10 @@
use std::net::TcpListener;
use std::thread::spawn;
use std::{net::TcpListener, thread::spawn};
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::StatusCode;
use tungstenite::{
accept_hdr,
handshake::server::{Request, Response},
http::StatusCode,
};
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();

@ -14,9 +14,7 @@ fn main() {
println!("* {}", header);
}
socket
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop {
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);

@ -1,8 +1,9 @@
use std::net::TcpListener;
use std::thread::spawn;
use std::{net::TcpListener, thread::spawn};
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, Response};
use tungstenite::{
accept_hdr,
handshake::server::{Request, Response},
};
fn main() {
env_logger::init();

@ -0,0 +1,7 @@
# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check].
# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md
# Break complex but short statements a bit less.
use_small_heuristics = "Max"
merge_imports = true

@ -1,16 +1,20 @@
//! Methods to connect to a WebSocket as a client.
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use std::{
io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::{Uri, request::Parts};
use http::{request::Parts, Uri};
use log::*;
use url::Url;
use crate::handshake::client::{Request, Response};
use crate::protocol::WebSocketConfig;
use crate::{
handshake::client::{Request, Response},
protocol::WebSocketConfig,
};
#[cfg(feature = "tls")]
mod encryption {
@ -22,8 +26,7 @@ mod encryption {
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use crate::error::Result;
use crate::stream::Mode;
use crate::{error::Result, stream::Mode};
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
@ -48,8 +51,10 @@ mod encryption {
mod encryption {
use std::net::TcpStream;
use crate::error::{Error, Result};
use crate::stream::Mode;
use crate::{
error::{Error, Result},
stream::Mode,
};
/// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream;
@ -65,11 +70,12 @@ mod encryption {
use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
use crate::{
error::{Error, Result},
handshake::{client::ClientHandshake, HandshakeError},
protocol::WebSocket,
stream::{Mode, NoDelay},
};
/// Connect to the given WebSocket in blocking mode.
///
@ -91,16 +97,14 @@ pub fn connect_with_config<Req: IntoClientRequest>(
config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> {
fn try_client_handshake(request: Request, config: Option<WebSocketConfig>)
-> Result<(WebSocket<AutoStream>, Response)>
{
fn try_client_handshake(
request: Request,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let host =
request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
@ -118,7 +122,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version.clone());
.version(parts.version);
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
@ -164,9 +168,7 @@ pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoSt
}
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -1,12 +1,6 @@
//! Error handling.
use std::borrow::Cow;
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io;
use std::result;
use std::str;
use std::string;
use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string};
use crate::protocol::Message;
use http::Response;

@ -1,17 +1,24 @@
//! Client handshake machine.
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::{
io::{Read, Write},
marker::PhantomData,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status;
use log::*;
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
use super::{
convert_key,
headers::{FromHttparse, MAX_HEADERS},
machine::{HandshakeMachine, StageResult, TryParse},
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Client request type.
pub type Request = HttpRequest<()>;
@ -35,15 +42,11 @@ impl<S: Read + Write> ClientHandshake<S> {
config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
"Invalid HTTP method, only GET supported".into(),
));
return Err(Error::Protocol("Invalid HTTP method, only GET supported".into()));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
// Check the URI scheme: only ws or wss are supported
@ -58,18 +61,11 @@ impl<S: Read + Write> ClientHandshake<S> {
let client = {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData { accept_key },
config,
_marker: PhantomData,
}
ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData }
};
trace!("Client handshake initiated.");
Ok(MidHandshake {
role: client,
machine,
})
Ok(MidHandshake { role: client, machine })
}
}
@ -85,11 +81,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading {
stream,
result,
tail,
} => {
StageResult::DoneReading { stream, result, tail } => {
let result = self.verify_data.verify_response(result)?;
debug!("Client handshake done.");
let websocket =
@ -105,16 +97,16 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();
let authority = uri.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let authority =
uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}
write!(
@ -128,17 +120,15 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
Sec-WebSocket-Key: {key}\r\n",
version = request.version(),
host = host,
path = uri
.path_and_query()
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
path =
uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
key = key
)
.unwrap();
for (k, v) in request.headers() {
let mut k = k.as_str();
if k == "sec-websocket-protocol" {
if k == "sec-websocket-protocol" {
k = "Sec-WebSocket-Protocol";
}
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
@ -175,9 +165,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
));
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
@ -189,22 +177,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
));
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !headers
.get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
@ -238,9 +218,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -266,9 +244,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::{super::machine::TryParse, generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
#[test]
fn random_keys() {
@ -342,9 +319,6 @@ mod tests {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!(
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],);
}
}

@ -41,8 +41,7 @@ impl TryParse for HeaderMap {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::HeaderMap;
use super::{super::machine::TryParse, HeaderMap};
#[test]
fn headers() {

@ -2,8 +2,10 @@ use bytes::Buf;
use log::*;
use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ};
/// A generic handshake state machine.
@ -23,10 +25,7 @@ impl<Stream> HandshakeMachine<Stream> {
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine {
stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
@ -52,21 +51,19 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.no_block()?;
match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok(
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
},
),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}),
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
@ -112,11 +109,7 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)]
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading {
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
/// Writing round finished.
DoneWriting(Stream),
}

@ -6,9 +6,11 @@ pub mod server;
mod machine;
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io::{Read, Write};
use std::{
error::Error as ErrorTrait,
fmt,
io::{Read, Write},
};
use sha1::{Digest, Sha1};
@ -39,10 +41,7 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop {
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake {
machine: m,
..self
}))
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {

@ -1,18 +1,25 @@
//! Server handshake machine.
use std::io::{self, Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;
use std::{
io::{self, Read, Write},
marker::PhantomData,
result::Result as StdResult,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status;
use log::*;
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
use super::{
convert_key,
headers::{FromHttparse, MAX_HEADERS},
machine::{HandshakeMachine, StageResult, TryParse},
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Server request type.
pub type Request = HttpRequest<()>;
@ -30,9 +37,7 @@ pub fn create_response(request: &Request) -> Result<Response> {
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
if !request
@ -42,9 +47,7 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in client request".into(),
));
return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into()));
}
if !request
@ -54,20 +57,11 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in client request".into(),
));
return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into()));
}
if !request
.headers()
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into()));
}
let key = request
@ -121,9 +115,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
}
let headers = HeaderMap::from_httparse(raw.headers)?;
@ -229,11 +221,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneReading {
stream,
result,
tail,
} => {
StageResult::DoneReading { stream, result, tail } => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()));
}
@ -290,9 +278,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use super::create_response;
use super::Request;
use super::{super::machine::TryParse, create_response, Request};
#[test]
fn request_parsing() {

@ -22,10 +22,10 @@ pub mod server;
pub mod stream;
pub mod util;
pub use crate::client::{client, connect};
pub use crate::error::{Error, Result};
pub use crate::handshake::client::ClientHandshake;
pub use crate::handshake::server::ServerHandshake;
pub use crate::handshake::HandshakeError;
pub use crate::protocol::{Message, WebSocket};
pub use crate::server::{accept, accept_hdr};
pub use crate::{
client::{client, connect},
error::{Error, Result},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
protocol::{Message, WebSocket},
server::{accept, accept_hdr},
};

@ -1,7 +1,9 @@
//! Various codes defined in RFC 6455.
use std::convert::{From, Into};
use std::fmt;
use std::{
convert::{From, Into},
fmt,
};
/// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -71,9 +73,11 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode {
fn into(self) -> u8 {
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
use self::{
Control::{Close, Ping, Pong},
Data::{Binary, Continue, Text},
OpCode::*,
};
match self {
Data(Continue) => 0,
Data(Text) => 1,
@ -90,9 +94,11 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
use self::{
Control::{Close, Ping, Pong},
Data::{Binary, Continue, Text},
OpCode::*,
};
match byte {
0 => Data(Continue),
1 => Data(Text),
@ -184,14 +190,7 @@ pub enum CloseCode {
impl CloseCode {
/// Check if this CloseCode is allowed.
pub fn is_allowed(self) -> bool {
match self {
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
!matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls)
}
}

@ -1,14 +1,18 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*;
use std::borrow::Cow;
use std::default::Default;
use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult;
use std::string::{FromUtf8Error, String};
use super::coding::{CloseCode, Control, Data, OpCode};
use super::mask::{apply_mask, generate_mask};
use std::{
borrow::Cow,
default::Default,
fmt,
io::{Cursor, ErrorKind, Read, Write},
result::Result as StdResult,
string::{FromUtf8Error, String},
};
use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, Result};
/// A struct representing the close command.
@ -23,10 +27,7 @@ pub struct CloseFrame<'t> {
impl<'t> CloseFrame<'t> {
/// Convert into a owned string.
pub fn into_owned(self) -> CloseFrame<'static> {
CloseFrame {
code: self.code,
reason: self.reason.into_owned().into(),
}
CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
}
}
@ -192,14 +193,7 @@ impl FrameHeader {
_ => (),
}
let hdr = FrameHeader {
is_final,
rsv1,
rsv2,
rsv3,
opcode,
mask,
};
let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
Ok(Some((hdr, length)))
}
@ -298,10 +292,7 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some(CloseFrame {
code,
reason: text.into(),
}))
Ok(Some(CloseFrame { code, reason: text.into() }))
}
}
}
@ -309,19 +300,9 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(
matches!(opcode, OpCode::Data(_)),
"Invalid opcode for data frame."
);
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader {
is_final,
opcode,
..FrameHeader::default()
},
payload: data,
}
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
}
/// Create a new Pong control frame.
@ -360,10 +341,7 @@ impl Frame {
Vec::new()
};
Frame {
header: FrameHeader::default(),
payload,
}
Frame { header: FrameHeader::default(), payload }
}
/// Create a frame from given header and data.
@ -401,10 +379,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()
)
}
}
@ -476,10 +451,7 @@ mod tests {
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
}
#[test]

@ -6,8 +6,7 @@ pub mod coding;
mod frame;
mod mask;
pub use self::frame::CloseFrame;
pub use self::frame::{Frame, FrameHeader};
pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ};
@ -26,18 +25,12 @@ pub struct FrameSocket<Stream> {
impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket.
pub fn new(stream: Stream) -> Self {
FrameSocket {
stream,
codec: FrameCodec::new(),
}
FrameSocket { stream, codec: FrameCodec::new() }
}
/// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket {
stream,
codec: FrameCodec::from_partially_read(part),
}
FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
}
/// Extract a stream from the socket.
@ -184,9 +177,7 @@ impl FrameCodec {
{
trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
self.write_pending(stream)
}
@ -231,10 +222,7 @@ mod tests {
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner();

@ -1,7 +1,9 @@
use std::convert::{AsRef, From, Into};
use std::fmt;
use std::result::Result as StdResult;
use std::str;
use std::{
convert::{AsRef, From, Into},
fmt,
result::Result as StdResult,
str,
};
use super::frame::CloseFrame;
use crate::error::{Error, Result};
@ -19,10 +21,7 @@ mod string_collect {
impl StringCollector {
pub fn new() -> Self {
StringCollector {
data: String::new(),
incomplete: None,
}
StringCollector { data: String::new(), incomplete: None }
}
pub fn len(&self) -> usize {
@ -54,10 +53,7 @@ mod string_collect {
self.data.push_str(text);
Ok(())
}
Err(DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => {
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix);
Ok(())
@ -127,11 +123,7 @@ impl IncompleteMessage {
// Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity(
format!(
"Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(),
));
}

@ -4,19 +4,26 @@ pub mod frame;
mod message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
pub use self::{frame::CloseFrame, message::Message};
use log::*;
use std::collections::VecDeque;
use std::io::{ErrorKind as IoErrorKind, Read, Write};
use std::mem::replace;
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::{IncompleteMessage, IncompleteMessageType};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
use std::{
collections::VecDeque,
io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
};
use self::{
frame::{
coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
Frame, FrameCodec,
},
message::{IncompleteMessage, IncompleteMessageType},
};
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -74,10 +81,7 @@ impl<Stream> WebSocket<Stream> {
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
}
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
@ -320,9 +324,7 @@ impl WebSocketContext {
// Do not write after sending a close frame.
if !self.state.is_active() {
return Err(Error::Protocol(
"Sending after closing is not allowed".into(),
));
return Err(Error::Protocol("Sending after closing is not allowed".into()));
}
if let Some(max_send_queue) = self.config.max_send_queue {
@ -455,9 +457,7 @@ impl WebSocketContext {
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol(
"Received a masked frame from server".into(),
));
return Err(Error::Protocol("Received a masked frame from server".into()));
}
}
}
@ -474,9 +474,9 @@ impl WebSocketContext {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => Err(Error::Protocol(
format!("Unknown control frame type {}", i).into(),
)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Ping => {
let data = frame.into_data();
// No ping processing after we sent a close frame.
@ -527,9 +527,9 @@ impl WebSocketContext {
Ok(None)
}
}
OpData::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
}
}
}
} // match opcode
@ -539,9 +539,7 @@ impl WebSocketContext {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed)
}
_ => Err(Error::Protocol(
"Connection reset without closing handshake".into(),
)),
_ => Err(Error::Protocol("Connection reset without closing handshake".into())),
}
}
}
@ -602,9 +600,7 @@ impl WebSocketContext {
}
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
}
@ -669,8 +665,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use std::io;
use std::io::Cursor;
use std::{io, io::Cursor};
struct WriteMoc<Stream>(Stream);
@ -699,14 +694,8 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(
socket.read_message().unwrap(),
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
}
#[test]
@ -715,10 +704,7 @@ mod tests {
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_message_size: Some(10),
..WebSocketConfig::default()
};
let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),
@ -729,10 +715,7 @@ mod tests {
#[test]
fn size_limiting_binary() {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig {
max_message_size: Some(2),
..WebSocketConfig::default()
};
let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
socket.read_message().unwrap_err().to_string(),

@ -2,8 +2,10 @@
pub use crate::handshake::server::ServerHandshake;
use crate::handshake::server::{Callback, NoCallback};
use crate::handshake::HandshakeError;
use crate::handshake::{
server::{Callback, NoCallback},
HandshakeError,
};
use crate::protocol::{WebSocket, WebSocketConfig};

@ -1,7 +1,9 @@
//! Helper traits to ease non-blocking handling.
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::result::Result as StdResult;
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
result::Result as StdResult,
};
use crate::error::Error;

@ -1,15 +1,17 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpStream, TcpListener};
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::{TcpListener, TcpStream},
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message, WebSocket, stream::Stream};
use native_tls::TlsStream;
use url::Url;
use net2::TcpStreamExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
@ -26,8 +28,8 @@ where
exit(1);
});
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is port already in use?");
let server =
TcpListener::bind(("127.0.0.1", port)).expect("Can't listen, is port already in use?");
let client_thread = spawn(move || {
let (client, _) = connect(Url::parse(&format!("ws://localhost:{}/socket", port)).unwrap())
@ -46,11 +48,10 @@ where
#[test]
fn test_server_close() {
do_test(3012,
do_test(
3012,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close());
@ -75,16 +76,16 @@ fn test_server_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}
#[test]
fn test_evil_server_close() {
do_test(3013,
do_test(
3013,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
sleep(Duration::from_secs(1));
@ -108,16 +109,16 @@ fn test_evil_server_close() {
// and now just drop the connection without waiting for `ConnectionClosed`
srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
drop(srv_sock);
});
},
);
}
#[test]
fn test_client_close() {
do_test(3014,
do_test(
3014,
|mut cli_sock| {
cli_sock
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server");
@ -147,6 +148,6 @@ fn test_client_close() {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
},
);
}

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
use std::net::TcpListener;
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message};
use url::Url;

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation.
use std::net::TcpListener;
use std::process::exit;
use std::thread::{sleep, spawn};
use std::time::Duration;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message};
use url::Url;
@ -24,9 +26,7 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close());

Loading…
Cancel
Save