Skip to content

Commit

Permalink
Improve padding of coalesced datagrams containing INITIAL
Browse files Browse the repository at this point in the history
Our previous padding algorithm padded all client-sent or ack-eliciting
INITIAL packets to a full datagram size. While this satisfies the
specification requirements, the downside is that it made it impossible
to coalesce any other packets after the INITIAL.

We now mostly defer the padding decision until the datagram is finalised
and perform padding by appending zeroes at the end of the datagram. As
an exception to this rule, in the presence of short-header packets we
insert the padding inside the packet.
  • Loading branch information
jlaine committed Jul 1, 2024
1 parent afe5525 commit c960814
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 56 deletions.
35 changes: 27 additions & 8 deletions src/aioquic/quic/packet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self._datagrams: List[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: List[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
Expand Down Expand Up @@ -217,6 +218,7 @@ def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False

# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
Expand Down Expand Up @@ -270,15 +272,23 @@ def _end_packet(self) -> None:
- packet_size
)

# Padding for initial packets; see RFC 9000 section
# 14.1.
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
(self._is_client or self._packet.is_ack_eliciting)
and self._packet_type == QuicPacketType.INITIAL
and self.remaining_flight_space
and self.remaining_flight_space > padding_size
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True

# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
padding_size = self.remaining_flight_space
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False

# write padding
if padding_size > 0:
Expand Down Expand Up @@ -343,7 +353,7 @@ def _end_packet(self) -> None:
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes

# short header packets cannot be coalesced, we need a new datagram
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()

Expand All @@ -358,6 +368,15 @@ def _end_packet(self) -> None:
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes

self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
Expand Down
57 changes: 23 additions & 34 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
)

CLIENT_ADDR = ("1.2.3.4", 1234)
CLIENT_HANDSHAKE_DATAGRAM_SIZES = [1200]

SERVER_ADDR = ("2.3.4.5", 4433)
SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1200, 986]
SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1162]

HANDSHAKE_COMPLETED_EVENTS = [
events.HandshakeCompleted,
Expand Down Expand Up @@ -464,9 +465,8 @@ def test_connect_without_loss(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.425)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -475,7 +475,6 @@ def test_connect_without_loss(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.425)
Expand Down Expand Up @@ -529,9 +528,8 @@ def test_connect_with_loss_1(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.625)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -540,7 +538,6 @@ def test_connect_with_loss_1(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.625)
Expand Down Expand Up @@ -607,9 +604,8 @@ def test_connect_with_loss_2(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.525)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -618,7 +614,6 @@ def test_connect_with_loss_2(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.525)
Expand Down Expand Up @@ -683,9 +678,8 @@ def test_connect_with_loss_3(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.625)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -694,7 +688,6 @@ def test_connect_with_loss_3(self):

now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.625)
Expand Down Expand Up @@ -733,12 +726,11 @@ def test_connect_with_loss_4(self):
self.assertSentPackets(server, [1, 2, 0])
self.assertEvents(server, [events.ProtocolNegotiated])

# client only receives first two datagrams and sends ACKS
# client only receives the first datagram and sends ACKS
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 48])
self.assertEqual(datagram_sizes(items), [1200])
self.assertAlmostEqual(client.get_timer(), 0.325)
self.assertSentPackets(client, [0, 1, 0])
self.assertEvents(client, [events.ProtocolNegotiated])
Expand Down Expand Up @@ -821,9 +813,8 @@ def test_connect_with_loss_5(self):
now += TICK
client.receive_datagram(items[0][0], SERVER_ADDR, now=now)
client.receive_datagram(items[1][0], SERVER_ADDR, now=now)
client.receive_datagram(items[2][0], SERVER_ADDR, now=now)
items = client.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [1200, 327])
self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES)
self.assertAlmostEqual(client.get_timer(), 0.425)
self.assertSentPackets(client, [0, 1, 1])
self.assertEvents(
Expand All @@ -833,7 +824,6 @@ def test_connect_with_loss_5(self):
# server completes handshake, but HANDSHAKE_DONE is lost
now += TICK
server.receive_datagram(items[0][0], CLIENT_ADDR, now=now)
server.receive_datagram(items[1][0], CLIENT_ADDR, now=now)
items = server.datagrams_to_send(now=now)
self.assertEqual(datagram_sizes(items), [229])
self.assertAlmostEqual(server.get_timer(), 0.425)
Expand Down Expand Up @@ -1066,7 +1056,7 @@ def save_session_ticket(ticket):
stream_id = client.get_next_available_stream_id()
client.send_stream_data(stream_id, b"hello")

self.assertEqual(roundtrip(client, server), (2, 2))
self.assertEqual(roundtrip(client, server), (1, 1))

event = server.next_event()
self.assertEqual(type(event), events.ProtocolNegotiated)
Expand Down Expand Up @@ -2787,7 +2777,7 @@ def test_send_max_data_blocked_by_cc(self):
with client_and_server() as (client, server):
# check congestion control
self.assertEqual(client._loss.bytes_in_flight, 0)
self.assertEqual(client._loss.congestion_window, 13423)
self.assertEqual(client._loss.congestion_window, 13536)

# artificially raise received data counter
client._local_max_data_used = client._local_max_data
Expand Down Expand Up @@ -3153,20 +3143,19 @@ def test_version_negotiation_ignore(self):
self.assertEqual(drop(client), 0)

def test_version_negotiation_ignore_server(self):
with client_and_server() as (client, server):
# The server does not reply to the version negotiation packet.
server.receive_datagram(
encode_quic_version_negotiation(
source_cid=server._peer_cid.cid,
destination_cid=server.host_cid,
supported_versions=[QuicProtocolVersion.VERSION_1],
),
CLIENT_ADDR,
now=time.time(),
)
self.assertEqual(drop(client), 0)
server = create_standalone_server(self)

self.assertPacketDropped(server, "unexpected_packet")
# Servers do not expect version negotiation packets.
server.receive_datagram(
encode_quic_version_negotiation(
source_cid=server._peer_cid.cid,
destination_cid=server.host_cid,
supported_versions=[QuicProtocolVersion.VERSION_1],
),
CLIENT_ADDR,
now=time.time(),
)
self.assertPacketDropped(server, "unexpected_packet")

def test_version_negotiation_ok(self):
client = create_standalone_client(
Expand Down
75 changes: 61 additions & 14 deletions tests/test_packet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_long_header_initial_client(self):
is_crypto_packet=True,
packet_number=0,
packet_type=QuicPacketType.INITIAL,
sent_bytes=1200,
sent_bytes=145,
)
],
)
Expand Down Expand Up @@ -134,7 +134,55 @@ def test_long_header_initial_client_2(self):
is_crypto_packet=True,
packet_number=1,
packet_type=QuicPacketType.INITIAL,
sent_bytes=1200,
sent_bytes=145,
),
],
)

# check builder
self.assertEqual(builder.packet_number, 2)

def test_long_header_initial_client_zero_rtt(self):
builder = create_builder(is_client=True)
crypto = create_crypto()

# INITIAL
builder.start_packet(QuicPacketType.INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1156)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(613))
self.assertFalse(builder.packet_is_empty)

# 0-RTT
builder.start_packet(QuicPacketType.ZERO_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 499)
buf = builder.start_frame(QuicFrameType.STREAM_BASE)
buf.push_bytes(bytes(100))
self.assertFalse(builder.packet_is_empty)

# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(datagram_sizes(datagrams), [1200])
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=QuicPacketType.INITIAL,
sent_bytes=658,
),
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=False,
packet_number=1,
packet_type=QuicPacketType.ZERO_RTT,
sent_bytes=144,
),
],
)
Expand Down Expand Up @@ -163,10 +211,10 @@ def test_long_header_initial_server(self):

# HANDSHAKE with CRYPTO
builder.start_packet(QuicPacketType.HANDSHAKE, crypto)
self.assertEqual(builder.remaining_flight_space, 1157)
self.assertEqual(builder.remaining_flight_space, 995)

buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(1156))
buf.push_bytes(bytes(994))
self.assertFalse(builder.packet_is_empty)

# HANDSHAKE with CRYPTO
Expand All @@ -183,7 +231,7 @@ def test_long_header_initial_server(self):

# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(datagram_sizes(datagrams), [1200, 1200, 844])
self.assertEqual(datagram_sizes(datagrams), [1200, 844])
self.assertEqual(
packets,
[
Expand All @@ -194,7 +242,7 @@ def test_long_header_initial_server(self):
is_crypto_packet=True,
packet_number=0,
packet_type=QuicPacketType.INITIAL,
sent_bytes=1200,
sent_bytes=162,
),
QuicSentPacket(
epoch=Epoch.HANDSHAKE,
Expand All @@ -203,7 +251,7 @@ def test_long_header_initial_server(self):
is_crypto_packet=True,
packet_number=1,
packet_type=QuicPacketType.HANDSHAKE,
sent_bytes=1200,
sent_bytes=1038,
),
QuicSentPacket(
epoch=Epoch.HANDSHAKE,
Expand Down Expand Up @@ -252,7 +300,7 @@ def test_long_header_initial_server_without_handshake(self):
is_crypto_packet=True,
packet_number=0,
packet_type=QuicPacketType.INITIAL,
sent_bytes=1200,
sent_bytes=145,
)
],
)
Expand Down Expand Up @@ -363,24 +411,23 @@ def test_long_header_then_long_header(self):

# HANDSHAKE
builder.start_packet(QuicPacketType.HANDSHAKE, crypto)
self.assertEqual(builder.remaining_flight_space, 1157)
self.assertEqual(builder.remaining_flight_space, 913)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(299))
self.assertFalse(builder.packet_is_empty)
self.assertEqual(builder.remaining_flight_space, 857)
self.assertEqual(builder.remaining_flight_space, 613)

# ONE_RTT
builder.start_packet(QuicPacketType.ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 830)
self.assertEqual(builder.remaining_flight_space, 586)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(299))
self.assertFalse(builder.packet_is_empty)

# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 2)
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 1200)
self.assertEqual(len(datagrams[1]), 670)
self.assertEqual(
packets,
[
Expand All @@ -391,7 +438,7 @@ def test_long_header_then_long_header(self):
is_crypto_packet=True,
packet_number=0,
packet_type=QuicPacketType.INITIAL,
sent_bytes=1200,
sent_bytes=244,
),
QuicSentPacket(
epoch=Epoch.HANDSHAKE,
Expand Down

0 comments on commit c960814

Please sign in to comment.