Merge pull request #43 from snapview/devel

Limiting and config improvements
pull/48/head
Alexey Galakhov 6 years ago committed by GitHub
commit 59f8d9c402
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      src/handshake/client.rs
  2. 2
      src/handshake/headers.rs
  3. 4
      src/handshake/machine.rs
  4. 2
      src/handshake/mod.rs
  5. 597
      src/protocol/frame/frame.rs
  6. 14
      src/protocol/frame/mask.rs
  7. 93
      src/protocol/frame/mod.rs
  8. 31
      src/protocol/message.rs
  9. 97
      src/protocol/mod.rs

@ -52,7 +52,7 @@ impl<'t> Request<'t> {
/// Adds a custom header to the request. /// Adds a custom header to the request.
pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) {
let mut headers = self.extra_headers.take().unwrap_or(vec![]); let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new);
headers.push((name, value)); headers.push((name, value));
self.extra_headers = Some(headers); self.extra_headers = Some(headers);
} }
@ -113,7 +113,7 @@ impl<S: Read + Write> ClientHandshake<S> {
}; };
trace!("Client handshake initiated."); trace!("Client handshake initiated.");
MidHandshake { role: client, machine: machine } MidHandshake { role: client, machine }
} }
} }

@ -28,7 +28,7 @@ impl Headers {
/// Iterate over all headers with the given name. /// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter { HeadersIter {
name: name, name,
iter: self.data.iter() iter: self.data.iter()
} }
} }

@ -16,14 +16,14 @@ impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer. /// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self { pub fn start_read(stream: Stream) -> Self {
HandshakeMachine { HandshakeMachine {
stream: stream, stream,
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)), state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
} }
} }
/// Start writing data to the peer. /// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine { HandshakeMachine {
stream: stream, stream,
state: HandshakeState::Writing(Cursor::new(data.into())), state: HandshakeState::Writing(Cursor::new(data.into())),
} }
} }

@ -110,7 +110,7 @@ pub enum ProcessingResult<Stream, FinalResult> {
fn convert_key(input: &[u8]) -> Result<String, Error> { fn convert_key(input: &[u8]) -> Result<String, Error> {
// ... field is constructed by concatenating /key/ ... // ... field is constructed by concatenating /key/ ...
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = Sha1::default(); let mut sha1 = Sha1::default();
sha1.input(input); sha1.input(input);
sha1.input(WS_GUID); sha1.input(WS_GUID);

@ -35,159 +35,242 @@ impl<'t> fmt::Display for CloseFrame<'t> {
} }
} }
/// A struct representing a WebSocket frame. /// A struct representing a WebSocket frame header.
#[allow(missing_copy_implementations)]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Frame { pub struct FrameHeader {
finished: bool, /// Indicates that the frame is the last one of a possibly fragmented message.
rsv1: bool, pub is_final: bool,
rsv2: bool, /// Reserved for protocol extensions.
rsv3: bool, pub rsv1: bool,
opcode: OpCode, /// Reserved for protocol extensions.
pub rsv2: bool,
/// Reserved for protocol extensions.
pub rsv3: bool,
/// WebSocket protocol opcode.
pub opcode: OpCode,
/// A frame mask, if any.
mask: Option<[u8; 4]>, mask: Option<[u8; 4]>,
payload: Vec<u8>,
} }
impl Frame { impl Default for FrameHeader {
fn default() -> Self {
FrameHeader {
is_final: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: OpCode::Control(Control::Close),
mask: None,
}
}
}
/// Get the length of the frame. impl FrameHeader {
/// This is the length of the header + the length of the payload. /// Parse a header from an input stream.
#[inline] /// Returns `None` if insufficient data and does not consume anything in this case.
pub fn len(&self) -> usize { /// Payload size is returned along with the header.
let mut header_length = 2; pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
let payload_len = self.payload().len(); let initial = cursor.position();
if payload_len > 125 { match Self::parse_internal(cursor) {
if payload_len <= u16::max_value() as usize { ret @ Ok(None) => {
header_length += 2; cursor.set_position(initial);
} else { ret
header_length += 8;
} }
ret => ret
} }
}
/// Get the size of the header formatted with given payload length.
pub fn len(&self, length: u64) -> usize {
2
+ LengthFormat::for_length(length).extra_bytes()
+ if self.mask.is_some() { 4 } else { 0 }
}
if self.is_masked() { /// Format a header for given payload size.
header_length += 4; pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
let code: u8 = self.opcode.into();
let one = {
code
| if self.is_final { 0x80 } else { 0 }
| if self.rsv1 { 0x40 } else { 0 }
| if self.rsv2 { 0x20 } else { 0 }
| if self.rsv3 { 0x10 } else { 0 }
};
let lenfmt = LengthFormat::for_length(length);
let two = {
lenfmt.length_byte()
| if self.mask.is_some() { 0x80 } else { 0 }
};
output.write_all(&[one, two])?;
match lenfmt {
LengthFormat::U8(_) => (),
LengthFormat::U16 => output.write_u16::<NetworkEndian>(length as u16)?,
LengthFormat::U64 => output.write_u64::<NetworkEndian>(length)?,
} }
header_length + payload_len if let Some(ref mask) = self.mask {
} output.write_all(mask)?
}
/// Test whether the frame is a final frame. Ok(())
#[inline]
pub fn is_final(&self) -> bool {
self.finished
} }
/// Test whether the first reserved bit is set. /// Generate a random frame mask and store this in the header.
#[inline] ///
pub fn has_rsv1(&self) -> bool { /// Of course this does not change frame contents. It just generates a mask.
self.rsv1 pub(crate) fn set_random_mask(&mut self) {
self.mask = Some(generate_mask())
} }
}
/// Test whether the second reserved bit is set. impl FrameHeader {
#[inline] /// Internal parse engine.
pub fn has_rsv2(&self) -> bool { /// Returns `None` if insufficient data.
self.rsv2 /// Payload size is returned along with the header.
} fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
let (first, second) = {
let mut head = [0u8; 2];
if cursor.read(&mut head)? != 2 {
return Ok(None)
}
trace!("Parsed headers {:?}", head);
(head[0], head[1])
};
/// Test whether the third reserved bit is set. trace!("First: {:b}", first);
#[inline] trace!("Second: {:b}", second);
pub fn has_rsv3(&self) -> bool {
self.rsv3
}
/// Get the OpCode of the frame. let is_final = first & 0x80 != 0;
#[inline]
pub fn opcode(&self) -> OpCode {
self.opcode
}
/// Get a reference to the frame's payload. let rsv1 = first & 0x40 != 0;
#[inline] let rsv2 = first & 0x20 != 0;
pub fn payload(&self) -> &Vec<u8> { let rsv3 = first & 0x10 != 0;
&self.payload
}
// Test whether the frame is masked. let opcode = OpCode::from(first & 0x0F);
#[doc(hidden)] trace!("Opcode: {:?}", opcode);
#[inline]
pub fn is_masked(&self) -> bool {
self.mask.is_some()
}
// Get an optional reference to the frame's mask. let masked = second & 0x80 != 0;
#[doc(hidden)] trace!("Masked: {:?}", masked);
#[allow(dead_code)]
#[inline]
pub fn mask(&self) -> Option<&[u8; 4]> {
self.mask.as_ref()
}
/// Make this frame a final frame. let length = {
#[allow(dead_code)] let length_byte = second & 0x7F;
#[inline] let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
pub fn set_final(&mut self, is_final: bool) -> &mut Frame { if length_length > 0 {
self.finished = is_final; match cursor.read_uint::<NetworkEndian>(length_length) {
self Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
return Ok(None);
}
Err(err) => {
return Err(err.into());
}
Ok(read) => read
}
} else {
length_byte as u64
}
};
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if cursor.read(&mut mask_bytes)? != 4 {
return Ok(None)
} else {
Some(mask_bytes)
}
} else {
None
};
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
}
_ => ()
}
let hdr = FrameHeader {
is_final,
rsv1,
rsv2,
rsv3,
opcode,
mask,
};
Ok(Some((hdr, length)))
} }
}
/// A struct representing a WebSocket frame.
#[derive(Debug, Clone)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
}
/// Set the first reserved bit. impl Frame {
/// Get the length of the frame.
/// This is the length of the header + the length of the payload.
#[inline] #[inline]
pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { pub fn len(&self) -> usize {
self.rsv1 = has_rsv1; let length = self.payload.len();
self self.header.len(length as u64) + length
} }
/// Set the second reserved bit. /// Get a reference to the frame's header.
#[inline] #[inline]
pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { pub fn header(&self) -> &FrameHeader {
self.rsv2 = has_rsv2; &self.header
self
} }
/// Set the third reserved bit. /// Get a mutable reference to the frame's header.
#[inline] #[inline]
pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { pub fn header_mut(&mut self) -> &mut FrameHeader {
self.rsv3 = has_rsv3; &mut self.header
self
} }
/// Set the OpCode. /// Get a reference to the frame's payload.
#[allow(dead_code)]
#[inline] #[inline]
pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { pub fn payload(&self) -> &Vec<u8> {
self.opcode = opcode; &self.payload
self
} }
/// Edit the frame's payload. /// Get a mutable reference to the frame's payload.
#[allow(dead_code)]
#[inline] #[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> { pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload &mut self.payload
} }
// Generate a new mask for this frame. /// Test whether the frame is masked.
//
// This method simply generates and stores the mask. It does not change the payload data.
// Instead, the payload data will be masked with the generated mask when the frame is sent
// to the other endpoint.
#[doc(hidden)]
#[inline] #[inline]
pub fn set_mask(&mut self) -> &mut Frame { pub(crate) fn is_masked(&self) -> bool {
self.mask = Some(generate_mask()); self.header.mask.is_some()
self
} }
// This method unmasks the payload and should only be called on frames that are actually /// Generate a random mask for the frame.
// masked. In other words, those frames that have just been received from a client endpoint. ///
#[doc(hidden)] /// This just generates a mask, payload is not changed. The actual masking is performed
/// either on `format()` or on `apply_mask()` call.
#[inline] #[inline]
pub fn remove_mask(&mut self) { pub(crate) fn set_random_mask(&mut self) {
self.mask.and_then(|mask| { self.header.set_random_mask()
Some(apply_mask(&mut self.payload, &mask)) }
});
self.mask = None; /// This method unmasks the payload and should only be called on frames that are actually
/// masked. In other words, those frames that have just been received from a client endpoint.
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask)
}
} }
/// Consume the frame into its payload as binary. /// Consume the frame into its payload as binary.
@ -204,7 +287,7 @@ impl Frame {
/// Consume the frame into a closing frame. /// Consume the frame into a closing frame.
#[inline] #[inline]
pub fn into_close(self) -> Result<Option<CloseFrame<'static>>> { pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() { match self.payload.len() {
0 => Ok(None), 0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())), 1 => Err(Error::Protocol("Invalid close sequence".into())),
@ -213,24 +296,26 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code: code, reason: text.into() })) Ok(Some(CloseFrame { code, reason: text.into() }))
} }
} }
} }
/// Create a new data frame. /// Create a new data frame.
#[inline] #[inline]
pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame { pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(match code { debug_assert!(match opcode {
OpCode::Data(_) => true, OpCode::Data(_) => true,
_ => false, _ => false,
}, "Invalid opcode for data frame."); }, "Invalid opcode for data frame.");
Frame { Frame {
finished: finished, header: FrameHeader {
opcode: code, is_final,
opcode,
.. FrameHeader::default()
},
payload: data, payload: data,
.. Frame::default()
} }
} }
@ -238,9 +323,11 @@ impl Frame {
#[inline] #[inline]
pub fn pong(data: Vec<u8>) -> Frame { pub fn pong(data: Vec<u8>) -> Frame {
Frame { Frame {
opcode: OpCode::Control(Control::Pong), header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
.. FrameHeader::default()
},
payload: data, payload: data,
.. Frame::default()
} }
} }
@ -248,9 +335,11 @@ impl Frame {
#[inline] #[inline]
pub fn ping(data: Vec<u8>) -> Frame { pub fn ping(data: Vec<u8>) -> Frame {
Frame { Frame {
opcode: OpCode::Control(Control::Ping), header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
.. FrameHeader::default()
},
payload: data, payload: data,
.. Frame::default()
} }
} }
@ -267,199 +356,28 @@ impl Frame {
}; };
Frame { Frame {
payload: payload, header: FrameHeader::default(),
.. Frame::default() payload,
} }
} }
/// Parse the input stream into a frame. /// Create a frame from given header and data.
pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> { pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
let size = cursor.get_ref().len() as u64 - cursor.position(); Frame {
let initial = cursor.position(); header,
trace!("Position in buffer {}", initial); payload,
let mut head = [0u8; 2];
if try!(cursor.read(&mut head)) != 2 {
cursor.set_position(initial);
return Ok(None)
}
trace!("Parsed headers {:?}", head);
let first = head[0];
let second = head[1];
trace!("First: {:b}", first);
trace!("Second: {:b}", second);
let finished = first & 0x80 != 0;
let rsv1 = first & 0x40 != 0;
let rsv2 = first & 0x20 != 0;
let rsv3 = first & 0x10 != 0;
let opcode = OpCode::from(first & 0x0F);
trace!("Opcode: {:?}", opcode);
let masked = second & 0x80 != 0;
trace!("Masked: {:?}", masked);
let mut header_length = 2;
let mut length = (second & 0x7F) as u64;
if let Some(length_nbytes) = match length {
126 => Some(2),
127 => Some(8),
_ => None,
} {
match cursor.read_uint::<NetworkEndian>(length_nbytes) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
cursor.set_position(initial);
return Ok(None);
}
Err(err) => {
return Err(Error::from(err));
}
Ok(read) => {
length = read;
}
};
header_length += length_nbytes as u64;
}
trace!("Payload length: {}", length);
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if try!(cursor.read(&mut mask_bytes)) != 4 {
cursor.set_position(initial);
return Ok(None)
} else {
header_length += 4;
Some(mask_bytes)
}
} else {
None
};
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
}
_ => ()
}
// Make sure `length` is not too big (fits into `usize`).
if length > usize::max_value() as u64 {
return Err(Error::Capacity(format!("Message length too big: {}", length).into()));
}
if size < header_length || size - header_length < length {
cursor.set_position(initial);
return Ok(None)
}
// Size is checked above, so it won't be truncated here.
let mut data = Vec::with_capacity(length as usize);
if length > 0 {
try!(cursor.take(length).read_to_end(&mut data));
} }
debug_assert_eq!(data.len() as u64, length);
let frame = Frame {
finished: finished,
rsv1: rsv1,
rsv2: rsv2,
rsv3: rsv3,
opcode: opcode,
mask: mask,
payload: data,
};
Ok(Some(frame))
} }
/// Write a frame out to a buffer /// Write a frame out to a buffer
pub fn format<W>(mut self, w: &mut W) -> Result<()> pub fn format(mut self, output: &mut impl Write) -> Result<()> {
where W: Write self.header.format(self.payload.len() as u64, output)?;
{ self.apply_mask();
let mut one = 0u8; output.write_all(self.payload())?;
let code: u8 = self.opcode.into();
if self.is_final() {
one |= 0x80;
}
if self.has_rsv1() {
one |= 0x40;
}
if self.has_rsv2() {
one |= 0x20;
}
if self.has_rsv3() {
one |= 0x10;
}
one |= code;
let mut two = 0u8;
if self.is_masked() {
two |= 0x80;
}
if self.payload.len() < 126 {
two |= self.payload.len() as u8;
let headers = [one, two];
try!(w.write(&headers));
} else if self.payload.len() <= 65535 {
two |= 126;
let mut length_bytes = [0u8; 2];
NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16);
let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers));
} else {
two |= 127;
let mut length_bytes = [0u8; 8];
NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64);
let headers = [
one,
two,
length_bytes[0],
length_bytes[1],
length_bytes[2],
length_bytes[3],
length_bytes[4],
length_bytes[5],
length_bytes[6],
length_bytes[7],
];
try!(w.write(&headers));
}
if self.is_masked() {
let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask);
try!(w.write(&mask));
}
try!(w.write(&self.payload));
Ok(()) Ok(())
} }
} }
impl Default for Frame {
fn default() -> Frame {
Frame {
finished: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: OpCode::Control(Control::Close),
mask: None,
payload: Vec::new(),
}
}
}
impl fmt::Display for Frame { impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, write!(f,
@ -472,11 +390,11 @@ length: {}
payload length: {} payload length: {}
payload: 0x{} payload: 0x{}
", ",
self.finished, self.header.is_final,
self.rsv1, self.header.rsv1,
self.rsv2, self.header.rsv2,
self.rsv3, self.header.rsv3,
self.opcode, self.header.opcode,
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
@ -484,6 +402,57 @@ payload: 0x{}
} }
} }
/// Handling of the length format.
enum LengthFormat {
U8(u8),
U16,
U64,
}
impl LengthFormat {
/// Get length format for a given data size.
#[inline]
fn for_length(length: u64) -> Self {
if length < 126 {
LengthFormat::U8(length as u8)
} else if length < 65536 {
LengthFormat::U16
} else {
LengthFormat::U64
}
}
/// Get the size of length encoding.
#[inline]
fn extra_bytes(&self) -> usize {
match *self {
LengthFormat::U8(_) => 0,
LengthFormat::U16 => 2,
LengthFormat::U64 => 8,
}
}
/// Encode the givem length.
#[inline]
fn length_byte(&self) -> u8 {
match *self {
LengthFormat::U8(b) => b,
LengthFormat::U16 => 126,
LengthFormat::U64 => 127,
}
}
/// Get length format for a given length byte.
#[inline]
fn for_byte(byte: u8) -> Self {
match byte & 0x7F {
126 => LengthFormat::U16,
127 => LengthFormat::U64,
b => LengthFormat::U8(b)
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -496,7 +465,11 @@ mod tests {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![ let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
]); ]);
let frame = Frame::parse(&mut raw).unwrap().unwrap(); let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]); assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
} }
@ -515,12 +488,4 @@ mod tests {
assert!(view.contains("payload:")); assert!(view.contains("payload:"));
} }
#[test]
fn parse_overflow() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let _ = Frame::parse(&mut raw); // should not crash
}
} }

@ -11,14 +11,14 @@ pub fn generate_mask() -> [u8; 4] {
/// Mask/unmask a frame. /// Mask/unmask a frame.
#[inline] #[inline]
pub fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) {
apply_mask_fast32(buf, mask) apply_mask_fast32(buf, mask)
} }
/// A safe unoptimized mask application. /// A safe unoptimized mask application.
#[inline] #[inline]
#[allow(dead_code)] #[allow(dead_code)]
fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) {
for (i, byte) in buf.iter_mut().enumerate() { for (i, byte) in buf.iter_mut().enumerate() {
*byte ^= mask[i & 3]; *byte ^= mask[i & 3];
} }
@ -27,7 +27,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
/// Faster version of `apply_mask()` which operates on 4-byte blocks. /// Faster version of `apply_mask()` which operates on 4-byte blocks.
#[inline] #[inline]
#[allow(dead_code)] #[allow(dead_code)]
fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32: u32 = unsafe { let mask_u32: u32 = unsafe {
read_unaligned(mask.as_ptr() as *const u32) read_unaligned(mask.as_ptr() as *const u32)
}; };
@ -101,10 +101,10 @@ mod tests {
// Check masking with proper alignment. // Check masking with proper alignment.
{ {
let mut masked = unmasked.clone(); let mut masked = unmasked.clone();
apply_mask_fallback(&mut masked, &mask); apply_mask_fallback(&mut masked, mask);
let mut masked_fast = unmasked.clone(); let mut masked_fast = unmasked.clone();
apply_mask_fast32(&mut masked_fast, &mask); apply_mask_fast32(&mut masked_fast, mask);
assert_eq!(masked, masked_fast); assert_eq!(masked, masked_fast);
} }
@ -112,10 +112,10 @@ mod tests {
// Check masking without alignment. // Check masking without alignment.
{ {
let mut masked = unmasked.clone(); let mut masked = unmasked.clone();
apply_mask_fallback(&mut masked[1..], &mask); apply_mask_fallback(&mut masked[1..], mask);
let mut masked_fast = unmasked.clone(); let mut masked_fast = unmasked.clone();
apply_mask_fast32(&mut masked_fast[1..], &mask); apply_mask_fast32(&mut masked_fast[1..], mask);
assert_eq!(masked, masked_fast); assert_eq!(masked, masked_fast);
} }

@ -5,7 +5,7 @@ pub mod coding;
mod frame; mod frame;
mod mask; mod mask;
pub use self::frame::Frame; pub use self::frame::{Frame, FrameHeader};
pub use self::frame::CloseFrame; pub use self::frame::CloseFrame;
use std::io::{Read, Write}; use std::io::{Read, Write};
@ -16,36 +16,47 @@ use error::{Error, Result};
/// A reader and writer for WebSocket frames. /// A reader and writer for WebSocket frames.
#[derive(Debug)] #[derive(Debug)]
pub struct FrameSocket<Stream> { pub struct FrameSocket<Stream> {
/// The underlying network stream.
stream: Stream, stream: Stream,
/// Buffer to read data from the stream.
in_buffer: InputBuffer, in_buffer: InputBuffer,
/// Buffer to send packets to the network.
out_buffer: Vec<u8>, out_buffer: Vec<u8>,
/// Header and remaining size of the incoming packet being processed.
header: Option<(FrameHeader, u64)>,
} }
impl<Stream> FrameSocket<Stream> { impl<Stream> FrameSocket<Stream> {
/// Create a new frame socket. /// Create a new frame socket.
pub fn new(stream: Stream) -> Self { pub fn new(stream: Stream) -> Self {
FrameSocket { FrameSocket {
stream: stream, stream,
in_buffer: InputBuffer::with_capacity(MIN_READ), in_buffer: InputBuffer::with_capacity(MIN_READ),
out_buffer: Vec::new(), out_buffer: Vec::new(),
header: None,
} }
} }
/// Create a new frame socket from partially read data. /// Create a new frame socket from partially read data.
pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self { pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
FrameSocket { FrameSocket {
stream: stream, stream,
in_buffer: InputBuffer::from_partially_read(part), in_buffer: InputBuffer::from_partially_read(part),
out_buffer: Vec::new(), out_buffer: Vec::new(),
header: None,
} }
} }
/// Extract a stream from the socket. /// Extract a stream from the socket.
pub fn into_inner(self) -> (Stream, Vec<u8>) { pub fn into_inner(self) -> (Stream, Vec<u8>) {
(self.stream, self.in_buffer.into_vec()) (self.stream, self.in_buffer.into_vec())
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream { pub fn get_ref(&self) -> &Stream {
&self.stream &self.stream
} }
/// Returns a mutable reference to the inner stream. /// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream { pub fn get_mut(&mut self) -> &mut Stream {
&mut self.stream &mut self.stream
@ -56,13 +67,41 @@ impl<Stream> FrameSocket<Stream>
where Stream: Read where Stream: Read
{ {
/// Read a frame from stream. /// Read a frame from stream.
pub fn read_frame(&mut self) -> Result<Option<Frame>> { pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
loop { let max_size = max_size.unwrap_or_else(usize::max_value);
if let Some(frame) = Frame::parse(&mut self.in_buffer.as_cursor_mut())? {
trace!("received frame {}", frame); let payload = loop {
return Ok(Some(frame)); {
let cursor = self.in_buffer.as_cursor_mut();
if self.header.is_none() {
self.header = FrameHeader::parse(cursor)?;
}
if let Some((_, ref length)) = self.header {
let length = *length;
// Enforce frame size limit early and make sure `length`
// is not too big (fits into `usize`).
if length > max_size as u64 {
return Err(Error::Capacity(
format!("Message length too big: {} > {}", length, max_size).into()
))
}
let input_size = cursor.get_ref().len() as u64 - cursor.position();
if length <= input_size {
// No truncation here since `length` is checked above
let mut payload = Vec::with_capacity(length as usize);
if length > 0 {
cursor.take(length).read_to_end(&mut payload)?;
}
break payload
}
}
} }
// No full frames in buffer.
// Not enough data in buffer.
let size = self.in_buffer.prepare_reserve(MIN_READ) let size = self.in_buffer.prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) .with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
@ -71,7 +110,13 @@ impl<Stream> FrameSocket<Stream>
trace!("no frame received"); trace!("no frame received");
return Ok(None) return Ok(None)
} }
} };
let (header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, payload);
trace!("received frame {}", frame);
Ok(Some(frame))
} }
} }
@ -118,11 +163,11 @@ mod tests {
]); ]);
let mut sock = FrameSocket::new(raw); let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]); vec![0x03, 0x02, 0x01]);
assert!(sock.read_frame().unwrap().is_none()); assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner(); let (_, rest) = sock.into_inner();
assert_eq!(rest, vec![0x99]); assert_eq!(rest, vec![0x99]);
@ -134,7 +179,7 @@ mod tests {
0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]); ]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(sock.read_frame().unwrap().unwrap().into_data(), assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
} }
@ -155,4 +200,24 @@ mod tests {
]); ]);
} }
#[test]
fn parse_overflow() {
let raw = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let mut sock = FrameSocket::new(raw);
let _ = sock.read_frame(None); // should not crash
}
#[test]
fn size_limit_hit() {
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(),
"Space limit exceeded: Message length too big: 7 > 5"
);
}
} }

@ -3,7 +3,7 @@ use std::fmt;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::str; use std::str;
use error::Result; use error::{Result, Error};
mod string_collect { mod string_collect {
@ -26,6 +26,11 @@ mod string_collect {
} }
} }
pub fn len(&self) -> usize {
self.data.len()
.saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
}
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> { pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
let mut input: &[u8] = tail.as_ref(); let mut input: &[u8] = tail.as_ref();
@ -105,8 +110,29 @@ impl IncompleteMessage {
} }
} }
} }
/// Get the current filled size of the buffer.
pub fn len(&self) -> usize {
match self.collector {
IncompleteMessageCollector::Text(ref t) => t.len(),
IncompleteMessageCollector::Binary(ref b) => b.len(),
}
}
/// Add more data to an existing message. /// Add more data to an existing message.
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> { pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
// Always have a max size. This ensures an error in case of concatenating two buffers
// of more than `usize::max_value()` bytes in total.
let max_size = size_limit.unwrap_or_else(usize::max_value);
let my_size = self.len();
let portion_size = tail.as_ref().len();
// Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity(
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into()
))
}
match self.collector { match self.collector {
IncompleteMessageCollector::Binary(ref mut v) => { IncompleteMessageCollector::Binary(ref mut v) => {
v.extend(tail.as_ref()); v.extend(tail.as_ref());
@ -117,6 +143,7 @@ impl IncompleteMessage {
} }
} }
} }
/// Convert an incomplete message into a complete one. /// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> { pub fn complete(self) -> Result<Message> {
match self.collector { match self.collector {

@ -33,12 +33,23 @@ pub struct WebSocketConfig {
/// means here that the size of the queue is unlimited. The default value is the unlimited /// means here that the size of the queue is unlimited. The default value is the unlimited
/// queue. /// queue.
pub max_send_queue: Option<usize>, pub max_send_queue: Option<usize>,
/// The maximum size of a message. `None` means no size limit. The default value is 64 megabytes
/// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user.
pub max_message_size: Option<usize>,
/// The maximum size of a single message frame. `None` means no size limit. The limit is for
/// frame payload NOT including the frame header. The default value is 16 megabytes which should
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
pub max_frame_size: Option<usize>,
} }
impl Default for WebSocketConfig { impl Default for WebSocketConfig {
fn default() -> Self { fn default() -> Self {
WebSocketConfig { WebSocketConfig {
max_send_queue: None, max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
} }
} }
} }
@ -98,6 +109,13 @@ impl<Stream> WebSocket<Stream> {
self.socket.get_mut() self.socket.get_mut()
} }
/// Change the configuration.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config)
}
}
impl<Stream> WebSocket<Stream> {
/// Convert a frame socket into a WebSocket. /// Convert a frame socket into a WebSocket.
fn from_frame_socket( fn from_frame_socket(
socket: FrameSocket<Stream>, socket: FrameSocket<Stream>,
@ -105,13 +123,13 @@ impl<Stream> WebSocket<Stream> {
config: Option<WebSocketConfig> config: Option<WebSocketConfig>
) -> Self { ) -> Self {
WebSocket { WebSocket {
role: role, role,
socket: socket, socket,
state: WebSocketState::Active, state: WebSocketState::Active,
incomplete: None, incomplete: None,
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
pong: None, pong: None,
config: config.unwrap_or_else(|| WebSocketConfig::default()), config: config.unwrap_or_else(WebSocketConfig::default),
} }
} }
} }
@ -145,10 +163,14 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// Note that only the last pong frame is stored to be sent, and only the /// Note that only the last pong frame is stored to be sent, and only the
/// most recent pong frame is sent if multiple pong frames are queued. /// most recent pong frame is sent if multiple pong frames are queued.
pub fn write_message(&mut self, message: Message) -> Result<()> { pub fn write_message(&mut self, message: Message) -> Result<()> {
// Try to make some room for the new message
self.write_pending().no_block()?;
if let Some(max_send_queue) = self.config.max_send_queue { if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
// Try to make some room for the new message.
// Do not return here if write would block, ignore WouldBlock silently
// since we must queue the message anyway.
self.write_pending().no_block()?;
}
if self.send_queue.len() >= max_send_queue { if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message)); return Err(Error::SendQueueFull(message));
} }
@ -167,8 +189,9 @@ impl<Stream: Read + Write> WebSocket<Stream> {
return self.write_pending() return self.write_pending()
} }
}; };
self.send_queue.push_back(frame); self.send_queue.push_back(frame);
Ok(()) self.write_pending()
} }
/// Close the connection. /// Close the connection.
@ -229,15 +252,18 @@ impl<Stream: Read + Write> WebSocket<Stream> {
impl<Stream: Read + Write> WebSocket<Stream> { impl<Stream: Read + Write> WebSocket<Stream> {
/// Try to decode one message frame. May return None. /// Try to decode one message frame. May return None.
fn read_message_frame(&mut self) -> Result<Option<Message>> { fn read_message_frame(&mut self) -> Result<Option<Message>> {
if let Some(mut frame) = self.socket.read_frame()? { if let Some(mut frame) = self.socket.read_frame(self.config.max_frame_size)? {
// MUST be 0 unless an extension is negotiated that defines meanings // MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of // for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero // the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket // value, the receiving endpoint MUST _Fail the WebSocket
// Connection_. // Connection_.
if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { {
return Err(Error::Protocol("Reserved bits are non-zero".into())) let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()))
}
} }
match self.role { match self.role {
@ -245,7 +271,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
if frame.is_masked() { if frame.is_masked() {
// A server MUST remove masking for data frames received from a client // A server MUST remove masking for data frames received from a client
// as described in Section 5.3. (RFC 6455) // as described in Section 5.3. (RFC 6455)
frame.remove_mask() frame.apply_mask()
} else { } else {
// The server MUST close the connection upon receiving a // The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455) // frame that is not masked. (RFC 6455)
@ -260,13 +286,13 @@ impl<Stream: Read + Write> WebSocket<Stream> {
} }
} }
match frame.opcode() { match frame.header().opcode {
OpCode::Control(ctl) => { OpCode::Control(ctl) => {
match ctl { match ctl {
// All control frames MUST have a payload length of 125 bytes or less // All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455) // and MUST NOT be fragmented. (RFC 6455)
_ if !frame.is_final() => { _ if !frame.header().is_final => {
Err(Error::Protocol("Fragmented control frame".into())) Err(Error::Protocol("Fragmented control frame".into()))
} }
_ if frame.payload().len() > 125 => { _ if frame.payload().len() > 125 => {
@ -299,12 +325,11 @@ impl<Stream: Read + Write> WebSocket<Stream> {
} }
OpCode::Data(data) => { OpCode::Data(data) => {
let fin = frame.is_final(); let fin = frame.header().is_final;
match data { match data {
OpData::Continue => { OpData::Continue => {
if let Some(ref mut msg) = self.incomplete { if let Some(ref mut msg) = self.incomplete {
// TODO if msg too big msg.extend(frame.into_data(), self.config.max_message_size)?;
msg.extend(frame.into_data())?;
} else { } else {
return Err(Error::Protocol("Continue frame but nothing to continue".into())) return Err(Error::Protocol("Continue frame but nothing to continue".into()))
} }
@ -327,7 +352,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
_ => panic!("Bug: message is not text nor binary"), _ => panic!("Bug: message is not text nor binary"),
}; };
let mut m = IncompleteMessage::new(message_type); let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data())?; m.extend(frame.into_data(), self.config.max_message_size)?;
m m
}; };
if fin { if fin {
@ -414,7 +439,7 @@ impl<Stream: Read + Write> WebSocket<Stream> {
Role::Client => { Role::Client => {
// 5. If the data is being sent by the client, the frame(s) MUST be // 5. If the data is being sent by the client, the frame(s) MUST be
// masked as defined in Section 5.3. (RFC 6455) // masked as defined in Section 5.3. (RFC 6455)
frame.set_mask(); frame.set_random_mask();
} }
} }
let res = self.socket.write_frame(frame); let res = self.socket.write_frame(frame);
@ -470,7 +495,7 @@ impl WebSocketState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{WebSocket, Role, Message}; use super::{WebSocket, Role, Message, WebSocketConfig};
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;
@ -512,4 +537,38 @@ mod tests {
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
} }
#[test]
fn size_limiting_text_fragmented() {
let incoming = Cursor::new(vec![
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_message_size: Some(10),
.. WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 7 + 6 > 10"
);
}
#[test]
fn size_limiting_binary() {
let incoming = Cursor::new(vec![
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let limit = WebSocketConfig {
max_message_size: Some(2),
.. WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2"
);
}
} }

Loading…
Cancel
Save