Compare commits

...

4 Commits

  1. 6
      Cargo.toml
  2. 43
      src/handshake/server.rs

@ -1,7 +1,7 @@
[package] [package]
name = "tungstenite" name = "ng-tungstenite"
description = "Lightweight stream-based WebSocket implementation" description = "fork of tungstenite for Nextgraph.org"
categories = ["web-programming::websocket", "network-programming"] categories = []
keywords = ["websocket", "io", "web"] keywords = ["websocket", "io", "web"]
authors = ["Alexey Galakhov", "Daniel Abramov"] authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"

@ -30,7 +30,7 @@ pub type Request = HttpRequest<()>;
pub type Response = HttpResponse<()>; pub type Response = HttpResponse<()>;
/// Server error response type. /// Server error response type.
pub type ErrorResponse = HttpResponse<Option<String>>; pub type ErrorResponse = HttpResponse<Option<Vec<u8>>>;
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
if request.method() != http::Method::GET { if request.method() != http::Method::GET {
@ -156,23 +156,15 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse>;
self,
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>, F: FnOnce(&Request) -> StdResult<(), ErrorResponse>,
{ {
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse> {
self, self(request)
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
self(request, response)
} }
} }
@ -181,12 +173,8 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request( fn on_request(self, _request: &Request) -> StdResult<(), ErrorResponse> {
self, Ok(())
_request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
Ok(response)
} }
} }
@ -240,24 +228,24 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result)
} else { } else {
Ok(response) Ok(())
}; };
match callback_result { match callback_result {
Ok(response) => { Ok(_) => {
let response = create_response(&result)?;
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &response)?; write_response(&mut output, &response)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
Err(resp) => { Err(resp) => {
if resp.status().is_success() { // if resp.status().is_success() {
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); // return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
} // }
self.error_response = Some(resp); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap(); let resp = self.error_response.as_ref().unwrap();
@ -266,7 +254,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
write_response(&mut output, resp)?; write_response(&mut output, resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend(body);
} }
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
@ -279,7 +267,6 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
debug!("Server handshake failed."); debug!("Server handshake failed.");
let (parts, body) = err.into_parts(); let (parts, body) = err.into_parts();
let body = body.map(|b| b.as_bytes().to_vec());
return Err(Error::Http(http::Response::from_parts(parts, body))); return Err(Error::Http(http::Response::from_parts(parts, body)));
} else { } else {
debug!("Server handshake done."); debug!("Server handshake done.");

Loading…
Cancel
Save