diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index ba9eff7..61220d3 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -5,6 +5,7 @@ pub mod tls; use std::ascii::AsciiExt; use std::str::from_utf8; +use std::slice; use base64; use bytes::Buf; @@ -76,9 +77,15 @@ impl Headers { /// Get first header with the given name, if any. pub fn find_first(&self, name: &str) -> Option<&[u8]> { - self.data.iter() - .find(|&&(ref n, _)| n.eq_ignore_ascii_case(name)) - .map(|&(_, ref v)| v.as_ref()) + self.find(name).next() + } + + /// Iterate over all headers with the given name. + pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { + HeadersIter { + name: name, + iter: self.data.iter() + } } /// Check if the given header has the given value. @@ -103,6 +110,25 @@ impl Headers { } +/// The iterator over headers. +pub struct HeadersIter<'name, 'headers> { + name: &'name str, + iter: slice::Iter<'headers, (String, Box<[u8]>)>, +} + +impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { + type Item = &'headers [u8]; + fn next(&mut self) -> Option { + while let Some(&(ref name, ref value)) = self.iter.next() { + if name.eq_ignore_ascii_case(self.name) { + return Some(value) + } + } + None + } +} + + /// Trait to read HTTP parseable objects. trait Httparse: Sized { fn httparse(buf: &[u8]) -> Result>; @@ -159,7 +185,10 @@ mod tests { #[test] fn headers() { const data: &'static [u8] = - b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"; + b"Host: foo.com\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + \r\n"; let mut inp = Cursor::new(data); let hdr = Headers::parse(&mut inp).unwrap().unwrap(); assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); @@ -171,10 +200,29 @@ mod tests { assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); } + #[test] + fn headers_iter() { + const data: &'static [u8] = + b"Host: foo.com\r\n\ + Sec-WebSocket-Extensions: permessage-deflate\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ + Upgrade: websocket\r\n\ + \r\n"; + let mut inp = Cursor::new(data); + let hdr = Headers::parse(&mut inp).unwrap().unwrap(); + let mut iter = hdr.find("Sec-WebSocket-Extensions"); + assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); + assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); + assert_eq!(iter.next(), None); + } + #[test] fn headers_incomplete() { const data: &'static [u8] = - b"Host: foo.com\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"; + b"Host: foo.com\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n"; let mut inp = Cursor::new(data); let hdr = Headers::parse(&mut inp).unwrap(); assert!(hdr.is_none());