allow returning success response from accept callback

nextgraph
Niko PLP 2 years ago
parent 869a67ca0b
commit b9eae28e30
  1. 38
      src/handshake/server.rs

@ -156,23 +156,15 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it. /// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers. /// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection. /// Returning an error resulting in rejecting the incoming connection.
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse>;
self,
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse>;
} }
impl<F> Callback for F impl<F> Callback for F
where where
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>, F: FnOnce(&Request) -> StdResult<(), ErrorResponse>,
{ {
fn on_request( fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse> {
self, self(request)
request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
self(request, response)
} }
} }
@ -181,12 +173,8 @@ where
pub struct NoCallback; pub struct NoCallback;
impl Callback for NoCallback { impl Callback for NoCallback {
fn on_request( fn on_request(self, _request: &Request) -> StdResult<(), ErrorResponse> {
self, Ok(())
_request: &Request,
response: Response,
) -> StdResult<Response, ErrorResponse> {
Ok(response)
} }
} }
@ -240,24 +228,24 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
} }
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() { let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response) callback.on_request(&result)
} else { } else {
Ok(response) Ok(())
}; };
match callback_result { match callback_result {
Ok(response) => { Ok(_) => {
let response = create_response(&result)?;
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &response)?; write_response(&mut output, &response)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
} }
Err(resp) => { Err(resp) => {
if resp.status().is_success() { // if resp.status().is_success() {
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); // return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
} // }
self.error_response = Some(resp); self.error_response = Some(resp);
let resp = self.error_response.as_ref().unwrap(); let resp = self.error_response.as_ref().unwrap();

Loading…
Cancel
Save