Merge pull request #7 from bluetech/misc-improvements

Miscellaneous improvements
pull/12/head
Alexey Galakhov 8 years ago committed by GitHub
commit 65248f159f
  1. 4
      Cargo.toml
  2. 4
      src/client.rs
  3. 2
      src/handshake/server.rs
  4. 3
      src/input_buffer.rs
  5. 118
      src/protocol/frame/frame.rs
  6. 15
      src/protocol/mod.rs

@ -20,7 +20,6 @@ base64 = "0.4.0"
byteorder = "1.0.0"
bytes = "0.4.1"
httparse = "1.2.1"
env_logger = "0.4.2"
log = "0.3.7"
rand = "0.3.15"
sha1 = "0.2.0"
@ -30,3 +29,6 @@ utf-8 = "0.7.0"
[dependencies.native-tls]
optional = true
version = "0.1.1"
[dev-dependencies]
env_logger = "0.4.2"

@ -57,7 +57,7 @@ fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
.map(StreamSwitcher::Tls)
}
}
}
@ -73,7 +73,7 @@ fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStrea
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr>
{
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?;
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -18,7 +18,7 @@ impl Request {
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\

@ -23,7 +23,8 @@ impl InputBuffer {
/// Reserve the given amount of space.
pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{
if self.inp_mut().remaining_mut() >= space {
let remaining = self.inp_mut().capacity() - self.inp_mut().len();
if remaining >= space {
// We have enough space right now.
Ok(())
} else {

@ -1,11 +1,10 @@
use std::fmt;
use std::mem::transmute;
use std::io::{Cursor, Read, Write};
use std::io::{Cursor, Read, Write, ErrorKind};
use std::default::Default;
use std::iter::FromIterator;
use std::string::{String, FromUtf8Error};
use std::result::Result as StdResult;
use byteorder::{ByteOrder, NetworkEndian};
use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian};
use bytes::BufMut;
use rand;
@ -14,15 +13,34 @@ use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode};
fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
let iter = buf.iter_mut().zip(mask.iter().cycle());
for (byte, &key) in iter {
*byte ^= key
for (i, byte) in buf.iter_mut().enumerate() {
*byte ^= mask[i & 3];
}
}
/// Faster version of `apply_mask()` which operates on 4-byte blocks.
///
/// Safety: `buf` must be at least 4-bytes aligned.
unsafe fn apply_mask_aligned32(buf: &mut [u8], mask: &[u8; 4]) {
debug_assert_eq!(buf.as_ptr() as usize % 4, 0);
let mask_u32 = transmute(*mask);
let mut ptr = buf.as_mut_ptr() as *mut u32;
for _ in 0..(buf.len() / 4) {
*ptr ^= mask_u32;
ptr = ptr.offset(1);
}
// Possible last block with less than 4 bytes.
let last_block_start = buf.len() & !3;
let last_block = &mut buf[last_block_start..];
apply_mask(last_block, mask);
}
#[inline]
fn generate_mask() -> [u8; 4] {
unsafe { transmute(rand::random::<u32>()) }
rand::random()
}
/// A struct representing a WebSocket frame.
@ -175,7 +193,10 @@ impl Frame {
#[inline]
pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask))
// Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
Some(apply_mask_aligned32(&mut self.payload, &mask))
}
});
self.mask = None;
}
@ -252,10 +273,7 @@ impl Frame {
let u: u16 = code.into();
transmute(u.to_be())
};
Vec::from_iter(
raw[..].iter()
.chain(reason.as_bytes().iter())
.map(|&b| b))
[&raw[..], reason.as_bytes()].concat()
} else {
Vec::new()
};
@ -301,29 +319,24 @@ impl Frame {
let mut length = (second & 0x7F) as u64;
if length == 126 {
let mut length_bytes = [0u8; 2];
if try!(cursor.read(&mut length_bytes)) != 2 {
cursor.set_position(initial);
return Ok(None)
}
length = unsafe {
let mut wide: u16 = transmute(length_bytes);
wide = u16::from_be(wide);
wide
} as u64;
header_length += 2;
} else if length == 127 {
let mut length_bytes = [0u8; 8];
if try!(cursor.read(&mut length_bytes)) != 8 {
cursor.set_position(initial);
return Ok(None)
}
unsafe { length = transmute(length_bytes); }
length = u64::from_be(length);
header_length += 8;
if let Some(length_nbytes) = match length {
126 => Some(2),
127 => Some(8),
_ => None,
} {
match cursor.read_uint::<NetworkEndian>(length_nbytes) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
cursor.set_position(initial);
return Ok(None);
}
Err(err) => {
return Err(Error::from(err));
}
Ok(read) => {
length = read;
}
};
header_length += length_nbytes as u64;
}
trace!("Payload length: {}", length);
@ -407,18 +420,14 @@ impl Frame {
try!(w.write(&headers));
} else if self.payload.len() <= 65535 {
two |= 126;
let length_bytes: [u8; 2] = unsafe {
let short = self.payload.len() as u16;
transmute(short.to_be())
};
let mut length_bytes = [0u8; 2];
NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16);
let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers));
} else {
two |= 127;
let length_bytes: [u8; 8] = unsafe {
let long = self.payload.len() as u64;
transmute(long.to_be())
};
let mut length_bytes = [0u8; 8];
NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64);
let headers = [
one,
two,
@ -436,7 +445,10 @@ impl Frame {
if self.is_masked() {
let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask);
// Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
apply_mask_aligned32(&mut self.payload, &mask);
}
try!(w.write(&mask));
}
@ -490,6 +502,24 @@ mod tests {
use super::super::coding::{OpCode, Data};
use std::io::Cursor;
#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00,
];
let mut masked = unmasked.clone();
apply_mask(&mut masked, &mask);
let mut masked_aligned = unmasked.clone();
unsafe { apply_mask_aligned32(&mut masked_aligned, &mask) };
assert_eq!(masked, masked_aligned);
}
#[test]
fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![

@ -124,15 +124,12 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// This function guarantees that the close frame will be queued.
/// There is no need to call it again, just like write_message().
pub fn close(&mut self) -> Result<()> {
match self.state {
WebSocketState::Active => {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(None);
self.send_queue.push_back(frame);
}
_ => {
// already closed, nothing to do
}
if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(None);
self.send_queue.push_back(frame);
} else {
// Already closed, nothing to do.
}
self.write_pending()
}

Loading…
Cancel
Save