Skip to content

Commit

Permalink
Add support for version_information transport parameter
Browse files Browse the repository at this point in the history
This parameter is defined in RFC 9368 and is used for compatible version
negotiation.

We also ensure that if parsing a parameter results in a shorter read
than the parameter's length, we raise a `ValueError` not an
`AssertionError`.
  • Loading branch information
jlaine committed Jun 29, 2024
1 parent 70dd040 commit a59d9ad
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: pip install check-manifest mypy ruff types-certifi types-pyopenssl
- name: Run linters
run: |
ruff .
ruff check .
ruff format --check --diff .
mypy src tests
check-manifest
Expand Down
52 changes: 49 additions & 3 deletions src/aioquic/quic/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ class QuicPreferredAddress:
stateless_reset_token: bytes


@dataclass
class QuicVersionInformation:
chosen_version: int
available_versions: List[int]


@dataclass
class QuicTransportParameters:
original_destination_connection_id: Optional[bytes] = None
Expand All @@ -352,6 +358,7 @@ class QuicTransportParameters:
active_connection_id_limit: Optional[int] = None
initial_source_connection_id: Optional[bytes] = None
retry_source_connection_id: Optional[bytes] = None
version_information: Optional[QuicVersionInformation] = None
max_datagram_frame_size: Optional[int] = None
quantum_readiness: Optional[bytes] = None

Expand All @@ -374,6 +381,8 @@ class QuicTransportParameters:
0x0E: ("active_connection_id_limit", int),
0x0F: ("initial_source_connection_id", bytes),
0x10: ("retry_source_connection_id", bytes),
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
0x11: ("version_information", QuicVersionInformation),
# extensions
0x0020: ("max_datagram_frame_size", int),
0x0C37: ("quantum_readiness", bytes),
Expand Down Expand Up @@ -425,27 +434,62 @@ def push_quic_preferred_address(
buf.push_bytes(preferred_address.stateless_reset_token)


def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
chosen_version = buf.pull_uint32()
available_versions = []
for i in range(length // 4 - 1):
available_versions.append(buf.pull_uint32())

# If an endpoint receives a Chosen Version equal to zero, or any Available Version
# equal to zero, it MUST treat it as a parsing failure.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if chosen_version == 0 or 0 in available_versions:
raise ValueError("Version Information must not contain version 0")

return QuicVersionInformation(
chosen_version=chosen_version,
available_versions=available_versions,
)


def push_quic_version_information(
buf: Buffer, version_information: QuicVersionInformation
) -> None:
buf.push_uint32(version_information.chosen_version)
for version in version_information.available_versions:
buf.push_uint32(version)


def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
params = QuicTransportParameters()
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# parse known parameter
# Parse known parameter.
param_name, param_type = PARAMS[param_id]
if param_type == int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type == bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type == QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
elif param_type == QuicVersionInformation:
setattr(
params,
param_name,
pull_quic_version_information(buf, param_len),
)
else:
setattr(params, param_name, True)
else:
# skip unknown parameter
# Skip unknown parameter.
buf.pull_bytes(param_len)
assert buf.tell() == param_start + param_len

if buf.tell() != param_start + param_len:
raise ValueError("Transport parameter length does not match")

return params

Expand All @@ -463,6 +507,8 @@ def push_quic_transport_parameters(
param_buf.push_bytes(param_value)
elif param_type == QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
elif param_type == QuicVersionInformation:
push_quic_version_information(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
Expand Down
74 changes: 73 additions & 1 deletion tests/test_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
QuicPreferredAddress,
QuicProtocolVersion,
QuicTransportParameters,
QuicVersionInformation,
decode_packet_number,
encode_quic_retry,
encode_quic_version_negotiation,
Expand Down Expand Up @@ -316,6 +317,25 @@ def test_params_disable_active_migration(self):
push_quic_transport_parameters(buf, params)
self.assertEqual(buf.data, data)

def test_params_max_ack_delay(self):
data = binascii.unhexlify("0b010a")

# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(buf)
self.assertEqual(params, QuicTransportParameters(max_ack_delay=10))

# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(buf, params)
self.assertEqual(buf.data, data)

def test_params_max_ack_delay_length_mismatch(self):
buf = Buffer(data=binascii.unhexlify("0b020a"))
with self.assertRaises(ValueError) as cm:
pull_quic_transport_parameters(buf)
self.assertEqual(str(cm.exception), "Transport parameter length does not match")

def test_params_preferred_address(self):
data = binascii.unhexlify(
"0d3b8ba27b8611532400890200000000f03c91fffe69a45411531262c4518d6"
Expand All @@ -338,7 +358,7 @@ def test_params_preferred_address(self):
)

# serialize
buf = Buffer(capacity=1000)
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(buf, params)
self.assertEqual(buf.data, data)

Expand All @@ -350,6 +370,58 @@ def test_params_unknown(self):
params = pull_quic_transport_parameters(buf)
self.assertEqual(params, QuicTransportParameters())

def test_params_version_information(self):
data = binascii.unhexlify("110c00000001000000016b3343cf")

# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(buf)
self.assertEqual(
params,
QuicTransportParameters(
version_information=QuicVersionInformation(
chosen_version=QuicProtocolVersion.VERSION_1,
available_versions=[
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
],
),
),
)

# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(buf, params)
self.assertEqual(buf.data, data)

def test_params_version_information_available_version_0(self):
buf = Buffer(data=binascii.unhexlify("11080000000100000000"))
with self.assertRaises(ValueError) as cm:
pull_quic_transport_parameters(buf)
self.assertEqual(
str(cm.exception), "Version Information must not contain version 0"
)

def test_params_version_information_chosen_version_0(self):
buf = Buffer(data=binascii.unhexlify("110400000000"))
with self.assertRaises(ValueError) as cm:
pull_quic_transport_parameters(buf)
self.assertEqual(
str(cm.exception), "Version Information must not contain version 0"
)

def test_params_version_information_length_not_divisible_by_four(self):
buf = Buffer(data=binascii.unhexlify("11050000000100"))
with self.assertRaises(ValueError) as cm:
pull_quic_transport_parameters(buf)
self.assertEqual(str(cm.exception), "Transport parameter length does not match")

def test_params_version_information_truncated(self):
buf = Buffer(data=binascii.unhexlify("110800000000"))
with self.assertRaises(ValueError) as cm:
pull_quic_transport_parameters(buf)
self.assertEqual(str(cm.exception), "Read out of bounds")

def test_preferred_address_ipv4_only(self):
data = binascii.unhexlify(
"8ba27b8611530000000000000000000000000000000000001262c4518d63013"
Expand Down

0 comments on commit a59d9ad

Please sign in to comment.