Merge pull request #93 from sdroege/http-types

Base HTTP-types (request, headers, response, status code, etc) on the ones from the http crate
pull/95/head
Daniel Abramov 5 years ago committed by GitHub
commit 345d262972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      Cargo.toml
  2. 4
      examples/autobahn-client.rs
  3. 14
      examples/callback-error.rs
  4. 4
      examples/client.rs
  5. 21
      examples/server.rs
  6. 148
      src/client.rs
  7. 49
      src/error.rs
  8. 163
      src/handshake/client.rs
  9. 124
      src/handshake/headers.rs
  10. 252
      src/handshake/server.rs
  11. 8
      src/protocol/frame/mod.rs
  12. 2
      src/protocol/message.rs
  13. 17
      src/protocol/mod.rs
  14. 7
      tests/no_send_after_close.rs

@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.9.3" documentation = "https://docs.rs/tungstenite/0.10.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.9.3" version = "0.10.0"
edition = "2018" edition = "2018"
[features] [features]

@ -3,7 +3,7 @@ use url::Url;
use tungstenite::{connect, Error, Message, Result}; use tungstenite::{connect, Error, Message, Result};
const AGENT: &'static str = "Tungstenite"; const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> { fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
@ -47,7 +47,7 @@ fn main() {
let total = get_case_count().unwrap(); let total = get_case_count().unwrap();
for case in 1..(total + 1) { for case in 1..=total {
if let Err(e) = run_test(case) { if let Err(e) = run_test(case) {
match e { match e {
Error::Protocol(_) => {} Error::Protocol(_) => {}

@ -2,19 +2,19 @@ use std::net::TcpListener;
use std::thread::spawn; use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::accept_hdr;
use tungstenite::handshake::server::{ErrorResponse, Request}; use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::StatusCode; use tungstenite::http::StatusCode;
fn main() { fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap(); let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() { for stream in server.incoming() {
spawn(move || { spawn(move || {
let callback = |_req: &Request| { let callback = |_req: &Request, _resp| {
Err(ErrorResponse { let resp = Response::builder()
error_code: StatusCode::FORBIDDEN, .status(StatusCode::FORBIDDEN)
headers: None, .body(Some("Access denied".into()))
body: Some("Access denied".into()), .unwrap();
}) Err(resp)
}; };
accept_hdr(stream.unwrap(), callback).unwrap_err(); accept_hdr(stream.unwrap(), callback).unwrap_err();
}); });

@ -8,9 +8,9 @@ fn main() {
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect"); connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");
println!("Connected to the server"); println!("Connected to the server");
println!("Response HTTP code: {}", response.code); println!("Response HTTP code: {}", response.status());
println!("Response contains the following headers:"); println!("Response contains the following headers:");
for &(ref header, _ /*value*/) in response.headers.iter() { for (ref header, _value) in response.headers() {
println!("* {}", header); println!("* {}", header);
} }

@ -2,30 +2,27 @@ use std::net::TcpListener;
use std::thread::spawn; use std::thread::spawn;
use tungstenite::accept_hdr; use tungstenite::accept_hdr;
use tungstenite::handshake::server::Request; use tungstenite::handshake::server::{Request, Response};
fn main() { fn main() {
env_logger::init(); env_logger::init();
let server = TcpListener::bind("127.0.0.1:3012").unwrap(); let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() { for stream in server.incoming() {
spawn(move || { spawn(move || {
let callback = |req: &Request| { let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake"); println!("Received a new ws handshake");
println!("The request's path is: {}", req.path); println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:"); println!("The request's headers are:");
for &(ref header, _ /* value */) in req.headers.iter() { for (ref header, _value) in req.headers() {
println!("* {}", header); println!("* {}", header);
} }
// Let's add an additional header to our response to the client. // Let's add an additional header to our response to the client.
let extra_headers = vec![ let headers = response.headers_mut();
(String::from("MyCustomHeader"), String::from(":)")), headers.append("MyCustomHeader", ":)".parse().unwrap());
( headers.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"), Ok(response)
),
];
Ok(Some(extra_headers))
}; };
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();

@ -4,10 +4,12 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::Uri;
use log::*; use log::*;
use url::Url; use url::Url;
use crate::handshake::client::Response; use crate::handshake::client::{Request, Response};
use crate::protocol::WebSocketConfig; use crate::protocol::WebSocketConfig;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
@ -64,7 +66,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request}; use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError; use crate::handshake::HandshakeError;
use crate::protocol::WebSocket; use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay}; use crate::stream::{Mode, NoDelay};
@ -84,37 +86,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>( pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into(); let request: Request = request.into_client_request()?;
let mode = url_mode(&request.url)?; let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request let host = request
.url .uri()
.host() .host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; .ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request let port = uri.port_u16().unwrap_or(match mode {
.url Mode::Plain => 80,
.port_or_known_default() Mode::Tls => 443,
.ok_or_else(|| Error::Url("No port number in the URL".into()))?; });
let addrs; let addrs = (host, port).to_socket_addrs()?;
let addr; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e { client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
@ -134,35 +122,33 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>( pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None) connect_with_config(request, None)
} }
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = url let domain = uri
.host_str() .host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; .ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream); return Ok(stream);
} }
} }
} }
Err(Error::Url(format!("Unable to connect to {}", url).into())) Err(Error::Url(format!("Unable to connect to {}", uri).into()))
} }
/// Get the mode of the given URL. /// Get the mode of the given URL.
/// ///
/// This function may be used to ease the creation of custom TLS streams /// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`. /// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> { pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match url.scheme() { match uri.scheme_str() {
"ws" => Ok(Mode::Plain), Some("ws") => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls), Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())), _ => Err(Error::Url("URL scheme not supported".into())),
} }
} }
@ -173,16 +159,16 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>( pub fn client_with_config<Stream, Req>(
request: Req, request: Req,
stream: Stream, stream: Stream,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: IntoClientRequest,
{ {
ClientHandshake::start(stream, request.into(), config).handshake() ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
} }
/// Do the client handshake over the given stream. /// Do the client handshake over the given stream.
@ -190,13 +176,87 @@ where
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>( pub fn client<Stream, Req>(
request: Req, request: Req,
stream: Stream, stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: IntoClientRequest,
{ {
client_with_config(request, stream, None) client_with_config(request, stream, None)
} }
/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request` that can be used for a client connection.
fn into_client_request(self) -> Result<Request>;
}
impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self.clone()).body(())?)
}
}
impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Request {
fn into_client_request(self) -> Result<Request> {
Ok(self)
}
}
impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request> {
use crate::handshake::headers::FromHttparse;
Request::from_httparse(self)
}
}

@ -9,6 +9,7 @@ use std::result;
use std::str; use std::str;
use std::string; use std::string;
use http;
use httparse; use httparse;
use crate::protocol::Message; use crate::protocol::Message;
@ -45,7 +46,7 @@ pub enum Error {
/// connection when it really shouldn't anymore, so this really indicates a programmer /// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part. /// error on your part.
AlreadyClosed, AlreadyClosed,
/// Input-output error. Appart from WouldBlock, these are generally errors with the /// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal. /// underlying connection and you should probably consider them fatal.
Io(io::Error), Io(io::Error),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
@ -61,10 +62,12 @@ pub enum Error {
SendQueueFull(Message), SendQueueFull(Message),
/// UTF coding error /// UTF coding error
Utf8, Utf8,
/// Invlid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error. /// HTTP error.
Http(u16), Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
} }
impl fmt::Display for Error { impl fmt::Display for Error {
@ -80,7 +83,8 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg), Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code), Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
} }
} }
} }
@ -99,6 +103,7 @@ impl ErrorTrait for Error {
Error::Utf8 => "", Error::Utf8 => "",
Error::Url(ref msg) => msg.borrow(), Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "", Error::Http(_) => "",
Error::HttpFormat(ref err) => err.description(),
} }
} }
} }
@ -121,6 +126,42 @@ impl From<string::FromUtf8Error> for Error {
} }
} }
impl From<http::header::InvalidHeaderValue> for Error {
fn from(err: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::InvalidHeaderName> for Error {
fn from(err: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::header::ToStrError> for Error {
fn from(_: http::header::ToStrError) -> Self {
Error::Utf8
}
}
impl From<http::uri::InvalidUri> for Error {
fn from(err: http::uri::InvalidUri) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::status::InvalidStatusCode> for Error {
fn from(err: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(err.into())
}
}
impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
impl From<tls::Error> for Error { impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self { fn from(err: tls::Error) -> Self {

@ -1,69 +1,23 @@
//! Client handshake machine. //! Client handshake machine.
use std::borrow::Cow;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use url::Url;
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request. /// Client request type.
#[derive(Debug)] pub type Request = HttpRequest<()>;
pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to.
pub url: Url,
/// Extra HTTP headers to append to the request.
pub extra_headers: Option<Vec<(Cow<'t, str>, Cow<'t, str>)>>,
}
impl<'t> Request<'t> {
/// Returns the GET part of the request.
fn get_path(&self) -> String {
if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query)
} else {
self.url.path().into()
}
}
/// Returns the host part of the request.
fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() {
format!("{host}:{port}", host = host, port = port)
} else {
host.into()
}
}
/// Adds a WebSocket protocol to the request.
pub fn add_protocol(&mut self, protocol: Cow<'t, str>) {
self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol);
}
/// Adds a custom header to the request. /// Client response type.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { pub type Response = HttpResponse<()>;
let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new);
headers.push((name, value));
self.extra_headers = Some(headers);
}
}
impl From<Url> for Request<'static> {
fn from(value: Url) -> Self {
Request {
url: value,
extra_headers: None,
}
}
}
/// Client handshake role. /// Client handshake role.
#[derive(Debug)] #[derive(Debug)]
@ -79,29 +33,49 @@ impl<S: Read + Write> ClientHandshake<S> {
stream: S, stream: S,
request: Request, request: Request,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> MidHandshake<Self> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
"Invalid HTTP method, only GET supported".into(),
));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
// Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?;
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri();
write!( write!(
req, req,
"\ "\
GET {path} HTTP/1.1\r\n\ GET {path} {version:?}\r\n\
Host: {host}\r\n\ Host: {host}\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), version = request.version(),
path = request.get_path(), host = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?,
path = uri
.path_and_query()
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
if let Some(eh) = request.extra_headers { for (k, v) in request.headers() {
for (k, v) in eh { writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
writeln!(req, "{}: {}\r", k, v).unwrap();
}
} }
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
HandshakeMachine::start_write(stream, req) HandshakeMachine::start_write(stream, req)
@ -117,10 +91,10 @@ impl<S: Read + Write> ClientHandshake<S> {
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
MidHandshake { Ok(MidHandshake {
role: client, role: client,
machine, machine,
} })
} }
} }
@ -162,16 +136,20 @@ impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> { pub fn verify_response(&self, response: &Response) -> Result<()> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.code != 101 { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.code)); return Err(Error::Http(response.status()));
} }
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
// header field contains a value that is not an ASCII case- // header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST // insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455) // _Fail the WebSocket Connection_. (RFC 6455)
if !response if !headers
.headers .get("Upgrade")
.header_is_ignore_case("Upgrade", "websocket") .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(), "No \"Upgrade: websocket\" in server reply".into(),
@ -181,9 +159,11 @@ impl VerifyData {
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client // ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response if !headers
.headers .get("Connection")
.header_is_ignore_case("Connection", "Upgrade") .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(), "No \"Connection: upgrade\" in server reply".into(),
@ -193,9 +173,10 @@ impl VerifyData {
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !response if !headers
.headers .get("Sec-WebSocket-Accept")
.header_is("Sec-WebSocket-Accept", &self.accept_key) .map(|h| h == &self.accept_key)
.unwrap_or(false)
{ {
return Err(Error::Protocol( return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(), "Key mismatch in Sec-WebSocket-Accept".into(),
@ -219,15 +200,6 @@ impl VerifyData {
} }
} }
/// Server response.
#[derive(Debug)]
pub struct Response {
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: Headers,
}
impl TryParse for Response { impl TryParse for Response {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -246,10 +218,17 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
"HTTP version should be 1.1 or higher".into(), "HTTP version should be 1.1 or higher".into(),
)); ));
} }
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"), let headers = HeaderMap::from_httparse(raw.headers)?;
headers: Headers::from_httparse(raw.headers)?,
}) let mut response = Response::new(());
*response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?;
*response.headers_mut() = headers;
// TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
// so the only valid value we could get in the response would be 1.1.
*response.version_mut() = http::Version::HTTP_11;
Ok(response)
} }
} }
@ -278,18 +257,18 @@ mod tests {
assert_eq!(k2.len(), 24); assert_eq!(k2.len(), 24);
assert!(k1.ends_with("==")); assert!(k1.ends_with("=="));
assert!(k2.ends_with("==")); assert!(k2.ends_with("=="));
assert!(k1[..22].find("=").is_none()); assert!(k1[..22].find('=').is_none());
assert!(k2[..22].find("=").is_none()); assert!(k2[..22].find('=').is_none());
} }
#[test] #[test]
fn response_parsing() { fn response_parsing() {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200); assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!( assert_eq!(
resp.headers.find_first("Content-Type"), resp.headers().get("Content-Type").unwrap(),
Some(&b"text/html"[..]) &b"text/html"[..],
); );
} }
} }

@ -1,8 +1,7 @@
//! HTTP Request and response header handling. //! HTTP Request and response header handling.
use std::slice; use http;
use std::str::from_utf8; use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse; use httparse;
use httparse::Status; use httparse::Status;
@ -12,90 +11,31 @@ use crate::error::Result;
/// Limit for the number of header lines. /// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124; pub const MAX_HEADERS: usize = 124;
/// HTTP request or response headers. /// Trait to convert raw objects into HTTP parseables.
#[derive(Debug)] pub(crate) trait FromHttparse<T>: Sized {
pub struct Headers { /// Convert raw object into parsed HTTP headers.
data: Vec<(String, Box<[u8]>)>, fn from_httparse(raw: T) -> Result<Self>;
} }
impl Headers { impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap {
/// Get first header with the given name, if any. fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
pub fn find_first(&self, name: &str) -> Option<&[u8]> { let mut headers = HeaderMap::new();
self.find(name).next() for h in raw {
} headers.append(
HeaderName::from_bytes(h.name.as_bytes())?,
/// Iterate over all headers with the given name. HeaderValue::from_bytes(h.value)?,
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { );
HeadersIter {
name,
iter: self.data.iter(),
} }
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers. Ok(headers)
#[derive(Debug)]
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value);
}
}
None
} }
} }
impl TryParse for HeaderMap {
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}
impl TryParse for Headers {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None, Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
}) })
} }
} }
@ -104,45 +44,41 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
mod tests { mod tests {
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::Headers; use super::HeaderMap;
#[test] #[test]
fn headers() { fn headers() {
const DATA: &'static [u8] = b"Host: foo.com\r\n\ const DATA: &[u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
\r\n"; \r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]);
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]);
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]);
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
} }
#[test] #[test]
fn headers_iter() { fn headers_iter() {
const DATA: &'static [u8] = b"Host: foo.com\r\n\ const DATA: &[u8] = b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
\r\n"; \r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions"); let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter();
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]);
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]);
assert_eq!(iter.next(), None); assert_eq!(iter.next(), None);
} }
#[test] #[test]
fn headers_incomplete() { fn headers_incomplete() {
const DATA: &'static [u8] = b"Host: foo.com\r\n\ const DATA: &[u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n"; Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap(); let hdr = HeaderMap::try_parse(DATA).unwrap();
assert!(hdr.is_none()); assert!(hdr.is_none());
} }
} }

@ -1,56 +1,108 @@
//! Server handshake machine. //! Server handshake machine.
use std::fmt::Write as FmtWrite; use std::io::{self, Read, Write};
use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::StatusCode; use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use super::headers::{FromHttparse, Headers, MAX_HEADERS}; use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Request from the client. /// Server request type.
#[derive(Debug)] pub type Request = HttpRequest<()>;
pub struct Request {
/// Path part of the URL.
pub path: String,
/// HTTP headers.
pub headers: Headers,
}
impl Request { /// Server response type.
/// Reply to the response. pub type Response = HttpResponse<()>;
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self /// Server error response type.
.headers pub type ErrorResponse = HttpResponse<Option<String>>;
.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; /// Create a response for the request.
let mut reply = format!( pub fn create_response(request: &Request) -> Result<Response> {
"\ if request.method() != http::Method::GET {
HTTP/1.1 101 Switching Protocols\r\n\ return Err(Error::Protocol("Method is not GET".into()));
Connection: Upgrade\r\n\ }
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n", if request.version() < http::Version::HTTP_11 {
convert_key(key)? return Err(Error::Protocol(
); "HTTP version should be 1.1 or higher".into(),
add_headers(&mut reply, extra_headers); ));
Ok(reply.into()) }
if !request
.headers()
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in client request".into(),
));
}
if !request
.headers()
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in client request".into(),
));
}
if !request
.headers()
.get("Sec-WebSocket-Version")
.map(|h| h == "13")
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Sec-WebSocket-Version: 13\" in client request".into(),
));
} }
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut response = Response::builder();
response.status(StatusCode::SWITCHING_PROTOCOLS);
response.version(request.version());
response.header("Connection", "Upgrade");
response.header("Upgrade", "websocket");
response.header("Sec-WebSocket-Accept", convert_key(key.as_bytes())?);
Ok(response.body(())?)
} }
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { // Assumes that this is a valid response
if let Some(eh) = extra_headers { fn write_response<T>(w: &mut dyn io::Write, response: &HttpResponse<T>) -> Result<()> {
for (k, v) in eh { writeln!(
writeln!(reply, "{}: {}\r", k, v).unwrap(); w,
} "{version:?} {status} {reason}\r",
version = response.version(),
status = response.status(),
reason = response.status().canonical_reason().unwrap_or(""),
)?;
for (k, v) in response.headers() {
writeln!(w, "{}: {}\r", k, v.to_str()?).unwrap();
} }
writeln!(reply, "\r").unwrap();
writeln!(w, "\r")?;
Ok(())
} }
impl TryParse for Request { impl TryParse for Request {
@ -69,39 +121,24 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
if raw.method.expect("Bug: no method in header") != "GET" { if raw.method.expect("Bug: no method in header") != "GET" {
return Err(Error::Protocol("Method is not GET".into())); return Err(Error::Protocol("Method is not GET".into()));
} }
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(), "HTTP version should be 1.1 or higher".into(),
)); ));
} }
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?,
})
}
}
/// Extra headers for responses. let headers = HeaderMap::from_httparse(raw.headers)?;
pub type ExtraHeaders = Vec<(String, String)>;
/// An error response sent to the client. let mut request = Request::new(());
#[derive(Debug)] *request.method_mut() = http::Method::GET;
pub struct ErrorResponse { *request.headers_mut() = headers;
/// HTTP error code. *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
pub error_code: StatusCode, // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
/// Extra response headers, if any. // so the only valid value we could get in the response would be 1.1.
pub headers: Option<ExtraHeaders>, *request.version_mut() = http::Version::HTTP_11;
/// Response body, if any.
pub body: Option<String>,
}
impl From<StatusCode> for ErrorResponse { Ok(request)
fn from(error_code: StatusCode) -> Self {
ErrorResponse {
error_code,
headers: None,
body: None,
}
} }
} }
@ -115,15 +152,23 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; fn on_request(
self,
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>, F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
{ {
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { fn on_request(
self(request) self,
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
self(request, response)
} }
} }
@ -132,8 +177,12 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { fn on_request(
Ok(None) self,
_request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
Ok(response)
} }
} }
@ -191,34 +240,35 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol("Junk after client request".into())); return Err(Error::Protocol("Junk after client request".into()));
} }
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result) callback.on_request(&result, response)
} else { } else {
Ok(None) Ok(response)
}; };
match callback_result { match callback_result {
Ok(extra_headers) => { Ok(response) => {
let response = result.reply(extra_headers)?; let mut output = vec![];
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) write_response(&mut output, &response)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
Err(ErrorResponse { Err(resp) => {
error_code, if resp.status().is_success() {
headers, return Err(Error::Protocol(
body, "Custom response must not be successful".into(),
}) => { ));
self.error_code = Some(error_code.as_u16());
let mut response = format!(
"HTTP/1.1 {} {}\r\n",
error_code.as_str(),
error_code.canonical_reason().unwrap_or("")
);
add_headers(&mut response, headers);
if let Some(body) = body {
response += &body;
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
self.error_code = Some(resp.status().as_u16());
let mut output = vec![];
write_response(&mut output, &resp)?;
if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes());
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
} }
} }
@ -226,7 +276,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_code.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::Http(err)); return Err(Error::Http(StatusCode::from_u16(err)?));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
@ -239,21 +289,21 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::client::Response;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::create_response;
use super::Request; use super::Request;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
assert_eq!(req.path, "/script.ws"); assert_eq!(req.uri().path(), "/script.ws");
assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
} }
#[test] #[test]
fn request_replying() { fn request_replying() {
const DATA: &'static [u8] = b"\ const DATA: &[u8] = b"\
GET /script.ws HTTP/1.1\r\n\ GET /script.ws HTTP/1.1\r\n\
Host: foo.com\r\n\ Host: foo.com\r\n\
Connection: upgrade\r\n\ Connection: upgrade\r\n\
@ -262,21 +312,11 @@ mod tests {
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
\r\n"; \r\n";
let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap(); let response = create_response(&req).unwrap();
let extra_headers = Some(vec![
(
String::from("MyCustomHeader"),
String::from("MyCustomValue"),
),
(String::from("MyVersion"), String::from("LOL")),
]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!( assert_eq!(
req.headers.find_first("MyCustomHeader"), response.headers().get("Sec-WebSocket-Accept").unwrap(),
Some(b"MyCustomValue".as_ref()) b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
); );
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
} }
} }

@ -12,7 +12,7 @@ pub use self::frame::{Frame, FrameHeader};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ}; use input_buffer::{InputBuffer, MIN_READ};
use log::*; use log::*;
use std::io::{Read, Write, Error as IoError, ErrorKind as IoErrorKind}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
@ -199,7 +199,11 @@ impl FrameCodec {
let len = stream.write(&self.out_buffer)?; let len = stream.write(&self.out_buffer)?;
if len == 0 { if len == 0 {
// This is the same as "Connection reset by peer" // This is the same as "Connection reset by peer"
return Err(IoError::new(IoErrorKind::ConnectionReset, "Connection reset while sending").into()) return Err(IoError::new(
IoErrorKind::ConnectionReset,
"Connection reset while sending",
)
.into());
} }
self.out_buffer.drain(0..len); self.out_buffer.drain(0..len);
} }

@ -343,7 +343,7 @@ mod tests {
#[test] #[test]
fn display() { fn display() {
let t = Message::text(format!("test")); let t = Message::text("test".to_owned());
assert_eq!(t.to_string(), "test".to_owned()); assert_eq!(t.to_string(), "test".to_owned());
let bin = Message::binary(vec![0, 1, 3, 4, 241]); let bin = Message::binary(vec![0, 1, 3, 4, 241]);

@ -280,7 +280,9 @@ impl WebSocketContext {
// Do not write after sending a close frame. // Do not write after sending a close frame.
if !self.state.is_active() { if !self.state.is_active() {
return Err(Error::Protocol("Sending after closing is not allowed".into())); return Err(Error::Protocol(
"Sending after closing is not allowed".into(),
));
} }
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
@ -378,7 +380,9 @@ impl WebSocketContext {
{ {
if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? {
if !self.state.can_read() { if !self.state.can_read() {
return Err(Error::Protocol("Remote sent frame after having sent a Close Frame".into())); return Err(Error::Protocol(
"Remote sent frame after having sent a Close Frame".into(),
));
} }
// MUST be 0 unless an extension is negotiated that defines meanings // MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of // for non-zero values. If a nonzero value is received and none of
@ -588,7 +592,7 @@ enum WebSocketState {
impl WebSocketState { impl WebSocketState {
/// Tell if we're allowed to process normal messages. /// Tell if we're allowed to process normal messages.
fn is_active(&self) -> bool { fn is_active(self) -> bool {
match self { match self {
WebSocketState::Active => true, WebSocketState::Active => true,
_ => false, _ => false,
@ -598,16 +602,15 @@ impl WebSocketState {
/// Tell if we should process incoming data. Note that if we send a close frame /// Tell if we should process incoming data. Note that if we send a close frame
/// but the remote hasn't confirmed, they might have sent data before they receive our /// but the remote hasn't confirmed, they might have sent data before they receive our
/// close frame, so we should still pass those to client code, hence ClosedByUs is valid. /// close frame, so we should still pass those to client code, hence ClosedByUs is valid.
fn can_read(&self) -> bool { fn can_read(self) -> bool {
match self { match self {
WebSocketState::Active | WebSocketState::Active | WebSocketState::ClosedByUs => true,
WebSocketState::ClosedByUs => true,
_ => false, _ => false,
} }
} }
/// Check if the state is active, return error if not. /// Check if the state is active, return error if not.
fn check_active(&self) -> Result<()> { fn check_active(self) -> Result<()> {
match self { match self {
WebSocketState::Terminated => Err(Error::AlreadyClosed), WebSocketState::Terminated => Err(Error::AlreadyClosed),
_ => Ok(()), _ => Ok(()),

@ -39,13 +39,12 @@ fn test_no_send_after_close() {
client_handler.close(None).unwrap(); // send close to client client_handler.close(None).unwrap(); // send close to client
let err = client_handler let err = client_handler.write_message(Message::Text("Hello WebSocket".into()));
.write_message(Message::Text("Hello WebSocket".into()));
assert!( err.is_err() ); assert!(err.is_err());
match err.unwrap_err() { match err.unwrap_err() {
Error::Protocol(s) => { assert_eq!( "Sending after closing is not allowed", s )} Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s),
e => panic!("unexpected error: {:?}", e), e => panic!("unexpected error: {:?}", e),
} }

Loading…
Cancel
Save