Remove custom Request/Response types from client code

Fixes https://github.com/snapview/tungstenite-rs/issues/92
pull/93/head
Sebastian Dröge 5 years ago
parent 38a7d1a375
commit 9020840f84
  1. 153
      src/client.rs
  2. 3
      src/error.rs
  3. 153
      src/handshake/client.rs
  4. 8
      src/handshake/server.rs

@ -4,10 +4,11 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use http::{Request, Response, Uri};
use log::*;
use url::Url;
use crate::handshake::client::Response;
use crate::protocol::WebSocketConfig;
#[cfg(feature = "tls")]
@ -64,7 +65,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
let request: Request<()> = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.url
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
connect_with_config(request, None)
}
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url
.host_str()
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
Err(Error::Url(format!("Unable to connect to {}", uri).into()))
}
/// Get the mode of the given URL.
///
/// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
@ -173,16 +160,16 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
ClientHandshake::start(stream, request.into(), config).handshake()
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}
/// Do the client handshake over the given stream.
@ -190,13 +177,87 @@ where
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
client_with_config(request, stream, None)
}
/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request<()>` that can be used for a client connection.
fn into_client_request(self) -> Result<Request<()>>;
}
impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self.clone()).body(())?)
}
}
impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self).body(())?)
}
}
impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;
Ok(Request::get(uri).body(())?)
}
}
impl IntoClientRequest for Request<()> {
fn into_client_request(self) -> Result<Request<()>> {
Ok(self)
}
}
impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request<()>> {
use crate::handshake::headers::FromHttparse;
Request::<()>::from_httparse(self)
}
}

@ -9,6 +9,7 @@ use std::result;
use std::str;
use std::string;
use http;
use httparse;
use crate::protocol::Message;
@ -64,7 +65,7 @@ pub enum Error {
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
}

@ -1,13 +1,11 @@
//! Client handshake machine.
use std::borrow::Cow;
use std::io::{Read, Write};
use std::marker::PhantomData;
use http::HeaderMap;
use http::{HeaderMap, Request, Response, StatusCode};
use httparse::Status;
use log::*;
use url::Url;
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
@ -15,57 +13,6 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request.
#[derive(Debug)]
pub struct Request<'t> {
/// `ws://` or `wss://` URL to connect to.
pub url: Url,
/// Extra HTTP headers to append to the request.
pub extra_headers: Option<Vec<(Cow<'t, str>, Cow<'t, str>)>>,
}
impl<'t> Request<'t> {
/// Returns the GET part of the request.
fn get_path(&self) -> String {
if let Some(query) = self.url.query() {
format!("{path}?{query}", path = self.url.path(), query = query)
} else {
self.url.path().into()
}
}
/// Returns the host part of the request.
fn get_host(&self) -> String {
let host = self.url.host_str().expect("Bug: URL without host");
if let Some(port) = self.url.port() {
format!("{host}:{port}", host = host, port = port)
} else {
host.into()
}
}
/// Adds a WebSocket protocol to the request.
pub fn add_protocol(&mut self, protocol: Cow<'t, str>) {
self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol);
}
/// Adds a custom header to the request.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) {
let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new);
headers.push((name, value));
self.extra_headers = Some(headers);
}
}
impl From<Url> for Request<'static> {
fn from(value: Url) -> Self {
Request {
url: value,
extra_headers: None,
}
}
}
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S> {
@ -78,31 +25,51 @@ impl<S: Read + Write> ClientHandshake<S> {
/// Initiate a client handshake.
pub fn start(
stream: S,
request: Request,
request: Request<()>,
config: Option<WebSocketConfig>,
) -> MidHandshake<Self> {
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
"Invalid HTTP method, only GET supported".into(),
));
}
if request.version() < http::Version::HTTP_11 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
// Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?;
let key = generate_key();
let machine = {
let mut req = Vec::new();
let uri = request.uri();
write!(
req,
"\
GET {path} HTTP/1.1\r\n\
GET {path} {version:?}\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(),
path = request.get_path(),
version = request.version(),
host = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?,
path = uri
.path_and_query()
.ok_or_else(|| Error::Url("No path/query in URL".into()))?
.as_str(),
key = key
)
.unwrap();
if let Some(eh) = request.extra_headers {
for (k, v) in eh {
writeln!(req, "{}: {}\r", k, v).unwrap();
}
for (k, v) in request.headers() {
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
}
writeln!(req, "\r").unwrap();
HandshakeMachine::start_write(stream, req)
@ -118,17 +85,17 @@ impl<S: Read + Write> ClientHandshake<S> {
};
trace!("Client handshake initiated.");
MidHandshake {
Ok(MidHandshake {
role: client,
machine,
}
})
}
}
impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
type IncomingData = Response<()>;
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
type FinalResult = (WebSocket<S>, Response<()>);
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
@ -160,18 +127,19 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: &Response) -> Result<()> {
pub fn verify_response(&self, response: &Response<()>) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.code != 101 {
return Err(Error::Http(response.code));
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::Http(response.status()));
}
let headers = response.headers();
// 2. If the response lacks an |Upgrade| header field or the |Upgrade|
// header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
if !headers
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
@ -185,8 +153,7 @@ impl VerifyData {
// |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
if !headers
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
@ -200,8 +167,7 @@ impl VerifyData {
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !response
.headers
if !headers
.get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
@ -228,16 +194,7 @@ impl VerifyData {
}
}
/// Server response.
#[derive(Debug)]
pub struct Response {
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: HeaderMap,
}
impl TryParse for Response {
impl TryParse for Response<()> {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Response::new(&mut hbuffer);
@ -248,17 +205,24 @@ impl TryParse for Response {
}
}
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response<()> {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
headers: HeaderMap::from_httparse(raw.headers)?,
})
let headers = HeaderMap::from_httparse(raw.headers)?;
let mut response = Response::new(());
*response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?;
*response.headers_mut() = headers;
// TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
// so the only valid value we could get in the response would be 1.1.
*response.version_mut() = http::Version::HTTP_11;
Ok(response)
}
}
@ -295,7 +259,10 @@ mod tests {
fn response_parsing() {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200);
assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],);
assert_eq!(resp.status(), http::StatusCode::OK);
assert_eq!(
resp.headers().get("Content-Type").unwrap(),
&b"text/html"[..],
);
}
}

@ -227,7 +227,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() {
debug!("Server handshake failed.");
return Err(Error::Http(err));
return Err(Error::Http(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
@ -240,10 +240,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)]
mod tests {
use super::super::client::Response;
use super::super::machine::TryParse;
use super::{HeaderMap, Request};
use http::header::HeaderName;
use http::Response;
#[test]
fn request_parsing() {
@ -282,9 +282,9 @@ mod tests {
let reply = req.reply(Some(extra_headers)).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(
req.headers.get("MyCustomHeader").unwrap(),
req.headers().get("MyCustomHeader").unwrap(),
b"MyCustomValue".as_ref()
);
assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref());
assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref());
}
}

Loading…
Cancel
Save