Merge pull request #67 from vorot93/rust-2018

Edition 2018, formatting, clippy fixes
pull/75/head
Daniel Abramov 5 years ago committed by GitHub
commit c6c3db34cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      Cargo.toml
  2. 40
      examples/autobahn-client.rs
  3. 18
      examples/autobahn-server.rs
  4. 6
      examples/callback-error.rs
  5. 15
      examples/client.rs
  6. 10
      examples/server.rs
  7. 1
      fuzz/Cargo.toml
  8. 78
      src/client.rs
  9. 12
      src/error.rs
  10. 92
      src/handshake/client.rs
  11. 29
      src/handshake/headers.rs
  12. 37
      src/handshake/machine.rs
  13. 30
      src/handshake/mod.rs
  14. 83
      src/handshake/server.rs
  15. 38
      src/lib.rs
  16. 24
      src/protocol/frame/coding.rs
  17. 97
      src/protocol/frame/frame.rs
  18. 26
      src/protocol/frame/mask.rs
  19. 77
      src/protocol/frame/mod.rs
  20. 73
      src/protocol/message.rs
  21. 145
      src/protocol/mod.rs
  22. 30
      src/server.rs
  23. 6
      src/stream.rs
  24. 7
      src/util.rs
  25. 14
      tests/connection_reset.rs

@ -10,6 +10,7 @@ homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.9.1" documentation = "https://docs.rs/tungstenite/0.9.1"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.9.1" version = "0.9.1"
edition = "2018"
[features] [features]
default = ["tls"] default = ["tls"]

@ -1,18 +1,12 @@
#[macro_use] extern crate log; use log::*;
extern crate env_logger;
extern crate tungstenite;
extern crate url;
use url::Url; use url::Url;
use tungstenite::{connect, Error, Result, Message}; use tungstenite::{connect, Error, Message, Result};
const AGENT: &'static str = "Tungstenite"; const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect( let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
)?;
let msg = socket.read_message()?; let msg = socket.read_message()?;
socket.close(None)?; socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap()) Ok(msg.into_text()?.parse::<u32>().unwrap())
@ -20,7 +14,11 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> { fn update_reports() -> Result<()> {
let (mut socket, _) = connect( let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(), Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?; )?;
socket.close(None)?; socket.close(None)?;
Ok(()) Ok(())
@ -28,19 +26,18 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> { fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case); info!("Running test case {}", case);
let case_url = Url::parse( let case_url = Url::parse(&format!(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT) "ws://localhost:9001/runCase?case={}&agent={}",
).unwrap(); case, AGENT
))
.unwrap();
let (mut socket, _) = connect(case_url)?; let (mut socket, _) = connect(case_url)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) | msg @ Message::Binary(_) => {
msg @ Message::Binary(_) => {
socket.write_message(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
Message::Pong(_) |
Message::Close(_) => {}
} }
} }
} }
@ -53,12 +50,13 @@ fn main() {
for case in 1..(total + 1) { for case in 1..(total + 1) {
if let Err(e) = run_test(case) { if let Err(e) = run_test(case) {
match e { match e {
Error::Protocol(_) => { } Error::Protocol(_) => {}
err => { warn!("test: {}", err); } err => {
warn!("test: {}", err);
}
} }
} }
} }
update_reports().unwrap(); update_reports().unwrap();
} }

@ -1,12 +1,9 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate tungstenite;
use std::net::{TcpListener, TcpStream}; use std::net::{TcpListener, TcpStream};
use std::thread::spawn; use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message}; use log::*;
use tungstenite::handshake::HandshakeRole; use tungstenite::handshake::HandshakeRole;
use tungstenite::{accept, Error, HandshakeError, Message, Result};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error { fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err { match err {
@ -19,13 +16,10 @@ fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?; let mut socket = accept(stream).map_err(must_not_block)?;
loop { loop {
match socket.read_message()? { match socket.read_message()? {
msg @ Message::Text(_) | msg @ Message::Text(_) | msg @ Message::Binary(_) => {
msg @ Message::Binary(_) => {
socket.write_message(msg)?; socket.write_message(msg)?;
} }
Message::Ping(_) | Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
Message::Pong(_) |
Message::Close(_) => {}
} }
} }
} }
@ -36,14 +30,12 @@ fn main() {
let server = TcpListener::bind("127.0.0.1:9002").unwrap(); let server = TcpListener::bind("127.0.0.1:9002").unwrap();
for stream in server.incoming() { for stream in server.incoming() {
spawn(move || { spawn(move || match stream {
match stream {
Ok(stream) => match handle_client(stream) { Ok(stream) => match handle_client(stream) {
Ok(_) => (), Ok(_) => (),
Err(e) => warn!("Error in client: {}", e), Err(e) => warn!("Error in client: {}", e),
}, },
Err(e) => warn!("Error accepting stream: {}", e), Err(e) => warn!("Error accepting stream: {}", e),
}
}); });
} }
} }

@ -1,10 +1,8 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener; use std::net::TcpListener;
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, ErrorResponse}; use tungstenite::handshake::server::{ErrorResponse, Request};
use tungstenite::http::StatusCode; use tungstenite::http::StatusCode;
fn main() { fn main() {

@ -1,15 +1,11 @@
extern crate tungstenite; use tungstenite::{connect, Message};
extern crate url;
extern crate env_logger;
use url::Url; use url::Url;
use tungstenite::{Message, connect};
fn main() { fn main() {
env_logger::init(); env_logger::init();
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap()) let (mut socket, response) =
.expect("Can't connect"); connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");
println!("Connected to the server"); println!("Connected to the server");
println!("Response HTTP code: {}", response.code); println!("Response HTTP code: {}", response.code);
@ -18,11 +14,12 @@ fn main() {
println!("* {}", header); println!("* {}", header);
} }
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap(); socket
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
loop { loop {
let msg = socket.read_message().expect("Error reading message"); let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg); println!("Received: {}", msg);
} }
// socket.close(None); // socket.close(None);
} }

@ -1,8 +1,5 @@
extern crate tungstenite;
extern crate env_logger;
use std::thread::spawn;
use std::net::TcpListener; use std::net::TcpListener;
use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::accept_hdr;
use tungstenite::handshake::server::Request; use tungstenite::handshake::server::Request;
@ -23,7 +20,10 @@ fn main() {
// Let's add an additional header to our response to the client. // Let's add an additional header to our response to the client.
let extra_headers = vec![ let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")), (String::from("MyCustomHeader"), String::from(":)")),
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")), (
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"),
),
]; ];
Ok(Some(extra_headers)) Ok(Some(extra_headers))
}; };

@ -1,4 +1,3 @@
[package] [package]
name = "tungstenite-fuzz" name = "tungstenite-fuzz"
version = "0.0.1" version = "0.0.1"

@ -1,36 +1,40 @@
//! Methods to connect to an WebSocket as a client. //! Methods to connect to an WebSocket as a client.
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use std::result::Result as StdResult;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use log::*;
use url::Url; use url::Url;
use handshake::client::Response; use crate::handshake::client::Response;
use protocol::WebSocketConfig; use crate::protocol::WebSocketConfig;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
mod encryption { mod encryption {
use std::net::TcpStream;
use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError};
pub use native_tls::TlsStream; pub use native_tls::TlsStream;
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector};
use std::net::TcpStream;
pub use stream::Stream as StreamSwitcher; pub use crate::stream::Stream as StreamSwitcher;
/// TCP stream switcher (plain/TLS). /// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>; pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use stream::Mode; use crate::error::Result;
use error::Result; use crate::stream::Mode;
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> { pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode { match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)), Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => { Mode::Tls => {
let connector = TlsConnector::builder().build()?; let connector = TlsConnector::builder().build()?;
connector.connect(domain, stream) connector
.connect(domain, stream)
.map_err(|e| match e { .map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(), TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::WouldBlock(_) => panic!("Bug: TLS handshake not blocked"), TlsHandshakeError::WouldBlock(_) => {
panic!("Bug: TLS handshake not blocked")
}
}) })
.map(StreamSwitcher::Tls) .map(StreamSwitcher::Tls)
} }
@ -38,12 +42,12 @@ mod encryption {
} }
} }
#[cfg(not(feature="tls"))] #[cfg(not(feature = "tls"))]
mod encryption { mod encryption {
use std::net::TcpStream; use std::net::TcpStream;
use stream::Mode;
use error::{Error, Result}; use error::{Error, Result};
use stream::Mode;
/// TLS support is nod compiled in, this is just standard `TcpStream`. /// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream; pub type AutoStream = TcpStream;
@ -56,15 +60,14 @@ mod encryption {
} }
} }
pub use self::encryption::AutoStream;
use self::encryption::wrap_stream; use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use protocol::WebSocket; use crate::error::{Error, Result};
use handshake::HandshakeError; use crate::handshake::client::{ClientHandshake, Request};
use handshake::client::{ClientHandshake, Request}; use crate::handshake::HandshakeError;
use stream::{NoDelay, Mode}; use crate::protocol::WebSocket;
use error::{Error, Result}; use crate::stream::{Mode, NoDelay};
/// Connect to the given WebSocket in blocking mode. /// Connect to the given WebSocket in blocking mode.
/// ///
@ -83,13 +86,17 @@ use error::{Error, Result};
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>( pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
request: Req, request: Req,
config: Option<WebSocketConfig> config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into(); let request: Request = request.into();
let mode = url_mode(&request.url)?; let mode = url_mode(&request.url)?;
let host = request.url.host() let host = request
.url
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; .ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request.url.port_or_known_default() let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?; .ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs; let addrs;
let addr; let addr;
@ -109,8 +116,7 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
}; };
let mut stream = connect_to_some(addrs, &request.url, mode)?; let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config) client_with_config(request, stream, config).map_err(|e| match e {
.map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"), HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
}) })
@ -128,19 +134,21 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req) pub fn connect<'t, Req: Into<Request<'t>>>(
-> Result<(WebSocket<AutoStream>, Response)> request: Req,
{ ) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None) connect_with_config(request, None)
} }
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?; let domain = url
.host_str()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr); debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream) return Ok(stream);
} }
} }
} }
@ -155,7 +163,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() { match url.scheme() {
"ws" => Ok(Mode::Plain), "ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls), "wss" => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())) _ => Err(Error::Url("URL scheme not supported".into())),
} }
} }
@ -182,8 +190,10 @@ where
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream) pub fn client<'t, Stream, Req>(
-> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: Into<Request<'t>>,

@ -11,9 +11,9 @@ use std::string;
use httparse; use httparse;
use protocol::Message; use crate::protocol::Message;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
pub mod tls { pub mod tls {
//! TLS error wrapper module, feature-gated. //! TLS error wrapper module, feature-gated.
pub use native_tls::Error; pub use native_tls::Error;
@ -41,7 +41,7 @@ pub enum Error {
AlreadyClosed, AlreadyClosed,
/// Input-output error /// Input-output error
Io(io::Error), Io(io::Error),
#[cfg(feature="tls")] #[cfg(feature = "tls")]
/// TLS error /// TLS error
Tls(tls::Error), Tls(tls::Error),
/// Buffer capacity exhausted /// Buffer capacity exhausted
@ -64,7 +64,7 @@ impl fmt::Display for Error {
Error::ConnectionClosed => write!(f, "Connection closed normally"), Error::ConnectionClosed => write!(f, "Connection closed normally"),
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"), Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
Error::Io(ref err) => write!(f, "IO error: {}", err), Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature="tls")] #[cfg(feature = "tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err), Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
@ -82,7 +82,7 @@ impl ErrorTrait for Error {
Error::ConnectionClosed => "A close handshake is performed", Error::ConnectionClosed => "A close handshake is performed",
Error::AlreadyClosed => "Trying to read or write after getting close notification", Error::AlreadyClosed => "Trying to read or write after getting close notification",
Error::Io(ref err) => err.description(), Error::Io(ref err) => err.description(),
#[cfg(feature="tls")] #[cfg(feature = "tls")]
Error::Tls(ref err) => err.description(), Error::Tls(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(), Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(), Error::Protocol(ref msg) => msg.borrow(),
@ -112,7 +112,7 @@ impl From<string::FromUtf8Error> for Error {
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
impl From<tls::Error> for Error { impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self { fn from(err: tls::Error) -> Self {
Error::Tls(err) Error::Tls(err)

@ -4,17 +4,15 @@ use std::borrow::Cow;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use base64;
use httparse::Status; use httparse::Status;
use httparse; use log::*;
use rand;
use url::Url; use url::Url;
use error::{Error, Result}; use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request. /// Client request.
#[derive(Debug)] #[derive(Debug)]
@ -80,26 +78,32 @@ impl<S: Read + Write> ClientHandshake<S> {
pub fn start( pub fn start(
stream: S, stream: S,
request: Request, request: Request,
config: Option<WebSocketConfig> config: Option<WebSocketConfig>,
) -> MidHandshake<Self> { ) -> MidHandshake<Self> {
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
let mut req = Vec::new(); let mut req = Vec::new();
write!(req, "\ write!(
req,
"\
GET {path} HTTP/1.1\r\n\ GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\ Host: {host}\r\n\
Connection: upgrade\r\n\ Connection: upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), path = request.get_path(), key = key).unwrap(); host = request.get_host(),
path = request.get_path(),
key = key
)
.unwrap();
if let Some(eh) = request.extra_headers { if let Some(eh) = request.extra_headers {
for (k, v) in eh { for (k, v) in eh {
write!(req, "{}: {}\r\n", k, v).unwrap(); writeln!(req, "{}: {}\r", k, v).unwrap();
} }
} }
write!(req, "\r\n").unwrap(); writeln!(req, "\r").unwrap();
HandshakeMachine::start_write(stream, req) HandshakeMachine::start_write(stream, req)
}; };
@ -113,7 +117,10 @@ impl<S: Read + Write> ClientHandshake<S> {
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
MidHandshake { role: client, machine } MidHandshake {
role: client,
machine,
}
} }
} }
@ -121,22 +128,23 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response; type IncomingData = Response;
type InternalStream = S; type InternalStream = S;
type FinalResult = (WebSocket<S>, Response); type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>) fn stage_finished(
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> &mut self,
{ finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish { Ok(match finish {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream)) ProcessingResult::Continue(HandshakeMachine::start_read(stream))
} }
StageResult::DoneReading { stream, result, tail, } => { StageResult::DoneReading {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read(
stream, stream,
result,
tail, tail,
Role::Client, } => {
self.config.clone(), self.verify_data.verify_response(&result)?;
); debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
ProcessingResult::Done((websocket, result)) ProcessingResult::Done((websocket, result))
} }
}) })
@ -161,22 +169,37 @@ impl VerifyData {
// header field contains a value that is not an ASCII case- // header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST // insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455) // _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Upgrade", "websocket") { if !response
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); .headers
.header_is_ignore_case("Upgrade", "websocket")
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
));
} }
// 3. If the response lacks a |Connection| header field or the // 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client // ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Connection", "Upgrade") { if !response
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); .headers
.header_is_ignore_case("Connection", "Upgrade")
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
));
} }
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or // 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) { if !response
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); .headers
.header_is("Sec-WebSocket-Accept", &self.accept_key)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
} }
// 5. If the response includes a |Sec-WebSocket-Extensions| header // 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension // field and this header field indicates the use of an extension
@ -219,7 +242,9 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
} }
Ok(Response { Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"), code: raw.code.expect("Bug: no HTTP response code"),
@ -238,8 +263,8 @@ fn generate_key() -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Response, generate_key};
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::{generate_key, Response};
#[test] #[test]
fn random_keys() { fn random_keys() {
@ -262,6 +287,9 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200); assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..])); assert_eq!(
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
} }
} }

@ -1,13 +1,13 @@
//! HTTP Request and response header handling. //! HTTP Request and response header handling.
use std::str::from_utf8;
use std::slice; use std::slice;
use std::str::from_utf8;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
use error::Result;
use super::machine::TryParse; use super::machine::TryParse;
use crate::error::Result;
/// Limit for the number of header lines. /// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124; pub const MAX_HEADERS: usize = 124;
@ -19,7 +19,6 @@ pub struct Headers {
} }
impl Headers { impl Headers {
/// Get first header with the given name, if any. /// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> { pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next() self.find(name).next()
@ -29,7 +28,7 @@ impl Headers {
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter { HeadersIter {
name, name,
iter: self.data.iter() iter: self.data.iter(),
} }
} }
@ -42,7 +41,8 @@ impl Headers {
/// Check if the given header has the given value (case-insensitive). /// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(()) self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value)) .map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false) .unwrap_or(false)
@ -52,7 +52,6 @@ impl Headers {
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter() self.data.iter()
} }
} }
/// The iterator over headers. /// The iterator over headers.
@ -67,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() { while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) { if name.eq_ignore_ascii_case(self.name) {
return Some(value) return Some(value);
} }
} }
None None
} }
} }
/// Trait to convert raw objects into HTTP parseables. /// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized { pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers. /// Convert raw object into parsed HTTP headers.
@ -94,7 +92,8 @@ impl TryParse for Headers {
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> { fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers { Ok(Headers {
data: raw.iter() data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(), .collect(),
}) })
@ -104,13 +103,12 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Headers;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::Headers;
#[test] #[test]
fn headers() { fn headers() {
const DATA: &'static [u8] = const DATA: &'static [u8] = b"Host: foo.com\r\n\
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
\r\n"; \r\n";
@ -126,8 +124,7 @@ mod tests {
#[test] #[test]
fn headers_iter() { fn headers_iter() {
const DATA: &'static [u8] = const DATA: &'static [u8] = b"Host: foo.com\r\n\
b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
@ -142,12 +139,10 @@ mod tests {
#[test] #[test]
fn headers_incomplete() { fn headers_incomplete() {
const DATA: &'static [u8] = const DATA: &'static [u8] = b"Host: foo.com\r\n\
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n"; Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap(); let hdr = Headers::try_parse(DATA).unwrap();
assert!(hdr.is_none()); assert!(hdr.is_none());
} }
} }

@ -1,9 +1,10 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf; use bytes::Buf;
use log::*;
use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use util::NonBlockingResult;
/// A generic handshake state machine. /// A generic handshake state machine.
#[derive(Debug)] #[derive(Debug)]
@ -43,16 +44,16 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
trace!("Doing handshake round."); trace!("Doing handshake round.");
match self.state { match self.state {
HandshakeState::Reading(mut buf) => { HandshakeState::Reading(mut buf) => {
let read = buf.prepare_reserve(MIN_READ) let read = buf
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) // TODO limit size .with_limit(usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))? .map_err(|_| Error::Capacity("Header too long".into()))?
.read_from(&mut self.stream).no_block()?; .read_from(&mut self.stream)
.no_block()?;
match read { match read {
Some(0) => { Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Err(Error::Protocol("Handshake not finished".into())) Some(_) => Ok(
} if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
Some(_) => {
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size); buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading { RoundResult::StageFinished(StageResult::DoneReading {
result: obj, result: obj,
@ -64,14 +65,12 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
}) })
}) },
} ),
None => { None => Ok(RoundResult::WouldBlock(HandshakeMachine {
Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
})) })),
}
} }
} }
HandshakeState::Writing(mut buf) => { HandshakeState::Writing(mut buf) => {
@ -113,7 +112,11 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)] #[derive(Debug)]
pub enum StageResult<Obj, Stream> { pub enum StageResult<Obj, Stream> {
/// Reading round finished. /// Reading round finished.
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> }, DoneReading {
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
/// Writing round finished. /// Writing round finished.
DoneWriting(Stream), DoneWriting(Stream),
} }

@ -1,7 +1,7 @@
//! WebSocket handshake control. //! WebSocket handshake control.
pub mod headers;
pub mod client; pub mod client;
pub mod headers;
pub mod server; pub mod server;
mod machine; mod machine;
@ -11,10 +11,10 @@ use std::fmt;
use std::io::{Read, Write}; use std::io::{Read, Write};
use base64; use base64;
use sha1::{Sha1, Digest}; use sha1::{Digest, Sha1};
use error::Error;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse}; use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
use crate::error::Error;
/// A WebSocket handshake. /// A WebSocket handshake.
#[derive(Debug)] #[derive(Debug)]
@ -30,15 +30,16 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop { loop {
mach = match mach.single_round()? { mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => { RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self })) return Err(HandshakeError::Interrupted(MidHandshake {
machine: m,
..self
}))
} }
RoundResult::Incomplete(m) => m, RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => { RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m, ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result), ProcessingResult::Done(result) => return Ok(result),
} },
}
} }
} }
} }
@ -94,8 +95,10 @@ pub trait HandshakeRole {
#[doc(hidden)] #[doc(hidden)]
type FinalResult; type FinalResult;
#[doc(hidden)] #[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>) fn stage_finished(
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>; &mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
} }
/// Stage processing result. /// Stage processing result.
@ -124,8 +127,9 @@ mod tests {
#[test] #[test]
fn key_conversion() { fn key_conversion() {
// example from RFC 6455 // example from RFC 6455
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(), assert_eq!(
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
} }
} }

@ -5,15 +5,15 @@ use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use httparse;
use httparse::Status;
use http::StatusCode; use http::StatusCode;
use httparse::Status;
use log::*;
use error::{Error, Result}; use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Request from the client. /// Request from the client.
#[derive(Debug)] #[derive(Debug)]
@ -27,7 +27,9 @@ pub struct Request {
impl Request { impl Request {
/// Reply to the response. /// Reply to the response.
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> { pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key") let key = self
.headers
.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!( let mut reply = format!(
"\ "\
@ -45,13 +47,12 @@ impl Request {
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
if let Some(eh) = extra_headers { if let Some(eh) = extra_headers {
for (k, v) in eh { for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap(); writeln!(reply, "{}: {}\r", k, v).unwrap();
} }
} }
write!(reply, "\r\n").unwrap(); writeln!(reply, "\r").unwrap();
} }
impl TryParse for Request { impl TryParse for Request {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -69,11 +70,13 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol("Method is not GET".into()));
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
} }
Ok(Request { Ok(Request {
path: raw.path.expect("Bug: no path in header").into(), path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)? headers: Headers::from_httparse(raw.headers)?,
}) })
} }
} }
@ -115,7 +118,10 @@ pub trait Callback: Sized {
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
} }
impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { impl<F> Callback for F
where
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>,
{
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
self(request) self(request)
} }
@ -160,7 +166,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback), callback: Some(callback),
config, config,
error_code: None, error_code: None,
_marker: PhantomData _marker: PhantomData,
}, },
} }
} }
@ -171,13 +177,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type InternalStream = S; type InternalStream = S;
type FinalResult = WebSocket<S>; type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>) fn stage_finished(
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> &mut self,
{ finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish { Ok(match finish {
StageResult::DoneReading { stream, result, tail } => { StageResult::DoneReading {
stream,
result,
tail,
} => {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())) return Err(Error::Protocol("Junk after client request".into()));
} }
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
@ -192,8 +203,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
Err(ErrorResponse { error_code, headers, body }) => { Err(ErrorResponse {
self.error_code= Some(error_code.as_u16()); error_code,
headers,
body,
}) => {
self.error_code = Some(error_code.as_u16());
let mut response = format!( let mut response = format!(
"HTTP/1.1 {} {}\r\n", "HTTP/1.1 {} {}\r\n",
error_code.as_str(), error_code.as_str(),
@ -214,11 +229,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(err)); return Err(Error::Http(err));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket( let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }
} }
@ -228,9 +239,9 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Request;
use super::super::machine::TryParse;
use super::super::client::Response; use super::super::client::Response;
use super::super::machine::TryParse;
use super::Request;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
@ -253,13 +264,19 @@ mod tests {
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap(); let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![(String::from("MyCustomHeader"), let extra_headers = Some(vec![
String::from("MyCustomValue")), (
(String::from("MyVersion"), String::from("MyCustomHeader"),
String::from("LOL"))]); String::from("MyCustomValue"),
),
(String::from("MyVersion"), String::from("LOL")),
]);
let reply = req.reply(extra_headers).unwrap(); let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref())); assert_eq!(
req.headers.find_first("MyCustomHeader"),
Some(b"MyCustomValue".as_ref())
);
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
} }
} }

@ -3,39 +3,29 @@
missing_docs, missing_docs,
missing_copy_implementations, missing_copy_implementations,
missing_debug_implementations, missing_debug_implementations,
trivial_casts, trivial_numeric_casts, trivial_casts,
trivial_numeric_casts,
unstable_features, unstable_features,
unused_must_use, unused_must_use,
unused_mut, unused_mut,
unused_imports, unused_imports,
unused_import_braces)] unused_import_braces
)]
#[macro_use] extern crate log; pub use http;
extern crate base64;
extern crate byteorder;
extern crate bytes;
extern crate httparse;
extern crate input_buffer;
extern crate rand;
extern crate sha1;
extern crate url;
extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls;
pub extern crate http;
pub mod client;
pub mod error; pub mod error;
pub mod handshake;
pub mod protocol; pub mod protocol;
pub mod client;
pub mod server; pub mod server;
pub mod handshake;
pub mod stream; pub mod stream;
pub mod util; pub mod util;
pub use client::{connect, client}; pub use crate::client::{client, connect};
pub use server::{accept, accept_hdr}; pub use crate::error::{Error, Result};
pub use error::{Error, Result}; pub use crate::handshake::client::ClientHandshake;
pub use protocol::{WebSocket, Message}; pub use crate::handshake::server::ServerHandshake;
pub use handshake::HandshakeError; pub use crate::handshake::HandshakeError;
pub use handshake::client::ClientHandshake; pub use crate::protocol::{Message, WebSocket};
pub use handshake::server::ServerHandshake; pub use crate::server::{accept, accept_hdr};

@ -1,7 +1,7 @@
//! Various codes defined in RFC 6455. //! Various codes defined in RFC 6455.
use std::convert::{From, Into};
use std::fmt; use std::fmt;
use std::convert::{Into, From};
/// WebSocket message opcode as in RFC 6455. /// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -71,8 +71,8 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode { impl Into<u8> for OpCode {
fn into(self) -> u8 { fn into(self) -> u8 {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong}; use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*; use self::OpCode::*;
match self { match self {
Data(Continue) => 0, Data(Continue) => 0,
@ -90,18 +90,18 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode { impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode { fn from(byte: u8) -> OpCode {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong}; use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*; use self::OpCode::*;
match byte { match byte {
0 => Data(Continue), 0 => Data(Continue),
1 => Data(Text), 1 => Data(Text),
2 => Data(Binary), 2 => Data(Binary),
i @ 3 ... 7 => Data(self::Data::Reserved(i)), i @ 3..=7 => Data(self::Data::Reserved(i)),
8 => Control(Close), 8 => Control(Close),
9 => Control(Ping), 9 => Control(Ping),
10 => Control(Pong), 10 => Control(Pong),
i @ 11 ... 15 => Control(self::Control::Reserved(i)), i @ 11..=15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"), _ => panic!("Bug: OpCode out of range"),
} }
} }
@ -183,8 +183,8 @@ pub enum CloseCode {
impl CloseCode { impl CloseCode {
/// Check if this CloseCode is allowed. /// Check if this CloseCode is allowed.
pub fn is_allowed(&self) -> bool { pub fn is_allowed(self) -> bool {
match *self { match self {
Bad(_) => false, Bad(_) => false,
Reserved(_) => false, Reserved(_) => false,
Status => false, Status => false,
@ -250,11 +250,11 @@ impl From<u16> for CloseCode {
1012 => Restart, 1012 => Restart,
1013 => Again, 1013 => Again,
1015 => Tls, 1015 => Tls,
1...999 => Bad(code), 1..=999 => Bad(code),
1000...2999 => Reserved(code), 1016..=2999 => Reserved(code),
3000...3999 => Iana(code), 3000..=3999 => Iana(code),
4000...4999 => Library(code), 4000..=4999 => Library(code),
_ => Bad(code) _ => Bad(code),
} }
} }
} }

@ -1,14 +1,15 @@
use std::fmt; use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*;
use std::borrow::Cow; use std::borrow::Cow;
use std::io::{Cursor, Read, Write, ErrorKind};
use std::default::Default; use std::default::Default;
use std::string::{String, FromUtf8Error}; use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use byteorder::{ByteOrder, ReadBytesExt, WriteBytesExt, NetworkEndian}; use std::string::{FromUtf8Error, String};
use error::{Error, Result}; use super::coding::{CloseCode, Control, Data, OpCode};
use super::coding::{OpCode, Control, Data, CloseCode}; use super::mask::{apply_mask, generate_mask};
use super::mask::{generate_mask, apply_mask}; use crate::error::{Error, Result};
/// A struct representing the close command. /// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
@ -77,15 +78,13 @@ impl FrameHeader {
cursor.set_position(initial); cursor.set_position(initial);
ret ret
} }
ret => ret ret => ret,
} }
} }
/// Get the size of the header formatted with given payload length. /// Get the size of the header formatted with given payload length.
pub fn len(&self, length: u64) -> usize { pub fn len(&self, length: u64) -> usize {
2 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
+ LengthFormat::for_length(length).extra_bytes()
+ if self.mask.is_some() { 4 } else { 0 }
} }
/// Format a header for given payload size. /// Format a header for given payload size.
@ -93,8 +92,7 @@ impl FrameHeader {
let code: u8 = self.opcode.into(); let code: u8 = self.opcode.into();
let one = { let one = {
code code | if self.is_final { 0x80 } else { 0 }
| if self.is_final { 0x80 } else { 0 }
| if self.rsv1 { 0x40 } else { 0 } | if self.rsv1 { 0x40 } else { 0 }
| if self.rsv2 { 0x20 } else { 0 } | if self.rsv2 { 0x20 } else { 0 }
| if self.rsv3 { 0x10 } else { 0 } | if self.rsv3 { 0x10 } else { 0 }
@ -102,10 +100,7 @@ impl FrameHeader {
let lenfmt = LengthFormat::for_length(length); let lenfmt = LengthFormat::for_length(length);
let two = { let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
lenfmt.length_byte()
| if self.mask.is_some() { 0x80 } else { 0 }
};
output.write_all(&[one, two])?; output.write_all(&[one, two])?;
match lenfmt { match lenfmt {
@ -137,7 +132,7 @@ impl FrameHeader {
let (first, second) = { let (first, second) = {
let mut head = [0u8; 2]; let mut head = [0u8; 2];
if cursor.read(&mut head)? != 2 { if cursor.read(&mut head)? != 2 {
return Ok(None) return Ok(None);
} }
trace!("Parsed headers {:?}", head); trace!("Parsed headers {:?}", head);
(head[0], head[1]) (head[0], head[1])
@ -169,17 +164,17 @@ impl FrameHeader {
Err(err) => { Err(err) => {
return Err(err.into()); return Err(err.into());
} }
Ok(read) => read Ok(read) => read,
} }
} else { } else {
length_byte as u64 u64::from(length_byte)
} }
}; };
let mask = if masked { let mask = if masked {
let mut mask_bytes = [0u8; 4]; let mut mask_bytes = [0u8; 4];
if cursor.read(&mut mask_bytes)? != 4 { if cursor.read(&mut mask_bytes)? != 4 {
return Ok(None) return Ok(None);
} else { } else {
Some(mask_bytes) Some(mask_bytes)
} }
@ -190,9 +185,11 @@ impl FrameHeader {
// Disallow bad opcode // Disallow bad opcode
match opcode { match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into())) return Err(Error::Protocol(
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
} }
_ => () _ => (),
} }
let hdr = FrameHeader { let hdr = FrameHeader {
@ -216,7 +213,6 @@ pub struct Frame {
} }
impl Frame { impl Frame {
/// Get the length of the frame. /// Get the length of the frame.
/// This is the length of the header + the length of the payload. /// This is the length of the header + the length of the payload.
#[inline] #[inline]
@ -225,6 +221,12 @@ impl Frame {
self.header.len(length as u64) + length self.header.len(length as u64) + length
} }
/// Check if the frame is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get a reference to the frame's header. /// Get a reference to the frame's header.
#[inline] #[inline]
pub fn header(&self) -> &FrameHeader { pub fn header(&self) -> &FrameHeader {
@ -296,7 +298,10 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code, reason: text.into() })) Ok(Some(CloseFrame {
code,
reason: text.into(),
}))
} }
} }
} }
@ -304,16 +309,19 @@ impl Frame {
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(match opcode { debug_assert!(
match opcode {
OpCode::Data(_) => true, OpCode::Data(_) => true,
_ => false, _ => false,
}, "Invalid opcode for data frame."); },
"Invalid opcode for data frame."
);
Frame { Frame {
header: FrameHeader { header: FrameHeader {
is_final, is_final,
opcode, opcode,
.. FrameHeader::default() ..FrameHeader::default()
}, },
payload: data, payload: data,
} }
@ -325,7 +333,7 @@ impl Frame {
Frame { Frame {
header: FrameHeader { header: FrameHeader {
opcode: OpCode::Control(Control::Pong), opcode: OpCode::Control(Control::Pong),
.. FrameHeader::default() ..FrameHeader::default()
}, },
payload: data, payload: data,
} }
@ -337,7 +345,7 @@ impl Frame {
Frame { Frame {
header: FrameHeader { header: FrameHeader {
opcode: OpCode::Control(Control::Ping), opcode: OpCode::Control(Control::Ping),
.. FrameHeader::default() ..FrameHeader::default()
}, },
payload: data, payload: data,
} }
@ -363,10 +371,7 @@ impl Frame {
/// Create a frame from given header and data. /// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self { pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame { Frame { header, payload }
header,
payload,
}
} }
/// Write a frame out to a buffer /// Write a frame out to a buffer
@ -380,7 +385,8 @@ impl Frame {
impl fmt::Display for Frame { impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, write!(
f,
" "
<FRAME> <FRAME>
final: {} final: {}
@ -398,7 +404,11 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>()) self.payload
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
)
} }
} }
@ -448,7 +458,7 @@ impl LengthFormat {
match byte & 0x7F { match byte & 0x7F {
126 => LengthFormat::U16, 126 => LengthFormat::U16,
127 => LengthFormat::U64, 127 => LengthFormat::U64,
b => LengthFormat::U8(b) b => LengthFormat::U8(b),
} }
} }
} }
@ -457,20 +467,22 @@ impl LengthFormat {
mod tests { mod tests {
use super::*; use super::*;
use super::super::coding::{OpCode, Data}; use super::super::coding::{Data, OpCode};
use std::io::Cursor; use std::io::Cursor;
#[test] #[test]
fn parse() { fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![ let mut raw: Cursor<Vec<u8>> =
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
]);
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7); assert_eq!(length, 7);
let mut payload = Vec::new(); let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap(); raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload); let frame = Frame::from_payload(header, payload);
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); assert_eq!(
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
} }
#[test] #[test]
@ -487,5 +499,4 @@ mod tests {
let view = format!("{}", f); let view = format!("{}", f);
assert!(view.contains("payload:")); assert!(view.contains("payload:"));
} }
} }

@ -1,7 +1,8 @@
use rand;
use std::cmp::min; use std::cmp::min;
#[allow(deprecated)]
use std::mem::uninitialized; use std::mem::uninitialized;
use std::ptr::{copy_nonoverlapping, read_unaligned}; use std::ptr::{copy_nonoverlapping, read_unaligned};
use rand;
/// Generate a random frame mask. /// Generate a random frame mask.
#[inline] #[inline]
@ -26,11 +27,9 @@ fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) {
/// Faster version of `apply_mask()` which operates on 4-byte blocks. /// Faster version of `apply_mask()` which operates on 4-byte blocks.
#[inline] #[inline]
#[allow(dead_code)] #[allow(dead_code, clippy::cast_ptr_alignment)]
fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32: u32 = unsafe { let mask_u32: u32 = unsafe { read_unaligned(mask.as_ptr() as *const u32) };
read_unaligned(mask.as_ptr() as *const u32)
};
let mut ptr = buf.as_mut_ptr(); let mut ptr = buf.as_mut_ptr();
let mut len = buf.len(); let mut len = buf.len();
@ -40,7 +39,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32 = if head > 0 { let mask_u32 = if head > 0 {
unsafe { unsafe {
xor_mem(ptr, mask_u32, head); xor_mem(ptr, mask_u32, head);
ptr = ptr.offset(head as isize); ptr = ptr.add(head);
} }
len -= head; len -= head;
if cfg!(target_endian = "big") { if cfg!(target_endian = "big") {
@ -67,7 +66,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
// Possible last block. // Possible last block.
if len > 0 { if len > 0 {
unsafe { xor_mem(ptr, mask_u32, len); } unsafe {
xor_mem(ptr, mask_u32, len);
}
} }
} }
@ -75,6 +76,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient, // TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient,
// it could be done better. The compiler does not see that len is limited to 3. // it could be done better. The compiler does not see that len is limited to 3.
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
#[allow(deprecated)]
let mut b: u32 = uninitialized(); let mut b: u32 = uninitialized();
#[allow(trivial_casts)] #[allow(trivial_casts)]
copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
@ -90,12 +92,10 @@ mod tests {
#[test] #[test]
fn test_apply_mask() { fn test_apply_mask() {
let mask = [ let mask = [0x6d, 0xb6, 0xb2, 0x80];
0x6d, 0xb6, 0xb2, 0x80,
];
let unmasked = vec![ let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, 0x12, 0x03,
]; ];
// Check masking with proper alignment. // Check masking with proper alignment.
@ -120,6 +120,4 @@ mod tests {
assert_eq!(masked, masked_fast); assert_eq!(masked, masked_fast);
} }
} }
} }

@ -2,16 +2,17 @@
pub mod coding; pub mod coding;
#[allow(clippy::module_inception)]
mod frame; mod frame;
mod mask; mod mask;
pub use self::frame::{Frame, FrameHeader};
pub use self::frame::CloseFrame; pub use self::frame::CloseFrame;
pub use self::frame::{Frame, FrameHeader};
use std::io::{Read, Write}; use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result}; use log::*;
use std::io::{Read, Write};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
@ -56,7 +57,8 @@ impl<Stream> FrameSocket<Stream> {
} }
impl<Stream> FrameSocket<Stream> impl<Stream> FrameSocket<Stream>
where Stream: Read where
Stream: Read,
{ {
/// Read a frame from stream. /// Read a frame from stream.
pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> { pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
@ -65,7 +67,8 @@ impl<Stream> FrameSocket<Stream>
} }
impl<Stream> FrameSocket<Stream> impl<Stream> FrameSocket<Stream>
where Stream: Write where
Stream: Write,
{ {
/// Write a frame to stream. /// Write a frame to stream.
/// ///
@ -138,8 +141,8 @@ impl FrameCodec {
// is not too big (fits into `usize`). // is not too big (fits into `usize`).
if length > max_size as u64 { if length > max_size as u64 {
return Err(Error::Capacity( return Err(Error::Capacity(
format!("Message length too big: {} > {}", length, max_size).into() format!("Message length too big: {} > {}", length, max_size).into(),
)) ));
} }
let input_size = cursor.get_ref().len() as u64 - cursor.position(); let input_size = cursor.get_ref().len() as u64 - cursor.position();
@ -149,19 +152,21 @@ impl FrameCodec {
if length > 0 { if length > 0 {
cursor.take(length).read_to_end(&mut payload)?; cursor.take(length).read_to_end(&mut payload)?;
} }
break payload break payload;
} }
} }
} }
// Not enough data in buffer. // Not enough data in buffer.
let size = self.in_buffer.prepare_reserve(MIN_READ) let size = self
.in_buffer
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) .with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
.read_from(stream)?; .read_from(stream)?;
if size == 0 { if size == 0 {
trace!("no frame received"); trace!("no frame received");
return Ok(None) return Ok(None);
} }
}; };
@ -173,17 +178,15 @@ impl FrameCodec {
} }
/// Write a frame to the provided stream. /// Write a frame to the provided stream.
pub(super) fn write_frame<Stream>( pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
&mut self,
stream: &mut Stream,
frame: Frame,
) -> Result<()>
where where
Stream: Write, Stream: Write,
{ {
trace!("writing frame {}", frame); trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len()); self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending(stream) self.write_pending(stream)
} }
@ -211,16 +214,19 @@ mod tests {
#[test] #[test]
fn read_frames() { fn read_frames() {
let raw = Cursor::new(vec![ let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
0x82, 0x03, 0x03, 0x02, 0x01,
0x99, 0x99,
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), assert_eq!(
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); sock.read_frame(None).unwrap().unwrap().into_data(),
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
vec![0x03, 0x02, 0x01]); );
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();
@ -229,12 +235,12 @@ mod tests {
#[test] #[test]
fn from_partially_read() { fn from_partially_read() {
let raw = Cursor::new(vec![ let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(), assert_eq!(
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
} }
#[test] #[test]
@ -248,17 +254,13 @@ mod tests {
sock.write_frame(frame).unwrap(); sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner(); let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![ assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
0x89, 0x02, 0x04, 0x05,
0x8a, 0x01, 0x01
]);
} }
#[test] #[test]
fn parse_overflow() { fn parse_overflow() {
let raw = Cursor::new(vec![ let raw = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
let _ = sock.read_frame(None); // should not crash let _ = sock.read_frame(None); // should not crash
@ -266,11 +268,10 @@ mod tests {
#[test] #[test]
fn size_limit_hit() { fn size_limit_hit() {
let raw = Cursor::new(vec![ let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(), assert_eq!(
sock.read_frame(Some(5)).unwrap_err().to_string(),
"Space limit exceeded: Message length too big: 7 > 5" "Space limit exceeded: Message length too big: 7 > 5"
); );
} }

@ -1,17 +1,17 @@
use std::convert::{From, Into, AsRef}; use std::convert::{AsRef, From, Into};
use std::fmt; use std::fmt;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::str; use std::str;
use error::{Result, Error};
use super::frame::CloseFrame; use super::frame::CloseFrame;
use crate::error::{Error, Result};
mod string_collect { mod string_collect {
use utf8; use utf8;
use utf8::DecodeError; use utf8::DecodeError;
use error::{Error, Result}; use crate::error::{Error, Result};
#[derive(Debug)] #[derive(Debug)]
pub struct StringCollector { pub struct StringCollector {
@ -28,7 +28,8 @@ mod string_collect {
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.data.len() self.data
.len()
.saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0)) .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
} }
@ -41,7 +42,7 @@ mod string_collect {
if let Ok(text) = result { if let Ok(text) = result {
self.data.push_str(text); self.data.push_str(text);
} else { } else {
return Err(Error::Utf8) return Err(Error::Utf8);
} }
true true
} else { } else {
@ -59,7 +60,10 @@ mod string_collect {
self.data.push_str(text); self.data.push_str(text);
Ok(()) Ok(())
} }
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => { Err(DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => {
self.data.push_str(valid_prefix); self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix); self.incomplete = Some(incomplete_suffix);
Ok(()) Ok(())
@ -82,7 +86,6 @@ mod string_collect {
} }
} }
} }
} }
use self::string_collect::StringCollector; use self::string_collect::StringCollector;
@ -104,11 +107,11 @@ impl IncompleteMessage {
pub fn new(message_type: IncompleteMessageType) -> Self { pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage { IncompleteMessage {
collector: match message_type { collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageCollector::Binary(Vec::new()), IncompleteMessageType::Text => {
IncompleteMessageType::Text => IncompleteMessageCollector::Text(StringCollector::new())
IncompleteMessageCollector::Text(StringCollector::new()),
} }
},
} }
} }
@ -130,8 +133,12 @@ impl IncompleteMessage {
// Be careful about integer overflows here. // Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size { if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity( return Err(Error::Capacity(
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into() format!(
)) "Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
));
} }
match self.collector { match self.collector {
@ -139,18 +146,14 @@ impl IncompleteMessage {
v.extend(tail.as_ref()); v.extend(tail.as_ref());
Ok(()) Ok(())
} }
IncompleteMessageCollector::Text(ref mut t) => { IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
t.extend(tail)
}
} }
} }
/// Convert an incomplete message into a complete one. /// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> { pub fn complete(self) -> Result<Message> {
match self.collector { match self.collector {
IncompleteMessageCollector::Binary(v) => { IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
Ok(Message::Binary(v))
}
IncompleteMessageCollector::Text(t) => { IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?; let text = t.into_string()?;
Ok(Message::Text(text)) Ok(Message::Text(text))
@ -185,17 +188,18 @@ pub enum Message {
} }
impl Message { impl Message {
/// Create a new text WebSocket message from a stringable. /// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message pub fn text<S>(string: S) -> Message
where S: Into<String> where
S: Into<String>,
{ {
Message::Text(string.into()) Message::Text(string.into())
} }
/// Create a new binary WebSocket message by converting to Vec<u8>. /// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message pub fn binary<B>(bin: B) -> Message
where B: Into<Vec<u8>> where
B: Into<Vec<u8>>,
{ {
Message::Binary(bin.into()) Message::Binary(bin.into())
} }
@ -244,9 +248,9 @@ impl Message {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match *self { match *self {
Message::Text(ref string) => string.len(), Message::Text(ref string) => string.len(),
Message::Binary(ref data) | Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
Message::Ping(ref data) | data.len()
Message::Pong(ref data) => data.len(), }
Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
} }
} }
@ -261,9 +265,7 @@ impl Message {
pub fn into_data(self) -> Vec<u8> { pub fn into_data(self) -> Vec<u8> {
match self { match self {
Message::Text(string) => string.into_bytes(), Message::Text(string) => string.into_bytes(),
Message::Binary(data) | Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
Message::Ping(data) |
Message::Pong(data) => data,
Message::Close(None) => Vec::new(), Message::Close(None) => Vec::new(),
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
} }
@ -273,10 +275,9 @@ impl Message {
pub fn into_text(self) -> Result<String> { pub fn into_text(self) -> Result<String> {
match self { match self {
Message::Text(string) => Ok(string), Message::Text(string) => Ok(string),
Message::Binary(data) | Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
Message::Ping(data) | Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?)
Message::Pong(data) => Ok(try!( }
String::from_utf8(data).map_err(|err| err.utf8_error()))),
Message::Close(None) => Ok(String::new()), Message::Close(None) => Ok(String::new()),
Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
} }
@ -287,14 +288,13 @@ impl Message {
pub fn to_text(&self) -> Result<&str> { pub fn to_text(&self) -> Result<&str> {
match *self { match *self {
Message::Text(ref string) => Ok(string), Message::Text(ref string) => Ok(string),
Message::Binary(ref data) | Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
Message::Ping(ref data) | Ok(str::from_utf8(data)?)
Message::Pong(ref data) => Ok(try!(str::from_utf8(data))), }
Message::Close(None) => Ok(""), Message::Close(None) => Ok(""),
Message::Close(Some(ref frame)) => Ok(&frame.reason), Message::Close(Some(ref frame)) => Ok(&frame.reason),
} }
} }
} }
impl From<String> for Message { impl From<String> for Message {
@ -358,7 +358,6 @@ mod tests {
assert!(msg.into_text().is_err()); assert!(msg.into_text().is_err());
} }
#[test] #[test]
fn binary_convert_vec() { fn binary_convert_vec() {
let bin = vec![6u8, 7, 8, 9, 10, 241]; let bin = vec![6u8, 7, 8, 9, 10, 241];

@ -4,18 +4,19 @@ pub mod frame;
mod message; mod message;
pub use self::message::Message;
pub use self::frame::CloseFrame; pub use self::frame::CloseFrame;
pub use self::message::Message;
use log::*;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io::{Read, Write, ErrorKind as IoErrorKind}; use std::io::{ErrorKind as IoErrorKind, Read, Write};
use std::mem::replace; use std::mem::replace;
use error::{Error, Result}; use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::frame::{Frame, FrameCodec}; use self::frame::{Frame, FrameCodec};
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode}; use self::message::{IncompleteMessage, IncompleteMessageType};
use util::NonBlockingResult; use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
/// Indicates a Client or Server role of the websocket /// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -147,7 +148,6 @@ impl<Stream: Read + Write> WebSocket<Stream> {
} }
} }
/// A context for managing WebSocket stream. /// A context for managing WebSocket stream.
#[derive(Debug)] #[derive(Debug)]
pub struct WebSocketContext { pub struct WebSocketContext {
@ -182,11 +182,7 @@ impl WebSocketContext {
} }
/// Create a WebSocket context that manages an post-handshake stream. /// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read( pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
WebSocketContext { WebSocketContext {
frame: FrameCodec::from_partially_read(part), frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config) ..WebSocketContext::new(role, config)
@ -217,7 +213,7 @@ impl WebSocketContext {
// Thus if read blocks, just let it return WouldBlock. // Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? { if let Some(message) = self.read_message_frame(stream)? {
trace!("Received message {}", message); trace!("Received message {}", message);
return Ok(message) return Ok(message);
} }
} }
} }
@ -251,20 +247,14 @@ impl WebSocketContext {
} }
let frame = match message { let frame = match message {
Message::Text(data) => { Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Frame::message(data.into(), OpCode::Data(OpData::Text), true) Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
}
Message::Binary(data) => {
Frame::message(data, OpCode::Data(OpData::Binary), true)
}
Message::Ping(data) => Frame::ping(data), Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => { Message::Pong(data) => {
self.pong = Some(Frame::pong(data)); self.pong = Some(Frame::pong(data));
return self.write_pending(stream) return self.write_pending(stream);
}
Message::Close(code) => {
return self.close(stream, code)
} }
Message::Close(code) => return self.close(stream, code),
}; };
self.send_queue.push_back(frame); self.send_queue.push_back(frame);
@ -342,7 +332,6 @@ impl WebSocketContext {
Stream: Read + Write, Stream: Read + Write,
{ {
if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? {
// MUST be 0 unless an extension is negotiated that defines meanings // MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of // for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero // the negotiated extensions defines the meaning of such a nonzero
@ -351,7 +340,7 @@ impl WebSocketContext {
{ {
let hdr = frame.header(); let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into())) return Err(Error::Protocol("Reserved bits are non-zero".into()));
} }
} }
@ -364,19 +353,22 @@ impl WebSocketContext {
} else { } else {
// The server MUST close the connection upon receiving a // The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455) // frame that is not masked. (RFC 6455)
return Err(Error::Protocol("Received an unmasked frame from client".into())) return Err(Error::Protocol(
"Received an unmasked frame from client".into(),
));
} }
} }
Role::Client => { Role::Client => {
if frame.is_masked() { if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455) // A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into())) return Err(Error::Protocol(
"Received a masked frame from server".into(),
));
} }
} }
} }
match frame.header().opcode { match frame.header().opcode {
OpCode::Control(ctl) => { OpCode::Control(ctl) => {
match ctl { match ctl {
// All control frames MUST have a payload length of 125 bytes or less // All control frames MUST have a payload length of 125 bytes or less
@ -387,12 +379,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into())) Err(Error::Protocol("Control frame too big".into()))
} }
OpCtl::Close => { OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
Ok(self.do_close(frame.into_close()?).map(Message::Close)) OpCtl::Reserved(i) => Err(Error::Protocol(
} format!("Unknown control frame type {}", i).into(),
OpCtl::Reserved(i) => { )),
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => { OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
// No ping processing while closing. // No ping processing while closing.
Ok(None) Ok(None)
@ -402,9 +392,7 @@ impl WebSocketContext {
self.pong = Some(Frame::pong(data.clone())); self.pong = Some(Frame::pong(data.clone()));
Ok(Some(Message::Ping(data))) Ok(Some(Message::Ping(data)))
} }
OpCtl::Pong => { OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
Ok(Some(Message::Pong(frame.into_data())))
}
} }
} }
@ -420,7 +408,9 @@ impl WebSocketContext {
if let Some(ref mut msg) = self.incomplete { if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?; msg.extend(frame.into_data(), self.config.max_message_size)?;
} else { } else {
return Err(Error::Protocol("Continue frame but nothing to continue".into())) return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
} }
if fin { if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?)) Ok(Some(self.incomplete.take().unwrap().complete()?))
@ -428,11 +418,9 @@ impl WebSocketContext {
Ok(None) Ok(None)
} }
} }
c if self.incomplete.is_some() => { c if self.incomplete.is_some() => Err(Error::Protocol(
Err(Error::Protocol( format!("Received {} while waiting for more fragments", c).into(),
format!("Received {} while waiting for more fragments", c).into() )),
))
}
OpData::Text | OpData::Binary => { OpData::Text | OpData::Binary => {
let msg = { let msg = {
let message_type = match data { let message_type = match data {
@ -451,28 +439,27 @@ impl WebSocketContext {
Ok(None) Ok(None)
} }
} }
OpData::Reserved(i) => { OpData::Reserved(i) => Err(Error::Protocol(
Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) format!("Unknown data frame type {}", i).into(),
} )),
} }
} }
} // match opcode } // match opcode
} else { } else {
// Connection closed by peer // Connection closed by peer
match replace(&mut self.state, WebSocketState::Terminated) { match replace(&mut self.state, WebSocketState::Terminated) {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed) Err(Error::ConnectionClosed)
} }
_ => { _ => Err(Error::Protocol(
Err(Error::Protocol("Connection reset without closing handshake".into())) "Connection reset without closing handshake".into(),
} )),
} }
} }
} }
/// Received a close frame. Tells if we need to return a close frame to the user. /// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> { fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
debug!("Received close frame: {:?}", close); debug!("Received close frame: {:?}", close);
match self.state { match self.state {
@ -488,7 +475,7 @@ impl WebSocketContext {
} else { } else {
Frame::close(Some(CloseFrame { Frame::close(Some(CloseFrame {
code: CloseCode::Protocol, code: CloseCode::Protocol,
reason: "Protocol violation".into() reason: "Protocol violation".into(),
})) }))
} }
} else { } else {
@ -518,8 +505,7 @@ impl WebSocketContext {
Stream: Read + Write, Stream: Read + Write,
{ {
match self.role { match self.role {
Role::Server => { Role::Server => {}
}
Role::Client => { Role::Client => {
// 5. If the data is being sent by the client, the frame(s) MUST be // 5. If the data is being sent by the client, the frame(s) MUST be
// masked as defined in Section 5.3. (RFC 6455) // masked as defined in Section 5.3. (RFC 6455)
@ -535,7 +521,9 @@ impl WebSocketContext {
match self.state { match self.state {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged
if err.kind() == IoErrorKind::ConnectionReset => if err.kind() == IoErrorKind::ConnectionReset =>
Error::ConnectionClosed, {
Error::ConnectionClosed
}
_ => Error::Io(err), _ => Error::Io(err),
} }
}), }),
@ -544,7 +532,6 @@ impl WebSocketContext {
} }
} }
/// The current connection state. /// The current connection state.
#[derive(Debug)] #[derive(Debug)]
enum WebSocketState { enum WebSocketState {
@ -580,7 +567,7 @@ impl WebSocketState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{WebSocket, Role, Message, WebSocketConfig}; use super::{Message, Role, WebSocket, WebSocketConfig};
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;
@ -602,57 +589,53 @@ mod tests {
} }
} }
#[test] #[test]
fn receive_messages() { fn receive_messages() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![
0x89, 0x02, 0x01, 0x02, 0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
0x8a, 0x01, 0x03, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x01, 0x07, 0x03,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x82, 0x03,
0x01, 0x02, 0x03,
]); ]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2])); assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3])); assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into())); assert_eq!(
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); socket.read_message().unwrap(),
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
} }
#[test] #[test]
fn size_limiting_text_fragmented() { fn size_limiting_text_fragmented() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![
0x01, 0x07, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x6c, 0x64, 0x21,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
]); ]);
let limit = WebSocketConfig { let limit = WebSocketConfig {
max_message_size: Some(10), max_message_size: Some(10),
.. WebSocketConfig::default() ..WebSocketConfig::default()
}; };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(), assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 7 + 6 > 10" "Space limit exceeded: Message too big: 7 + 6 > 10"
); );
} }
#[test] #[test]
fn size_limiting_binary() { fn size_limiting_binary() {
let incoming = Cursor::new(vec![ let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let limit = WebSocketConfig { let limit = WebSocketConfig {
max_message_size: Some(2), max_message_size: Some(2),
.. WebSocketConfig::default() ..WebSocketConfig::default()
}; };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(), assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2" "Space limit exceeded: Message too big: 0 + 3 > 2"
); );
} }

@ -1,11 +1,11 @@
//! Methods to accept an incoming WebSocket connection on a server. //! Methods to accept an incoming WebSocket connection on a server.
pub use handshake::server::ServerHandshake; pub use crate::handshake::server::ServerHandshake;
use handshake::HandshakeError; use crate::handshake::server::{Callback, NoCallback};
use handshake::server::{Callback, NoCallback}; use crate::handshake::HandshakeError;
use protocol::{WebSocket, WebSocketConfig}; use crate::protocol::{WebSocket, WebSocketConfig};
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -18,9 +18,10 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketConfig>) pub fn accept_with_config<S: Read + Write>(
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> stream: S,
{ config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_hdr_with_config(stream, NoCallback, config) accept_hdr_with_config(stream, NoCallback, config)
} }
@ -30,9 +31,9 @@ pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketCo
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` /// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including /// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others. /// those from `Mio` and others.
pub fn accept<S: Read + Write>(stream: S) pub fn accept<S: Read + Write>(
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> stream: S,
{ ) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_with_config(stream, None) accept_with_config(stream, None)
} }
@ -47,7 +48,7 @@ pub fn accept<S: Read + Write>(stream: S)
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>( pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
stream: S, stream: S,
callback: C, callback: C,
config: Option<WebSocketConfig> config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> { ) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback, config).handshake() ServerHandshake::start(stream, callback, config).handshake()
} }
@ -57,8 +58,9 @@ pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
/// This function does the same as `accept()` but accepts an extra callback /// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming /// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply. /// requests and is able to add extra headers to the reply.
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C) pub fn accept_hdr<S: Read + Write, C: Callback>(
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> stream: S,
{ callback: C,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
accept_hdr_with_config(stream, callback, None) accept_hdr_with_config(stream, callback, None)
} }

@ -4,11 +4,11 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits. //! `Read + Write` traits.
use std::io::{Read, Write, Result as IoResult}; use std::io::{Read, Result as IoResult, Write};
use std::net::TcpStream; use std::net::TcpStream;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
use native_tls::TlsStream; use native_tls::TlsStream;
/// Stream mode, either plain TCP or TLS. /// Stream mode, either plain TCP or TLS.
@ -32,7 +32,7 @@ impl NoDelay for TcpStream {
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> { impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().set_nodelay(nodelay) self.get_mut().set_nodelay(nodelay)

@ -3,7 +3,7 @@
use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use error::Error; use crate::error::Error;
/// Non-blocking IO handling. /// Non-blocking IO handling.
pub trait NonBlockingError: Sized { pub trait NonBlockingError: Sized {
@ -40,7 +40,8 @@ pub trait NonBlockingResult {
} }
impl<T, E> NonBlockingResult for StdResult<T, E> impl<T, E> NonBlockingResult for StdResult<T, E>
where E : NonBlockingError where
E: NonBlockingError,
{ {
type Result = StdResult<Option<T>, E>; type Result = StdResult<Option<T>, E>;
fn no_block(self) -> Self::Result { fn no_block(self) -> Self::Result {
@ -49,7 +50,7 @@ impl<T, E> NonBlockingResult for StdResult<T, E>
Err(e) => match e.into_non_blocking() { Err(e) => match e.into_non_blocking() {
Some(e) => Err(e), Some(e) => Err(e),
None => Ok(None), None => Ok(None),
} },
} }
} }
} }

@ -1,13 +1,9 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection //! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket. //! is closedd from the server's point of view and drop the underlying tcp socket.
extern crate env_logger;
extern crate tungstenite;
extern crate url;
use std::net::TcpListener; use std::net::TcpListener;
use std::process::exit; use std::process::exit;
use std::thread::{spawn, sleep}; use std::thread::{sleep, spawn};
use std::time::Duration; use std::time::Duration;
use tungstenite::{accept, connect, Error, Message}; use tungstenite::{accept, connect, Error, Message};
@ -28,14 +24,16 @@ fn test_close() {
let client_thread = spawn(move || { let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()).unwrap(); let (mut client, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()).unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap(); client
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = client.read_message().unwrap(); // receive close from server let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close()); assert!(message.is_close());
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => { }, Error::ConnectionClosed => {}
_ => panic!("unexpected error"), _ => panic!("unexpected error"),
} }
}); });
@ -52,7 +50,7 @@ fn test_close() {
let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
match err { match err {
Error::ConnectionClosed => { }, Error::ConnectionClosed => {}
_ => panic!("unexpected error"), _ => panic!("unexpected error"),
} }

Loading…
Cancel
Save