Resolves PR comments

pull/144/head
SirCipher 5 years ago
parent 2be3d31ab0
commit bd170a0de6
  1. 9
      Cargo.toml
  2. 2
      README.md
  3. 8
      examples/autobahn-client.rs
  4. 5
      fuzz/fuzz_targets/read_message_client.rs
  5. 0
      scripts/autobahn-client.sh
  6. 0
      scripts/autobahn-server.sh
  7. 25
      src/client.rs
  8. 2
      src/error.rs
  9. 72
      src/extensions/deflate.rs
  10. 2
      src/extensions/mod.rs
  11. 30
      src/extensions/uncompressed.rs
  12. 57
      src/handshake/client.rs
  13. 42
      src/handshake/server.rs
  14. 57
      src/protocol/mod.rs
  15. 31
      src/server.rs
  16. 8
      tests/connection_reset.rs

@ -16,7 +16,7 @@ edition = "2018"
default = ["tls"]
tls = ["native-tls"]
tls-vendored = ["native-tls", "native-tls/vendored"]
deflate = ["flate2"]
deflate = []
[dependencies]
base64 = "0.12.0"
@ -30,12 +30,7 @@ rand = "0.7.2"
sha-1 = "0.9"
url = "2.1.0"
utf-8 = "0.7.5"
[dependencies.flate2]
optional = true
default-features = false
version = "1.0"
features = ["zlib"]
flate2 = { version = "1.0", features = ["zlib"], default-features = false }
[dependencies.native-tls]
optional = true

@ -58,7 +58,7 @@ Features
Tungstenite provides a complete implementation of the WebSocket specification.
TLS is supported on all platforms using native-tls.
There is no support for permessage-deflate at the moment. It's planned.
Permessage-deflate.
Testing
-------

@ -3,17 +3,13 @@ use url::Url;
use tungstenite::client::connect_with_config;
use tungstenite::extensions::deflate::DeflateExt;
use tungstenite::extensions::uncompressed::PlainTextExt;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::{connect, Error, Message, Result, WebSocket};
use tungstenite::{connect, Error, Message, Result};
const AGENT: &str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _): (WebSocket<_, PlainTextExt>, _) = connect_with_config(
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
None,
)?;
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read_message()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())

@ -1,12 +1,11 @@
#![no_main]
#[macro_use]
extern crate libfuzzer_sys;
#[macro_use] extern crate libfuzzer_sys;
extern crate tungstenite;
use std::io;
use std::io::Cursor;
use tungstenite::protocol::Role;
use tungstenite::WebSocket;
use tungstenite::protocol::Role;
//use std::result::Result;
// FIXME: copypasted from tungstenite's protocol/mod.rs

@ -66,7 +66,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
@ -88,12 +88,13 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<Req: IntoClientRequest, E>(
pub fn connect_with_config<Req, Ext>(
request: Req,
config: Option<WebSocketConfig<E>>,
) -> Result<(WebSocket<AutoStream, E>, Response)>
config: Option<WebSocketConfig<Ext>>,
) -> Result<(WebSocket<AutoStream, Ext>, Response)>
where
E: WebSocketExtension,
Req: IntoClientRequest,
Ext: WebSocketExtension,
{
let request: Request = request.into_client_request()?;
let uri = request.uri();
@ -129,7 +130,7 @@ where
/// `connect` since it's the only function that uses native_tls.
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream, PlainTextExt>, Response)> {
) -> Result<(WebSocket<AutoStream, UncompressedExt>, Response)> {
connect_with_config(request, None)
}
@ -166,15 +167,15 @@ pub fn uri_mode(uri: &Uri) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<Stream, Req, E>(
pub fn client_with_config<Stream, Req, Ext>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig<E>>,
) -> StdResult<(WebSocket<Stream, E>, Response), HandshakeError<ClientHandshake<Stream, E>>>
config: Option<WebSocketConfig<Ext>>,
) -> StdResult<(WebSocket<Stream, Ext>, Response), HandshakeError<ClientHandshake<Stream, Ext>>>
where
Stream: Read + Write,
Req: IntoClientRequest,
E: WebSocketExtension,
Ext: WebSocketExtension,
{
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}
@ -188,8 +189,8 @@ pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<
(WebSocket<Stream, PlainTextExt>, Response),
HandshakeError<ClientHandshake<Stream, PlainTextExt>>,
(WebSocket<Stream, UncompressedExt>, Response),
HandshakeError<ClientHandshake<Stream, UncompressedExt>>,
>
where
Stream: Read + Write,

@ -68,7 +68,7 @@ pub enum Error {
/// HTTP format error.
HttpFormat(http::Error),
/// An error from a WebSocket extension.
ExtensionError(Box<dyn std::error::Error + Send + Sync>),
ExtensionError(Cow<'static, str>),
}
impl fmt::Display for Error {

@ -2,7 +2,7 @@
use std::fmt::{Display, Formatter};
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
@ -15,10 +15,20 @@ use flate2::{
};
use http::header::{InvalidHeaderValue, SEC_WEBSOCKET_EXTENSIONS};
use http::{HeaderValue, Request, Response};
use std::borrow::Cow;
use std::mem::replace;
use std::slice;
const EXT_NAME: &str = "permessage-deflate";
/// The WebSocket Extension Identifier as per the IANA registry.
const EXT_IDENT: &str = "permessage-deflate";
/// The minimum size of the LZ77 sliding window size.
const LZ77_MIN_WINDOW_SIZE: u8 = 9;
/// The maximum size of the LZ77 sliding window size. Absence of the `max_window_bits` parameter
/// indicates that the client can receive messages compressed using an LZ77 sliding window of up to
/// 32,768 bytes. RFC 7692 7.1.2.1.
const LZ77_MAX_WINDOW_SIZE: u8 = 15;
/// A permessage-deflate configuration.
#[derive(Clone, Copy, Debug)]
@ -33,9 +43,14 @@ pub struct DeflateConfig {
max_window_bits: u8,
/// Request that the server resets the LZ77 sliding window between messages - RFC 7692 7.1.1.1.
request_no_context_takeover: bool,
/// Whether to accept `no_context_takeover`.
accept_no_context_takeover: bool,
// Whether the compressor should be reset after usage.
compress_reset: bool,
// Whether the decompressor should be reset after usage.
decompress_reset: bool,
/// The active compression level. The integer here is typically on a scale of 0-9 where 0 means
/// "no compression" and 9 means "take as long as you'd like".
compression_level: Compression,
}
@ -91,7 +106,7 @@ impl DeflateConfig {
/// Sets the LZ77 sliding window size.
pub fn set_max_window_bits(&mut self, max_window_bits: u8) {
assert!((9u8..=15u8).contains(&max_window_bits));
assert!((LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits));
self.max_window_bits = max_window_bits;
}
@ -110,7 +125,7 @@ impl Default for DeflateConfig {
fn default() -> Self {
DeflateConfig {
max_message_size: Some(MAX_MESSAGE_SIZE),
max_window_bits: 15,
max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
compress_reset: false,
@ -135,7 +150,7 @@ impl Default for DeflateConfigBuilder {
fn default() -> Self {
DeflateConfigBuilder {
max_message_size: Some(MAX_MESSAGE_SIZE),
max_window_bits: 15,
max_window_bits: LZ77_MAX_WINDOW_SIZE,
request_no_context_takeover: false,
accept_no_context_takeover: true,
fragments_grow: true,
@ -154,7 +169,7 @@ impl DeflateConfigBuilder {
/// Sets the LZ77 sliding window size. Panics if the provided size is not in `9..=15`.
pub fn max_window_bits(mut self, max_window_bits: u8) -> DeflateConfigBuilder {
assert!(
(9u8..=15u8).contains(&max_window_bits),
(LZ77_MIN_WINDOW_SIZE..=LZ77_MAX_WINDOW_SIZE).contains(&max_window_bits),
"max window bits must be in range 9..=15"
);
self.max_window_bits = max_window_bits;
@ -207,26 +222,7 @@ pub struct DeflateExt {
/// The deflate compressor.
deflator: Deflator,
/// If this deflate extension is not used, messages will be forwarded to this extension.
uncompressed_extension: PlainTextExt,
}
impl Clone for DeflateExt {
fn clone(&self) -> Self {
DeflateExt {
enabled: self.enabled,
config: self.config,
fragments: vec![],
inflator: Inflator::new(),
deflator: Deflator::new(self.config.compression_level()),
uncompressed_extension: PlainTextExt::new(self.config.max_message_size()),
}
}
}
impl Default for DeflateExt {
fn default() -> Self {
DeflateExt::new(Default::default())
}
uncompressed_extension: UncompressedExt,
}
impl DeflateExt {
@ -238,7 +234,7 @@ impl DeflateExt {
fragments: vec![],
inflator: Inflator::new(),
deflator: Deflator::new(Compression::fast()),
uncompressed_extension: PlainTextExt::new(config.max_message_size()),
uncompressed_extension: UncompressedExt::new(config.max_message_size()),
}
}
@ -262,10 +258,10 @@ impl DeflateExt {
match window_bits_str.trim().parse() {
Ok(mut window_bits) => {
if window_bits == 8 {
window_bits = 9;
window_bits = LZ77_MIN_WINDOW_SIZE;
}
if window_bits >= 9 && window_bits <= 15 {
if window_bits >= LZ77_MIN_WINDOW_SIZE && window_bits <= LZ77_MAX_WINDOW_SIZE {
if window_bits != self.config.max_window_bits() {
Ok(Some(window_bits))
} else {
@ -284,7 +280,7 @@ impl DeflateExt {
fn decline<T>(&mut self, res: &mut Response<T>) {
self.enabled = false;
res.headers_mut().remove(EXT_NAME);
res.headers_mut().remove(EXT_IDENT);
}
}
@ -319,7 +315,7 @@ impl std::error::Error for DeflateExtensionError {}
impl From<DeflateExtensionError> for crate::Error {
fn from(e: DeflateExtensionError) -> Self {
crate::Error::ExtensionError(Box::new(e))
crate::Error::ExtensionError(Cow::from(e.to_string()))
}
}
@ -329,6 +325,12 @@ impl From<InvalidHeaderValue> for DeflateExtensionError {
}
}
impl Default for DeflateExt {
fn default() -> Self {
DeflateExt::new(Default::default())
}
}
impl WebSocketExtension for DeflateExt {
type Error = DeflateExtensionError;
@ -344,14 +346,14 @@ impl WebSocketExtension for DeflateExt {
}
fn on_make_request<T>(&mut self, mut request: Request<T>) -> Request<T> {
let mut header_value = String::from(EXT_NAME);
let mut header_value = String::from(EXT_IDENT);
let DeflateConfig {
max_window_bits,
request_no_context_takeover,
..
} = self.config;
if max_window_bits < 15 {
if max_window_bits < LZ77_MAX_WINDOW_SIZE {
header_value.push_str(&format!(
"; client_max_window_bits={}; server_max_window_bits={}",
max_window_bits, max_window_bits
@ -486,7 +488,7 @@ impl WebSocketExtension for DeflateExt {
}
if !response_str.contains("client_max_window_bits")
&& self.config.max_window_bits() < 15
&& self.config.max_window_bits() < LZ77_MAX_WINDOW_SIZE
{
continue;
}
@ -671,7 +673,7 @@ impl WebSocketExtension for DeflateExt {
if self.enabled && (!self.fragments.is_empty() || frame.header().rsv1) {
if !frame.header().is_final {
self.fragments.push(frame);
return Ok(None);
Ok(None)
} else {
let message = if let OpCode::Data(Data::Continue) = frame.header().opcode {
self.fragments.push(frame);

@ -13,7 +13,7 @@ pub mod uncompressed;
/// A trait for defining WebSocket extensions. Extensions may be stacked by nesting them inside
/// one another.
pub trait WebSocketExtension: Default + Clone {
pub trait WebSocketExtension {
/// An error type that the extension produces.
type Error: Into<crate::Error>;

@ -2,47 +2,31 @@ use crate::extensions::WebSocketExtension;
use crate::protocol::frame::coding::{Data, OpCode};
use crate::protocol::frame::Frame;
use crate::protocol::message::{IncompleteMessage, IncompleteMessageType};
use crate::protocol::MAX_MESSAGE_SIZE;
use crate::{Error, Message};
/// An uncompressed message handler for a WebSocket.
#[derive(Debug)]
pub struct PlainTextExt {
pub struct UncompressedExt {
incomplete: Option<IncompleteMessage>,
max_message_size: Option<usize>,
}
impl PlainTextExt {
/// Builds a new `PlainTextExt` that will permit a maximum message size of `max_message_size`
impl UncompressedExt {
/// Builds a new `UncompressedExt` that will permit a maximum message size of `max_message_size`
/// or will be unbounded if `None`.
pub fn new(max_message_size: Option<usize>) -> PlainTextExt {
PlainTextExt {
pub fn new(max_message_size: Option<usize>) -> UncompressedExt {
UncompressedExt {
incomplete: None,
max_message_size,
}
}
}
impl Clone for PlainTextExt {
fn clone(&self) -> Self {
Self::default()
}
}
impl Default for PlainTextExt {
fn default() -> Self {
PlainTextExt {
incomplete: None,
max_message_size: Some(MAX_MESSAGE_SIZE),
}
}
}
impl WebSocketExtension for PlainTextExt {
impl WebSocketExtension for UncompressedExt {
type Error = Error;
fn new(max_message_size: Option<usize>) -> Self {
PlainTextExt {
UncompressedExt {
incomplete: None,
max_message_size,
}

@ -22,24 +22,25 @@ pub type Response = HttpResponse<()>;
/// Client handshake role.
#[derive(Debug)]
pub struct ClientHandshake<S, E>
pub struct ClientHandshake<S, Extension>
where
E: WebSocketExtension,
Extension: WebSocketExtension,
{
verify_data: VerifyData,
config: Option<WebSocketConfig<E>>,
config: Option<Option<WebSocketConfig<Extension>>>,
_marker: PhantomData<S>,
}
impl<S: Read + Write, E> ClientHandshake<S, E>
impl<Stream, Ext> ClientHandshake<Stream, Ext>
where
E: WebSocketExtension,
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Initiate a client handshake.
pub fn start(
stream: S,
stream: Stream,
request: Request,
mut config: Option<WebSocketConfig<E>>,
mut config: Option<WebSocketConfig<Ext>>,
) -> Result<MidHandshake<Self>> {
if request.method() != http::Method::GET {
return Err(Error::Protocol(
@ -67,7 +68,7 @@ where
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData { accept_key },
config,
config: Some(config),
_marker: PhantomData,
}
};
@ -80,13 +81,14 @@ where
}
}
impl<S: Read + Write, E> HandshakeRole for ClientHandshake<S, E>
impl<Stream, Ext> HandshakeRole for ClientHandshake<Stream, Ext>
where
E: WebSocketExtension,
Stream: Read + Write,
Ext: WebSocketExtension,
{
type IncomingData = Response;
type InternalStream = S;
type FinalResult = (WebSocket<S, E>, Response);
type InternalStream = Stream;
type FinalResult = (WebSocket<Stream, Ext>, Response);
fn stage_finished(
&mut self,
@ -101,11 +103,11 @@ where
result,
tail,
} => {
self.verify_data
.verify_response(&result, &mut self.config)?;
let mut config = self.config.take().unwrap();
self.verify_data.verify_response(&result, &mut config)?;
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config.clone());
let websocket = WebSocket::from_partially_read(stream, tail, Role::Client, config);
ProcessingResult::Done((websocket, result))
}
})
@ -113,13 +115,13 @@ where
}
/// Generate client request.
fn generate_request<E>(
fn generate_request<Ext>(
request: Request,
key: &str,
config: &mut Option<WebSocketConfig<E>>,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<Vec<u8>>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
let request = match config {
Some(ref mut config) => config.encoder.on_make_request(request),
@ -181,13 +183,13 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response<E>(
pub fn verify_response<Ext>(
&self,
response: &Response,
config: &mut Option<WebSocketConfig<E>>,
config: &mut Option<WebSocketConfig<Ext>>,
) -> Result<()>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
@ -306,7 +308,7 @@ mod tests {
use super::super::machine::TryParse;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
#[test]
fn random_keys() {
@ -337,7 +339,8 @@ mod tests {
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<PlainTextExt>(request, key, &mut Some(Default::default())).unwrap();
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -357,7 +360,8 @@ mod tests {
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<PlainTextExt>(request, key, &mut Some(Default::default())).unwrap();
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}
@ -377,7 +381,8 @@ mod tests {
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
\r\n";
let request =
generate_request::<PlainTextExt>(request, key, &mut Some(Default::default())).unwrap();
generate_request::<UncompressedExt>(request, key, &mut Some(Default::default()))
.unwrap();
println!("Request: {}", String::from_utf8_lossy(&request));
assert_eq!(&request[..], &correct[..]);
}

@ -93,10 +93,9 @@ pub fn create_response(request: &Request) -> Result<Response> {
fn write_response<T>(w: &mut dyn io::Write, response: &HttpResponse<T>) -> Result<()> {
writeln!(
w,
"{version:?} {status} {reason}\r",
"{version:?} {status}\r",
version = response.version(),
status = response.status(),
reason = response.status().canonical_reason().unwrap_or(""),
status = response.status()
)?;
for (k, v) in response.headers() {
@ -192,37 +191,43 @@ impl Callback for NoCallback {
/// Server handshake role.
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C, E>
pub struct ServerHandshake<S, C, Ext>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
/// Callback which is called whenever the server read the request from the client and is ready
/// to reply to it. The callback returns an optional headers which will be added to the reply
/// which the server sends to the user.
callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig<E>>,
config: Option<Option<WebSocketConfig<Ext>>>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type.
_marker: PhantomData<S>,
}
impl<S: Read + Write, C: Callback, E> ServerHandshake<S, C, E>
impl<S, C, Ext> ServerHandshake<S, C, Ext>
where
E: WebSocketExtension,
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
/// Start server handshake. `callback` specifies a custom callback which the user can pass to
/// the handshake, this callback will be called when the a websocket client connnects to the
/// server, you can specify the callback if you want to add additional header to the client
/// upon join based on the incoming headers.
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig<E>>) -> MidHandshake<Self> {
pub fn start(
stream: S,
callback: C,
config: Option<WebSocketConfig<Ext>>,
) -> MidHandshake<Self> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake {
callback: Some(callback),
config,
config: Some(config),
error_code: None,
_marker: PhantomData,
},
@ -230,13 +235,15 @@ where
}
}
impl<S: Read + Write, C: Callback, E> HandshakeRole for ServerHandshake<S, C, E>
impl<S, C, Ext> HandshakeRole for ServerHandshake<S, C, Ext>
where
E: WebSocketExtension,
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S, E>;
type FinalResult = WebSocket<S, Ext>;
fn stage_finished(
&mut self,
@ -254,7 +261,7 @@ where
let mut response = create_response(&request)?;
if let Some(ref mut config) = self.config {
if let Some(ref mut config) = self.config.as_mut().unwrap() {
if let Err(e) = config.encoder.on_receive_request(&request, &mut response) {
return Err(e.into());
}
@ -298,8 +305,11 @@ where
return Err(Error::Http(StatusCode::from_u16(err)?));
} else {
debug!("Server handshake done.");
let websocket =
WebSocket::from_raw_socket(stream, Role::Server, self.config.clone());
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.take().unwrap(),
);
ProcessingResult::Done(websocket)
}
}

@ -16,7 +16,7 @@ use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::message::IncompleteMessage;
use crate::error::{Error, Result};
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use crate::util::NonBlockingResult;
@ -33,7 +33,7 @@ pub enum Role {
/// The configuration for WebSocket connection.
#[derive(Debug, Copy, Clone)]
pub struct WebSocketConfig<E = PlainTextExt>
pub struct WebSocketConfig<E = UncompressedExt>
where
E: WebSocketExtension,
{
@ -83,26 +83,30 @@ where
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
/// It may be created by calling `connect`, `accept` or `client` functions.
#[derive(Debug)]
pub struct WebSocket<Stream, E>
pub struct WebSocket<Stream, Ext>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
/// The underlying socket.
socket: Stream,
/// The context for managing a WebSocket.
context: WebSocketContext<E>,
context: WebSocketContext<Ext>,
}
impl<Stream, E> WebSocket<Stream, E>
impl<Stream, Ext> WebSocket<Stream, Ext>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig<E>>) -> Self {
pub fn from_raw_socket(
stream: Stream,
role: Role,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::new(role, config),
@ -118,7 +122,7 @@ where
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<E>>,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocket {
socket: stream,
@ -137,12 +141,12 @@ where
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<E>)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
self.context.set_config(set_func)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig<E> {
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
self.context.get_config()
}
@ -162,9 +166,10 @@ where
}
}
impl<Stream: Read + Write, E> WebSocket<Stream, E>
impl<Stream, Ext> WebSocket<Stream, Ext>
where
E: WebSocketExtension,
Stream: Read + Write,
Ext: WebSocketExtension,
{
/// Read a message from stream, if possible.
///
@ -248,9 +253,9 @@ where
/// A context for managing WebSocket stream.
#[derive(Debug)]
pub struct WebSocketContext<E = PlainTextExt>
pub struct WebSocketContext<Ext = UncompressedExt>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
/// Server or client?
role: Role,
@ -265,15 +270,15 @@ where
/// Send: an OOB pong message.
pong: Option<Frame>,
/// The configuration for the websocket session.
config: WebSocketConfig<E>,
config: WebSocketConfig<Ext>,
}
impl<E> WebSocketContext<E>
impl<Ext> WebSocketContext<Ext>
where
E: WebSocketExtension,
Ext: WebSocketExtension,
{
/// Create a WebSocket context that manages a post-handshake stream.
pub fn new(role: Role, config: Option<WebSocketConfig<E>>) -> Self {
pub fn new(role: Role, config: Option<WebSocketConfig<Ext>>) -> Self {
let config = config.unwrap_or_else(Default::default);
WebSocketContext {
@ -291,7 +296,7 @@ where
pub fn from_partially_read(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig<E>>,
config: Option<WebSocketConfig<Ext>>,
) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
@ -300,12 +305,12 @@ where
}
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<E>)) {
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig<Ext>)) {
set_func(&mut self.config)
}
/// Read the configuration.
pub fn get_config(&self) -> &WebSocketConfig<E> {
pub fn get_config(&self) -> &WebSocketConfig<Ext> {
&self.config
}
@ -672,7 +677,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
use std::io;
use std::io::Cursor;
@ -700,7 +705,7 @@ mod tests {
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket: WebSocket<_, PlainTextExt> =
let mut socket: WebSocket<_, UncompressedExt> =
WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
@ -723,7 +728,7 @@ mod tests {
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: PlainTextExt::new(Some(10)),
encoder: UncompressedExt::new(Some(10)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(
@ -738,7 +743,7 @@ mod tests {
let limit = WebSocketConfig {
max_send_queue: None,
max_frame_size: Some(16 << 20),
encoder: PlainTextExt::new(Some(2)),
encoder: UncompressedExt::new(Some(2)),
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(

@ -7,7 +7,7 @@ use crate::handshake::HandshakeError;
use crate::protocol::{WebSocket, WebSocketConfig};
use crate::extensions::uncompressed::PlainTextExt;
use crate::extensions::uncompressed::UncompressedExt;
use crate::extensions::WebSocketExtension;
use std::io::{Read, Write};
@ -20,12 +20,13 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write, E>(
stream: S,
config: Option<WebSocketConfig<E>>,
) -> Result<WebSocket<S, E>, HandshakeError<ServerHandshake<S, NoCallback, E>>>
pub fn accept_with_config<Stream, Ext>(
stream: Stream,
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<Stream, Ext>, HandshakeError<ServerHandshake<Stream, NoCallback, Ext>>>
where
E: WebSocketExtension,
Stream: Read + Write,
Ext: WebSocketExtension,
{
accept_hdr_with_config(stream, NoCallback, config)
}
@ -38,8 +39,10 @@ where
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<WebSocket<S, PlainTextExt>, HandshakeError<ServerHandshake<S, NoCallback, PlainTextExt>>>
{
) -> Result<
WebSocket<S, UncompressedExt>,
HandshakeError<ServerHandshake<S, NoCallback, UncompressedExt>>,
> {
accept_with_config(stream, None)
}
@ -51,13 +54,15 @@ pub fn accept<S: Read + Write>(
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr_with_config<S: Read + Write, C: Callback, E>(
pub fn accept_hdr_with_config<S, C, Ext>(
stream: S,
callback: C,
config: Option<WebSocketConfig<E>>,
) -> Result<WebSocket<S, E>, HandshakeError<ServerHandshake<S, C, E>>>
config: Option<WebSocketConfig<Ext>>,
) -> Result<WebSocket<S, Ext>, HandshakeError<ServerHandshake<S, C, Ext>>>
where
E: WebSocketExtension,
S: Read + Write,
C: Callback,
Ext: WebSocketExtension,
{
ServerHandshake::start(stream, callback, config).handshake()
}
@ -70,6 +75,6 @@ where
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S, PlainTextExt>, HandshakeError<ServerHandshake<S, C, PlainTextExt>>> {
) -> Result<WebSocket<S, UncompressedExt>, HandshakeError<ServerHandshake<S, C, UncompressedExt>>> {
accept_hdr_with_config(stream, callback, None)
}

@ -8,16 +8,16 @@ use std::time::Duration;
use native_tls::TlsStream;
use net2::TcpStreamExt;
use tungstenite::extensions::uncompressed::PlainTextExt;
use tungstenite::extensions::uncompressed::UncompressedExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use url::Url;
type Sock<E> = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>, E>;
type Sock<Ext> = WebSocket<Stream<TcpStream, TlsStream<TcpStream>>, Ext>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where
CT: FnOnce(Sock<PlainTextExt>) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream, PlainTextExt>),
CT: FnOnce(Sock<UncompressedExt>) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream, UncompressedExt>),
{
env_logger::try_init().ok();

Loading…
Cancel
Save