Refactors extension handling

pull/144/head
SirCipher 5 years ago
parent 42e6fd8e68
commit 13cfb4db65
  1. 25
      src/client.rs
  2. 651
      src/extensions/deflate.rs
  3. 170
      src/extensions/mod.rs
  4. 19
      src/extensions/uncompressed.rs
  5. 61
      src/handshake/client.rs
  6. 38
      src/handshake/server.rs
  7. 111
      src/protocol/mod.rs
  8. 23
      src/server.rs
  9. 7
      tests/connection_reset.rs

@ -66,8 +66,6 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
@ -88,13 +86,12 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req, Ext>(
pub fn connect_with_config<Req>(
request: Req,
config: Option<WebSocketConfig<Ext>>,
) -> Result<(WebSocket<AutoStream, Ext>, Response)>
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)>
where
Req: IntoClientRequest,
Ext: WebSocketExtension,
{
let request: Request = request.into_client_request()?;
let uri = request.uri();
@ -128,9 +125,7 @@ where
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream, UncompressedExt>, Response)> {
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
}
@ -167,15 +162,14 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<Stream, Req, Ext>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig<Ext>>,
) -> StdResult<(WebSocket<Stream, Ext>, Response), HandshakeError<ClientHandshake<Stream, Ext>>>
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: IntoClientRequest,
Ext: WebSocketExtension,
{
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}
@ -188,10 +182,7 @@ where
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<
(WebSocket<Stream, UncompressedExt>, Response),
HandshakeError<ClientHandshake<Stream, UncompressedExt>>,
>
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: IntoClientRequest,

@ -231,41 +231,39 @@ impl DeflateExt {
enabled: false,
config,
fragment_buffer: FragmentBuffer::new(config.max_message_size),
inflator: Inflator::new(),
deflator: Deflator::new(Compression::fast()),
inflator: Inflator::new(config.max_window_bits),
deflator: Deflator::new(config.compression_level, config.max_window_bits),
uncompressed_extension: UncompressedExt::new(Some(config.max_message_size())),
}
}
}
fn parse_window_parameter<'a>(
&mut self,
mut param_iter: impl Iterator<Item = &'a str>,
) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() {
match window_bits_str.trim().parse() {
Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
if window_bits != self.config.max_window_bits() {
self.config.max_window_bits = window_bits;
Ok(Some(window_bits))
} else {
Ok(None)
}
fn parse_window_parameter<'a>(
mut param_iter: impl Iterator<Item = &'a str>,
max_window_bits: u8,
) -> Result<Option<u8>, String> {
if let Some(window_bits_str) = param_iter.next() {
match window_bits_str.trim().parse() {
Ok(window_bits) => {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
if window_bits != max_window_bits {
Ok(Some(window_bits))
} else {
Err(format!("Invalid window parameter: {}", window_bits))
Ok(None)
}
} else {
Err(format!("Invalid window parameter: {}", window_bits))
}
Err(e) => Err(e.to_string()),
}
} else {
Ok(None)
Err(e) => Err(e.to_string()),
}
} else {
Ok(None)
}
}
fn decline<T>(&mut self, res: &mut Response<T>) {
self.enabled = false;
res.headers_mut().remove(EXT_IDENT);
}
fn decline<T>(res: &mut Response<T>) {
res.headers_mut().remove(EXT_IDENT);
}
/// A permessage-deflate extension error.
@ -298,328 +296,332 @@ impl Display for DeflateExtensionError {
}
}
impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
impl From<InvalidHeaderValue> for DeflateExtensionError {
fn from(e: InvalidHeaderValue) -> Self {
DeflateExtensionError::NegotiationError(e.to_string())
}
}
impl Default for DeflateExt {
fn default() -> Self {
DeflateExt::new(Default::default())
}
}
impl WebSocketExtension for DeflateExt {
type Error = DeflateExtensionError;
fn new(max_message_size: Option<usize>) -> Self {
DeflateExt::new(DeflateConfig {
max_message_size: max_message_size.unwrap_or_else(usize::max_value),
..Default::default()
})
}
fn enabled(&self) -> bool {
self.enabled
}
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_IDENT);
let DeflateConfig {
max_window_bits,
request_no_context_takeover,
..
} = self.config;
if max_window_bits < LZ77_MAX_WINDOW_SIZE {
header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits
))
} else {
header_value.push_str("; client_max_window_bits")
}
if request_no_context_takeover {
header_value.push_str("; server_no_context_takeover")
}
request.headers_mut().append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&header_value).unwrap(),
);
request
}
fn on_receive_request<T>(
&mut self,
request: &Request<T>,
response: &mut Response<T>,
) -> Result<(), Self::Error> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len());
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => response_str.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if server_takeover {
self.decline(response);
} else {
server_takeover = true;
if self.config.accept_no_context_takeover() {
self.config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
///
pub fn on_response<T>(
response: &Response<T>,
config: &mut DeflateConfig,
) -> Result<bool, DeflateExtensionError> {
let mut extension_name = false;
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_window_bits = false;
let mut client_max_window_bits = false;
let mut enabled = false;
let DeflateConfig {
max_window_bits,
accept_no_context_takeover,
compress_reset,
decompress_reset,
..
} = config;
for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() {
match header.to_str() {
Ok(header) => {
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
if extension_name {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate"
)));
} else {
enabled = true;
extension_name = true;
}
"client_no_context_takeover" => {
if client_takeover {
self.decline(response);
}
"server_no_context_takeover" => {
if server_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_no_context_takeover"
)));
} else {
server_takeover = true;
*decompress_reset = true;
}
}
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_no_context_takeover"
)));
} else {
client_takeover = true;
if *accept_no_context_takeover {
*compress_reset = true;
} else {
client_takeover = true;
self.config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
return Err(DeflateExtensionError::NegotiationError(format!(
"The client requires context takeover."
)));
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
self.decline(response);
} else {
server_max_bits = true;
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
response_str.push_str("; ");
response_str.push_str(param)
}
Ok(None) => {}
Err(_) => {
self.decline(response);
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_max_window_bits"
)));
} else {
server_max_window_bits = true;
match parse_window_parameter(
param.split("=").skip(1),
*max_window_bits,
) {
Ok(Some(bits)) => {
*max_window_bits = bits;
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"server_max_window_bits parameter error: {}",
e
),
))
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
self.decline(response);
} else {
client_max_bits = true;
match self.parse_window_parameter(param.split('=').skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator::new_with_window_bits(bits);
response_str.push_str("; ");
response_str.push_str(param);
continue;
}
Ok(None) => {}
Err(_) => {
self.decline(response);
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_max_window_bits"
)));
} else {
client_max_window_bits = true;
match parse_window_parameter(
param.split("=").skip(1),
*max_window_bits,
) {
Ok(Some(bits)) => {
*max_window_bits = bits;
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"client_max_window_bits parameter error: {}",
e
),
))
}
response_str.push_str("; ");
response_str.push_str(&format!(
"client_max_window_bits={}",
self.config.max_window_bits()
))
}
}
_ => {
self.decline(response);
}
}
p => {
return Err(DeflateExtensionError::NegotiationError(format!(
"Unknown permessage-deflate parameter: {}",
p
)));
}
}
}
}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse extension parameter: {}",
e
)));
}
}
}
if !response_str.contains("client_no_context_takeover")
&& self.config.request_no_context_takeover()
{
self.config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
self.config.max_window_bits()
))
}
if !response_str.contains("client_max_window_bits")
&& self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
continue;
}
Ok(enabled)
}
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
///
pub fn on_request<T>(mut request: Request<T>, config: &DeflateConfig) -> Request<T> {
let mut header_value = String::from(EXT_IDENT);
self.enabled = true;
let DeflateConfig {
max_window_bits,
request_no_context_takeover,
..
} = config;
Ok(())
}
Err(e) => {
self.enabled = false;
Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}",
e,
)))
}
};
}
if *max_window_bits < LZ77_MAX_WINDOW_SIZE {
header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits
))
} else {
header_value.push_str("; client_max_window_bits")
}
self.decline(response);
Ok(())
if *request_no_context_takeover {
header_value.push_str("; server_no_context_takeover")
}
fn on_response<T>(&mut self, response: &Response<T>) -> Result<(), Self::Error> {
let mut extension_name = false;
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_window_bits = false;
let mut client_max_window_bits = false;
request.headers_mut().append(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&header_value).unwrap(),
);
for header in response.headers().get_all(SEC_WEBSOCKET_EXTENSIONS).iter() {
match header.to_str() {
Ok(header) => {
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => {
if extension_name {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: permessage-deflate"
)));
} else {
self.enabled = true;
extension_name = true;
request
}
///
pub fn on_receive_request<T>(
request: &Request<T>,
response: &mut Response<T>,
config: &mut DeflateConfig,
) -> Result<(), DeflateExtensionError> {
for header in request.headers().get_all(SEC_WEBSOCKET_EXTENSIONS) {
return match header.to_str() {
Ok(header) => {
let mut response_str = String::with_capacity(header.len());
let mut server_takeover = false;
let mut client_takeover = false;
let mut server_max_bits = false;
let mut client_max_bits = false;
for param in header.split(';') {
match param.trim().to_lowercase().as_str() {
"permessage-deflate" => response_str.push_str("permessage-deflate"),
"server_no_context_takeover" => {
if server_takeover {
decline(response);
} else {
server_takeover = true;
if config.accept_no_context_takeover() {
config.compress_reset = true;
response_str.push_str("; server_no_context_takeover");
}
}
"server_no_context_takeover" => {
if server_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_no_context_takeover"
)));
} else {
server_takeover = true;
self.config.decompress_reset = true;
}
}
"client_no_context_takeover" => {
if client_takeover {
decline(response);
} else {
client_takeover = true;
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
"client_no_context_takeover" => {
if client_takeover {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_no_context_takeover"
)));
} else {
client_takeover = true;
if self.config.accept_no_context_takeover() {
self.config.compress_reset = true;
} else {
return Err(DeflateExtensionError::NegotiationError(
format!("The client requires context takeover."),
));
}
param if param.starts_with("server_max_window_bits") => {
if server_max_bits {
decline(response);
} else {
server_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.max_window_bits,
) {
Ok(Some(bits)) => {
config.max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(param)
}
}
}
param if param.starts_with("server_max_window_bits") => {
if server_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: server_max_window_bits"
)));
} else {
server_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.inflator = Inflator::new_with_window_bits(bits);
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"server_max_window_bits parameter error: {}",
e
),
))
}
Ok(None) => {}
Err(_) => {
decline(response);
}
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_window_bits {
return Err(DeflateExtensionError::NegotiationError(format!(
"Duplicate extension parameter: client_max_window_bits"
)));
} else {
client_max_window_bits = true;
match self.parse_window_parameter(param.split("=").skip(1)) {
Ok(Some(bits)) => {
self.deflator = Deflator::new_with_window_bits(
self.config.compression_level,
bits,
);
}
Ok(None) => {}
Err(e) => {
return Err(DeflateExtensionError::NegotiationError(
format!(
"client_max_window_bits parameter error: {}",
e
),
))
}
}
param if param.starts_with("client_max_window_bits") => {
if client_max_bits {
decline(response);
} else {
client_max_bits = true;
match parse_window_parameter(
param.split('=').skip(1),
config.max_window_bits,
) {
Ok(Some(bits)) => {
config.max_window_bits = bits;
response_str.push_str("; ");
response_str.push_str(param);
continue;
}
Ok(None) => {}
Err(_) => {
decline(response);
}
}
response_str.push_str("; ");
response_str.push_str(&format!(
"client_max_window_bits={}",
config.max_window_bits()
))
}
p => {
return Err(DeflateExtensionError::NegotiationError(format!(
"Unknown permessage-deflate parameter: {}",
p
)));
}
}
_ => {
decline(response);
}
}
}
Err(e) => {
self.enabled = false;
return Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse extension parameter: {}",
e
)));
if !response_str.contains("client_no_context_takeover")
&& config.request_no_context_takeover()
{
config.decompress_reset = true;
response_str.push_str("; client_no_context_takeover");
}
if !response_str.contains("server_max_window_bits") {
response_str.push_str("; ");
response_str.push_str(&format!(
"server_max_window_bits={}",
config.max_window_bits()
))
}
if !response_str.contains("client_max_window_bits")
&& config.max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
continue;
}
response.headers_mut().insert(
SEC_WEBSOCKET_EXTENSIONS,
HeaderValue::from_str(&response_str)?,
);
Ok(())
}
}
Err(e) => Err(DeflateExtensionError::NegotiationError(format!(
"Failed to parse request header: {}",
e,
))),
};
}
decline(response);
Ok(())
}
impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
Ok(())
impl From<InvalidHeaderValue> for DeflateExtensionError {
fn from(e: InvalidHeaderValue) -> Self {
DeflateExtensionError::NegotiationError(e.to_string())
}
}
impl Default for DeflateExt {
fn default() -> Self {
DeflateExt::new(Default::default())
}
}
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, Self::Error> {
impl WebSocketExtension for DeflateExt {
fn on_send_frame(&mut self, mut frame: Frame) -> Result<Frame, crate::Error> {
if self.enabled {
if let OpCode::Data(_) = frame.header().opcode {
let mut compressed = Vec::with_capacity(frame.payload().len());
@ -640,7 +642,7 @@ impl WebSocketExtension for DeflateExt {
Ok(frame)
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
let r = if self.enabled && (!self.fragment_buffer.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragment_buffer
@ -696,20 +698,20 @@ impl WebSocketExtension for DeflateExt {
match r {
Ok(msg) => Ok(msg),
Err(e) => Err(DeflateExtensionError::DeflateError(e.to_string())),
Err(e) => Err(crate::Error::ExtensionError(e.to_string().into())),
}
}
}
impl From<DecompressError> for DeflateExtensionError {
impl From<DecompressError> for crate::Error {
fn from(e: DecompressError) -> Self {
DeflateExtensionError::InflateError(e.to_string())
crate::Error::ExtensionError(e.to_string().into())
}
}
impl From<CompressError> for DeflateExtensionError {
impl From<CompressError> for crate::Error {
fn from(e: CompressError) -> Self {
DeflateExtensionError::DeflateError(e.to_string())
crate::Error::ExtensionError(e.to_string().into())
}
}
@ -719,13 +721,7 @@ struct Deflator {
}
impl Deflator {
fn new(compresion: Compression) -> Deflator {
Deflator {
compress: Compress::new(compresion, false),
}
}
fn new_with_window_bits(compression: Compression, mut window_size: u8) -> Deflator {
fn new(compression: Compression, mut window_size: u8) -> Deflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
@ -790,13 +786,7 @@ struct Inflator {
}
impl Inflator {
fn new() -> Inflator {
Inflator {
decompress: Decompress::new(false),
}
}
fn new_with_window_bits(mut window_size: u8) -> Inflator {
fn new(mut window_size: u8) -> Inflator {
// https://github.com/madler/zlib/blob/cacf7f1d4e3d44d871b605da3b647f07d718623f/deflate.c#L303
if window_size == 8 {
window_size = 9;
@ -888,11 +878,10 @@ impl FragmentBuffer {
*fragments_len += frame.payload().len();
if *fragments_len > *max_len || frame.len() > *max_len - *fragments_len {
return Err(format!(
Err(format!(
"Message too big: {} + {} > {}",
fragments_len, fragments_len, max_len
)
.into());
))
} else {
fragments.push(frame);
Ok(())

@ -2,8 +2,15 @@
use http::{Request, Response};
#[cfg(feature = "deflate")]
use crate::extensions::deflate::{DeflateConfig, DeflateExt};
use crate::extensions::uncompressed::UncompressedExt;
use crate::protocol::frame::Frame;
use crate::protocol::WebSocketConfig;
use crate::Message;
use std::borrow::Cow;
use std::error::Error;
use std::fmt::{Display, Formatter};
/// A permessage-deflate WebSocket extension (RFC 7692).
#[cfg(feature = "deflate")]
@ -11,46 +18,155 @@ pub mod deflate;
/// An uncompressed message handler for a WebSocket.
pub mod uncompressed;
///
#[derive(Copy, Clone, Debug)]
pub enum WsCompression {
///
None(Option<usize>),
///
#[cfg(feature = "deflate")]
Deflate(DeflateConfig),
}
/// A trait for defining WebSocket extensions for both WebSocket clients and servers. Extensions
/// may be stacked by nesting them inside one another.
pub trait WebSocketExtension {
/// An error type that the extension produces.
type Error: Into<crate::Error>;
/// Called when a frame is about to be sent.
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, crate::Error> {
Ok(frame)
}
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
/// type `OpCode::Data`.
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error>;
}
/// Constructs a new WebSocket extension that will permit messages of the provided size.
fn new(max_message_size: Option<usize>) -> Self;
/// A WebSocket extension that is either `DeflateExt` or `UncompressedExt`.
#[derive(Debug)]
pub enum CompressionSwitcher {
///
#[cfg(feature = "deflate")]
Compressed(DeflateExt),
///
Uncompressed(UncompressedExt),
}
/// Returns whether or not the extension is enabled.
fn enabled(&self) -> bool {
false
impl CompressionSwitcher {
///
pub fn from_config(config: WsCompression) -> CompressionSwitcher {
match config {
WsCompression::None(size) => {
CompressionSwitcher::Uncompressed(UncompressedExt::new(size))
}
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => {
CompressionSwitcher::Compressed(DeflateExt::new(config))
}
}
}
}
/// For WebSocket clients, this will be called when a `Request` is being constructed.
fn on_make_request<T>(&mut self, request: Request<T>) -> Request<T> {
request
impl Default for CompressionSwitcher {
fn default() -> Self {
CompressionSwitcher::Uncompressed(UncompressedExt::default())
}
}
#[derive(Debug)]
///
pub struct CompressionError(String);
/// For WebSocket server, this will be called when a `Request` has been received.
fn on_receive_request<T>(
&mut self,
_request: &Request<T>,
_response: &mut Response<T>,
) -> Result<(), Self::Error> {
Ok(())
impl Error for CompressionError {}
impl From<CompressionError> for crate::Error {
fn from(e: CompressionError) -> Self {
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
/// For WebSocket clients, this will be called when a response from the server has been
/// received. If an error is produced, then subsequent calls to `rsv1()` should return `false`.
fn on_response<T>(&mut self, _response: &Response<T>) -> Result<(), Self::Error> {
Ok(())
impl Display for CompressionError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionError")
.field("error", &self.0)
.finish()
}
}
/// Called when a frame is about to be sent.
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, Self::Error> {
Ok(frame)
impl WebSocketExtension for CompressionSwitcher {
fn on_send_frame(&mut self, frame: Frame) -> Result<Frame, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_send_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_send_frame(frame),
}
}
/// Called when a frame has been received and unmasked. The frame provided frame will be of the
/// type `OpCode::Data`.
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error>;
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
match self {
CompressionSwitcher::Uncompressed(ext) => ext.on_receive_frame(frame),
#[cfg(feature = "deflate")]
CompressionSwitcher::Compressed(ext) => ext.on_receive_frame(frame),
}
}
}
///
pub fn build_compression_headers<T>(
request: Request<T>,
config: &mut Option<WebSocketConfig>,
) -> Request<T> {
match config {
Some(ref mut config) => match &config.compression {
WsCompression::None(_) => request,
#[cfg(feature = "deflate")]
WsCompression::Deflate(config) => deflate::on_request(request, config),
},
None => request,
}
}
///
pub fn verify_compression_resp_headers<T>(
_response: &Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
let result = deflate::on_response(_response, deflate_config)
.map_err(|e| CompressionError(e.to_string()));
match result {
Ok(true) => Ok(()),
Ok(false) => {
config.compression =
WsCompression::None(Some(deflate_config.max_message_size()));
Ok(())
}
Err(e) => Err(e),
}
}
},
None => Ok(()),
}
}
///
pub fn verify_compression_req_headers<T>(
_request: &Request<T>,
_response: &mut Response<T>,
config: &mut Option<WebSocketConfig>,
) -> Result<(), CompressionError> {
match config {
Some(ref mut config) => match &mut config.compression {
WsCompression::None(_) => Ok(()),
#[cfg(feature = "deflate")]
WsCompression::Deflate(ref mut deflate_config) => {
deflate::on_receive_request(_request, _response, deflate_config)
.map_err(|e| CompressionError(e.to_string()))
}
},
None => Ok(()),
}
}

@ -2,8 +2,8 @@ use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::{Error, Message};
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::{Error, Message};
/// An uncompressed message handler for a WebSocket.
#[derive(Debug)]
@ -16,7 +16,7 @@ impl Default for UncompressedExt {
fn default() -> Self {
UncompressedExt {
incomplete: None,
max_message_size: Some(MAX_MESSAGE_SIZE)
max_message_size: Some(MAX_MESSAGE_SIZE),
}
}
}
@ -33,20 +33,7 @@ impl UncompressedExt {
}
impl WebSocketExtension for UncompressedExt {
type Error = Error;
fn new(max_message_size: Option<usize>) -> Self {
UncompressedExt {
incomplete: None,
max_message_size,
}
}
fn enabled(&self) -> bool {
true
}
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, Self::Error> {
fn on_receive_frame(&mut self, frame: Frame) -> Result<Option<Message>, crate::Error> {
let fin = frame.header().is_final;
let hdr = frame.header();

@ -11,7 +11,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::extensions::{build_compression_headers, verify_compression_resp_headers};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request type.
@ -22,25 +22,21 @@ pub type Response = HttpResponse<()>;
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S, Extension>
where
Extension: WebSocketExtension,
{
pub struct ClientHandshake<S> {
verify_data: VerifyData,
config: Option<Option<WebSocketConfig<Extension>>>,
config: Option<Option<WebSocketConfig>>,
_marker: PhantomData<S>,
}
impl<Stream, Ext> ClientHandshake<Stream, Ext>
impl<Stream> ClientHandshake<Stream>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Initiate a client handshake.
pub fn start(
stream: Stream,
request: Request,
mut config: Option<WebSocketConfig<Ext>>,
mut config: Option<WebSocketConfig>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
@ -81,14 +77,13 @@ where
}
}
impl<Stream, Ext> HandshakeRole for ClientHandshake<Stream, Ext>
impl<Stream> HandshakeRole for ClientHandshake<Stream>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
type IncomingData = Response;
type InternalStream = Stream;
type FinalResult = (WebSocket<Stream, Ext>, Response);
type FinalResult = (WebSocket<Stream>, Response);
fn stage_finished(
&mut self,
@ -115,18 +110,12 @@ where
}
/// Generate client request.
fn generate_request<Ext>(
fn generate_request(
request: Request,
key: &str,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<Vec<u8>>
where
Ext: WebSocketExtension,
{
let request = match config {
Some(ref mut config) => config.encoder.on_make_request(request),
None => request,
};
config: &mut Option<WebSocketConfig>,
) -> Result<Vec<u8>> {
let request = build_compression_headers(request, config);
let mut req = Vec::new();
let uri = request.uri();
@ -183,14 +172,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response<Ext>(
pub fn verify_response(
&self,
response: &Response,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<()>
where
Ext: WebSocketExtension,
{
config: &mut Option<WebSocketConfig>,
) -> Result<()> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -246,11 +232,7 @@ impl VerifyData {
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if let Some(config) = config {
if let Err(e) = config.encoder.on_response(response) {
return Err(e.into());
}
}
verify_compression_resp_headers(response, config)?;
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -308,7 +290,6 @@ mod tests {
use super::super::machine::TryParse;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use crate::extensions::uncompressed::UncompressedExt;
#[test]
fn random_keys() {
@ -338,9 +319,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -359,9 +338,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -380,9 +357,7 @@ mod tests {
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
let request = generate_request(request, key, &mut Some(Default::default())).unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -12,7 +12,7 @@ use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::extensions::WebSocketExtension;
use crate::extensions::verify_compression_req_headers;
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Server request type.
@ -191,43 +191,35 @@ impl Callback for NoCallback {
/// Server handshake role.
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C, Ext>
where
Ext: WebSocketExtension,
{
pub struct ServerHandshake<S, C> {
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<C>,
/// WebSocket configuration.
config: Option<Option<WebSocketConfig<Ext>>>,
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl<S, C, Ext> ServerHandshake<S, C, Ext>
impl<S, C> ServerHandshake<S, C>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(
stream: S,
callback: C,
config: Option<WebSocketConfig<Ext>>,
) -> MidHandshake<Self> {
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake {
callback: Some(callback),
config: Some(config),
config,
error_code: None,
_marker: PhantomData,
},
@ -235,15 +227,14 @@ where
}
}
impl<S, C, Ext> HandshakeRole for ServerHandshake<S, C, Ext>
impl<S, C> HandshakeRole for ServerHandshake<S, C>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S, Ext>;
type FinalResult = WebSocket<S>;
fn stage_finished(
&mut self,
@ -260,12 +251,7 @@ where
}
let mut response = create_response(&request)?;
if let Some(ref mut config) = self.config.as_mut().unwrap() {
if let Err(e) = config.encoder.on_receive_request(&request, &mut response) {
return Err(e.into());
}
}
verify_compression_req_headers(&request, &mut response, &mut self.config)?;
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&request, response)
@ -305,11 +291,7 @@ where
return Err(Error::Http(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.take().unwrap(),
);
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
ProcessingResult::Done(websocket)
}
}

@ -16,8 +16,7 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::IncompleteMessage;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::extensions::{CompressionSwitcher, WebSocketExtension, WsCompression};
use crate::util::NonBlockingResult;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20;
@ -33,10 +32,7 @@ pub enum Role {
/// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig<E = UncompressedExt>
where
E: WebSocketExtension,
{
pub struct WebSocketConfig {
/// The size of the send queue. You can use it to turn on/off the backpressure features. `None`
/// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue.
@ -46,34 +42,16 @@ where
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
/// Per-message compression strategy.
pub encoder: E,
/// A per-message compression configuration.
pub compression: WsCompression,
}
impl<E> Default for WebSocketConfig<E>
where
E: WebSocketExtension,
{
impl Default for WebSocketConfig {
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: E::new(Some(MAX_MESSAGE_SIZE)),
}
}
}
impl<E> WebSocketConfig<E>
where
E: WebSocketExtension,
{
/// Creates a `WebSocketConfig` instance using the default configuration and the provided
/// encoder for new connections.
pub fn default_with_encoder(encoder: E) -> WebSocketConfig<E> {
WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder,
compression: WsCompression::None(Some(MAX_MESSAGE_SIZE)),
}
}
}
@ -83,30 +61,20 @@ where
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)]
pub struct WebSocket<Stream, Ext>
where
Ext: WebSocketExtension,
{
pub struct WebSocket<Stream> {
/// The underlying socket.
socket: Stream,
/// The context for managing a WebSocket.
context: WebSocketContext<Ext>,
context: WebSocketContext,
}
impl<Stream, Ext> WebSocket<Stream, Ext>
where
Ext: WebSocketExtension,
{
impl<Stream> WebSocket<Stream> {
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(
stream: Stream,
role: Role,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
@ -122,7 +90,7 @@ where
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<Ext>>,
config: Option<WebSocketConfig>,
) -> Self {
WebSocket {
socket: stream,
@ -141,12 +109,12 @@ where
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
self.context.set_config(set_func)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
pub fn get_config(&self) -> &WebSocketConfig {
self.context.get_config()
}
@ -166,10 +134,9 @@ where
}
}
impl<Stream, Ext> WebSocket<Stream, Ext>
impl<Stream> WebSocket<Stream>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Read a message from stream, if possible.
///
@ -253,10 +220,7 @@ where
/// A context for managing WebSocket stream.
#[derive(Debug)]
pub struct WebSocketContext<Ext = UncompressedExt>
where
Ext: WebSocketExtension,
{
pub struct WebSocketContext {
/// Server or client?
role: Role,
/// encoder/decoder of frame.
@ -270,16 +234,16 @@ where
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig<Ext>,
config: WebSocketConfig,
/// A per-message compression strategy.
decoder: CompressionSwitcher,
}
impl<Ext> WebSocketContext<Ext>
where
Ext: WebSocketExtension,
{
impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig<Ext>>) -> Self {
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
let config = config.unwrap_or_else(Default::default);
let decoder = CompressionSwitcher::from_config(config.compression);
WebSocketContext {
role,
@ -289,15 +253,12 @@ where
send_queue: VecDeque::new(),
pong: None,
config,
decoder,
}
}
/// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
@ -305,12 +266,12 @@ where
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
pub fn get_config(&self) -> &WebSocketConfig {
&self.config
}
@ -527,12 +488,8 @@ where
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
}
}
_ => match self.config.encoder.on_receive_frame(frame) {
Ok(r) => Ok(r),
Err(e) => Err(e.into()),
},
} // match opcode
_ => self.decoder.on_receive_frame(frame),
}
} else {
// Connection closed by peer
match replace(&mut self.state, WebSocketState::Terminated) {
@ -602,10 +559,7 @@ where
}
if frame.header().is_final {
frame = match self.config.encoder.on_send_frame(frame) {
Ok(frame) => frame,
Err(e) => return Err(e.into()),
};
frame = self.decoder.on_send_frame(frame)?;
}
trace!("Sending frame: {:?}", frame);
@ -682,7 +636,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WsCompression;
use std::io;
use std::io::Cursor;
@ -710,8 +664,7 @@ mod tests {
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(
@ -733,7 +686,7 @@ mod tests {
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: UncompressedExt::new(Some(10)),
compression: WsCompression::None(Some(10)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
@ -748,7 +701,7 @@ mod tests {
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: UncompressedExt::new(Some(2)),
compression: WsCompression::None(Some(2)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(

@ -7,8 +7,6 @@ use crate::handshake::HandshakeError;
use crate::protocol::{WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use std::io::{Read, Write};
/// Accept the given Stream as a WebSocket.
@ -20,13 +18,12 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<Stream, Ext>(
pub fn accept_with_config<Stream>(
stream: Stream,
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<Stream, Ext>, HandshakeError<ServerHandshake<Stream, NoCallback, Ext>>>
config: Option<WebSocketConfig>,
) -> Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, NoCallback>>>
where
Stream: Read + Write,
Ext: WebSocketExtension,
{
accept_hdr_with_config(stream, NoCallback, config)
}
@ -39,10 +36,7 @@ where
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<
WebSocket<S, UncompressedExt>,
HandshakeError<ServerHandshake<S, NoCallback, UncompressedExt>>,
> {
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_with_config(stream, None)
}
@ -54,15 +48,14 @@ pub fn accept<S: Read + Write>(
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_with_config<S, C, Ext>(
pub fn accept_hdr_with_config<S, C>(
stream: S,
callback: C,
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<S, Ext>, HandshakeError<ServerHandshake<S, C, Ext>>>
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
where
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
ServerHandshake::start(stream, callback, config).handshake()
}
@ -75,6 +68,6 @@ where
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S, UncompressedExt>, HandshakeError<ServerHandshake<S, C, UncompressedExt>>> {
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
accept_hdr_with_config(stream, callback, None)
}

@ -8,16 +8,15 @@ use std::time::Duration;
use native_tls::TlsStream;
use net2::TcpStreamExt;
use tungstenite::extensions::uncompressed::UncompressedExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock<Ext> = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>, Ext>;
type Sock = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where
CT: FnOnce(Sock<UncompressedExt>) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream, UncompressedExt>),
CT: FnOnce(Sock) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream>),
{
env_logger::try_init().ok();

Loading…
Cancel
Save