Skip to content

Commit

Permalink
More principled error handling for invalid frames
Browse files Browse the repository at this point in the history
  • Loading branch information
djc authored and Ralith committed Sep 20, 2023
1 parent 0af891e commit f81c2fa
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 24 deletions.
26 changes: 13 additions & 13 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2192,7 +2192,15 @@ impl Connection {
return Ok(());
}
State::Closed(_) => {
for frame in frame::Iter::new(packet.payload.freeze()) {
for result in frame::Iter::new(packet.payload.freeze()) {
let frame = match result {
Ok(frame) => frame,
Err(err) => {
debug!("frame decoding error: {err:?}");
continue;
}
};

if let Frame::Padding = frame {
continue;
};
Expand Down Expand Up @@ -2433,7 +2441,8 @@ impl Connection {
debug_assert_ne!(packet.header.space(), SpaceId::Data);
let payload_len = packet.payload.len();
let mut ack_eliciting = false;
for frame in frame::Iter::new(packet.payload.freeze()) {
for result in frame::Iter::new(packet.payload.freeze()) {
let frame = result?;
let span = match frame {
Frame::Padding => continue,
_ => Some(trace_span!("frame", ty = %frame.ty())),
Expand All @@ -2458,11 +2467,6 @@ impl Connection {
self.state = State::Draining;
return Ok(());
}
Frame::Invalid { ty, reason } => {
let mut err = TransportError::FRAME_ENCODING_ERROR(reason);
err.frame = Some(ty);
return Err(err);
}
_ => {
let mut err =
TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake");
Expand Down Expand Up @@ -2495,7 +2499,8 @@ impl Connection {
let mut close = None;
let payload_len = payload.len();
let mut ack_eliciting = false;
for frame in frame::Iter::new(payload) {
for result in frame::Iter::new(payload) {
let frame = result?;
let span = match frame {
Frame::Padding => continue,
_ => Some(trace_span!("frame", ty = %frame.ty())),
Expand Down Expand Up @@ -2543,11 +2548,6 @@ impl Connection {
}
}
match frame {
Frame::Invalid { ty, reason } => {
let mut err = TransportError::FRAME_ENCODING_ERROR(reason);
err.frame = Some(ty);
return Err(err);
}
Frame::Crypto(frame) => {
self.read_crypto(SpaceId::Data, &frame, payload_len)?;
}
Expand Down
1 change: 0 additions & 1 deletion quinn-proto/src/connection/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ impl FrameStats {
Frame::AckFrequency(_) => self.ack_frequency += 1,
Frame::ImmediateAck => self.immediate_ack += 1,
Frame::HandshakeDone => self.handshake_done += 1,
Frame::Invalid { .. } => {}
}
}
}
Expand Down
30 changes: 22 additions & 8 deletions quinn-proto/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ pub(crate) enum Frame {
Datagram(Datagram),
AckFrequency(AckFrequency),
ImmediateAck,
Invalid { ty: Type, reason: &'static str },
HandshakeDone,
}

Expand Down Expand Up @@ -204,7 +203,6 @@ impl Frame {
Datagram(_) => Type(*DATAGRAM_TYS.start()),
AckFrequency(_) => Type::ACK_FREQUENCY,
ImmediateAck => Type::IMMEDIATE_ACK,
Invalid { ty, .. } => ty,
HandshakeDone => Type::HANDSHAKE_DONE,
}
}
Expand Down Expand Up @@ -734,25 +732,39 @@ impl Iter {
}

impl Iterator for Iter {
type Item = Frame;
type Item = Result<Frame, InvalidFrame>;
fn next(&mut self) -> Option<Self::Item> {
if !self.bytes.has_remaining() {
return None;
}
match self.try_next() {
Ok(x) => Some(x),
Ok(x) => Some(Ok(x)),
Err(e) => {
// Corrupt frame, skip it and everything that follows
self.bytes = io::Cursor::new(Bytes::new());
Some(Frame::Invalid {
ty: self.last_ty.unwrap(),
Some(Err(InvalidFrame {
ty: self.last_ty,
reason: e.reason(),
})
}))
}
}
}
}

#[derive(Debug)]
pub(crate) struct InvalidFrame {
pub(crate) ty: Option<Type>,
pub(crate) reason: &'static str,
}

impl From<InvalidFrame> for TransportError {
fn from(err: InvalidFrame) -> Self {
let mut te = Self::FRAME_ENCODING_ERROR(err.reason);
te.frame = err.ty;
te
}
}

fn scan_ack_blocks(buf: &mut io::Cursor<Bytes>, largest: u64, n: usize) -> Result<(), IterErr> {
let first_block = buf.get_var()?;
let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?;
Expand Down Expand Up @@ -910,7 +922,9 @@ mod test {
use assert_matches::assert_matches;

fn frames(buf: Vec<u8>) -> Vec<Frame> {
Iter::new(Bytes::from(buf)).collect::<Vec<_>>()
Iter::new(Bytes::from(buf))
.collect::<Result<Vec<_>, _>>()
.unwrap()
}

#[test]
Expand Down
5 changes: 3 additions & 2 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2256,8 +2256,9 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() {

// The ACK delay is properly calculated
assert_eq!(pair.client.captured_packets.len(), 1);
let mut frames =
frame::Iter::new(pair.client.captured_packets.remove(0).into()).collect::<Vec<_>>();
let mut frames = frame::Iter::new(pair.client.captured_packets.remove(0).into())
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(frames.len(), 1);
if let Frame::Ack(ack) = frames.remove(0) {
let ack_delay_exp = TransportParameters::default().ack_delay_exponent;
Expand Down

0 comments on commit f81c2fa

Please sign in to comment.