Remove custom Request/Response types from client code

Fixes https://github.com/snapview/tungstenite-rs/issues/92
pull/93/head
Sebastian Dröge 5 years ago
parent 38a7d1a375
commit 9020840f84
  1. 153
      src/client.rs
  2. 3
      src/error.rs
  3. 153
      src/handshake/client.rs
  4. 8
      src/handshake/server.rs

@ -4,10 +4,11 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use http::{Request, Response, Uri};
use log::*; use log::*;
use url::Url; use url::Url;
use crate::handshake::client::Response;
use crate::protocol::WebSocketConfig; use crate::protocol::WebSocketConfig;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
@ -64,7 +65,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream; pub use self::encryption::AutoStream;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request}; use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError; use crate::handshake::HandshakeError;
use crate::protocol::WebSocket; use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay}; use crate::stream::{Mode, NoDelay};
@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>( pub fn connect_with_config<Req: IntoClientRequest>(
request: Req, request: Req,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response<()>)> {
let request: Request = request.into(); let request: Request<()> = request.into_client_request()?;
let mode = url_mode(&request.url)?; let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request let host = request
.url .uri()
.host() .host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; .ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request let port = uri.port_u16().unwrap_or(match mode {
.url Mode::Plain => 80,
.port_or_known_default() Mode::Tls => 443,
.ok_or_else(|| Error::Url("No port number in the URL".into()))?; });
let addrs; let addrs = (host, port).to_socket_addrs()?;
let addr; let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e { client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f, HandshakeError::Failure(f) => f,
@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call /// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls. /// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>( pub fn connect<Req: IntoClientRequest>(
request: Req, request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> { ) -> Result<(WebSocket<AutoStream>, Response<()>)> {
connect_with_config(request, None) connect_with_config(request, None)
} }
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> { fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = url let domain = uri
.host_str() .host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?; .ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr); debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream); return Ok(stream);
} }
} }
} }
Err(Error::Url(format!("Unable to connect to {}", url).into())) Err(Error::Url(format!("Unable to connect to {}", uri).into()))
} }
/// Get the mode of the given URL. /// Get the mode of the given URL.
/// ///
/// This function may be used to ease the creation of custom TLS streams /// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`. /// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> { pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match url.scheme() { match uri.scheme_str() {
"ws" => Ok(Mode::Plain), Some("ws") => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls), Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())), _ => Err(Error::Url("URL scheme not supported".into())),
} }
} }
@ -173,16 +160,16 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>( pub fn client_with_config<Stream, Req>(
request: Req, request: Req,
stream: Stream, stream: Stream,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> ) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: IntoClientRequest,
{ {
ClientHandshake::start(stream, request.into(), config).handshake() ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
} }
/// Do the client handshake over the given stream. /// Do the client handshake over the given stream.
@ -190,13 +177,87 @@ where
/// Use this function if you need a nonblocking handshake support or if you /// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do. /// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>( pub fn client<Stream, Req>(
request: Req, request: Req,
stream: Stream, stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> ) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where where
Stream: Read + Write, Stream: Read + Write,
Req: Into<Request<'t>>, Req: IntoClientRequest,
{ {
client_with_config(request, stream, None) client_with_config(request, stream, None)
} }
/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request<()>` that can be used for a client connection.
fn into_client_request(self) -> Result<Request<()>>;
}
impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self.clone()).body(())?)
}
}
impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Request<()> {
fn into_client_request(self) -> Result<Request<()>> {
Ok(self)
}
}
impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request<()>> {
use crate::handshake::headers::FromHttparse;
Request::<()>::from_httparse(self)
}
}

@ -9,6 +9,7 @@ use std::result;
use std::str; use std::str;
use std::string; use std::string;
use http;
use httparse; use httparse;
use crate::protocol::Message; use crate::protocol::Message;
@ -64,7 +65,7 @@ pub enum Error {
/// Invalid URL. /// Invalid URL.
Url(Cow<'static, str>), Url(Cow<'static, str>),
/// HTTP error. /// HTTP error.
Http(u16), Http(http::StatusCode),
/// HTTP format error. /// HTTP format error.
HttpFormat(http::Error), HttpFormat(http::Error),
} }

@ -1,13 +1,11 @@
//! Client handshake machine. //! Client handshake machine.
use std::borrow::Cow;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use http::HeaderMap; use http::{HeaderMap, Request, Response, StatusCode};
use httparse::Status; use httparse::Status;
use log::*; use log::*;
use url::Url;
use super::headers::{FromHttparse, MAX_HEADERS}; use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::machine::{HandshakeMachine, StageResult, TryParse};
@ -15,57 +13,6 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig}; use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request.
#[derive(Debug)]
pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to.
pub url: Url,
/// Extra HTTP headers to append to the request.
pub extra_headers: Option<Vec<(Cow<'t, str>, Cow<'t, str>)>>,
}
impl<'t> Request<'t> {
/// Returns the GET part of the request.
fn get_path(&self) -> String {
if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query)
} else {
self.url.path().into()
}
}
/// Returns the host part of the request.
fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() {
format!("{host}:{port}", host = host, port = port)
} else {
host.into()
}
}
/// Adds a WebSocket protocol to the request.
pub fn add_protocol(&mut self, protocol: Cow<'t, str>) {
self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol);
}
/// Adds a custom header to the request.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) {
let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new);
headers.push((name, value));
self.extra_headers = Some(headers);
}
}
impl From<Url> for Request<'static> {
fn from(value: Url) -> Self {
Request {
url: value,
extra_headers: None,
}
}
}
/// Client handshake role. /// Client handshake role.
#[derive(Debug)] #[derive(Debug)]
pub struct ClientHandshake<S> { pub struct ClientHandshake<S> {
@ -78,31 +25,51 @@ impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake. /// Initiate a client handshake.
pub fn start( pub fn start(
stream: S, stream: S,
request: Request, request: Request<()>,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
) -> MidHandshake<Self> { ) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
"Invalid HTTP method, only GET supported".into(),
));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
// Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?;
let key = generate_key(); let key = generate_key();
let machine = { let machine = {
let mut req = Vec::new(); let mut req = Vec::new();
let uri = request.uri();
write!( write!(
req, req,
"\ "\
GET {path} HTTP/1.1\r\n\ GET {path} {version:?}\r\n\
Host: {host}\r\n\ Host: {host}\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n", Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), version = request.version(),
path = request.get_path(), host = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?,
path = uri
.path_and_query()
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
key = key key = key
) )
.unwrap(); .unwrap();
if let Some(eh) = request.extra_headers { for (k, v) in request.headers() {
for (k, v) in eh { writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
writeln!(req, "{}: {}\r", k, v).unwrap();
}
} }
writeln!(req, "\r").unwrap(); writeln!(req, "\r").unwrap();
HandshakeMachine::start_write(stream, req) HandshakeMachine::start_write(stream, req)
@ -118,17 +85,17 @@ impl<S: Read + Write> ClientHandshake<S> {
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
MidHandshake { Ok(MidHandshake {
role: client, role: client,
machine, machine,
} })
} }
} }
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> { impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response; type IncomingData = Response<()>;
type InternalStream = S; type InternalStream = S;
type FinalResult = (WebSocket<S>, Response); type FinalResult = (WebSocket<S>, Response<()>);
fn stage_finished( fn stage_finished(
&mut self, &mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>, finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -160,18 +127,19 @@ struct VerifyData {
} }
impl VerifyData { impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> { pub fn verify_response(&self, response: &Response<()>) -> Result<()> {
// 1. If the status code received from the server is not 101, the // 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455) // client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.code != 101 { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.code)); return Err(Error::Http(response.status()));
} }
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade| // 2. If the response lacks an |Upgrade| header field or the |Upgrade|
// header field contains a value that is not an ASCII case- // header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST // insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455) // _Fail the WebSocket Connection_. (RFC 6455)
if !response if !headers
.headers
.get("Upgrade") .get("Upgrade")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
@ -185,8 +153,7 @@ impl VerifyData {
// |Connection| header field doesn't contain a token that is an // |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client // ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455) // MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response if !headers
.headers
.get("Connection") .get("Connection")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
@ -200,8 +167,7 @@ impl VerifyData {
// the |Sec-WebSocket-Accept| contains a value other than the // the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455) // Connection_. (RFC 6455)
if !response if !headers
.headers
.get("Sec-WebSocket-Accept") .get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key) .map(|h| h == &self.accept_key)
.unwrap_or(false) .unwrap_or(false)
@ -228,16 +194,7 @@ impl VerifyData {
} }
} }
/// Server response. impl TryParse for Response<()> {
#[derive(Debug)]
pub struct Response {
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: HeaderMap,
}
impl TryParse for Response {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Response::new(&mut hbuffer); let mut req = httparse::Response::new(&mut hbuffer);
@ -248,17 +205,24 @@ impl TryParse for Response {
} }
} }
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response { impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response<()> {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol( return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(), "HTTP version should be 1.1 or higher".into(),
)); ));
} }
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"), let headers = HeaderMap::from_httparse(raw.headers)?;
headers: HeaderMap::from_httparse(raw.headers)?,
}) let mut response = Response::new(());
*response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?;
*response.headers_mut() = headers;
// TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
// so the only valid value we could get in the response would be 1.1.
*response.version_mut() = http::Version::HTTP_11;
Ok(response)
} }
} }
@ -295,7 +259,10 @@ mod tests {
fn response_parsing() { fn response_parsing() {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200); assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],); assert_eq!(
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
} }
} }

@ -227,7 +227,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() { if let Some(err) = self.error_code.take() {
debug!("Server handshake failed."); debug!("Server handshake failed.");
return Err(Error::Http(err)); return Err(Error::Http(StatusCode::from_u16(err)?));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
@ -240,10 +240,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::client::Response;
use super::super::machine::TryParse; use super::super::machine::TryParse;
use super::{HeaderMap, Request}; use super::{HeaderMap, Request};
use http::header::HeaderName; use http::header::HeaderName;
use http::Response;
#[test] #[test]
fn request_parsing() { fn request_parsing() {
@ -282,9 +282,9 @@ mod tests {
let reply = req.reply(Some(extra_headers)).unwrap(); let reply = req.reply(Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!( assert_eq!(
req.headers.get("MyCustomHeader").unwrap(), req.headers().get("MyCustomHeader").unwrap(),
b"MyCustomValue".as_ref() b"MyCustomValue".as_ref()
); );
assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref());
} }
} }

Loading…
Cancel
Save