diff --git a/Cargo.toml b/Cargo.toml index 34b8330..ec0b5c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"] license = "MIT/Apache-2.0" readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" -documentation = "https://docs.rs/tungstenite/0.9.3" +documentation = "https://docs.rs/tungstenite/0.10.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.9.3" +version = "0.10.0" edition = "2018" [features] diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 5ef0c24..0962d9d 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -3,7 +3,7 @@ use url::Url; use tungstenite::{connect, Error, Message, Result}; -const AGENT: &'static str = "Tungstenite"; +const AGENT: &str = "Tungstenite"; fn get_case_count() -> Result { let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?; @@ -47,7 +47,7 @@ fn main() { let total = get_case_count().unwrap(); - for case in 1..(total + 1) { + for case in 1..=total { if let Err(e) = run_test(case) { match e { Error::Protocol(_) => {} diff --git a/examples/callback-error.rs b/examples/callback-error.rs index fdb5b8a..357d343 100644 --- a/examples/callback-error.rs +++ b/examples/callback-error.rs @@ -2,19 +2,19 @@ use std::net::TcpListener; use std::thread::spawn; use tungstenite::accept_hdr; -use tungstenite::handshake::server::{ErrorResponse, Request}; +use tungstenite::handshake::server::{Request, Response}; use tungstenite::http::StatusCode; fn main() { let server = TcpListener::bind("127.0.0.1:3012").unwrap(); for stream in server.incoming() { spawn(move || { - let callback = |_req: &Request| { - Err(ErrorResponse { - error_code: StatusCode::FORBIDDEN, - headers: None, - body: Some("Access denied".into()), - }) + let callback = |_req: &Request, _resp| { + let resp = Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Some("Access denied".into())) + .unwrap(); + Err(resp) }; accept_hdr(stream.unwrap(), callback).unwrap_err(); }); diff --git a/examples/client.rs b/examples/client.rs index e3200d2..7938cfb 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,9 +8,9 @@ fn main() { connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect"); println!("Connected to the server"); - println!("Response HTTP code: {}", response.code); + println!("Response HTTP code: {}", response.status()); println!("Response contains the following headers:"); - for &(ref header, _ /*value*/) in response.headers.iter() { + for (ref header, _value) in response.headers() { println!("* {}", header); } diff --git a/examples/server.rs b/examples/server.rs index 70ba186..def2f45 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -2,30 +2,27 @@ use std::net::TcpListener; use std::thread::spawn; use tungstenite::accept_hdr; -use tungstenite::handshake::server::Request; +use tungstenite::handshake::server::{Request, Response}; fn main() { env_logger::init(); let server = TcpListener::bind("127.0.0.1:3012").unwrap(); for stream in server.incoming() { spawn(move || { - let callback = |req: &Request| { + let callback = |req: &Request, mut response: Response| { println!("Received a new ws handshake"); - println!("The request's path is: {}", req.path); + println!("The request's path is: {}", req.uri().path()); println!("The request's headers are:"); - for &(ref header, _ /* value */) in req.headers.iter() { + for (ref header, _value) in req.headers() { println!("* {}", header); } // Let's add an additional header to our response to the client. - let extra_headers = vec![ - (String::from("MyCustomHeader"), String::from(":)")), - ( - String::from("SOME_TUNGSTENITE_HEADER"), - String::from("header_value"), - ), - ]; - Ok(Some(extra_headers)) + let headers = response.headers_mut(); + headers.append("MyCustomHeader", ":)".parse().unwrap()); + headers.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); + + Ok(response) }; let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap(); diff --git a/src/client.rs b/src/client.rs index e35a24d..200ca53 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,10 +4,12 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; +use http::Uri; use log::*; + use url::Url; -use crate::handshake::client::Response; +use crate::handshake::client::{Request, Response}; use crate::protocol::WebSocketConfig; #[cfg(feature = "tls")] @@ -64,7 +66,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::error::{Error, Result}; -use crate::handshake::client::{ClientHandshake, Request}; +use crate::handshake::client::ClientHandshake; use crate::handshake::HandshakeError; use crate::protocol::WebSocket; use crate::stream::{Mode, NoDelay}; @@ -84,37 +86,23 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config<'t, Req: Into>>( +pub fn connect_with_config( request: Req, config: Option, ) -> Result<(WebSocket, Response)> { - let request: Request = request.into(); - let mode = url_mode(&request.url)?; + let request: Request = request.into_client_request()?; + let uri = request.uri(); + let mode = uri_mode(uri)?; let host = request - .url + .uri() .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = request - .url - .port_or_known_default() - .ok_or_else(|| Error::Url("No port number in the URL".into()))?; - let addrs; - let addr; - let addrs = match host { - url::Host::Domain(domain) => { - addrs = (domain, port).to_socket_addrs()?; - addrs.as_slice() - } - url::Host::Ipv4(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - url::Host::Ipv6(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - }; - let mut stream = connect_to_some(addrs, &request.url, mode)?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, @@ -134,35 +122,33 @@ pub fn connect_with_config<'t, Req: Into>>( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>( - request: Req, -) -> Result<(WebSocket, Response)> { +pub fn connect(request: Req) -> Result<(WebSocket, Response)> { connect_with_config(request, None) } -fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result { - let domain = url - .host_str() +fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { + let domain = uri + .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { - debug!("Trying to contact {} at {}...", url, addr); + debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { return Ok(stream); } } } - Err(Error::Url(format!("Unable to connect to {}", url).into())) + Err(Error::Url(format!("Unable to connect to {}", uri).into())) } /// Get the mode of the given URL. /// /// This function may be used to ease the creation of custom TLS streams /// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`. -pub fn url_mode(url: &Url) -> Result { - match url.scheme() { - "ws" => Ok(Mode::Plain), - "wss" => Ok(Mode::Tls), +pub fn uri_mode(uri: &Uri) -> Result { + match uri.scheme_str() { + Some("ws") => Ok(Mode::Plain), + Some("wss") => Ok(Mode::Tls), _ => Err(Error::Url("URL scheme not supported".into())), } } @@ -173,16 +159,16 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client_with_config<'t, Stream, Req>( +pub fn client_with_config( request: Req, stream: Stream, config: Option, ) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { - ClientHandshake::start(stream, request.into(), config).handshake() + ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake() } /// Do the client handshake over the given stream. @@ -190,13 +176,87 @@ where /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>( +pub fn client( request: Req, stream: Stream, ) -> StdResult<(WebSocket, Response), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { client_with_config(request, stream, None) } + +/// Trait for converting various types into HTTP requests used for a client connection. +/// +/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and +/// `http::Request<()>`. +pub trait IntoClientRequest { + /// Convert into a `Request` that can be used for a client connection. + fn into_client_request(self) -> Result; +} + +impl<'a> IntoClientRequest for &'a str { + fn into_client_request(self) -> Result { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a String { + fn into_client_request(self) -> Result { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for String { + fn into_client_request(self) -> Result { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Uri { + fn into_client_request(self) -> Result { + Ok(Request::get(self.clone()).body(())?) + } +} + +impl IntoClientRequest for Uri { + fn into_client_request(self) -> Result { + Ok(Request::get(self).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Url { + fn into_client_request(self) -> Result { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Url { + fn into_client_request(self) -> Result { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Request { + fn into_client_request(self) -> Result { + Ok(self) + } +} + +impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> { + fn into_client_request(self) -> Result { + use crate::handshake::headers::FromHttparse; + Request::from_httparse(self) + } +} diff --git a/src/error.rs b/src/error.rs index 8629efc..ac2de26 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::result; use std::str; use std::string; +use http; use httparse; use crate::protocol::Message; @@ -45,7 +46,7 @@ pub enum Error { /// connection when it really shouldn't anymore, so this really indicates a programmer /// error on your part. AlreadyClosed, - /// Input-output error. Appart from WouldBlock, these are generally errors with the + /// Input-output error. Apart from WouldBlock, these are generally errors with the /// underlying connection and you should probably consider them fatal. Io(io::Error), #[cfg(feature = "tls")] @@ -61,10 +62,12 @@ pub enum Error { SendQueueFull(Message), /// UTF coding error Utf8, - /// Invlid URL. + /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. - Http(u16), + Http(http::StatusCode), + /// HTTP format error. + HttpFormat(http::Error), } impl fmt::Display for Error { @@ -80,7 +83,8 @@ impl fmt::Display for Error { Error::SendQueueFull(_) => write!(f, "Send queue is full"), Error::Utf8 => write!(f, "UTF-8 encoding error"), Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(code) => write!(f, "HTTP code: {}", code), + Error::Http(code) => write!(f, "HTTP error: {}", code), + Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), } } } @@ -99,6 +103,7 @@ impl ErrorTrait for Error { Error::Utf8 => "", Error::Url(ref msg) => msg.borrow(), Error::Http(_) => "", + Error::HttpFormat(ref err) => err.description(), } } } @@ -121,6 +126,42 @@ impl From for Error { } } +impl From for Error { + fn from(err: http::header::InvalidHeaderValue) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::header::InvalidHeaderName) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(_: http::header::ToStrError) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(err: http::uri::InvalidUri) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::status::InvalidStatusCode) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::Error) -> Self { + Error::HttpFormat(err) + } +} + #[cfg(feature = "tls")] impl From for Error { fn from(err: tls::Error) -> Self { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 7e23af4..cc956b8 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,69 +1,23 @@ //! Client handshake machine. -use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; +use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use url::Url; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Client request. -#[derive(Debug)] -pub struct Request<'t> { - /// `ws://` or `wss://` URL to connect to. - pub url: Url, - /// Extra HTTP headers to append to the request. - pub extra_headers: Option, Cow<'t, str>)>>, -} - -impl<'t> Request<'t> { - /// Returns the GET part of the request. - fn get_path(&self) -> String { - if let Some(query) = self.url.query() { - format!("{path}?{query}", path = self.url.path(), query = query) - } else { - self.url.path().into() - } - } - - /// Returns the host part of the request. - fn get_host(&self) -> String { - let host = self.url.host_str().expect("Bug: URL without host"); - if let Some(port) = self.url.port() { - format!("{host}:{port}", host = host, port = port) - } else { - host.into() - } - } - - /// Adds a WebSocket protocol to the request. - pub fn add_protocol(&mut self, protocol: Cow<'t, str>) { - self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol); - } +/// Client request type. +pub type Request = HttpRequest<()>; - /// Adds a custom header to the request. - pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { - let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new); - headers.push((name, value)); - self.extra_headers = Some(headers); - } -} - -impl From for Request<'static> { - fn from(value: Url) -> Self { - Request { - url: value, - extra_headers: None, - } - } -} +/// Client response type. +pub type Response = HttpResponse<()>; /// Client handshake role. #[derive(Debug)] @@ -79,29 +33,49 @@ impl ClientHandshake { stream: S, request: Request, config: Option, - ) -> MidHandshake { + ) -> Result> { + if request.method() != http::Method::GET { + return Err(Error::Protocol( + "Invalid HTTP method, only GET supported".into(), + )); + } + + if request.version() < http::Version::HTTP_11 { + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); + } + + // Check the URI scheme: only ws or wss are supported + let _ = crate::client::uri_mode(request.uri())?; + let key = generate_key(); let machine = { let mut req = Vec::new(); + let uri = request.uri(); write!( req, "\ - GET {path} HTTP/1.1\r\n\ + GET {path} {version:?}\r\n\ Host: {host}\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: {key}\r\n", - host = request.get_host(), - path = request.get_path(), + version = request.version(), + host = uri + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?, + path = uri + .path_and_query() + .ok_or_else(|| Error::Url("No path/query in URL".into()))? + .as_str(), key = key ) .unwrap(); - if let Some(eh) = request.extra_headers { - for (k, v) in eh { - writeln!(req, "{}: {}\r", k, v).unwrap(); - } + for (k, v) in request.headers() { + writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); } writeln!(req, "\r").unwrap(); HandshakeMachine::start_write(stream, req) @@ -117,10 +91,10 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { + Ok(MidHandshake { role: client, machine, - } + }) } } @@ -162,16 +136,20 @@ impl VerifyData { pub fn verify_response(&self, response: &Response) -> Result<()> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.code != 101 { - return Err(Error::Http(response.code)); + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.status())); } + let headers = response.headers(); + // 2. If the response lacks an |Upgrade| header field or the |Upgrade| // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers - .header_is_ignore_case("Upgrade", "websocket") + if !headers + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Upgrade: websocket\" in server reply".into(), @@ -181,9 +159,11 @@ impl VerifyData { // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers - .header_is_ignore_case("Connection", "Upgrade") + if !headers + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Connection: upgrade\" in server reply".into(), @@ -193,9 +173,10 @@ impl VerifyData { // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !response - .headers - .header_is("Sec-WebSocket-Accept", &self.accept_key) + if !headers + .get("Sec-WebSocket-Accept") + .map(|h| h == &self.accept_key) + .unwrap_or(false) { return Err(Error::Protocol( "Key mismatch in Sec-WebSocket-Accept".into(), @@ -219,15 +200,6 @@ impl VerifyData { } } -/// Server response. -#[derive(Debug)] -pub struct Response { - /// HTTP response code of the response. - pub code: u16, - /// Received headers. - pub headers: Headers, -} - impl TryParse for Response { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; @@ -246,10 +218,17 @@ impl<'h, 'b: 'h> FromHttparse> for Response { "HTTP version should be 1.1 or higher".into(), )); } - Ok(Response { - code: raw.code.expect("Bug: no HTTP response code"), - headers: Headers::from_httparse(raw.headers)?, - }) + + let headers = HeaderMap::from_httparse(raw.headers)?; + + let mut response = Response::new(()); + *response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?; + *response.headers_mut() = headers; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *response.version_mut() = http::Version::HTTP_11; + + Ok(response) } } @@ -278,18 +257,18 @@ mod tests { assert_eq!(k2.len(), 24); assert!(k1.ends_with("==")); assert!(k2.ends_with("==")); - assert!(k1[..22].find("=").is_none()); - assert!(k2[..22].find("=").is_none()); + assert!(k1[..22].find('=').is_none()); + assert!(k2[..22].find('=').is_none()); } #[test] fn response_parsing() { - const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); - assert_eq!(resp.code, 200); + assert_eq!(resp.status(), http::StatusCode::OK); assert_eq!( - resp.headers.find_first("Content-Type"), - Some(&b"text/html"[..]) + resp.headers().get("Content-Type").unwrap(), + &b"text/html"[..], ); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 097b22b..0c51008 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,8 +1,7 @@ //! HTTP Request and response header handling. -use std::slice; -use std::str::from_utf8; - +use http; +use http::header::{HeaderMap, HeaderName, HeaderValue}; use httparse; use httparse::Status; @@ -12,90 +11,31 @@ use crate::error::Result; /// Limit for the number of header lines. pub const MAX_HEADERS: usize = 124; -/// HTTP request or response headers. -#[derive(Debug)] -pub struct Headers { - data: Vec<(String, Box<[u8]>)>, +/// Trait to convert raw objects into HTTP parseables. +pub(crate) trait FromHttparse: Sized { + /// Convert raw object into parsed HTTP headers. + fn from_httparse(raw: T) -> Result; } -impl Headers { - /// Get first header with the given name, if any. - pub fn find_first(&self, name: &str) -> Option<&[u8]> { - self.find(name).next() - } - - /// Iterate over all headers with the given name. - pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { - HeadersIter { - name, - iter: self.data.iter(), +impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap { + fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { + let mut headers = HeaderMap::new(); + for h in raw { + headers.append( + HeaderName::from_bytes(h.name.as_bytes())?, + HeaderValue::from_bytes(h.value)?, + ); } - } - - /// Check if the given header has the given value. - pub fn header_is(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .map(|v| v == value.as_bytes()) - .unwrap_or(false) - } - - /// Check if the given header has the given value (case-insensitive). - pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .ok_or(()) - .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) - .map(|val| val.eq_ignore_ascii_case(value)) - .unwrap_or(false) - } - - /// Allows to iterate over available headers. - pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { - self.data.iter() - } -} -/// The iterator over headers. -#[derive(Debug)] -pub struct HeadersIter<'name, 'headers> { - name: &'name str, - iter: slice::Iter<'headers, (String, Box<[u8]>)>, -} - -impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { - type Item = &'headers [u8]; - fn next(&mut self) -> Option { - while let Some(&(ref name, ref value)) = self.iter.next() { - if name.eq_ignore_ascii_case(self.name) { - return Some(value); - } - } - None + Ok(headers) } } - -/// Trait to convert raw objects into HTTP parseables. -pub trait FromHttparse: Sized { - /// Convert raw object into parsed HTTP headers. - fn from_httparse(raw: T) -> Result; -} - -impl TryParse for Headers { +impl TryParse for HeaderMap { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Status::Partial => None, - Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), - }) - } -} - -impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { - fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { - Ok(Headers { - data: raw - .iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), + Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)), }) } } @@ -104,45 +44,41 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { mod tests { use super::super::machine::TryParse; - use super::Headers; + use super::HeaderMap; #[test] fn headers() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); - assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); - assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); - - assert!(hdr.header_is("upgrade", "websocket")); - assert!(!hdr.header_is("upgrade", "Websocket")); - assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]); + assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]); + assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]); } #[test] fn headers_iter() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Sec-WebSocket-Extensions: permessage-deflate\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - let mut iter = hdr.find("Sec-WebSocket-Extensions"); - assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); - assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter(); + assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]); + assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]); assert_eq!(iter.next(), None); } #[test] fn headers_incomplete() { - const DATA: &'static [u8] = b"Host: foo.com\r\n\ + const DATA: &[u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; - let hdr = Headers::try_parse(DATA).unwrap(); + let hdr = HeaderMap::try_parse(DATA).unwrap(); assert!(hdr.is_none()); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index d3624c3..f665a0b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -1,56 +1,108 @@ //! Server handshake machine. -use std::fmt::Write as FmtWrite; -use std::io::{Read, Write}; +use std::io::{self, Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::StatusCode; +use http::{HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Request from the client. -#[derive(Debug)] -pub struct Request { - /// Path part of the URL. - pub path: String, - /// HTTP headers. - pub headers: Headers, -} +/// Server request type. +pub type Request = HttpRequest<()>; -impl Request { - /// Reply to the response. - pub fn reply(&self, extra_headers: Option>) -> Result> { - let key = self - .headers - .find_first("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; - let mut reply = format!( - "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Connection: Upgrade\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: {}\r\n", - convert_key(key)? - ); - add_headers(&mut reply, extra_headers); - Ok(reply.into()) +/// Server response type. +pub type Response = HttpResponse<()>; + +/// Server error response type. +pub type ErrorResponse = HttpResponse>; + +/// Create a response for the request. +pub fn create_response(request: &Request) -> Result { + if request.method() != http::Method::GET { + return Err(Error::Protocol("Method is not GET".into())); + } + + if request.version() < http::Version::HTTP_11 { + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); + } + + if !request + .headers() + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Connection: upgrade\" in client request".into(), + )); + } + + if !request + .headers() + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Upgrade: websocket\" in client request".into(), + )); + } + + if !request + .headers() + .get("Sec-WebSocket-Version") + .map(|h| h == "13") + .unwrap_or(false) + { + return Err(Error::Protocol( + "No \"Sec-WebSocket-Version: 13\" in client request".into(), + )); } + + let key = request + .headers() + .get("Sec-WebSocket-Key") + .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + + let mut response = Response::builder(); + + response.status(StatusCode::SWITCHING_PROTOCOLS); + response.version(request.version()); + response.header("Connection", "Upgrade"); + response.header("Upgrade", "websocket"); + response.header("Sec-WebSocket-Accept", convert_key(key.as_bytes())?); + + Ok(response.body(())?) } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) { - if let Some(eh) = extra_headers { - for (k, v) in eh { - writeln!(reply, "{}: {}\r", k, v).unwrap(); - } +// Assumes that this is a valid response +fn write_response(w: &mut dyn io::Write, response: &HttpResponse) -> Result<()> { + writeln!( + w, + "{version:?} {status} {reason}\r", + version = response.version(), + status = response.status(), + reason = response.status().canonical_reason().unwrap_or(""), + )?; + + for (k, v) in response.headers() { + writeln!(w, "{}: {}\r", k, v.to_str()?).unwrap(); } - writeln!(reply, "\r").unwrap(); + + writeln!(w, "\r")?; + + Ok(()) } impl TryParse for Request { @@ -69,39 +121,24 @@ impl<'h, 'b: 'h> FromHttparse> for Request { if raw.method.expect("Bug: no method in header") != "GET" { return Err(Error::Protocol("Method is not GET".into())); } + if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( "HTTP version should be 1.1 or higher".into(), )); } - Ok(Request { - path: raw.path.expect("Bug: no path in header").into(), - headers: Headers::from_httparse(raw.headers)?, - }) - } -} -/// Extra headers for responses. -pub type ExtraHeaders = Vec<(String, String)>; + let headers = HeaderMap::from_httparse(raw.headers)?; -/// An error response sent to the client. -#[derive(Debug)] -pub struct ErrorResponse { - /// HTTP error code. - pub error_code: StatusCode, - /// Extra response headers, if any. - pub headers: Option, - /// Response body, if any. - pub body: Option, -} + let mut request = Request::new(()); + *request.method_mut() = http::Method::GET; + *request.headers_mut() = headers; + *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *request.version_mut() = http::Version::HTTP_11; -impl From for ErrorResponse { - fn from(error_code: StatusCode) -> Self { - ErrorResponse { - error_code, - headers: None, - body: None, - } + Ok(request) } } @@ -115,15 +152,23 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request( + self, + request: &Request, + response: Response, + ) -> StdResult; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request, Response) -> StdResult, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { - self(request) + fn on_request( + self, + request: &Request, + response: Response, + ) -> StdResult { + self(request, response) } } @@ -132,8 +177,12 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { - Ok(None) + fn on_request( + self, + _request: &Request, + response: Response, + ) -> StdResult { + Ok(response) } } @@ -191,34 +240,35 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol("Junk after client request".into())); } + let response = create_response(&result)?; let callback_result = if let Some(callback) = self.callback.take() { - callback.on_request(&result) + callback.on_request(&result, response) } else { - Ok(None) + Ok(response) }; match callback_result { - Ok(extra_headers) => { - let response = result.reply(extra_headers)?; - ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + Ok(response) => { + let mut output = vec![]; + write_response(&mut output, &response)?; + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } - Err(ErrorResponse { - error_code, - headers, - body, - }) => { - self.error_code = Some(error_code.as_u16()); - let mut response = format!( - "HTTP/1.1 {} {}\r\n", - error_code.as_str(), - error_code.canonical_reason().unwrap_or("") - ); - add_headers(&mut response, headers); - if let Some(body) = body { - response += &body; + Err(resp) => { + if resp.status().is_success() { + return Err(Error::Protocol( + "Custom response must not be successful".into(), + )); } - ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + + self.error_code = Some(resp.status().as_u16()); + + let mut output = vec![]; + write_response(&mut output, &resp)?; + if let Some(body) = resp.body() { + output.extend_from_slice(body.as_bytes()); + } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } } } @@ -226,7 +276,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(err)); + return Err(Error::Http(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); @@ -239,21 +289,21 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::super::client::Response; use super::super::machine::TryParse; + use super::create_response; use super::Request; #[test] fn request_parsing() { - const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; + const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - assert_eq!(req.path, "/script.ws"); - assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); + assert_eq!(req.uri().path(), "/script.ws"); + assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]); } #[test] fn request_replying() { - const DATA: &'static [u8] = b"\ + const DATA: &[u8] = b"\ GET /script.ws HTTP/1.1\r\n\ Host: foo.com\r\n\ Connection: upgrade\r\n\ @@ -262,21 +312,11 @@ mod tests { Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ \r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); - let _ = req.reply(None).unwrap(); - - let extra_headers = Some(vec![ - ( - String::from("MyCustomHeader"), - String::from("MyCustomValue"), - ), - (String::from("MyVersion"), String::from("LOL")), - ]); - let reply = req.reply(extra_headers).unwrap(); - let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); + let response = create_response(&req).unwrap(); + assert_eq!( - req.headers.find_first("MyCustomHeader"), - Some(b"MyCustomValue".as_ref()) + response.headers().get("Sec-WebSocket-Accept").unwrap(), + b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref() ); - assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); } } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index f93f911..6756f0a 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -12,7 +12,7 @@ pub use self::frame::{Frame, FrameHeader}; use crate::error::{Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; use log::*; -use std::io::{Read, Write, Error as IoError, ErrorKind as IoErrorKind}; +use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -199,7 +199,11 @@ impl FrameCodec { let len = stream.write(&self.out_buffer)?; if len == 0 { // This is the same as "Connection reset by peer" - return Err(IoError::new(IoErrorKind::ConnectionReset, "Connection reset while sending").into()) + return Err(IoError::new( + IoErrorKind::ConnectionReset, + "Connection reset while sending", + ) + .into()); } self.out_buffer.drain(0..len); } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index ba00765..5df4ba0 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -343,7 +343,7 @@ mod tests { #[test] fn display() { - let t = Message::text(format!("test")); + let t = Message::text("test".to_owned()); assert_eq!(t.to_string(), "test".to_owned()); let bin = Message::binary(vec![0, 1, 3, 4, 241]); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5eddf75..3997803 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -280,7 +280,9 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol("Sending after closing is not allowed".into())); + return Err(Error::Protocol( + "Sending after closing is not allowed".into(), + )); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -378,7 +380,9 @@ impl WebSocketContext { { if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? { if !self.state.can_read() { - return Err(Error::Protocol("Remote sent frame after having sent a Close Frame".into())); + return Err(Error::Protocol( + "Remote sent frame after having sent a Close Frame".into(), + )); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -588,7 +592,7 @@ enum WebSocketState { impl WebSocketState { /// Tell if we're allowed to process normal messages. - fn is_active(&self) -> bool { + fn is_active(self) -> bool { match self { WebSocketState::Active => true, _ => false, @@ -598,16 +602,15 @@ impl WebSocketState { /// Tell if we should process incoming data. Note that if we send a close frame /// but the remote hasn't confirmed, they might have sent data before they receive our /// close frame, so we should still pass those to client code, hence ClosedByUs is valid. - fn can_read(&self) -> bool { + fn can_read(self) -> bool { match self { - WebSocketState::Active | - WebSocketState::ClosedByUs => true, + WebSocketState::Active | WebSocketState::ClosedByUs => true, _ => false, } } /// Check if the state is active, return error if not. - fn check_active(&self) -> Result<()> { + fn check_active(self) -> Result<()> { match self { WebSocketState::Terminated => Err(Error::AlreadyClosed), _ => Ok(()), diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index 3993e81..d8e20e5 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -39,13 +39,12 @@ fn test_no_send_after_close() { client_handler.close(None).unwrap(); // send close to client - let err = client_handler - .write_message(Message::Text("Hello WebSocket".into())); + let err = client_handler.write_message(Message::Text("Hello WebSocket".into())); - assert!( err.is_err() ); + assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => { assert_eq!( "Sending after closing is not allowed", s )} + Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), e => panic!("unexpected error: {:?}", e), }