diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 6fbcf12..d41cde8 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,10 +1,10 @@ use std::fmt; use std::mem::transmute; -use std::io::{Cursor, Read, Write}; +use std::io::{Cursor, Read, Write, ErrorKind}; use std::default::Default; use std::string::{String, FromUtf8Error}; use std::result::Result as StdResult; -use byteorder::{ByteOrder, NetworkEndian}; +use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian}; use bytes::BufMut; use rand; @@ -319,29 +319,24 @@ impl Frame { let mut length = (second & 0x7F) as u64; - if length == 126 { - let mut length_bytes = [0u8; 2]; - if try!(cursor.read(&mut length_bytes)) != 2 { - cursor.set_position(initial); - return Ok(None) - } - - length = unsafe { - let mut wide: u16 = transmute(length_bytes); - wide = u16::from_be(wide); - wide - } as u64; - header_length += 2; - } else if length == 127 { - let mut length_bytes = [0u8; 8]; - if try!(cursor.read(&mut length_bytes)) != 8 { - cursor.set_position(initial); - return Ok(None) - } - - unsafe { length = transmute(length_bytes); } - length = u64::from_be(length); - header_length += 8; + if let Some(length_nbytes) = match length { + 126 => Some(2), + 127 => Some(8), + _ => None, + } { + match cursor.read_uint::(length_nbytes) { + Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { + cursor.set_position(initial); + return Ok(None); + } + Err(err) => { + return Err(Error::from(err)); + } + Ok(read) => { + length = read; + } + }; + header_length += length_nbytes as u64; } trace!("Payload length: {}", length); @@ -425,18 +420,14 @@ impl Frame { try!(w.write(&headers)); } else if self.payload.len() <= 65535 { two |= 126; - let length_bytes: [u8; 2] = unsafe { - let short = self.payload.len() as u16; - transmute(short.to_be()) - }; + let mut length_bytes = [0u8; 2]; + NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16); let headers = [one, two, length_bytes[0], length_bytes[1]]; try!(w.write(&headers)); } else { two |= 127; - let length_bytes: [u8; 8] = unsafe { - let long = self.payload.len() as u64; - transmute(long.to_be()) - }; + let mut length_bytes = [0u8; 8]; + NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64); let headers = [ one, two,