Merge pull request #7 from bluetech/misc-improvements

Miscellaneous improvements
pull/12/head
Alexey Galakhov 8 years ago committed by GitHub
commit 65248f159f
  1. 4
      Cargo.toml
  2. 4
      src/client.rs
  3. 2
      src/handshake/server.rs
  4. 3
      src/input_buffer.rs
  5. 118
      src/protocol/frame/frame.rs
  6. 15
      src/protocol/mod.rs

@ -20,7 +20,6 @@ base64 = "0.4.0"
byteorder = "1.0.0" byteorder = "1.0.0"
bytes = "0.4.1" bytes = "0.4.1"
httparse = "1.2.1" httparse = "1.2.1"
env_logger = "0.4.2"
log = "0.3.7" log = "0.3.7"
rand = "0.3.15" rand = "0.3.15"
sha1 = "0.2.0" sha1 = "0.2.0"
@ -30,3 +29,6 @@ utf-8 = "0.7.0"
[dependencies.native-tls] [dependencies.native-tls]
optional = true optional = true
version = "0.1.1" version = "0.1.1"
[dev-dependencies]
env_logger = "0.4.2"

@ -57,7 +57,7 @@ fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream
TlsHandshakeError::Failure(f) => f.into(), TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"), TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
}) })
.map(|s| StreamSwitcher::Tls(s)) .map(StreamSwitcher::Tls)
} }
} }
} }
@ -73,7 +73,7 @@ fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStrea
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream> fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr> where A: Iterator<Item=SocketAddr>
{ {
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?; let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs { for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr); debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(raw_stream) = TcpStream::connect(addr) {

@ -18,7 +18,7 @@ impl Request {
/// Reply to the response. /// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> { pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key") let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?; .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\ let reply = format!("\
HTTP/1.1 101 Switching Protocols\r\n\ HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\

@ -23,7 +23,8 @@ impl InputBuffer {
/// Reserve the given amount of space. /// Reserve the given amount of space.
pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{ pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{
if self.inp_mut().remaining_mut() >= space { let remaining = self.inp_mut().capacity() - self.inp_mut().len();
if remaining >= space {
// We have enough space right now. // We have enough space right now.
Ok(()) Ok(())
} else { } else {

@ -1,11 +1,10 @@
use std::fmt; use std::fmt;
use std::mem::transmute; use std::mem::transmute;
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write, ErrorKind};
use std::default::Default; use std::default::Default;
use std::iter::FromIterator;
use std::string::{String, FromUtf8Error}; use std::string::{String, FromUtf8Error};
use std::result::Result as StdResult; use std::result::Result as StdResult;
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian};
use bytes::BufMut; use bytes::BufMut;
use rand; use rand;
@ -14,15 +13,34 @@ use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode}; use super::coding::{OpCode, Control, Data, CloseCode};
fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
let iter = buf.iter_mut().zip(mask.iter().cycle()); for (i, byte) in buf.iter_mut().enumerate() {
for (byte, &key) in iter { *byte ^= mask[i & 3];
*byte ^= key
} }
} }
/// Faster version of `apply_mask()` which operates on 4-byte blocks.
///
/// Safety: `buf` must be at least 4-bytes aligned.
unsafe fn apply_mask_aligned32(buf: &mut [u8], mask: &[u8; 4]) {
debug_assert_eq!(buf.as_ptr() as usize % 4, 0);
let mask_u32 = transmute(*mask);
let mut ptr = buf.as_mut_ptr() as *mut u32;
for _ in 0..(buf.len() / 4) {
*ptr ^= mask_u32;
ptr = ptr.offset(1);
}
// Possible last block with less than 4 bytes.
let last_block_start = buf.len() & !3;
let last_block = &mut buf[last_block_start..];
apply_mask(last_block, mask);
}
#[inline] #[inline]
fn generate_mask() -> [u8; 4] { fn generate_mask() -> [u8; 4] {
unsafe { transmute(rand::random::<u32>()) } rand::random()
} }
/// A struct representing a WebSocket frame. /// A struct representing a WebSocket frame.
@ -175,7 +193,10 @@ impl Frame {
#[inline] #[inline]
pub fn remove_mask(&mut self) { pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| { self.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask)) // Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
Some(apply_mask_aligned32(&mut self.payload, &mask))
}
}); });
self.mask = None; self.mask = None;
} }
@ -252,10 +273,7 @@ impl Frame {
let u: u16 = code.into(); let u: u16 = code.into();
transmute(u.to_be()) transmute(u.to_be())
}; };
Vec::from_iter( [&raw[..], reason.as_bytes()].concat()
raw[..].iter()
.chain(reason.as_bytes().iter())
.map(|&b| b))
} else { } else {
Vec::new() Vec::new()
}; };
@ -301,29 +319,24 @@ impl Frame {
let mut length = (second & 0x7F) as u64; let mut length = (second & 0x7F) as u64;
if length == 126 { if let Some(length_nbytes) = match length {
let mut length_bytes = [0u8; 2]; 126 => Some(2),
if try!(cursor.read(&mut length_bytes)) != 2 { 127 => Some(8),
cursor.set_position(initial); _ => None,
return Ok(None) } {
} match cursor.read_uint::<NetworkEndian>(length_nbytes) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
length = unsafe { cursor.set_position(initial);
let mut wide: u16 = transmute(length_bytes); return Ok(None);
wide = u16::from_be(wide); }
wide Err(err) => {
} as u64; return Err(Error::from(err));
header_length += 2; }
} else if length == 127 { Ok(read) => {
let mut length_bytes = [0u8; 8]; length = read;
if try!(cursor.read(&mut length_bytes)) != 8 { }
cursor.set_position(initial); };
return Ok(None) header_length += length_nbytes as u64;
}
unsafe { length = transmute(length_bytes); }
length = u64::from_be(length);
header_length += 8;
} }
trace!("Payload length: {}", length); trace!("Payload length: {}", length);
@ -407,18 +420,14 @@ impl Frame {
try!(w.write(&headers)); try!(w.write(&headers));
} else if self.payload.len() <= 65535 { } else if self.payload.len() <= 65535 {
two |= 126; two |= 126;
let length_bytes: [u8; 2] = unsafe { let mut length_bytes = [0u8; 2];
let short = self.payload.len() as u16; NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16);
transmute(short.to_be())
};
let headers = [one, two, length_bytes[0], length_bytes[1]]; let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers)); try!(w.write(&headers));
} else { } else {
two |= 127; two |= 127;
let length_bytes: [u8; 8] = unsafe { let mut length_bytes = [0u8; 8];
let long = self.payload.len() as u64; NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64);
transmute(long.to_be())
};
let headers = [ let headers = [
one, one,
two, two,
@ -436,7 +445,10 @@ impl Frame {
if self.is_masked() { if self.is_masked() {
let mask = self.mask.take().unwrap(); let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask); // Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
apply_mask_aligned32(&mut self.payload, &mask);
}
try!(w.write(&mask)); try!(w.write(&mask));
} }
@ -490,6 +502,24 @@ mod tests {
use super::super::coding::{OpCode, Data}; use super::super::coding::{OpCode, Data};
use std::io::Cursor; use std::io::Cursor;
#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00,
];
let mut masked = unmasked.clone();
apply_mask(&mut masked, &mask);
let mut masked_aligned = unmasked.clone();
unsafe { apply_mask_aligned32(&mut masked_aligned, &mask) };
assert_eq!(masked, masked_aligned);
}
#[test] #[test]
fn parse() { fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![ let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![

@ -124,15 +124,12 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// This function guarantees that the close frame will be queued. /// This function guarantees that the close frame will be queued.
/// There is no need to call it again, just like write_message(). /// There is no need to call it again, just like write_message().
pub fn close(&mut self) -> Result<()> { pub fn close(&mut self) -> Result<()> {
match self.state { if let WebSocketState::Active = self.state {
WebSocketState::Active => { self.state = WebSocketState::ClosedByUs;
self.state = WebSocketState::ClosedByUs; let frame = Frame::close(None);
let frame = Frame::close(None); self.send_queue.push_back(frame);
self.send_queue.push_back(frame); } else {
} // Already closed, nothing to do.
_ => {
// already closed, nothing to do
}
} }
self.write_pending() self.write_pending()
} }

Loading…
Cancel
Save