@ -1,7 +1,6 @@
//! Server handshake machine.
use std ::fmt ::Write as FmtWrite ;
use std ::io ::{ Read , Write } ;
use std ::io ::{ self , Read , Write } ;
use std ::marker ::PhantomData ;
use std ::result ::Result as StdResult ;
@ -15,31 +14,84 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate ::error ::{ Error , Result } ;
use crate ::protocol ::{ Role , WebSocket , WebSocketConfig } ;
/// Reply to the response.
fn reply ( request : & Request < ( ) > , extra_headers : Option < HeaderMap > ) -> Result < Vec < u8 > > {
/// Create a response for the request.
pub fn create_response ( request : & Request < ( ) > ) -> Result < Response < ( ) > > {
if request . method ( ) ! = http ::Method ::GET {
return Err ( Error ::Protocol ( "Method is not GET" . into ( ) ) ) ;
}
if request . version ( ) < http ::Version ::HTTP_11 {
return Err ( Error ::Protocol (
"HTTP version should be 1.1 or higher" . into ( ) ,
) ) ;
}
if ! request
. headers ( )
. get ( "Connection" )
. and_then ( | h | h . to_str ( ) . ok ( ) )
. map ( | h | h . eq_ignore_ascii_case ( "Upgrade" ) )
. unwrap_or ( false )
{
return Err ( Error ::Protocol (
"No \"Connection: upgrade\" in client request" . into ( ) ,
) ) ;
}
if ! request
. headers ( )
. get ( "Upgrade" )
. and_then ( | h | h . to_str ( ) . ok ( ) )
. map ( | h | h . eq_ignore_ascii_case ( "websocket" ) )
. unwrap_or ( false )
{
return Err ( Error ::Protocol (
"No \"Upgrade: websocket\" in client request" . into ( ) ,
) ) ;
}
if ! request
. headers ( )
. get ( "Sec-WebSocket-Version" )
. map ( | h | h = = "13" )
. unwrap_or ( false )
{
return Err ( Error ::Protocol (
"No \"Sec-WebSocket-Version: 13\" in client request" . into ( ) ,
) ) ;
}
let key = request
. headers ( )
. get ( "Sec-WebSocket-Key" )
. ok_or_else ( | | Error ::Protocol ( "Missing Sec-WebSocket-Key" . into ( ) ) ) ? ;
let mut reply = format! (
" \
HTTP / 1.1 101 Switching Protocols \ r \ n \
Connection : Upgrade \ r \ n \
Upgrade : websocket \ r \ n \
Sec - WebSocket - Accept : { } \ r \ n " ,
convert_key ( key . as_bytes ( ) ) ?
) ;
add_headers ( & mut reply , extra_headers . as_ref ( ) ) ? ;
Ok ( reply . into ( ) )
}
fn add_headers ( reply : & mut impl FmtWrite , extra_headers : Option < & HeaderMap > ) -> Result < ( ) > {
if let Some ( eh ) = extra_headers {
for ( k , v ) in eh {
writeln! ( reply , "{}: {}\r" , k , v . to_str ( ) ? ) . unwrap ( ) ;
let mut response = Response ::builder ( ) ;
response . status ( StatusCode ::SWITCHING_PROTOCOLS ) ;
response . version ( request . version ( ) ) ;
response . header ( "Connection" , "Upgrade" ) ;
response . header ( "Upgrade" , "websocket" ) ;
response . header ( "Sec-WebSocket-Accept" , convert_key ( key . as_bytes ( ) ) ? ) ;
Ok ( response . body ( ( ) ) ? )
}
// Assumes that this is a valid response
fn write_response < T > ( w : & mut dyn io ::Write , response : & Response < T > ) -> Result < ( ) > {
writeln! (
w ,
"{version:?} {status} {reason}\r" ,
version = response . version ( ) ,
status = response . status ( ) ,
reason = response . status ( ) . canonical_reason ( ) . unwrap_or ( "" ) ,
) ? ;
for ( k , v ) in response . headers ( ) {
writeln! ( w , "{}: {}\r" , k , v . to_str ( ) ? ) . unwrap ( ) ;
}
writeln! ( reply , "\r" ) . unwrap ( ) ;
writeln! ( w , "\r" ) ? ;
Ok ( ( ) )
}
@ -94,18 +146,20 @@ pub trait Callback: Sized {
fn on_request (
self ,
request : & Request < ( ) > ,
) -> StdResult < Option < HeaderMap > , Response < Option < String > > > ;
response : Response < ( ) > ,
) -> StdResult < Response < ( ) > , Response < Option < String > > > ;
}
impl < F > Callback for F
where
F : FnOnce ( & Request < ( ) > ) -> StdResult < Option < HeaderMap > , Response < Option < String > > > ,
F : FnOnce ( & Request < ( ) > , Response < ( ) > ) -> StdResult < Response < ( ) > , Response < Option < String > > > ,
{
fn on_request (
self ,
request : & Request < ( ) > ,
) -> StdResult < Option < HeaderMap > , Response < Option < String > > > {
self ( request )
response : Response < ( ) > ,
) -> StdResult < Response < ( ) > , Response < Option < String > > > {
self ( request , response )
}
}
@ -117,8 +171,9 @@ impl Callback for NoCallback {
fn on_request (
self ,
_request : & Request < ( ) > ,
) -> StdResult < Option < HeaderMap > , Response < Option < String > > > {
Ok ( None )
response : Response < ( ) > ,
) -> StdResult < Response < ( ) > , Response < Option < String > > > {
Ok ( response )
}
}
@ -176,16 +231,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err ( Error ::Protocol ( "Junk after client request" . into ( ) ) ) ;
}
let response = create_response ( & result ) ? ;
let callback_result = if let Some ( callback ) = self . callback . take ( ) {
callback . on_request ( & result )
callback . on_request ( & result , response )
} else {
Ok ( Non e)
Ok ( respons e)
} ;
match callback_result {
Ok ( extra_headers ) = > {
let response = reply ( & result , extra_headers ) ? ;
ProcessingResult ::Continue ( HandshakeMachine ::start_write ( stream , response ) )
Ok ( response ) = > {
let mut output = vec! [ ] ;
write_response ( & mut output , & response ) ? ;
ProcessingResult ::Continue ( HandshakeMachine ::start_write ( stream , output ) )
}
Err ( resp ) = > {
@ -196,17 +253,13 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
}
self . error_code = Some ( resp . status ( ) . as_u16 ( ) ) ;
let mut response = format! (
"{version:?} {status} {reason}\r\n" ,
version = resp . version ( ) ,
status = resp . status ( ) . as_u16 ( ) ,
reason = resp . status ( ) . canonical_reason ( ) . unwrap_or ( "" )
) ;
add_headers ( & mut response , Some ( resp . headers ( ) ) ) ? ;
let mut output = vec! [ ] ;
write_response ( & mut output , & resp ) ? ;
if let Some ( body ) = resp . body ( ) {
response + = & body ;
output . extend_from_slice ( body . as_bytes ( ) ) ;
}
ProcessingResult ::Continue ( HandshakeMachine ::start_write ( stream , response ) )
ProcessingResult ::Continue ( HandshakeMachine ::start_write ( stream , output ) )
}
}
}
@ -228,10 +281,8 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[ cfg(test) ]
mod tests {
use super ::super ::machine ::TryParse ;
use super ::reply ;
use super ::{ HeaderMap , Request } ;
use http ::header ::HeaderName ;
use http ::Response ;
use super ::create_response ;
use super ::Request ;
#[ test ]
fn request_parsing ( ) {
@ -252,27 +303,11 @@ mod tests {
Sec - WebSocket - Key : dGhlIHNhbXBsZSBub25jZQ = = \ r \ n \
\ r \ n " ;
let ( _ , req ) = Request ::try_parse ( DATA ) . unwrap ( ) . unwrap ( ) ;
let _ = reply ( & req , None ) . unwrap ( ) ;
let response = create_response ( & req ) . unwrap ( ) ;
let extra_headers = {
let mut headers = HeaderMap ::new ( ) ;
headers . insert (
HeaderName ::from_bytes ( & b" MyCustomHeader " [ .. ] ) . unwrap ( ) ,
"MyCustomValue" . parse ( ) . unwrap ( ) ,
) ;
headers . insert (
HeaderName ::from_bytes ( & b" MyVersion " [ .. ] ) . unwrap ( ) ,
"LOL" . parse ( ) . unwrap ( ) ,
) ;
headers
} ;
let reply = reply ( & req , Some ( extra_headers ) ) . unwrap ( ) ;
let ( _ , req ) = Response ::try_parse ( & reply ) . unwrap ( ) . unwrap ( ) ;
assert_eq! (
req . headers ( ) . get ( "MyCustomHeader " ) . unwrap ( ) ,
b" MyCustomValue " . as_ref ( )
response . headers ( ) . get ( "Sec-WebSocket-Accept" ) . unwrap ( ) ,
b" s3pPLMBiTxaQ9kYGzzhZRbK+xOo= " . as_ref ( )
) ;
assert_eq! ( req . headers ( ) . get ( "MyVersion" ) . unwrap ( ) , b" LOL " . as_ref ( ) ) ;
}
}