From 38a7d1a3753820b0057c8f71f559fd5c8e2317d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Thu, 21 Nov 2019 19:10:47 +0200 Subject: [PATCH 1/9] Remove custom Headers type and use http::HeaderMap instead Fixes https://github.com/snapview/tungstenite-rs/issues/92 --- src/error.rs | 46 ++++++++++++++- src/handshake/client.rs | 26 +++++---- src/handshake/headers.rs | 118 +++++++++------------------------------ src/handshake/server.rs | 72 +++++++++++++----------- 4 files changed, 126 insertions(+), 136 deletions(-) diff --git a/src/error.rs b/src/error.rs index 8629efc..e04c199 100644 --- a/src/error.rs +++ b/src/error.rs @@ -45,7 +45,7 @@ pub enum Error { /// connection when it really shouldn't anymore, so this really indicates a programmer /// error on your part. AlreadyClosed, - /// Input-output error. Appart from WouldBlock, these are generally errors with the + /// Input-output error. Apart from WouldBlock, these are generally errors with the /// underlying connection and you should probably consider them fatal. Io(io::Error), #[cfg(feature = "tls")] @@ -61,10 +61,12 @@ pub enum Error { SendQueueFull(Message), /// UTF coding error Utf8, - /// Invlid URL. + /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. Http(u16), + /// HTTP format error. + HttpFormat(http::Error), } impl fmt::Display for Error { @@ -80,7 +82,8 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP code: {}", code), + Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } } @@ -99,6 +102,7 @@ impl ErrorTrait for Error { Error::Utf8 => "", Error::Url(ref msg) => msg.borrow(), Error::Http(_) => "", + Error::HttpFormat(ref err) => err.description(), } } } @@ -121,6 +125,42 @@ impl From for Error { } } +impl From for Error { + fn from(err: http::header::InvalidHeaderValue) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::header::InvalidHeaderName) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(_: http::header::ToStrError) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(err: http::uri::InvalidUri) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::status::InvalidStatusCode) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::Error) -> Self { + Error::HttpFormat(err) + } +} + #[cfg(feature = "tls")] impl From for Error { fn from(err: tls::Error) -> Self { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 7e23af4..8aec86b 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -4,11 +4,12 @@ use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; +use http::HeaderMap; use httparse::Status; use log::*; use url::Url; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; @@ -171,7 +172,10 @@ impl VerifyData { // _Fail the WebSocket Connection_. (RFC 6455) if !response .headers - .header_is_ignore_case("Upgrade", "websocket") + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Upgrade: websocket\" in server reply".into(), @@ -183,7 +187,10 @@ impl VerifyData { // MUST _Fail the WebSocket Connection_. (RFC 6455) if !response .headers - .header_is_ignore_case("Connection", "Upgrade") + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Connection: upgrade\" in server reply".into(), @@ -195,7 +202,9 @@ impl VerifyData { // Connection_. (RFC 6455) if !response .headers - .header_is("Sec-WebSocket-Accept", &self.accept_key) + .get("Sec-WebSocket-Accept") + .map(|h| h == &self.accept_key) + .unwrap_or(false) { return Err(Error::Protocol( "Key mismatch in Sec-WebSocket-Accept".into(), @@ -225,7 +234,7 @@ pub struct Response { /// HTTP response code of the response. pub code: u16, /// Received headers. - pub headers: Headers, + pub headers: HeaderMap, } impl TryParse for Response { @@ -248,7 +257,7 @@ impl<'h, 'b: 'h> FromHttparse> for Response { } Ok(Response { code: raw.code.expect("Bug: no HTTP response code"), - headers: Headers::from_httparse(raw.headers)?, + headers: HeaderMap::from_httparse(raw.headers)?, }) } } @@ -287,9 +296,6 @@ mod tests { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.code, 200); - assert_eq!( - resp.headers.find_first("Content-Type"), - Some(&b"text/html"[..]) - ); + assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 097b22b..ba12954 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,8 +1,7 @@ //! HTTP Request and response header handling. -use std::slice; -use std::str::from_utf8; - +use http; +use http::header::{HeaderMap, HeaderName, HeaderValue}; use httparse; use httparse::Status; @@ -12,90 +11,31 @@ use crate::error::Result; /// Limit for the number of header lines. pub const MAX_HEADERS: usize = 124; -/// HTTP request or response headers. -#[derive(Debug)] -pub struct Headers { - data: Vec<(String, Box<[u8]>)>, +/// Trait to convert raw objects into HTTP parseables. +pub(crate) trait FromHttparse: Sized { + /// Convert raw object into parsed HTTP headers. + fn from_httparse(raw: T) -> Result; } -impl Headers { - /// Get first header with the given name, if any. - pub fn find_first(&self, name: &str) -> Option<&[u8]> { - self.find(name).next() - } - - /// Iterate over all headers with the given name. - pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { - HeadersIter { - name, - iter: self.data.iter(), +impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap { + fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { + let mut headers = HeaderMap::new(); + for h in raw { + headers.append( + HeaderName::from_bytes(h.name.as_bytes())?, + HeaderValue::from_bytes(h.value)?, + ); } - } - - /// Check if the given header has the given value. - pub fn header_is(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .map(|v| v == value.as_bytes()) - .unwrap_or(false) - } - - /// Check if the given header has the given value (case-insensitive). - pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .ok_or(()) - .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) - .map(|val| val.eq_ignore_ascii_case(value)) - .unwrap_or(false) - } - - /// Allows to iterate over available headers. - pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { - self.data.iter() - } -} -/// The iterator over headers. -#[derive(Debug)] -pub struct HeadersIter<'name, 'headers> { - name: &'name str, - iter: slice::Iter<'headers, (String, Box<[u8]>)>, -} - -impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { - type Item = &'headers [u8]; - fn next(&mut self) -> Option { - while let Some(&(ref name, ref value)) = self.iter.next() { - if name.eq_ignore_ascii_case(self.name) { - return Some(value); - } - } - None + Ok(headers) } } - -/// Trait to convert raw objects into HTTP parseables. -pub trait FromHttparse: Sized { - /// Convert raw object into parsed HTTP headers. - fn from_httparse(raw: T) -> Result; -} - -impl TryParse for Headers { +impl TryParse for HeaderMap { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Status::Partial => None, - Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), - }) - } -} - -impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { - fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { - Ok(Headers { - data: raw - .iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), + Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)), }) } } @@ -104,7 +44,7 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { mod tests { use super::super::machine::TryParse; - use super::Headers; + use super::HeaderMap; #[test] fn headers() { @@ -112,14 +52,10 @@ mod tests { Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); - assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); - assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); - - assert!(hdr.header_is("upgrade", "websocket")); - assert!(!hdr.header_is("upgrade", "Websocket")); - assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]); + assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]); + assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]); } #[test] @@ -130,10 +66,10 @@ mod tests { Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - let mut iter = hdr.find("Sec-WebSocket-Extensions"); - assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); - assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter(); + assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]); + assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]); assert_eq!(iter.next(), None); } @@ -142,7 +78,7 @@ mod tests { const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; - let hdr = Headers::try_parse(DATA).unwrap(); + let hdr = HeaderMap::try_parse(DATA).unwrap(); assert!(hdr.is_none()); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index d3624c3..3f551ea 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,11 +5,11 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; @@ -21,15 +21,15 @@ pub struct Request { /// Path part of the URL. pub path: String, /// HTTP headers. - pub headers: Headers, + pub headers: HeaderMap, } impl Request { /// Reply to the response. - pub fn reply(&self, extra_headers: Option>) -> Result> { + pub fn reply(&self, extra_headers: Option) -> Result> { let key = self .headers - .find_first("Sec-WebSocket-Key") + .get("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; let mut reply = format!( "\ @@ -37,20 +37,24 @@ impl Request { Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Accept: {}\r\n", - convert_key(key)? + convert_key(key.as_bytes())? ); - add_headers(&mut reply, extra_headers); + add_headers(&mut reply, extra_headers)?; Ok(reply.into()) } } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) { +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> Result<()> { if let Some(eh) = extra_headers { for (k, v) in eh { - writeln!(reply, "{}: {}\r", k, v).unwrap(); + if let Some(k) = k { + writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); + } } } writeln!(reply, "\r").unwrap(); + + Ok(()) } impl TryParse for Request { @@ -76,21 +80,18 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } Ok(Request { path: raw.path.expect("Bug: no path in header").into(), - headers: Headers::from_httparse(raw.headers)?, + headers: HeaderMap::from_httparse(raw.headers)?, }) } } -/// Extra headers for responses. -pub type ExtraHeaders = Vec<(String, String)>; - /// An error response sent to the client. #[derive(Debug)] pub struct ErrorResponse { /// HTTP error code. pub error_code: StatusCode, /// Extra response headers, if any. - pub headers: Option, + pub headers: Option, /// Response body, if any. pub body: Option, } @@ -115,14 +116,14 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request) -> StdResult, ErrorResponse>, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { + fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { self(request) } } @@ -132,7 +133,7 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { + fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { Ok(None) } } @@ -214,7 +215,7 @@ impl HandshakeRole for ServerHandshake { error_code.as_str(), error_code.canonical_reason().unwrap_or("") ); - add_headers(&mut response, headers); + add_headers(&mut response, headers)?; if let Some(body) = body { response += &body; } @@ -241,14 +242,15 @@ impl HandshakeRole for ServerHandshake { mod tests { use super::super::client::Response; use super::super::machine::TryParse; - use super::Request; + use super::{HeaderMap, Request}; + use http::header::HeaderName; #[test] fn request_parsing() { const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); assert_eq!(req.path, "/script.ws"); - assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); + assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); } #[test] @@ -264,19 +266,25 @@ mod tests { let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let _ = req.reply(None).unwrap(); - let extra_headers = Some(vec![ - ( - String::from("MyCustomHeader"), - String::from("MyCustomValue"), - ), - (String::from("MyVersion"), String::from("LOL")), - ]); - let reply = req.reply(extra_headers).unwrap(); + let extra_headers = { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(), + "MyCustomValue".parse().unwrap(), + ); + headers.insert( + HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(), + "LOL".parse().unwrap(), + ); + + headers + }; + let reply = req.reply(Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( - req.headers.find_first("MyCustomHeader"), - Some(b"MyCustomValue".as_ref()) + req.headers.get("MyCustomHeader").unwrap(), + b"MyCustomValue".as_ref() ); - assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); + assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); } } From 9020840f84b03498720b6a365cfb38907c86d17a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Fri, 22 Nov 2019 15:59:27 +0200 Subject: [PATCH 2/9] Remove custom Request/Response types from client code Fixes https://github.com/snapview/tungstenite-rs/issues/92 --- src/client.rs | 153 ++++++++++++++++++++++++++++------------ src/error.rs | 3 +- src/handshake/client.rs | 153 ++++++++++++++++------------------------ src/handshake/server.rs | 8 +-- 4 files changed, 173 insertions(+), 144 deletions(-) diff --git a/src/client.rs b/src/client.rs index e35a24d..66b73fd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,10 +4,11 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; +use http::{Request, Response, Uri}; use log::*; + use url::Url; -use crate::handshake::client::Response; use crate::protocol::WebSocketConfig; #[cfg(feature = "tls")] @@ -64,7 +65,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::error::{Error, Result}; -use crate::handshake::client::{ClientHandshake, Request}; +use crate::handshake::client::ClientHandshake; use crate::handshake::HandshakeError; use crate::protocol::WebSocket; use crate::stream::{Mode, NoDelay}; @@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config<'t, Req: Into>>( +pub fn connect_with_config( request: Req, config: Option, -) -> Result<(WebSocket, Response)> { - let request: Request = request.into(); - let mode = url_mode(&request.url)?; +) -> Result<(WebSocket, Response<()>)> { + let request: Request<()> = request.into_client_request()?; + let uri = request.uri(); + let mode = uri_mode(uri)?; let host = request - .url + .uri() .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = request - .url - .port_or_known_default() - .ok_or_else(|| Error::Url("No port number in the URL".into()))?; - let addrs; - let addr; - let addrs = match host { - url::Host::Domain(domain) => { - addrs = (domain, port).to_socket_addrs()?; - addrs.as_slice() - } - url::Host::Ipv4(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - url::Host::Ipv6(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - }; - let mut stream = connect_to_some(addrs, &request.url, mode)?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, @@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into>>( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>( +pub fn connect( request: Req, -) -> Result<(WebSocket, Response)> { +) -> Result<(WebSocket, Response<()>)> { connect_with_config(request, None) } -fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result { - let domain = url - .host_str() +fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { + let domain = uri + .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { - debug!("Trying to contact {} at {}...", url, addr); + debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { return Ok(stream); } } } - Err(Error::Url(format!("Unable to connect to {}", url).into())) + Err(Error::Url(format!("Unable to connect to {}", uri).into())) } /// Get the mode of the given URL. /// /// This function may be used to ease the creation of custom TLS streams /// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`. -pub fn url_mode(url: &Url) -> Result { - match url.scheme() { - "ws" => Ok(Mode::Plain), - "wss" => Ok(Mode::Tls), +pub fn uri_mode(uri: &Uri) -> Result { + match uri.scheme_str() { + Some("ws") => Ok(Mode::Plain), + Some("wss") => Ok(Mode::Tls), _ => Err(Error::Url("URL scheme not supported".into())), } } @@ -173,16 +160,16 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client_with_config<'t, Stream, Req>( +pub fn client_with_config( request: Req, stream: Stream, config: Option, -) -> StdResult<(WebSocket, Response), HandshakeError>> +) -> StdResult<(WebSocket, Response<()>), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { - ClientHandshake::start(stream, request.into(), config).handshake() + ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake() } /// Do the client handshake over the given stream. @@ -190,13 +177,87 @@ where /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>( +pub fn client( request: Req, stream: Stream, -) -> StdResult<(WebSocket, Response), HandshakeError>> +) -> StdResult<(WebSocket, Response<()>), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { client_with_config(request, stream, None) } + +/// Trait for converting various types into HTTP requests used for a client connection. +/// +/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and +/// `http::Request<()>`. +pub trait IntoClientRequest { + /// Convert into a `Request<()>` that can be used for a client connection. + fn into_client_request(self) -> Result>; +} + +impl<'a> IntoClientRequest for &'a str { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a String { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for String { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Uri { + fn into_client_request(self) -> Result> { + Ok(Request::get(self.clone()).body(())?) + } +} + +impl IntoClientRequest for Uri { + fn into_client_request(self) -> Result> { + Ok(Request::get(self).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Url { + fn into_client_request(self) -> Result> { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Url { + fn into_client_request(self) -> Result> { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Request<()> { + fn into_client_request(self) -> Result> { + Ok(self) + } +} + +impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> { + fn into_client_request(self) -> Result> { + use crate::handshake::headers::FromHttparse; + Request::<()>::from_httparse(self) + } +} diff --git a/src/error.rs b/src/error.rs index e04c199..ac2de26 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::result; use std::str; use std::string; +use http; use httparse; use crate::protocol::Message; @@ -64,7 +65,7 @@ pub enum Error { /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. - Http(u16), + Http(http::StatusCode), /// HTTP format error. HttpFormat(http::Error), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8aec86b..a0bb951 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,13 +1,11 @@ //! Client handshake machine. -use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; -use http::HeaderMap; +use http::{HeaderMap, Request, Response, StatusCode}; use httparse::Status; use log::*; -use url::Url; use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; @@ -15,57 +13,6 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Client request. -#[derive(Debug)] -pub struct Request<'t> { - /// `ws://` or `wss://` URL to connect to. - pub url: Url, - /// Extra HTTP headers to append to the request. - pub extra_headers: Option, Cow<'t, str>)>>, -} - -impl<'t> Request<'t> { - /// Returns the GET part of the request. - fn get_path(&self) -> String { - if let Some(query) = self.url.query() { - format!("{path}?{query}", path = self.url.path(), query = query) - } else { - self.url.path().into() - } - } - - /// Returns the host part of the request. - fn get_host(&self) -> String { - let host = self.url.host_str().expect("Bug: URL without host"); - if let Some(port) = self.url.port() { - format!("{host}:{port}", host = host, port = port) - } else { - host.into() - } - } - - /// Adds a WebSocket protocol to the request. - pub fn add_protocol(&mut self, protocol: Cow<'t, str>) { - self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol); - } - - /// Adds a custom header to the request. - pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { - let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new); - headers.push((name, value)); - self.extra_headers = Some(headers); - } -} - -impl From for Request<'static> { - fn from(value: Url) -> Self { - Request { - url: value, - extra_headers: None, - } - } -} - /// Client handshake role. #[derive(Debug)] pub struct ClientHandshake { @@ -78,31 +25,51 @@ impl ClientHandshake { /// Initiate a client handshake. pub fn start( stream: S, - request: Request, + request: Request<()>, config: Option, - ) -> MidHandshake { + ) -> Result> { + if request.method() != http::Method::GET { + return Err(Error::Protocol( + "Invalid HTTP method, only GET supported".into(), + )); + } + + if request.version() < http::Version::HTTP_11 { + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); + } + + // Check the URI scheme: only ws or wss are supported + let _ = crate::client::uri_mode(request.uri())?; + let key = generate_key(); let machine = { let mut req = Vec::new(); + let uri = request.uri(); write!( req, "\ - GET {path} HTTP/1.1\r\n\ + GET {path} {version:?}\r\n\ Host: {host}\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: {key}\r\n", - host = request.get_host(), - path = request.get_path(), + version = request.version(), + host = uri + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?, + path = uri + .path_and_query() + .ok_or_else(|| Error::Url("No path/query in URL".into()))? + .as_str(), key = key ) .unwrap(); - if let Some(eh) = request.extra_headers { - for (k, v) in eh { - writeln!(req, "{}: {}\r", k, v).unwrap(); - } + for (k, v) in request.headers() { + writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); } writeln!(req, "\r").unwrap(); HandshakeMachine::start_write(stream, req) @@ -118,17 +85,17 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { + Ok(MidHandshake { role: client, machine, - } + }) } } impl HandshakeRole for ClientHandshake { - type IncomingData = Response; + type IncomingData = Response<()>; type InternalStream = S; - type FinalResult = (WebSocket, Response); + type FinalResult = (WebSocket, Response<()>); fn stage_finished( &mut self, finish: StageResult, @@ -160,18 +127,19 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: &Response<()>) -> Result<()> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.code != 101 { - return Err(Error::Http(response.code)); + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.status())); } + let headers = response.headers(); + // 2. If the response lacks an |Upgrade| header field or the |Upgrade| // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Upgrade") .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) @@ -185,8 +153,7 @@ impl VerifyData { // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Connection") .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("Upgrade")) @@ -200,8 +167,7 @@ impl VerifyData { // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Sec-WebSocket-Accept") .map(|h| h == &self.accept_key) .unwrap_or(false) @@ -228,16 +194,7 @@ impl VerifyData { } } -/// Server response. -#[derive(Debug)] -pub struct Response { - /// HTTP response code of the response. - pub code: u16, - /// Received headers. - pub headers: HeaderMap, -} - -impl TryParse for Response { +impl TryParse for Response<()> { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Response::new(&mut hbuffer); @@ -248,17 +205,24 @@ impl TryParse for Response { } } -impl<'h, 'b: 'h> FromHttparse> for Response { +impl<'h, 'b: 'h> FromHttparse> for Response<()> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( "HTTP version should be 1.1 or higher".into(), )); } - Ok(Response { - code: raw.code.expect("Bug: no HTTP response code"), - headers: HeaderMap::from_httparse(raw.headers)?, - }) + + let headers = HeaderMap::from_httparse(raw.headers)?; + + let mut response = Response::new(()); + *response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?; + *response.headers_mut() = headers; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *response.version_mut() = http::Version::HTTP_11; + + Ok(response) } } @@ -295,7 +259,10 @@ mod tests { fn response_parsing() { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); - assert_eq!(resp.code, 200); - assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],); + assert_eq!(resp.status(), http::StatusCode::OK); + assert_eq!( + resp.headers().get("Content-Type").unwrap(), + &b"text/html"[..], + ); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 3f551ea..5a5890c 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -227,7 +227,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(err)); + return Err(Error::Http(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); @@ -240,10 +240,10 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::super::client::Response; use super::super::machine::TryParse; use super::{HeaderMap, Request}; use http::header::HeaderName; + use http::Response; #[test] fn request_parsing() { @@ -282,9 +282,9 @@ mod tests { let reply = req.reply(Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( - req.headers.get("MyCustomHeader").unwrap(), + req.headers().get("MyCustomHeader").unwrap(), b"MyCustomValue".as_ref() ); - assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); + assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref()); } } From 09a9b7ceef30e9b7c4624c04b3f3cfb320440ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sat, 23 Nov 2019 13:17:51 +0200 Subject: [PATCH 3/9] Remove custom Request/Response types from server code Fixes https://github.com/snapview/tungstenite-rs/issues/92 --- src/handshake/server.rs | 144 ++++++++++++++++++---------------------- 1 file changed, 66 insertions(+), 78 deletions(-) diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 5a5890c..3293b80 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,7 +5,7 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::{HeaderMap, StatusCode}; +use http::{HeaderMap, Request, Response, StatusCode}; use httparse::Status; use log::*; @@ -15,41 +15,28 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Request from the client. -#[derive(Debug)] -pub struct Request { - /// Path part of the URL. - pub path: String, - /// HTTP headers. - pub headers: HeaderMap, -} - -impl Request { - /// Reply to the response. - pub fn reply(&self, extra_headers: Option) -> Result> { - let key = self - .headers - .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let mut reply = format!( - "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", - convert_key(key.as_bytes())? - ); - add_headers(&mut reply, extra_headers)?; - Ok(reply.into()) - } +/// Reply to the response. +fn reply(request: &Request<()>, extra_headers: Option) -> Result> { + let key = request + .headers() + .get("Sec-WebSocket-Key") + .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + let mut reply = format!( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: {}\r\n", + convert_key(key.as_bytes())? + ); + add_headers(&mut reply, extra_headers.as_ref())?; + Ok(reply.into()) } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> Result<()> { +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { if let Some(eh) = extra_headers { for (k, v) in eh { - if let Some(k) = k { - writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); - } + writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); } } writeln!(reply, "\r").unwrap(); @@ -57,7 +44,7 @@ fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> R Ok(()) } -impl TryParse for Request { +impl TryParse for Request<()> { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Request::new(&mut hbuffer); @@ -68,41 +55,29 @@ impl TryParse for Request { } } -impl<'h, 'b: 'h> FromHttparse> for Request { +impl<'h, 'b: 'h> FromHttparse> for Request<()> { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { return Err(Error::Protocol("Method is not GET".into())); } + if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( "HTTP version should be 1.1 or higher".into(), )); } - Ok(Request { - path: raw.path.expect("Bug: no path in header").into(), - headers: HeaderMap::from_httparse(raw.headers)?, - }) - } -} -/// An error response sent to the client. -#[derive(Debug)] -pub struct ErrorResponse { - /// HTTP error code. - pub error_code: StatusCode, - /// Extra response headers, if any. - pub headers: Option, - /// Response body, if any. - pub body: Option, -} + let headers = HeaderMap::from_httparse(raw.headers)?; -impl From for ErrorResponse { - fn from(error_code: StatusCode) -> Self { - ErrorResponse { - error_code, - headers: None, - body: None, - } + let mut request = Request::new(()); + *request.method_mut() = http::Method::GET; + *request.headers_mut() = headers; + *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *request.version_mut() = http::Version::HTTP_11; + + Ok(request) } } @@ -116,14 +91,20 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request( + self, + request: &Request<()>, + ) -> StdResult, Response>>; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request<()>) -> StdResult, Response>>, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + request: &Request<()>, + ) -> StdResult, Response>> { self(request) } } @@ -133,7 +114,10 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + _request: &Request<()>, + ) -> StdResult, Response>> { Ok(None) } } @@ -174,7 +158,7 @@ impl ServerHandshake { } impl HandshakeRole for ServerHandshake { - type IncomingData = Request; + type IncomingData = Request<()>; type InternalStream = S; type FinalResult = WebSocket; @@ -200,23 +184,26 @@ impl HandshakeRole for ServerHandshake { match callback_result { Ok(extra_headers) => { - let response = result.reply(extra_headers)?; + let response = reply(&result, extra_headers)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } - Err(ErrorResponse { - error_code, - headers, - body, - }) => { - self.error_code = Some(error_code.as_u16()); + Err(resp) => { + if resp.status().is_success() { + return Err(Error::Protocol( + "Custom response must not be successful".into(), + )); + } + + self.error_code = Some(resp.status().as_u16()); let mut response = format!( - "HTTP/1.1 {} {}\r\n", - error_code.as_str(), - error_code.canonical_reason().unwrap_or("") + "{version:?} {status} {reason}\r\n", + version = resp.version(), + status = resp.status().as_u16(), + reason = resp.status().canonical_reason().unwrap_or("") ); - add_headers(&mut response, headers)?; - if let Some(body) = body { + add_headers(&mut response, Some(resp.headers()))?; + if let Some(body) = resp.body() { response += &body; } ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) @@ -241,6 +228,7 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { use super::super::machine::TryParse; + use super::reply; use super::{HeaderMap, Request}; use http::header::HeaderName; use http::Response; @@ -249,8 +237,8 @@ mod tests { fn request_parsing() { const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - assert_eq!(req.path, "/script.ws"); - assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); + assert_eq!(req.uri().path(), "/script.ws"); + assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); } #[test] @@ -264,7 +252,7 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply(None).unwrap(); + let _ = reply(&req, None).unwrap(); let extra_headers = { let mut headers = HeaderMap::new(); @@ -279,7 +267,7 @@ mod tests { headers }; - let reply = req.reply(Some(extra_headers)).unwrap(); + let reply = reply(&req, Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( req.headers().get("MyCustomHeader").unwrap(), From 1ecc4f900d8abb7f5791152bd41c9d0c7fdde9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 15:33:18 +0200 Subject: [PATCH 4/9] Use Response for the server handshake callback too And add a public create_response(&Request) function that creates an initial response. This can be used to simplify integration into existing HTTP libraries. --- src/handshake/server.rs | 163 ++++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 64 deletions(-) diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 3293b80..c2b7af4 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,7 +1,6 @@ //! Server handshake machine. -use std::fmt::Write as FmtWrite; -use std::io::{Read, Write}; +use std::io::{self, Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; @@ -15,31 +14,84 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Reply to the response. -fn reply(request: &Request<()>, extra_headers: Option) -> Result> { +/// Create a response for the request. +pub fn create_response(request: &Request<()>) -> Result> { + if request.method() != http::Method::GET { + return Err(Error::Protocol("Method is not GET".into())); + } + + if request.version() < http::Version::HTTP_11 { + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); + } + + if !request + .headers() + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Connection: upgrade\" in client request".into(), + )); + } + + if !request + .headers() + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Upgrade: websocket\" in client request".into(), + )); + } + + if !request + .headers() + .get("Sec-WebSocket-Version") + .map(|h| h == "13") + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Sec-WebSocket-Version: 13\" in client request".into(), + )); + } + let key = request .headers() .get("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let mut reply = format!( - "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", - convert_key(key.as_bytes())? - ); - add_headers(&mut reply, extra_headers.as_ref())?; - Ok(reply.into()) + + let mut response = Response::builder(); + + response.status(StatusCode::SWITCHING_PROTOCOLS); + response.version(request.version()); + response.header("Connection", "Upgrade"); + response.header("Upgrade", "websocket"); + response.header("Sec-WebSocket-Accept", convert_key(key.as_bytes())?); + + Ok(response.body(())?) } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { - if let Some(eh) = extra_headers { - for (k, v) in eh { - writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); - } +// Assumes that this is a valid response +fn write_response(w: &mut dyn io::Write, response: &Response) -> Result<()> { + writeln!( + w, + "{version:?} {status} {reason}\r", + version = response.version(), + status = response.status(), + reason = response.status().canonical_reason().unwrap_or(""), + )?; + + for (k, v) in response.headers() { + writeln!(w, "{}: {}\r", k, v.to_str()?).unwrap(); } - writeln!(reply, "\r").unwrap(); + + writeln!(w, "\r")?; Ok(()) } @@ -94,18 +146,20 @@ pub trait Callback: Sized { fn on_request( self, request: &Request<()>, - ) -> StdResult, Response>>; + response: Response<()>, + ) -> StdResult, Response>>; } impl Callback for F where - F: FnOnce(&Request<()>) -> StdResult, Response>>, + F: FnOnce(&Request<()>, Response<()>) -> StdResult, Response>>, { fn on_request( self, request: &Request<()>, - ) -> StdResult, Response>> { - self(request) + response: Response<()>, + ) -> StdResult, Response>> { + self(request, response) } } @@ -117,8 +171,9 @@ impl Callback for NoCallback { fn on_request( self, _request: &Request<()>, - ) -> StdResult, Response>> { - Ok(None) + response: Response<()>, + ) -> StdResult, Response>> { + Ok(response) } } @@ -176,16 +231,18 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol("Junk after client request".into())); } + let response = create_response(&result)?; let callback_result = if let Some(callback) = self.callback.take() { - callback.on_request(&result) + callback.on_request(&result, response) } else { - Ok(None) + Ok(response) }; match callback_result { - Ok(extra_headers) => { - let response = reply(&result, extra_headers)?; - ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + Ok(response) => { + let mut output = vec![]; + write_response(&mut output, &response)?; + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } Err(resp) => { @@ -196,17 +253,13 @@ impl HandshakeRole for ServerHandshake { } self.error_code = Some(resp.status().as_u16()); - let mut response = format!( - "{version:?} {status} {reason}\r\n", - version = resp.version(), - status = resp.status().as_u16(), - reason = resp.status().canonical_reason().unwrap_or("") - ); - add_headers(&mut response, Some(resp.headers()))?; + + let mut output = vec![]; + write_response(&mut output, &resp)?; if let Some(body) = resp.body() { - response += &body; + output.extend_from_slice(body.as_bytes()); } - ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } @@ -228,10 +281,8 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { use super::super::machine::TryParse; - use super::reply; - use super::{HeaderMap, Request}; - use http::header::HeaderName; - use http::Response; + use super::create_response; + use super::Request; #[test] fn request_parsing() { @@ -252,27 +303,11 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = reply(&req, None).unwrap(); - - let extra_headers = { - let mut headers = HeaderMap::new(); - headers.insert( - HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(), - "MyCustomValue".parse().unwrap(), - ); - headers.insert( - HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(), - "LOL".parse().unwrap(), - ); - - headers - }; - let reply = reply(&req, Some(extra_headers)).unwrap(); - let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); + let response = create_response(&req).unwrap(); + assert_eq!( - req.headers().get("MyCustomHeader").unwrap(), - b"MyCustomValue".as_ref() + response.headers().get("Sec-WebSocket-Accept").unwrap(), + b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref() ); - assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref()); } } From 07d4721ffd992808e114572256a5de8f968ccade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 14:52:58 +0100 Subject: [PATCH 5/9] Add type aliases for Response/Request with a fixed body type Simplifies correct usage of the API. --- src/client.rs | 41 ++++++++++++++++++++--------------------- src/handshake/client.rs | 20 +++++++++++++------- src/handshake/server.rs | 41 +++++++++++++++++++++++++---------------- 3 files changed, 58 insertions(+), 44 deletions(-) diff --git a/src/client.rs b/src/client.rs index 66b73fd..200ca53 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,11 +4,12 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; -use http::{Request, Response, Uri}; +use http::Uri; use log::*; use url::Url; +use crate::handshake::client::{Request, Response}; use crate::protocol::WebSocketConfig; #[cfg(feature = "tls")] @@ -88,8 +89,8 @@ use crate::stream::{Mode, NoDelay}; pub fn connect_with_config( request: Req, config: Option, -) -> Result<(WebSocket, Response<()>)> { - let request: Request<()> = request.into_client_request()?; +) -> Result<(WebSocket, Response)> { + let request: Request = request.into_client_request()?; let uri = request.uri(); let mode = uri_mode(uri)?; let host = request @@ -121,9 +122,7 @@ pub fn connect_with_config( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect( - request: Req, -) -> Result<(WebSocket, Response<()>)> { +pub fn connect(request: Req) -> Result<(WebSocket, Response)> { connect_with_config(request, None) } @@ -164,7 +163,7 @@ pub fn client_with_config( request: Req, stream: Stream, config: Option, -) -> StdResult<(WebSocket, Response<()>), HandshakeError>> +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: IntoClientRequest, @@ -180,7 +179,7 @@ where pub fn client( request: Req, stream: Stream, -) -> StdResult<(WebSocket, Response<()>), HandshakeError>> +) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, Req: IntoClientRequest, @@ -193,12 +192,12 @@ where /// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and /// `http::Request<()>`. pub trait IntoClientRequest { - /// Convert into a `Request<()>` that can be used for a client connection. - fn into_client_request(self) -> Result>; + /// Convert into a `Request` that can be used for a client connection. + fn into_client_request(self) -> Result; } impl<'a> IntoClientRequest for &'a str { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { let uri: Uri = self.parse()?; Ok(Request::get(uri).body(())?) @@ -206,7 +205,7 @@ impl<'a> IntoClientRequest for &'a str { } impl<'a> IntoClientRequest for &'a String { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { let uri: Uri = self.parse()?; Ok(Request::get(uri).body(())?) @@ -214,7 +213,7 @@ impl<'a> IntoClientRequest for &'a String { } impl IntoClientRequest for String { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { let uri: Uri = self.parse()?; Ok(Request::get(uri).body(())?) @@ -222,19 +221,19 @@ impl IntoClientRequest for String { } impl<'a> IntoClientRequest for &'a Uri { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { Ok(Request::get(self.clone()).body(())?) } } impl IntoClientRequest for Uri { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { Ok(Request::get(self).body(())?) } } impl<'a> IntoClientRequest for &'a Url { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { let uri: Uri = self.as_str().parse()?; Ok(Request::get(uri).body(())?) @@ -242,22 +241,22 @@ impl<'a> IntoClientRequest for &'a Url { } impl IntoClientRequest for Url { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { let uri: Uri = self.as_str().parse()?; Ok(Request::get(uri).body(())?) } } -impl IntoClientRequest for Request<()> { - fn into_client_request(self) -> Result> { +impl IntoClientRequest for Request { + fn into_client_request(self) -> Result { Ok(self) } } impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> { - fn into_client_request(self) -> Result> { + fn into_client_request(self) -> Result { use crate::handshake::headers::FromHttparse; - Request::<()>::from_httparse(self) + Request::from_httparse(self) } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index a0bb951..e6c8ad5 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -3,7 +3,7 @@ use std::io::{Read, Write}; use std::marker::PhantomData; -use http::{HeaderMap, Request, Response, StatusCode}; +use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; @@ -13,6 +13,12 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; +/// Client request type. +pub type Request = HttpRequest<()>; + +/// Client response type. +pub type Response = HttpResponse<()>; + /// Client handshake role. #[derive(Debug)] pub struct ClientHandshake { @@ -25,7 +31,7 @@ impl ClientHandshake { /// Initiate a client handshake. pub fn start( stream: S, - request: Request<()>, + request: Request, config: Option, ) -> Result> { if request.method() != http::Method::GET { @@ -93,9 +99,9 @@ impl ClientHandshake { } impl HandshakeRole for ClientHandshake { - type IncomingData = Response<()>; + type IncomingData = Response; type InternalStream = S; - type FinalResult = (WebSocket, Response<()>); + type FinalResult = (WebSocket, Response); fn stage_finished( &mut self, finish: StageResult, @@ -127,7 +133,7 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response<()>) -> Result<()> { + pub fn verify_response(&self, response: &Response) -> Result<()> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) if response.status() != StatusCode::SWITCHING_PROTOCOLS { @@ -194,7 +200,7 @@ impl VerifyData { } } -impl TryParse for Response<()> { +impl TryParse for Response { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Response::new(&mut hbuffer); @@ -205,7 +211,7 @@ impl TryParse for Response<()> { } } -impl<'h, 'b: 'h> FromHttparse> for Response<()> { +impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( diff --git a/src/handshake/server.rs b/src/handshake/server.rs index c2b7af4..f3e937e 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -4,7 +4,7 @@ use std::io::{self, Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::{HeaderMap, Request, Response, StatusCode}; +use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; @@ -14,8 +14,17 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; +/// Server request type. +pub type Request = HttpRequest<()>; + +/// Server response type. +pub type Response = HttpResponse<()>; + +/// Server error response type. +pub type ErrorResponse = HttpResponse>; + /// Create a response for the request. -pub fn create_response(request: &Request<()>) -> Result> { +pub fn create_response(request: &Request) -> Result { if request.method() != http::Method::GET { return Err(Error::Protocol("Method is not GET".into())); } @@ -78,7 +87,7 @@ pub fn create_response(request: &Request<()>) -> Result> { } // Assumes that this is a valid response -fn write_response(w: &mut dyn io::Write, response: &Response) -> Result<()> { +fn write_response(w: &mut dyn io::Write, response: &HttpResponse) -> Result<()> { writeln!( w, "{version:?} {status} {reason}\r", @@ -96,7 +105,7 @@ fn write_response(w: &mut dyn io::Write, response: &Response) -> Result<() Ok(()) } -impl TryParse for Request<()> { +impl TryParse for Request { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Request::new(&mut hbuffer); @@ -107,7 +116,7 @@ impl TryParse for Request<()> { } } -impl<'h, 'b: 'h> FromHttparse> for Request<()> { +impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { return Err(Error::Protocol("Method is not GET".into())); @@ -145,20 +154,20 @@ pub trait Callback: Sized { /// Returning an error resulting in rejecting the incoming connection. fn on_request( self, - request: &Request<()>, - response: Response<()>, - ) -> StdResult, Response>>; + request: &Request, + response: Response, + ) -> StdResult; } impl Callback for F where - F: FnOnce(&Request<()>, Response<()>) -> StdResult, Response>>, + F: FnOnce(&Request, Response) -> StdResult, { fn on_request( self, - request: &Request<()>, - response: Response<()>, - ) -> StdResult, Response>> { + request: &Request, + response: Response, + ) -> StdResult { self(request, response) } } @@ -170,9 +179,9 @@ pub struct NoCallback; impl Callback for NoCallback { fn on_request( self, - _request: &Request<()>, - response: Response<()>, - ) -> StdResult, Response>> { + _request: &Request, + response: Response, + ) -> StdResult { Ok(response) } } @@ -213,7 +222,7 @@ impl ServerHandshake { } impl HandshakeRole for ServerHandshake { - type IncomingData = Request<()>; + type IncomingData = Request; type InternalStream = S; type FinalResult = WebSocket; From f659af44939042eb0514838f77c73de3500741ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 15:06:41 +0100 Subject: [PATCH 6/9] Update examples to compile again --- examples/callback-error.rs | 14 +++++++------- examples/client.rs | 4 ++-- examples/server.rs | 21 +++++++++------------ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/examples/callback-error.rs b/examples/callback-error.rs index fdb5b8a..357d343 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -2,19 +2,19 @@ use std::net::TcpListener; use std::thread::spawn; use tungstenite::accept_hdr; -use tungstenite::handshake::server::{ErrorResponse, Request}; +use tungstenite::handshake::server::{Request, Response}; use tungstenite::http::StatusCode; fn main() { let server = TcpListener::bind("127.0.0.1:3012").unwrap(); for stream in server.incoming() { spawn(move || { - let callback = |_req: &Request| { - Err(ErrorResponse { - error_code: StatusCode::FORBIDDEN, - headers: None, - body: Some("Access denied".into()), - }) + let callback = |_req: &Request, _resp| { + let resp = Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Some("Access denied".into())) + .unwrap(); + Err(resp) }; accept_hdr(stream.unwrap(), callback).unwrap_err(); }); diff --git a/examples/client.rs b/examples/client.rs index e3200d2..7938cfb 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,9 +8,9 @@ fn main() { connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect"); println!("Connected to the server"); - println!("Response HTTP code: {}", response.code); + println!("Response HTTP code: {}", response.status()); println!("Response contains the following headers:"); - for &(ref header, _ /*value*/) in response.headers.iter() { + for (ref header, _value) in response.headers() { println!("* {}", header); } diff --git a/examples/server.rs b/examples/server.rs index 70ba186..def2f45 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -2,30 +2,27 @@ use std::net::TcpListener; use std::thread::spawn; use tungstenite::accept_hdr; -use tungstenite::handshake::server::Request; +use tungstenite::handshake::server::{Request, Response}; fn main() { env_logger::init(); let server = TcpListener::bind("127.0.0.1:3012").unwrap(); for stream in server.incoming() { spawn(move || { - let callback = |req: &Request| { + let callback = |req: &Request, mut response: Response| { println!("Received a new ws handshake"); - println!("The request's path is: {}", req.path); + println!("The request's path is: {}", req.uri().path()); println!("The request's headers are:"); - for &(ref header, _ /* value */) in req.headers.iter() { + for (ref header, _value) in req.headers() { println!("* {}", header); } // Let's add an additional header to our response to the client. - let extra_headers = vec![ - (String::from("MyCustomHeader"), String::from(":)")), - ( - String::from("SOME_TUNGSTENITE_HEADER"), - String::from("header_value"), - ), - ]; - Ok(Some(extra_headers)) + let headers = response.headers_mut(); + headers.append("MyCustomHeader", ":)".parse().unwrap()); + headers.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); + + Ok(response) }; let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); From 7a4779b6f6e106f79125e38c4095749775fbf234 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 15:06:55 +0100 Subject: [PATCH 7/9] Run everything through rustfmt --- src/protocol/frame/mod.rs | 8 ++++++-- src/protocol/mod.rs | 11 +++++++---- tests/no_send_after_close.rs | 7 +++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index f93f911..6756f0a 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -12,7 +12,7 @@ pub use self::frame::{Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; use log::*; -use std::io::{Read, Write, Error as IoError, ErrorKind as IoErrorKind}; +use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -199,7 +199,11 @@ impl FrameCodec { let len = stream.write(&self.out_buffer)?; if len == 0 { // This is the same as "Connection reset by peer" - return Err(IoError::new(IoErrorKind::ConnectionReset, "Connection reset while sending").into()) + return Err(IoError::new( + IoErrorKind::ConnectionReset, + "Connection reset while sending", + ) + .into()); } self.out_buffer.drain(0..len); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5eddf75..be43824 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -280,7 +280,9 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol("Sending after closing is not allowed".into())); + return Err(Error::Protocol( + "Sending after closing is not allowed".into(), + )); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -378,7 +380,9 @@ impl WebSocketContext { { if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { if !self.state.can_read() { - return Err(Error::Protocol("Remote sent frame after having sent a Close Frame".into())); + return Err(Error::Protocol( + "Remote sent frame after having sent a Close Frame".into(), + )); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -600,8 +604,7 @@ impl WebSocketState { /// close frame, so we should still pass those to client code, hence ClosedByUs is valid. fn can_read(&self) -> bool { match self { - WebSocketState::Active | - WebSocketState::ClosedByUs => true, + WebSocketState::Active | WebSocketState::ClosedByUs => true, _ => false, } } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index 3993e81..d8e20e5 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -39,13 +39,12 @@ fn test_no_send_after_close() { client_handler.close(None).unwrap(); // send close to client - let err = client_handler - .write_message(Message::Text("Hello WebSocket".into())); + let err = client_handler.write_message(Message::Text("Hello WebSocket".into())); - assert!( err.is_err() ); + assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => { assert_eq!( "Sending after closing is not allowed", s )} + Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), e => panic!("unexpected error: {:?}", e), } From 88760b8b5912dd4beca329c3acdff9a32bcdf1d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 15:20:19 +0100 Subject: [PATCH 8/9] Fix various clippy warnings --- examples/autobahn-client.rs | 4 ++-- src/handshake/client.rs | 6 +++--- src/handshake/headers.rs | 6 +++--- src/handshake/server.rs | 4 ++-- src/protocol/message.rs | 2 +- src/protocol/mod.rs | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 5ef0c24..0962d9d 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -3,7 +3,7 @@ use url::Url; use tungstenite::{connect, Error, Message, Result}; -const AGENT: &'static str = "Tungstenite"; +const AGENT: &str = "Tungstenite"; fn get_case_count() -> Result { let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; @@ -47,7 +47,7 @@ fn main() { let total = get_case_count().unwrap(); - for case in 1..(total + 1) { + for case in 1..=total { if let Err(e) = run_test(case) { match e { Error::Protocol(_) => {} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index e6c8ad5..cc956b8 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -257,13 +257,13 @@ mod tests { assert_eq!(k2.len(), 24); assert!(k1.ends_with("==")); assert!(k2.ends_with("==")); - assert!(k1[..22].find("=").is_none()); - assert!(k2[..22].find("=").is_none()); + assert!(k1[..22].find('=').is_none()); + assert!(k2[..22].find('=').is_none()); } #[test] fn response_parsing() { - const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); assert_eq!( diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index ba12954..0c51008 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -48,7 +48,7 @@ mod tests { #[test] fn headers() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; @@ -60,7 +60,7 @@ mod tests { #[test] fn headers_iter() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ @@ -75,7 +75,7 @@ mod tests { #[test] fn headers_incomplete() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; let hdr = HeaderMap::try_parse(DATA).unwrap(); diff --git a/src/handshake/server.rs b/src/handshake/server.rs index f3e937e..f665a0b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -295,7 +295,7 @@ mod tests { #[test] fn request_parsing() { - const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; + const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); assert_eq!(req.uri().path(), "/script.ws"); assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); @@ -303,7 +303,7 @@ mod tests { #[test] fn request_replying() { - const DATA: &'static [u8] = b"\ + const DATA: &[u8] = b"\ GET /script.ws HTTP/1.1\r\n\ Host: foo.com\r\n\ Connection: upgrade\r\n\ diff --git a/src/protocol/message.rs b/src/protocol/message.rs index ba00765..5df4ba0 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -343,7 +343,7 @@ mod tests { #[test] fn display() { - let t = Message::text(format!("test")); + let t = Message::text("test".to_owned()); assert_eq!(t.to_string(), "test".to_owned()); let bin = Message::binary(vec![0, 1, 3, 4, 241]); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index be43824..3997803 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -592,7 +592,7 @@ enum WebSocketState { impl WebSocketState { /// Tell if we're allowed to process normal messages. - fn is_active(&self) -> bool { + fn is_active(self) -> bool { match self { WebSocketState::Active => true, _ => false, @@ -602,7 +602,7 @@ impl WebSocketState { /// Tell if we should process incoming data. Note that if we send a close frame /// but the remote hasn't confirmed, they might have sent data before they receive our /// close frame, so we should still pass those to client code, hence ClosedByUs is valid. - fn can_read(&self) -> bool { + fn can_read(self) -> bool { match self { WebSocketState::Active | WebSocketState::ClosedByUs => true, _ => false, @@ -610,7 +610,7 @@ impl WebSocketState { } /// Check if the state is active, return error if not. - fn check_active(&self) -> Result<()> { + fn check_active(self) -> Result<()> { match self { WebSocketState::Terminated => Err(Error::AlreadyClosed), _ => Ok(()), From e1a5153f405051aeed01acac301400306fbba812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Sun, 24 Nov 2019 18:17:13 +0100 Subject: [PATCH 9/9] Bump version to 0.10.0 because of API changes --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 34b8330..ec0b5c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"] license = "MIT/Apache-2.0" readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" -documentation = "https://docs.rs/tungstenite/0.9.3" +documentation = "https://docs.rs/tungstenite/0.10.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.9.3" +version = "0.10.0" edition = "2018" [features]