Merge remote-tracking branch 'upstream/master' into deflate-merge

# Conflicts:
#	examples/autobahn-client.rs
#	examples/autobahn-server.rs
#	src/client.rs
#	src/handshake/client.rs
#	src/handshake/server.rs
#	src/lib.rs
#	src/protocol/frame/frame.rs
#	src/protocol/frame/mod.rs
#	src/protocol/mod.rs
#	tests/connection_reset.rs
pull/144/head
SirCipher 4 years ago
commit 655db5b2eb
  1. 4
      Cargo.toml
  2. 6
      examples/autobahn-client.rs
  3. 7
      examples/autobahn-server.rs
  4. 11
      examples/callback-error.rs
  5. 4
      examples/client.rs
  6. 9
      examples/server.rs
  7. 7
      rustfmt.toml
  8. 120
      src/client.rs
  9. 16
      src/error.rs
  10. 120
      src/handshake/client.rs
  11. 5
      src/handshake/headers.rs
  12. 45
      src/handshake/machine.rs
  13. 14
      src/handshake/mod.rs
  14. 96
      src/handshake/server.rs
  15. 17
      src/lib.rs
  16. 31
      src/protocol/frame/coding.rs
  17. 64
      src/protocol/frame/frame.rs
  18. 22
      src/protocol/frame/mod.rs
  19. 53
      src/protocol/message.rs
  20. 203
      src/protocol/mod.rs
  21. 6
      src/server.rs
  22. 6
      src/util.rs
  23. 31
      tests/connection_reset.rs
  24. 10
      tests/no_send_after_close.rs
  25. 14
      tests/receive_after_init_close.rs

@ -19,7 +19,7 @@ tls-vendored = ["native-tls", "native-tls/vendored"]
deflate = ["flate2"] deflate = ["flate2"]
[dependencies] [dependencies]
base64 = "0.12.0" base64 = "0.13.0"
byteorder = "1.3.2" byteorder = "1.3.2"
bytes = "0.5" bytes = "0.5"
http = "0.2" http = "0.2"
@ -42,7 +42,7 @@ optional = true
version = "0.2.3" version = "0.2.3"
[dev-dependencies] [dev-dependencies]
env_logger = "0.7.1" env_logger = "0.8.1"
net2 = "0.2.33" net2 = "0.2.33"
[[example]] [[example]]

@ -18,11 +18,7 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!( Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())

@ -1,7 +1,10 @@
use std::net::{TcpListener, TcpStream}; use std::{
use std::thread::spawn; net::{TcpListener, TcpStream},
thread::spawn,
};
use log::*; use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::extensions::compression::deflate::DeflateConfigBuilder; use tungstenite::extensions::compression::deflate::DeflateConfigBuilder;
use tungstenite::extensions::compression::WsCompression; use tungstenite::extensions::compression::WsCompression;
use tungstenite::handshake::HandshakeRole; use tungstenite::handshake::HandshakeRole;

@ -1,9 +1,10 @@
use std::net::TcpListener; use std::{net::TcpListener, thread::spawn};
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::{
use tungstenite::handshake::server::{Request, Response}; accept_hdr,
use tungstenite::http::StatusCode; handshake::server::{Request, Response},
http::StatusCode,
};
fn main() { fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap(); let server = TcpListener::bind("127.0.0.1:3012").unwrap();

@ -14,9 +14,7 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
loop { loop {
let msg = socket.read_message().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);

@ -1,8 +1,9 @@
use std::net::TcpListener; use std::{net::TcpListener, thread::spawn};
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::{
use tungstenite::handshake::server::{Request, Response}; accept_hdr,
handshake::server::{Request, Response},
};
fn main() { fn main() {
env_logger::init(); env_logger::init();

@ -0,0 +1,7 @@
# This project uses rustfmt to format source code. Run `cargo +nightly fmt [-- --check].
# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md
# Break complex but short statements a bit less.
use_small_heuristics = "Max"
merge_imports = true

@ -1,16 +1,20 @@
//! Methods to connect to a WebSocket as a client. //! Methods to connect to a WebSocket as a client.
use std::io::{Read, Write}; use std::{
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; io::{Read, Write},
use std::result::Result as StdResult; net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::Uri; use http::{request::Parts, Uri};
use log::*; use log::*;
use url::Url; use url::Url;
use crate::handshake::client::{Request, Response}; use crate::{
use crate::protocol::WebSocketConfig; handshake::client::{Request, Response},
protocol::WebSocketConfig,
};
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
mod encryption { mod encryption {
@ -22,8 +26,7 @@ mod encryption {
/// TCP stream switcher (plain/TLS). /// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>; pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use crate::error::Result; use crate::{error::Result, stream::Mode};
use crate::stream::Mode;
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> { pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode { match mode {
@ -48,8 +51,10 @@ mod encryption {
mod encryption { mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
use crate::error::{Error, Result}; use crate::{
use crate::stream::Mode; error::{Error, Result},
stream::Mode,
};
/// TLS support is nod compiled in, this is just standard `TcpStream`. /// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream; pub type AutoStream = TcpStream;
@ -65,11 +70,12 @@ mod encryption {
use self::encryption::wrap_stream; use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::error::{Error, Result}; use crate::{
use crate::handshake::client::ClientHandshake; error::{Error, Result},
use crate::handshake::HandshakeError; handshake::{client::ClientHandshake, HandshakeError},
use crate::protocol::WebSocket; protocol::WebSocket,
use crate::stream::{Mode, NoDelay}; stream::{Mode, NoDelay},
};
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
/// ///
@ -86,31 +92,63 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req>( pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> max_redirects: u8,
where ) -> Result<(WebSocket<AutoStream>, Response)> {
Req: IntoClientRequest, fn try_client_handshake(
{ request: Request,
let request: Request = request.into_client_request()?; config: Option<WebSocketConfig>,
let uri = request.uri(); ) -> Result<(WebSocket<AutoStream>, Response)> {
let mode = uri_mode(uri)?; let uri = request.uri();
let host = request let mode = uri_mode(uri)?;
.uri() let host =
.host() request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; let port = uri.port_u16().unwrap_or(match mode {
let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80,
Mode::Plain => 80, Mode::Tls => 443,
Mode::Tls => 443, });
}); let addrs = (host, port).to_socket_addrs()?;
let addrs = (host, port).to_socket_addrs()?; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?;
NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config).map_err(|e| match e {
client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f,
HandshakeError::Failure(f) => f, HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), })
}) }
fn create_request(parts: &Parts, uri: &Uri) -> Request {
let mut builder = Request::builder()
.uri(uri.clone())
.method(parts.method.clone())
.version(parts.version);
*builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
builder.body(()).expect("Failed to create `Request`")
}
let (parts, _) = request.into_client_request()?.into_parts();
let mut uri = parts.uri.clone();
for attempt in 0..(max_redirects + 1) {
let request = create_request(&parts, &uri);
match try_client_handshake(request, config) {
Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
if let Some(location) = res.headers().get("Location") {
uri = location.to_str()?.parse::<Uri>()?;
debug!("Redirecting to {:?}", uri);
continue;
} else {
warn!("No `Location` found in redirect");
return Err(Error::Http(res));
}
}
other => return other,
}
}
unreachable!("Bug in a redirect handling logic")
} }
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
@ -126,13 +164,11 @@ where
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> { pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None) connect_with_config(request, None, 3)
} }
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -1,17 +1,9 @@
//! Error handling. //! Error handling.
use std::borrow::Cow; use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string};
use std::error::Error as ErrorTrait;
use std::fmt;
use std::io;
use std::result;
use std::str;
use std::string;
use http;
use httparse;
use crate::protocol::Message; use crate::protocol::Message;
use http::Response;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
@ -64,7 +56,7 @@ pub enum Error {
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error. /// HTTP error.
Http(http::StatusCode), Http(Response<Option<String>>),
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), HttpFormat(http::Error),
/// An error from a WebSocket extension. /// An error from a WebSocket extension.
@ -84,7 +76,7 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP error: {}", code), Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
Error::ExtensionError(ref e) => write!(f, "Extension error: {}", e), Error::ExtensionError(ref e) => write!(f, "Extension error: {}", e),
} }

@ -1,18 +1,25 @@
//! Client handshake machine. //! Client handshake machine.
use std::io::{Read, Write}; use std::{
use std::marker::PhantomData; io::{Read, Write},
marker::PhantomData,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, MAX_HEADERS}; use super::{
use super::machine::{HandshakeMachine, StageResult, TryParse}; convert_key,
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; headers::{FromHttparse, MAX_HEADERS},
use crate::error::{Error, Result}; machine::{HandshakeMachine, StageResult, TryParse},
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers}; use crate::extensions::compression::{apply_compression_headers, verify_compression_resp_headers};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type. /// Client request type.
pub type Request = HttpRequest<()>; pub type Request = HttpRequest<()>;
@ -28,26 +35,19 @@ pub struct ClientHandshake<S> {
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<Stream> ClientHandshake<Stream> impl<S: Read + Write> ClientHandshake<S> {
where
Stream: Read + Write,
{
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start( pub fn start(
stream: Stream, stream: S,
request: Request, request: Request,
mut config: Option<WebSocketConfig>, mut config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
return Err(Error::Protocol( return Err(Error::Protocol("Invalid HTTP method, only GET supported".into()));
"Invalid HTTP method, only GET supported".into(),
));
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
// Check the URI scheme: only ws or wss are supported // Check the URI scheme: only ws or wss are supported
@ -62,29 +62,18 @@ where
let client = { let client = {
let accept_key = convert_key(key.as_ref()).unwrap(); let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake { ClientHandshake { verify_data: VerifyData { accept_key }, config: Some(config), _marker: PhantomData }
verify_data: VerifyData { accept_key },
config: Some(config),
_marker: PhantomData,
}
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
Ok(MidHandshake { Ok(MidHandshake { role: client, machine })
role: client,
machine,
})
} }
} }
impl<Stream> HandshakeRole for ClientHandshake<Stream> impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
where
Stream: Read + Write,
{
type IncomingData = Response; type IncomingData = Response;
type InternalStream = Stream; type InternalStream = S;
type FinalResult = (WebSocket<Stream>, Response); type FinalResult = (WebSocket<S>, Response);
fn stage_finished( fn stage_finished(
&mut self, &mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>, finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -93,16 +82,11 @@ where
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { StageResult::DoneReading { stream, result, tail } => {
stream, let result = self.verify_data.verify_response(result)?;
result,
tail,
} => {
let mut config = self.config.take().unwrap();
self.verify_data.verify_response(&result, &mut config)?;
debug!("Client handshake done."); debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config); let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
}) })
@ -119,10 +103,8 @@ fn generate_request(
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri(); let uri = request.uri();
let authority = uri let authority =
.authority() uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str();
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@ // handle possible name:password@
authority.split_at(idx + 1).1 authority.split_at(idx + 1).1
@ -144,10 +126,8 @@ fn generate_request(
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
version = request.version(), version = request.version(),
host = host, host = host,
path = uri path =
.path_and_query() uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(),
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
@ -176,12 +156,13 @@ impl VerifyData {
&self, &self,
response: &Response, response: &Response,
config: &mut Option<WebSocketConfig>, config: &mut Option<WebSocketConfig>,
) -> Result<()> { ) -> Result<Response> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status())); return Err(Error::Http(response.map(|_| None)));
} }
let headers = response.headers(); let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
@ -194,9 +175,7 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
"No \"Upgrade: websocket\" in server reply".into(),
));
} }
// 3. If the response lacks a |Connection| header field or the // 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
@ -208,22 +187,14 @@ impl VerifyData {
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
"No \"Connection: upgrade\" in server reply".into(),
));
} }
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or // 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !headers if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
.get("Sec-WebSocket-Accept") return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
} }
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
@ -231,7 +202,6 @@ impl VerifyData {
// that was not present in the client's handshake (the server has // that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client // indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
verify_compression_resp_headers(response, config)?; verify_compression_resp_headers(response, config)?;
// 6. If the response includes a |Sec-WebSocket-Protocol| header field // 6. If the response includes a |Sec-WebSocket-Protocol| header field
@ -241,7 +211,7 @@ impl VerifyData {
// the WebSocket Connection_. (RFC 6455) // the WebSocket Connection_. (RFC 6455)
// TODO // TODO
Ok(()) Ok(response)
} }
} }
@ -259,9 +229,7 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -287,8 +255,7 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, generate_key, generate_request, Response};
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest; use crate::client::IntoClientRequest;
#[test] #[test]
@ -367,9 +334,6 @@ mod tests {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.status(), http::StatusCode::OK); assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!( assert_eq!(resp.headers().get("Content-Type").unwrap(), &b"text/html"[..],);
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
} }
} }

@ -1,8 +1,6 @@
//! HTTP Request and response header handling. //! HTTP Request and response header handling.
use http;
use http::header::{HeaderMap, HeaderName, HeaderValue}; use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse;
use httparse::Status; use httparse::Status;
use super::machine::TryParse; use super::machine::TryParse;
@ -43,8 +41,7 @@ impl TryParse for HeaderMap {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, HeaderMap};
use super::HeaderMap;
#[test] #[test]
fn headers() { fn headers() {

@ -2,8 +2,10 @@ use bytes::Buf;
use log::*; use log::*;
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result}; use crate::{
use crate::util::NonBlockingResult; error::{Error, Result},
util::NonBlockingResult,
};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
/// A generic handshake state machine. /// A generic handshake state machine.
@ -23,10 +25,7 @@ impl<Stream> HandshakeMachine<Stream> {
} }
/// Start writing data to the peer. /// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine { HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
@ -52,21 +51,19 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
.no_block()?; .no_block()?;
match read { match read {
Some(0) => Err(Error::Protocol("Handshake not finished".into())), Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok( Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { buf.advance(size);
buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading {
RoundResult::StageFinished(StageResult::DoneReading { result: obj,
result: obj, stream: self.stream,
stream: self.stream, tail: buf.into_vec(),
tail: buf.into_vec(), })
}) } else {
} else { RoundResult::Incomplete(HandshakeMachine {
RoundResult::Incomplete(HandshakeMachine { state: HandshakeState::Reading(buf),
state: HandshakeState::Reading(buf), ..self
..self })
}) }),
},
),
None => Ok(RoundResult::WouldBlock(HandshakeMachine { None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
@ -112,11 +109,7 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)] #[derive(Debug)]
pub enum StageResult<Obj, Stream> { pub enum StageResult<Obj, Stream> {
/// Reading round finished. /// Reading round finished.
DoneReading { DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
/// Writing round finished. /// Writing round finished.
DoneWriting(Stream), DoneWriting(Stream),
} }

@ -6,11 +6,12 @@ pub mod server;
mod machine; mod machine;
use std::error::Error as ErrorTrait; use std::{
use std::fmt; error::Error as ErrorTrait,
use std::io::{Read, Write}; fmt,
io::{Read, Write},
};
use base64;
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
@ -40,10 +41,7 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => { RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
machine: m,
..self
}))
} }
RoundResult::Incomplete(m) => m, RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? { RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {

@ -1,19 +1,26 @@
//! Server handshake machine. //! Server handshake machine.
use std::io::{self, Read, Write}; use std::{
use std::marker::PhantomData; io::{self, Read, Write},
use std::result::Result as StdResult; marker::PhantomData,
result::Result as StdResult,
};
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, MAX_HEADERS}; use super::{
use super::machine::{HandshakeMachine, StageResult, TryParse}; convert_key,
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; headers::{FromHttparse, MAX_HEADERS},
use crate::error::{Error, Result}; machine::{HandshakeMachine, StageResult, TryParse},
use crate::extensions::compression::verify_compression_req_headers; HandshakeRole, MidHandshake, ProcessingResult,
use crate::protocol::{Role, WebSocket, WebSocketConfig}; extensions::verify_compression_req_headers
};
use crate::{
error::{Error, Result},
protocol::{Role, WebSocket, WebSocketConfig},
};
/// Server request type. /// Server request type.
pub type Request = HttpRequest<()>; pub type Request = HttpRequest<()>;
@ -31,24 +38,17 @@ pub fn create_response(request: &Request) -> Result<Response> {
} }
if request.version() < http::Version::HTTP_11 { if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
if !request if !request
.headers() .headers()
.get("Connection") .get("Connection")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case("Upgrade"))
})
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into()));
"No \"Connection: upgrade\" in client request".into(),
));
} }
if !request if !request
@ -58,20 +58,11 @@ pub fn create_response(request: &Request) -> Result<Response> {
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into()));
"No \"Upgrade: websocket\" in client request".into(),
));
} }
if !request if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
.headers() return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into()));
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
} }
let key = request let key = request
@ -125,9 +116,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
"HTTP version should be 1.1 or higher".into(),
));
} }
let headers = HeaderMap::from_httparse(raw.headers)?; let headers = HeaderMap::from_httparse(raw.headers)?;
@ -199,16 +188,12 @@ pub struct ServerHandshake<S, C> {
/// WebSocket configuration. /// WebSocket configuration.
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client. /// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>, error_response: Option<ErrorResponse>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<S, C> ServerHandshake<S, C> impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
where
S: Read + Write,
C: Callback,
{
/// Start server handshake. `callback` specifies a custom callback which the user can pass to /// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the /// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client /// server, you can specify the callback if you want to add additional header to the client
@ -220,18 +205,14 @@ where
role: ServerHandshake { role: ServerHandshake {
callback: Some(callback), callback: Some(callback),
config, config,
error_code: None, error_response: None,
_marker: PhantomData, _marker: PhantomData,
}, },
} }
} }
} }
impl<S, C> HandshakeRole for ServerHandshake<S, C> impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
where
S: Read + Write,
C: Callback,
{
type IncomingData = Request; type IncomingData = Request;
type InternalStream = S; type InternalStream = S;
type FinalResult = WebSocket<S>; type FinalResult = WebSocket<S>;
@ -241,20 +222,16 @@ where
finish: StageResult<Self::IncomingData, Self::InternalStream>, finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> { ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { StageResult::DoneReading { stream, result, tail } => {
stream,
result: request,
tail,
} => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol("Junk after client request".into()));
} }
let mut response = create_response(&request)?; let mut response = create_response(&result)?;
verify_compression_req_headers(&request, &mut response, &mut self.config)?; verify_compression_req_headers(&request, &mut response, &mut self.config)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&request, response) callback.on_request(&result, response)
} else { } else {
Ok(response) Ok(response)
}; };
@ -273,22 +250,25 @@ where
)); ));
} }
self.error_code = Some(resp.status().as_u16()); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap();
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &resp)?; write_response(&mut output, &resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend_from_slice(body.as_bytes());
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
} }
} }
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_response.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::Http(StatusCode::from_u16(err)?)); return Err(Error::Http(err));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
@ -301,9 +281,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::machine::TryParse; use super::{super::machine::TryParse, create_response, Request};
use super::create_response;
use super::Request;
#[test] #[test]
fn request_parsing() { fn request_parsing() {

@ -16,18 +16,17 @@ pub use http;
pub mod client; pub mod client;
pub mod error; pub mod error;
pub mod extensions;
pub mod handshake; pub mod handshake;
pub mod protocol; pub mod protocol;
pub mod server; pub mod server;
pub mod stream; pub mod stream;
pub mod util; pub mod util;
pub mod extensions; pub use crate::{
client::{client, connect},
pub use crate::client::{client, connect}; error::{Error, Result},
pub use crate::error::{Error, Result}; handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
pub use crate::handshake::client::ClientHandshake; protocol::{Message, WebSocket},
pub use crate::handshake::server::ServerHandshake; server::{accept, accept_hdr},
pub use crate::handshake::HandshakeError; };
pub use crate::protocol::{Message, WebSocket};
pub use crate::server::{accept, accept_hdr};

@ -1,7 +1,9 @@
//! Various codes defined in RFC 6455. //! Various codes defined in RFC 6455.
use std::convert::{From, Into}; use std::{
use std::fmt; convert::{From, Into},
fmt,
};
/// WebSocket message opcode as in RFC 6455. /// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -71,9 +73,11 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode { impl Into<u8> for OpCode {
fn into(self) -> u8 { fn into(self) -> u8 {
use self::Control::{Close, Ping, Pong}; use self::{
use self::Data::{Binary, Continue, Text}; Control::{Close, Ping, Pong},
use self::OpCode::*; Data::{Binary, Continue, Text},
OpCode::*,
};
match self { match self {
Data(Continue) => 0, Data(Continue) => 0,
Data(Text) => 1, Data(Text) => 1,
@ -90,9 +94,11 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode { impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode { fn from(byte: u8) -> OpCode {
use self::Control::{Close, Ping, Pong}; use self::{
use self::Data::{Binary, Continue, Text}; Control::{Close, Ping, Pong},
use self::OpCode::*; Data::{Binary, Continue, Text},
OpCode::*,
};
match byte { match byte {
0 => Data(Continue), 0 => Data(Continue),
1 => Data(Text), 1 => Data(Text),
@ -184,14 +190,7 @@ pub enum CloseCode {
impl CloseCode { impl CloseCode {
/// Check if this CloseCode is allowed. /// Check if this CloseCode is allowed.
pub fn is_allowed(self) -> bool { pub fn is_allowed(self) -> bool {
match self { !matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls)
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
} }
} }

@ -1,14 +1,18 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*; use log::*;
use std::borrow::Cow; use std::{
use std::default::Default; borrow::Cow,
use std::fmt; default::Default,
use std::io::{Cursor, ErrorKind, Read, Write}; fmt,
use std::result::Result as StdResult; io::{Cursor, ErrorKind, Read, Write},
use std::string::{FromUtf8Error, String}; result::Result as StdResult,
string::{FromUtf8Error, String},
use super::coding::{CloseCode, Control, Data, OpCode}; };
use super::mask::{apply_mask, generate_mask};
use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
/// A struct representing the close command. /// A struct representing the close command.
@ -23,10 +27,7 @@ pub struct CloseFrame<'t> {
impl<'t> CloseFrame<'t> { impl<'t> CloseFrame<'t> {
/// Convert into a owned string. /// Convert into a owned string.
pub fn into_owned(self) -> CloseFrame<'static> { pub fn into_owned(self) -> CloseFrame<'static> {
CloseFrame { CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
code: self.code,
reason: self.reason.into_owned().into(),
}
} }
} }
@ -313,10 +314,7 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { Ok(Some(CloseFrame { code, reason: text.into() }))
code,
reason: text.into(),
}))
} }
} }
} }
@ -324,22 +322,9 @@ impl Frame {
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!( debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
match opcode {
OpCode::Data(_) => true,
_ => false,
},
"Invalid opcode for data frame."
);
Frame { Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
header: FrameHeader {
is_final,
opcode,
..FrameHeader::default()
},
payload: data,
}
} }
/// Create a new Pong control frame. /// Create a new Pong control frame.
@ -378,10 +363,7 @@ impl Frame {
Vec::new() Vec::new()
}; };
Frame { Frame { header: FrameHeader::default(), payload }
header: FrameHeader::default(),
payload,
}
} }
/// Create a frame from given header and data. /// Create a frame from given header and data.
@ -425,10 +407,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
) )
} }
} }
@ -500,10 +479,7 @@ mod tests {
let mut payload = Vec::new(); let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap(); raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload); let frame = Frame::from_payload(header, payload);
assert_eq!( assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
} }
#[test] #[test]

@ -6,8 +6,7 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
pub use self::frame::CloseFrame; pub use self::frame::{CloseFrame, ExtensionHeaders, Frame, FrameHeader};
pub use self::frame::{ExtensionHeaders, Frame, FrameHeader};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
@ -26,18 +25,12 @@ pub struct FrameSocket<Stream> {
impl<Stream> FrameSocket<Stream> { impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket. /// Create a new frame socket.
pub fn new(stream: Stream) -> Self { pub fn new(stream: Stream) -> Self {
FrameSocket { FrameSocket { stream, codec: FrameCodec::new() }
stream,
codec: FrameCodec::new(),
}
} }
/// Create a new frame socket from partially read data. /// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self { pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket { FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
stream,
codec: FrameCodec::from_partially_read(part),
}
} }
/// Extract a stream from the socket. /// Extract a stream from the socket.
@ -184,9 +177,7 @@ impl FrameCodec {
{ {
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending(stream) self.write_pending(stream)
} }
@ -231,10 +222,7 @@ mod tests {
sock.read_frame(None).unwrap().unwrap().into_data(), sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
); );
assert_eq!( assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();

@ -1,14 +1,14 @@
use std::convert::{AsRef, From, Into}; use std::{
use std::fmt; convert::{AsRef, From, Into},
use std::result::Result as StdResult; fmt,
use std::str; result::Result as StdResult,
str,
};
use super::frame::CloseFrame; use super::frame::CloseFrame;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
mod string_collect { mod string_collect {
use utf8;
use utf8::DecodeError; use utf8::DecodeError;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -21,10 +21,7 @@ mod string_collect {
impl StringCollector { impl StringCollector {
pub fn new() -> Self { pub fn new() -> Self {
StringCollector { StringCollector { data: String::new(), incomplete: None }
data: String::new(),
incomplete: None,
}
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
@ -56,10 +53,7 @@ mod string_collect {
self.data.push_str(text); self.data.push_str(text);
Ok(()) Ok(())
} }
Err(DecodeError::Incomplete { Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
valid_prefix,
incomplete_suffix,
}) => {
self.data.push_str(valid_prefix); self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix); self.incomplete = Some(incomplete_suffix);
Ok(()) Ok(())
@ -129,11 +123,7 @@ impl IncompleteMessage {
// Be careful about integer overflows here. // Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size { if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity( return Err(Error::Capacity(
format!( format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(),
"Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
)); ));
} }
@ -203,42 +193,27 @@ impl Message {
/// Indicates whether a message is a text message. /// Indicates whether a message is a text message.
pub fn is_text(&self) -> bool { pub fn is_text(&self) -> bool {
match *self { matches!(*self, Message::Text(_))
Message::Text(_) => true,
_ => false,
}
} }
/// Indicates whether a message is a binary message. /// Indicates whether a message is a binary message.
pub fn is_binary(&self) -> bool { pub fn is_binary(&self) -> bool {
match *self { matches!(*self, Message::Binary(_))
Message::Binary(_) => true,
_ => false,
}
} }
/// Indicates whether a message is a ping message. /// Indicates whether a message is a ping message.
pub fn is_ping(&self) -> bool { pub fn is_ping(&self) -> bool {
match *self { matches!(*self, Message::Ping(_))
Message::Ping(_) => true,
_ => false,
}
} }
/// Indicates whether a message is a pong message. /// Indicates whether a message is a pong message.
pub fn is_pong(&self) -> bool { pub fn is_pong(&self) -> bool {
match *self { matches!(*self, Message::Pong(_))
Message::Pong(_) => true,
_ => false,
}
} }
/// Indicates whether a message ia s close message. /// Indicates whether a message ia s close message.
pub fn is_close(&self) -> bool { pub fn is_close(&self) -> bool {
match *self { matches!(*self, Message::Close(_))
Message::Close(_) => true,
_ => false,
}
} }
/// Get the length of the WebSocket message. /// Get the length of the WebSocket message.

@ -2,24 +2,29 @@
pub mod frame; pub mod frame;
pub(crate) mod message; mod message;
pub use self::frame::CloseFrame; pub use self::{frame::CloseFrame, message::Message};
pub use self::message::Message;
use log::*; use log::*;
use std::collections::VecDeque; use std::{
use std::io::{ErrorKind as IoErrorKind, Read, Write}; collections::VecDeque,
use std::mem::replace; io::{ErrorKind as IoErrorKind, Read, Write},
mem::replace,
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode}; };
use self::frame::{Frame, FrameCodec};
use self::message::IncompleteMessage; use self::{
use crate::error::{Error, Result}; frame::{
use crate::extensions::compression::{CompressionSwitcher, WsCompression}; coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode},
use crate::extensions::WebSocketExtension; Frame, FrameCodec,
use crate::protocol::frame::coding::Data; },
use crate::util::NonBlockingResult; message::{IncompleteMessage, IncompleteMessageType},
extensions::{WebSocketExtension, compression::{CompressionSwitcher, WsCompression}};
};
use crate::{
error::{Error, Result},
util::NonBlockingResult,
};
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
@ -33,7 +38,7 @@ pub enum Role {
} }
/// The configuration for WebSocket connection. /// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig { pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None` /// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited /// means here that the size of the queue is unlimited. The default value is the unlimited
@ -77,10 +82,7 @@ impl<Stream> WebSocket<Stream> {
/// or together with an existing one. If you need an initial handshake, use /// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket. /// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket { WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
socket: stream,
context: WebSocketContext::new(role, config),
}
} }
/// Convert a raw socket into a WebSocket without performing a handshake. /// Convert a raw socket into a WebSocket without performing a handshake.
@ -136,10 +138,7 @@ impl<Stream> WebSocket<Stream> {
} }
} }
impl<Stream> WebSocket<Stream> impl<Stream: Read + Write> WebSocket<Stream> {
where
Stream: Read + Write,
{
/// Read a message from stream, if possible. /// Read a message from stream, if possible.
/// ///
/// This will queue responses to ping and close messages to be sent. It will call /// This will queue responses to ping and close messages to be sent. It will call
@ -333,9 +332,7 @@ impl WebSocketContext {
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol( return Err(Error::Protocol("Sending after closing is not allowed".into()));
"Sending after closing is not allowed".into(),
));
} }
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
@ -457,9 +454,7 @@ impl WebSocketContext {
Role::Client => { Role::Client => {
if frame.is_masked() { if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455) // A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol( return Err(Error::Protocol("Received a masked frame from server".into()));
"Received a masked frame from server".into(),
));
} }
} }
} }
@ -476,9 +471,9 @@ impl WebSocketContext {
Err(Error::Protocol("Control frame too big".into())) Err(Error::Protocol("Control frame too big".into()))
} }
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => Err(Error::Protocol( OpCtl::Reserved(i) => {
format!("Unknown control frame type {}", i).into(), Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
)), }
OpCtl::Ping => { OpCtl::Ping => {
let data = frame.into_data(); let data = frame.into_data();
// No ping processing after we sent a close frame. // No ping processing after we sent a close frame.
@ -568,43 +563,8 @@ impl WebSocketContext {
} }
} }
if frame.header().is_final { trace!("Sending frame: {:?}", frame);
frame = self.decoder.on_send_frame(frame)?; self.frame.write_frame(stream, frame).check_connection_reset(self.state)
}
let max_frame_size = self.config.max_frame_size.unwrap_or_else(usize::max_value);
if frame.payload().len() > max_frame_size {
let mut chunks = frame.payload().chunks(max_frame_size).peekable();
let data_frame = Frame::message(
Vec::from(chunks.next().unwrap()),
frame.header().opcode,
false,
);
self.frame
.write_frame(stream, data_frame)
.check_connection_reset(self.state)?;
while let Some(chunk) = chunks.next() {
let frame = Frame::message(
Vec::from(chunk),
OpCode::Data(Data::Continue),
chunks.peek().is_none(),
);
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)?;
}
Ok(())
} else {
trace!("Sending frame: {:?}", frame);
self.frame
.write_frame(stream, frame)
.check_connection_reset(self.state)
}
} }
} }
@ -626,20 +586,14 @@ enum WebSocketState {
impl WebSocketState { impl WebSocketState {
/// Tell if we're allowed to process normal messages. /// Tell if we're allowed to process normal messages.
fn is_active(self) -> bool { fn is_active(self) -> bool {
match self { matches!(self, WebSocketState::Active)
WebSocketState::Active => true,
_ => false,
}
} }
/// Tell if we should process incoming data. Note that if we send a close frame /// Tell if we should process incoming data. Note that if we send a close frame
/// but the remote hasn't confirmed, they might have sent data before they receive our /// but the remote hasn't confirmed, they might have sent data before they receive our
/// close frame, so we should still pass those to client code, hence ClosedByUs is valid. /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
fn can_read(self) -> bool { fn can_read(self) -> bool {
match self { matches!(self, WebSocketState::Active | WebSocketState::ClosedByUs)
WebSocketState::Active | WebSocketState::ClosedByUs => true,
_ => false,
}
} }
/// Check if the state is active, return error if not. /// Check if the state is active, return error if not.
@ -675,11 +629,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests { mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::compression::WsCompression; use std::{io, io::Cursor};
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use std::io;
use std::io::Cursor;
struct WriteMoc<Stream>(Stream); struct WriteMoc<Stream>(Stream);
@ -708,14 +658,8 @@ mod tests {
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!( assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
socket.read_message().unwrap(), assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
} }
#[test] #[test]
@ -724,11 +668,7 @@ mod tests {
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x6c, 0x64, 0x21, 0x6c, 0x64, 0x21,
]); ]);
let limit = WebSocketConfig { let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() };
max_send_queue: None,
max_frame_size: Some(16 << 20),
compression: WsCompression::None(Some(10)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!( assert_eq!(
socket.read_message().unwrap_err().to_string(), socket.read_message().unwrap_err().to_string(),
@ -739,80 +679,11 @@ mod tests {
#[test] #[test]
fn size_limiting_binary() { fn size_limiting_binary() {
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig { let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() };
max_send_queue: None,
max_frame_size: Some(16 << 20),
compression: WsCompression::None(Some(2)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!( assert_eq!(
socket.read_message().unwrap_err().to_string(), socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2" "Space limit exceeded: Message too big: 0 + 3 > 2"
); );
} }
#[test]
fn fragmented_tx() {
let max_message_size = 2;
let input_str = "hello unit test";
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(2),
compression: WsCompression::None(Some(max_message_size)),
};
let mut socket =
WebSocket::from_raw_socket(Cursor::new(Vec::new()), Role::Client, Some(limit));
socket.write_message(Message::text(input_str)).unwrap();
socket.socket.set_position(0);
let WebSocket {
mut socket,
mut context,
} = socket;
let vec = input_str.chars().collect::<Vec<_>>();
let mut iter = vec
.chunks(max_message_size)
.map(|c| c.iter().collect::<String>())
.into_iter()
.peekable();
let frame_eq = |expected: Frame, actual: Frame| {
assert_eq!(expected.payload(), actual.payload());
assert_eq!(expected.header().opcode, actual.header().opcode);
assert_eq!(
expected.header().ext_headers.rsv1,
actual.header().ext_headers.rsv1
);
};
let expected = Frame::message(iter.next().unwrap().into(), OpCode::Data(Data::Text), false);
frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
while let Some(chars) = iter.next() {
let expected = Frame::message(
chars.into(),
OpCode::Data(Data::Continue),
iter.peek().is_none(),
);
frame_eq(
expected,
context
.frame
.read_frame(&mut socket, Some(max_message_size))
.unwrap()
.unwrap(),
);
}
}
} }

@ -2,8 +2,10 @@
pub use crate::handshake::server::ServerHandshake; pub use crate::handshake::server::ServerHandshake;
use crate::handshake::server::{Callback, NoCallback}; use crate::handshake::{
use crate::handshake::HandshakeError; server::{Callback, NoCallback},
HandshakeError,
};
use crate::protocol::{WebSocket, WebSocketConfig}; use crate::protocol::{WebSocket, WebSocketConfig};

@ -1,7 +1,9 @@
//! Helper traits to ease non-blocking handling. //! Helper traits to ease non-blocking handling.
use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::{
use std::result::Result as StdResult; io::{Error as IoError, ErrorKind as IoErrorKind},
result::Result as StdResult,
};
use crate::error::Error; use crate::error::Error;

@ -1,10 +1,12 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket. //! is closedd from the server's point of view and drop the underlying tcp socket.
use std::net::{TcpListener, TcpStream}; use std::{
use std::process::exit; net::{TcpListener, TcpStream},
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use native_tls::TlsStream; use native_tls::TlsStream;
use net2::TcpStreamExt; use net2::TcpStreamExt;
@ -49,9 +51,7 @@ fn test_server_close() {
do_test( do_test(
3012, 3012,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = cli_sock.read_message().unwrap(); // receive close from server let message = cli_sock.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
@ -85,9 +85,7 @@ fn test_evil_server_close() {
do_test( do_test(
3013, 3013,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
@ -109,10 +107,7 @@ fn test_evil_server_close() {
let message = srv_sock.read_message().unwrap(); // receive acknowledgement let message = srv_sock.read_message().unwrap(); // receive acknowledgement
assert!(message.is_close()); assert!(message.is_close());
// and now just drop the connection without waiting for `ConnectionClosed` // and now just drop the connection without waiting for `ConnectionClosed`
srv_sock srv_sock.get_mut().set_linger(Some(Duration::from_secs(0))).unwrap();
.get_mut()
.set_linger(Some(Duration::from_secs(0)))
.unwrap();
drop(srv_sock); drop(srv_sock);
}, },
); );
@ -123,9 +118,7 @@ fn test_client_close() {
do_test( do_test(
3014, 3014,
|mut cli_sock| { |mut cli_sock| {
cli_sock cli_sock.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = cli_sock.read_message().unwrap(); // receive answer from server let message = cli_sock.read_message().unwrap(); // receive answer from server
assert_eq!(message.into_data(), b"From Server"); assert_eq!(message.into_data(), b"From Server");
@ -145,9 +138,7 @@ fn test_client_close() {
let message = srv_sock.read_message().unwrap(); let message = srv_sock.read_message().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket"); assert_eq!(message.into_data(), b"Hello WebSocket");
srv_sock srv_sock.write_message(Message::Text("From Server".into())).unwrap();
.write_message(Message::Text("From Server".into()))
.unwrap();
let message = srv_sock.read_message().unwrap(); // receive close from client let message = srv_sock.read_message().unwrap(); // receive close from client
assert!(message.is_close()); assert!(message.is_close());

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
use std::net::TcpListener; use std::{
use std::process::exit; net::TcpListener,
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;

@ -1,10 +1,12 @@
//! Verifies that we can read data messages even if we have initiated a close handshake, //! Verifies that we can read data messages even if we have initiated a close handshake,
//! but before we got confirmation. //! but before we got confirmation.
use std::net::TcpListener; use std::{
use std::process::exit; net::TcpListener,
use std::thread::{sleep, spawn}; process::exit,
use std::time::Duration; thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
use url::Url; use url::Url;
@ -24,9 +26,7 @@ fn test_receive_after_init_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3013/socket").unwrap()).unwrap();
client client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = client.read_message().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());

Loading…
Cancel
Save