server: let callback return HTTP error messages

Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
pull/59/head
Alexey Galakhov 6 years ago
parent 8ed73fd28a
commit 6f132208ee
  1. 1
      Cargo.toml
  2. 24
      examples/callback-error.rs
  3. 86
      src/handshake/server.rs
  4. 2
      src/lib.rs

@ -19,6 +19,7 @@ tls = ["native-tls"]
base64 = "0.10.0" base64 = "0.10.0"
byteorder = "1.2.3" byteorder = "1.2.3"
bytes = "0.4.8" bytes = "0.4.8"
http = "0.1.17"
httparse = "1.3.1" httparse = "1.3.1"
input_buffer = "0.2.0" input_buffer = "0.2.0"
log = "0.4.2" log = "0.4.2"

@ -0,0 +1,24 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener;
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, ErrorResponse};
use tungstenite::http::StatusCode;
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |_req: &Request| {
Err(ErrorResponse {
error_code: StatusCode::FORBIDDEN,
headers: None,
body: Some("Access denied".into()),
})
};
accept_hdr(stream.unwrap(), callback).unwrap_err();
});
}
}

@ -3,9 +3,11 @@
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::result::Result as StdResult;
use httparse; use httparse;
use httparse::Status; use httparse::Status;
use http::StatusCode;
use error::{Error, Result}; use error::{Error, Result};
use protocol::{WebSocket, WebSocketConfig, Role}; use protocol::{WebSocket, WebSocketConfig, Role};
@ -35,16 +37,21 @@ impl Request {
Sec-WebSocket-Accept: {}\r\n", Sec-WebSocket-Accept: {}\r\n",
convert_key(key)? convert_key(key)?
); );
add_headers(&mut reply, extra_headers);
Ok(reply.into())
}
}
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
if let Some(eh) = extra_headers { if let Some(eh) = extra_headers {
for (k, v) in eh { for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap(); write!(reply, "{}: {}\r\n", k, v).unwrap();
} }
} }
write!(reply, "\r\n").unwrap(); write!(reply, "\r\n").unwrap();
Ok(reply.into())
}
} }
impl TryParse for Request { impl TryParse for Request {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> { fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -71,6 +78,30 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
} }
} }
/// Extra headers for responses.
pub type ExtraHeaders = Vec<(String, String)>;
/// An error response sent to the client.
#[derive(Debug)]
pub struct ErrorResponse {
/// HTTP error code.
pub error_code: StatusCode,
/// Extra response headers, if any.
pub headers: Option<ExtraHeaders>,
/// REsponse body, if any.
pub body: Option<String>,
}
impl From<StatusCode> for ErrorResponse {
fn from(error_code: StatusCode) -> Self {
ErrorResponse {
error_code,
headers: None,
body: None,
}
}
}
/// The callback trait. /// The callback trait.
/// ///
/// The callback is called when the server receives an incoming WebSocket /// The callback is called when the server receives an incoming WebSocket
@ -81,11 +112,11 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>>; fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
} }
impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>> { impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>> { fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
self(request) self(request)
} }
} }
@ -95,7 +126,7 @@ impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, S
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> Result<Option<Vec<(String, String)>>> { fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
Ok(None) Ok(None)
} }
} }
@ -110,6 +141,8 @@ pub struct ServerHandshake<S, C> {
callback: Option<C>, callback: Option<C>,
/// WebSocket configuration. /// WebSocket configuration.
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type. /// Internal stream type.
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -123,7 +156,12 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
trace!("Server handshake initiated."); trace!("Server handshake initiated.");
MidHandshake { MidHandshake {
machine: HandshakeMachine::start_read(stream), machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData }, role: ServerHandshake {
callback: Some(callback),
config,
error_code: None,
_marker: PhantomData
},
} }
} }
} }
@ -141,17 +179,40 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
if !tail.is_empty() { if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into())) return Err(Error::Protocol("Junk after client request".into()))
} }
let extra_headers = {
if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result)? callback.on_request(&result)
} else { } else {
None Ok(None)
}
}; };
match callback_result {
Ok(extra_headers) => {
let response = result.reply(extra_headers)?; let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
} }
Err(ErrorResponse { error_code, headers, body }) => {
self.error_code= Some(error_code.as_u16());
let mut response = format!(
"HTTP/1.1 {} {}\r\n",
error_code.as_str(),
error_code.canonical_reason().unwrap_or("")
);
add_headers(&mut response, headers);
if let Some(body) = body {
response += &body;
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
}
}
StageResult::DoneWriting(stream) => { StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_code.take() {
debug!("Server handshake failed.");
return Err(Error::Http(err));
} else {
debug!("Server handshake done."); debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket( let websocket = WebSocket::from_raw_socket(
stream, stream,
@ -160,6 +221,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
); );
ProcessingResult::Done(websocket) ProcessingResult::Done(websocket)
} }
}
}) })
} }
} }

@ -22,6 +22,8 @@ extern crate url;
extern crate utf8; extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls; #[cfg(feature="tls")] extern crate native_tls;
pub extern crate http;
pub mod error; pub mod error;
pub mod protocol; pub mod protocol;
pub mod client; pub mod client;

Loading…
Cancel
Save