You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.0 KiB
137 lines
4.0 KiB
//! WebSocket handshake control.
|
|
|
|
pub mod client;
|
|
pub mod headers;
|
|
pub mod server;
|
|
|
|
#[allow(missing_docs)]
|
|
pub mod machine;
|
|
|
|
use std::{
|
|
error::Error as ErrorTrait,
|
|
fmt,
|
|
io::{Read, Write},
|
|
};
|
|
|
|
use sha1::{Digest, Sha1};
|
|
|
|
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
|
|
use crate::error::Error;
|
|
|
|
/// A WebSocket handshake.
|
|
#[derive(Debug)]
|
|
pub struct MidHandshake<Role: HandshakeRole> {
|
|
role: Role,
|
|
machine: HandshakeMachine<Role::InternalStream>,
|
|
}
|
|
|
|
impl<Role: HandshakeRole> MidHandshake<Role> {
|
|
/// Allow access to machine
|
|
pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
|
|
&self.machine
|
|
}
|
|
|
|
/// Allow mutable access to machine
|
|
pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
|
|
&mut self.machine
|
|
}
|
|
|
|
/// Restarts the handshake process.
|
|
pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
|
|
let mut mach = self.machine;
|
|
loop {
|
|
mach = match mach.single_round()? {
|
|
RoundResult::WouldBlock(m) => {
|
|
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
|
|
}
|
|
RoundResult::Incomplete(m) => m,
|
|
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
|
|
ProcessingResult::Continue(m) => m,
|
|
ProcessingResult::Done(result) => return Ok(result),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A handshake result.
|
|
pub enum HandshakeError<Role: HandshakeRole> {
|
|
/// Handshake was interrupted (would block).
|
|
Interrupted(MidHandshake<Role>),
|
|
/// Handshake failed.
|
|
Failure(Error),
|
|
}
|
|
|
|
impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
match *self {
|
|
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
|
|
HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({:?})", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
match *self {
|
|
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
|
|
HandshakeError::Failure(ref e) => write!(f, "{}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
|
|
|
|
impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
|
|
fn from(err: Error) -> Self {
|
|
HandshakeError::Failure(err)
|
|
}
|
|
}
|
|
|
|
/// Handshake role.
|
|
pub trait HandshakeRole {
|
|
#[doc(hidden)]
|
|
type IncomingData: TryParse;
|
|
#[doc(hidden)]
|
|
type InternalStream: Read + Write;
|
|
#[doc(hidden)]
|
|
type FinalResult;
|
|
#[doc(hidden)]
|
|
fn stage_finished(
|
|
&mut self,
|
|
finish: StageResult<Self::IncomingData, Self::InternalStream>,
|
|
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
|
|
}
|
|
|
|
/// Stage processing result.
|
|
#[doc(hidden)]
|
|
#[derive(Debug)]
|
|
pub enum ProcessingResult<Stream, FinalResult> {
|
|
Continue(HandshakeMachine<Stream>),
|
|
Done(FinalResult),
|
|
}
|
|
|
|
/// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
|
|
///
|
|
/// This function can be used to perform a handshake before passing a raw TCP stream to
|
|
/// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
|
|
pub fn derive_accept_key(request_key: &[u8]) -> String {
|
|
// ... field is constructed by concatenating /key/ ...
|
|
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
|
|
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
let mut sha1 = Sha1::default();
|
|
sha1.update(request_key);
|
|
sha1.update(WS_GUID);
|
|
base64::encode(&sha1.finalize())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::derive_accept_key;
|
|
|
|
#[test]
|
|
fn key_conversion() {
|
|
// example from RFC 6455
|
|
assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
|
}
|
|
}
|
|
|