Fix possible deadlock in handshake.

Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
pull/13/head
Alexey Galakhov 7 years ago
parent e9ad145db6
commit 22f7df0b46
  1. 50
      src/handshake/machine.rs

@ -40,50 +40,58 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
/// Perform a single handshake round. /// Perform a single handshake round.
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round."); trace!("Doing handshake round.");
Ok(match self.state { match self.state {
HandshakeState::Reading(mut buf) => { HandshakeState::Reading(mut buf) => {
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))?; .map_err(|_| Error::Capacity("Header too long".into()))?;
if let Some(_) = buf.read_from(&mut self.stream).no_block()? { match buf.read_from(&mut self.stream).no_block()? {
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? { Some(0) => {
buf.advance(size); Err(Error::Protocol("Handshake not finished".into()))
RoundResult::StageFinished(StageResult::DoneReading { }
result: obj, Some(_) => {
stream: self.stream, Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
tail: buf.into_vec(), buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}) })
} else { }
RoundResult::Incomplete(HandshakeMachine { None => {
Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf),
..self ..self
}) }))
} }
} else {
RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
} }
} }
HandshakeState::Writing(mut buf) => { HandshakeState::Writing(mut buf) => {
assert!(buf.has_remaining());
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? { if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? {
assert!(size > 0);
buf.advance(size); buf.advance(size);
if buf.has_remaining() { Ok(if buf.has_remaining() {
RoundResult::Incomplete(HandshakeMachine { RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Writing(buf), state: HandshakeState::Writing(buf),
..self ..self
}) })
} else { } else {
RoundResult::StageFinished(StageResult::DoneWriting(self.stream)) RoundResult::StageFinished(StageResult::DoneWriting(self.stream))
} })
} else { } else {
RoundResult::WouldBlock(HandshakeMachine { Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Writing(buf), state: HandshakeState::Writing(buf),
..self ..self
}) }))
} }
} }
}) }
} }
} }

Loading…
Cancel
Save