From b9eae28e30be9c183278f041a39b8250e7ea9923 Mon Sep 17 00:00:00 2001 From: Niko PLP Date: Fri, 23 Jun 2023 01:14:13 +0300 Subject: [PATCH] allow returning success response from accept callback --- src/handshake/server.rs | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/handshake/server.rs b/src/handshake/server.rs index bc072ce..5edbbd2 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -156,23 +156,15 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request( - self, - request: &Request, - response: Response, - ) -> StdResult; + fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse>; } impl Callback for F where - F: FnOnce(&Request, Response) -> StdResult, + F: FnOnce(&Request) -> StdResult<(), ErrorResponse>, { - fn on_request( - self, - request: &Request, - response: Response, - ) -> StdResult { - self(request, response) + fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse> { + self(request) } } @@ -181,12 +173,8 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request( - self, - _request: &Request, - response: Response, - ) -> StdResult { - Ok(response) + fn on_request(self, _request: &Request) -> StdResult<(), ErrorResponse> { + Ok(()) } } @@ -240,24 +228,24 @@ impl HandshakeRole for ServerHandshake { return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); } - let response = create_response(&result)?; let callback_result = if let Some(callback) = self.callback.take() { - callback.on_request(&result, response) + callback.on_request(&result) } else { - Ok(response) + Ok(()) }; match callback_result { - Ok(response) => { + Ok(_) => { + let response = create_response(&result)?; let mut output = vec![]; write_response(&mut output, &response)?; ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) } Err(resp) => { - if resp.status().is_success() { - return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); - } + // if resp.status().is_success() { + // return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); + // } self.error_response = Some(resp); let resp = self.error_response.as_ref().unwrap();