|
|
@ -30,7 +30,7 @@ pub type Request = HttpRequest<()>; |
|
|
|
pub type Response = HttpResponse<()>; |
|
|
|
pub type Response = HttpResponse<()>; |
|
|
|
|
|
|
|
|
|
|
|
/// Server error response type.
|
|
|
|
/// Server error response type.
|
|
|
|
pub type ErrorResponse = HttpResponse<Option<String>>; |
|
|
|
pub type ErrorResponse = HttpResponse<Option<Vec<u8>>>; |
|
|
|
|
|
|
|
|
|
|
|
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { |
|
|
|
fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> { |
|
|
|
if request.method() != http::Method::GET { |
|
|
|
if request.method() != http::Method::GET { |
|
|
@ -156,23 +156,15 @@ pub trait Callback: Sized { |
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
/// Called whenever the server read the request from the client and is ready to reply to it.
|
|
|
|
/// May return additional reply headers.
|
|
|
|
/// May return additional reply headers.
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
/// Returning an error resulting in rejecting the incoming connection.
|
|
|
|
fn on_request( |
|
|
|
fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse>; |
|
|
|
self, |
|
|
|
|
|
|
|
request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse>; |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl<F> Callback for F |
|
|
|
impl<F> Callback for F |
|
|
|
where |
|
|
|
where |
|
|
|
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>, |
|
|
|
F: FnOnce(&Request) -> StdResult<(), ErrorResponse>, |
|
|
|
{ |
|
|
|
{ |
|
|
|
fn on_request( |
|
|
|
fn on_request(self, request: &Request) -> StdResult<(), ErrorResponse> { |
|
|
|
self, |
|
|
|
self(request) |
|
|
|
request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse> { |
|
|
|
|
|
|
|
self(request, response) |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -181,12 +173,8 @@ where |
|
|
|
pub struct NoCallback; |
|
|
|
pub struct NoCallback; |
|
|
|
|
|
|
|
|
|
|
|
impl Callback for NoCallback { |
|
|
|
impl Callback for NoCallback { |
|
|
|
fn on_request( |
|
|
|
fn on_request(self, _request: &Request) -> StdResult<(), ErrorResponse> { |
|
|
|
self, |
|
|
|
Ok(()) |
|
|
|
_request: &Request, |
|
|
|
|
|
|
|
response: Response, |
|
|
|
|
|
|
|
) -> StdResult<Response, ErrorResponse> { |
|
|
|
|
|
|
|
Ok(response) |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -240,24 +228,24 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); |
|
|
|
return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let response = create_response(&result)?; |
|
|
|
|
|
|
|
let callback_result = if let Some(callback) = self.callback.take() { |
|
|
|
let callback_result = if let Some(callback) = self.callback.take() { |
|
|
|
callback.on_request(&result, response) |
|
|
|
callback.on_request(&result) |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
Ok(response) |
|
|
|
Ok(()) |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
match callback_result { |
|
|
|
match callback_result { |
|
|
|
Ok(response) => { |
|
|
|
Ok(_) => { |
|
|
|
|
|
|
|
let response = create_response(&result)?; |
|
|
|
let mut output = vec![]; |
|
|
|
let mut output = vec![]; |
|
|
|
write_response(&mut output, &response)?; |
|
|
|
write_response(&mut output, &response)?; |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Err(resp) => { |
|
|
|
Err(resp) => { |
|
|
|
if resp.status().is_success() { |
|
|
|
// if resp.status().is_success() {
|
|
|
|
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); |
|
|
|
// return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
|
|
|
|
} |
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
self.error_response = Some(resp); |
|
|
|
self.error_response = Some(resp); |
|
|
|
let resp = self.error_response.as_ref().unwrap(); |
|
|
|
let resp = self.error_response.as_ref().unwrap(); |
|
|
@ -266,7 +254,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
write_response(&mut output, resp)?; |
|
|
|
write_response(&mut output, resp)?; |
|
|
|
|
|
|
|
|
|
|
|
if let Some(body) = resp.body() { |
|
|
|
if let Some(body) = resp.body() { |
|
|
|
output.extend_from_slice(body.as_bytes()); |
|
|
|
output.extend(body); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
|
ProcessingResult::Continue(HandshakeMachine::start_write(stream, output)) |
|
|
@ -279,7 +267,6 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> { |
|
|
|
debug!("Server handshake failed."); |
|
|
|
debug!("Server handshake failed."); |
|
|
|
|
|
|
|
|
|
|
|
let (parts, body) = err.into_parts(); |
|
|
|
let (parts, body) = err.into_parts(); |
|
|
|
let body = body.map(|b| b.as_bytes().to_vec()); |
|
|
|
|
|
|
|
return Err(Error::Http(http::Response::from_parts(parts, body))); |
|
|
|
return Err(Error::Http(http::Response::from_parts(parts, body))); |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
debug!("Server handshake done."); |
|
|
|
debug!("Server handshake done."); |
|
|
|