diff --git a/Cargo.toml b/Cargo.toml index 8a733cd..7262766 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ tls = ["native-tls"] base64 = "0.10.0" byteorder = "1.2.3" bytes = "0.4.8" +http = "0.1.17" httparse = "1.3.1" input_buffer = "0.2.0" log = "0.4.2" diff --git a/examples/callback-error.rs b/examples/callback-error.rs new file mode 100644 index 0000000..3072460 --- /dev/null +++ b/examples/callback-error.rs @@ -0,0 +1,24 @@ +extern crate tungstenite; + +use std::thread::spawn; +use std::net::TcpListener; + +use tungstenite::accept_hdr; +use tungstenite::handshake::server::{Request, ErrorResponse}; +use tungstenite::http::StatusCode; + +fn main() { + let server = TcpListener::bind("127.0.0.1:3012").unwrap(); + for stream in server.incoming() { + spawn(move || { + let callback = |_req: &Request| { + Err(ErrorResponse { + error_code: StatusCode::FORBIDDEN, + headers: None, + body: Some("Access denied".into()), + }) + }; + accept_hdr(stream.unwrap(), callback).unwrap_err(); + }); + } +} diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 11a8ffb..94649e1 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -3,9 +3,11 @@ use std::fmt::Write as FmtWrite; use std::io::{Read, Write}; use std::marker::PhantomData; +use std::result::Result as StdResult; use httparse; use httparse::Status; +use http::StatusCode; use error::{Error, Result}; use protocol::{WebSocket, WebSocketConfig, Role}; @@ -35,16 +37,21 @@ impl Request { Sec-WebSocket-Accept: {}\r\n", convert_key(key)? ); - if let Some(eh) = extra_headers { - for (k, v) in eh { - write!(reply, "{}: {}\r\n", k, v).unwrap(); - } - } - write!(reply, "\r\n").unwrap(); + add_headers(&mut reply, extra_headers); Ok(reply.into()) } } +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) { + if let Some(eh) = extra_headers { + for (k, v) in eh { + write!(reply, "{}: {}\r\n", k, v).unwrap(); + } + } + write!(reply, "\r\n").unwrap(); +} + + impl TryParse for Request { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; @@ -71,6 +78,30 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request { } } +/// Extra headers for responses. +pub type ExtraHeaders = Vec<(String, String)>; + +/// An error response sent to the client. +#[derive(Debug)] +pub struct ErrorResponse { + /// HTTP error code. + pub error_code: StatusCode, + /// Extra response headers, if any. + pub headers: Option<ExtraHeaders>, + /// REsponse body, if any. + pub body: Option<String>, +} + +impl From<StatusCode> for ErrorResponse { + fn from(error_code: StatusCode) -> Self { + ErrorResponse { + error_code, + headers: None, + body: None, + } + } +} + /// The callback trait. /// /// The callback is called when the server receives an incoming WebSocket @@ -81,11 +112,11 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>>; + fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>; } -impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>> { - fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>> { +impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { + fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { self(request) } } @@ -95,7 +126,7 @@ impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, S pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> Result<Option<Vec<(String, String)>>> { + fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> { Ok(None) } } @@ -110,6 +141,8 @@ pub struct ServerHandshake<S, C> { callback: Option<C>, /// WebSocket configuration. config: Option<WebSocketConfig>, + /// Error code/flag. If set, an error will be returned after sending response to the client. + error_code: Option<u16>, /// Internal stream type. _marker: PhantomData<S>, } @@ -123,7 +156,12 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> { trace!("Server handshake initiated."); MidHandshake { machine: HandshakeMachine::start_read(stream), - role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData }, + role: ServerHandshake { + callback: Some(callback), + config, + error_code: None, + _marker: PhantomData + }, } } } @@ -141,24 +179,48 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { if !tail.is_empty() { return Err(Error::Protocol("Junk after client request".into())) } - let extra_headers = { - if let Some(callback) = self.callback.take() { - callback.on_request(&result)? - } else { - None - } + + let callback_result = if let Some(callback) = self.callback.take() { + callback.on_request(&result) + } else { + Ok(None) }; - let response = result.reply(extra_headers)?; - ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + + match callback_result { + Ok(extra_headers) => { + let response = result.reply(extra_headers)?; + ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + } + + Err(ErrorResponse { error_code, headers, body }) => { + self.error_code= Some(error_code.as_u16()); + let mut response = format!( + "HTTP/1.1 {} {}\r\n", + error_code.as_str(), + error_code.canonical_reason().unwrap_or("") + ); + add_headers(&mut response, headers); + if let Some(body) = body { + response += &body; + } + ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) + } + } } + StageResult::DoneWriting(stream) => { - debug!("Server handshake done."); - let websocket = WebSocket::from_raw_socket( - stream, - Role::Server, - self.config.clone(), - ); - ProcessingResult::Done(websocket) + if let Some(err) = self.error_code.take() { + debug!("Server handshake failed."); + return Err(Error::Http(err)); + } else { + debug!("Server handshake done."); + let websocket = WebSocket::from_raw_socket( + stream, + Role::Server, + self.config.clone(), + ); + ProcessingResult::Done(websocket) + } } }) } diff --git a/src/lib.rs b/src/lib.rs index 47dfa54..547c454 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,8 @@ extern crate url; extern crate utf8; #[cfg(feature="tls")] extern crate native_tls; +pub extern crate http; + pub mod error; pub mod protocol; pub mod client;