diff --git a/examples/http3_client.py b/examples/http3_client.py index fe55d22db..8c29b982a 100644 --- a/examples/http3_client.py +++ b/examples/http3_client.py @@ -445,6 +445,12 @@ async def main( "CHACHA20_POLY1305_SHA256`" ), ) + parser.add_argument( + "--congestion-control-algorithm", + type=str, + default="reno", + help="use the specified congestion control algorithm", + ) parser.add_argument( "-d", "--data", type=str, help="send the specified data in a POST request" ) @@ -527,6 +533,7 @@ async def main( configuration = QuicConfiguration( is_client=True, alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN, + congestion_control_algorithm=args.congestion_control_algorithm, max_datagram_size=args.max_datagram_size, ) if args.ca_certs: diff --git a/examples/http3_server.py b/examples/http3_server.py index a8fd1a9e7..75d2b8664 100644 --- a/examples/http3_server.py +++ b/examples/http3_server.py @@ -514,6 +514,12 @@ async def main( required=True, help="load the TLS certificate from the specified file", ) + parser.add_argument( + "--congestion-control-algorithm", + type=str, + default="reno", + help="use the specified congestion control algorithm", + ) parser.add_argument( "--host", type=str, @@ -584,6 +590,7 @@ async def main( configuration = QuicConfiguration( alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], + congestion_control_algorithm=args.congestion_control_algorithm, is_client=False, max_datagram_frame_size=65536, max_datagram_size=args.max_datagram_size, diff --git a/src/aioquic/quic/configuration.py b/src/aioquic/quic/configuration.py index d1f184c66..0bf919612 100644 --- a/src/aioquic/quic/configuration.py +++ b/src/aioquic/quic/configuration.py @@ -26,6 +26,13 @@ class QuicConfiguration: A list of supported ALPN protocols. """ + congestion_control_algorithm: str = "reno" + """ + The name of the congestion control algorithm to use. + + Currently supported algorithms: `"reno"`. + """ + connection_id_length: int = 8 """ The length in bytes of local connection IDs. diff --git a/src/aioquic/quic/congestion/__init__.py b/src/aioquic/quic/congestion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/aioquic/quic/congestion/base.py b/src/aioquic/quic/congestion/base.py new file mode 100644 index 000000000..df6ca0028 --- /dev/null +++ b/src/aioquic/quic/congestion/base.py @@ -0,0 +1,126 @@ +import abc +from typing import Dict, Iterable, Optional, Protocol + +from ..packet_builder import QuicSentPacket + +K_GRANULARITY = 0.001 # seconds +K_INITIAL_WINDOW = 10 +K_MINIMUM_WINDOW = 2 + + +class QuicCongestionControl(abc.ABC): + """ + Base class for congestion control implementations. + """ + + bytes_in_flight: int = 0 + congestion_window: int = 0 + ssthresh: Optional[int] = None + + def __init__(self, *, max_datagram_size: int) -> None: + self.congestion_window = K_INITIAL_WINDOW * max_datagram_size + + @abc.abstractmethod + def on_packet_acked(self, *, packet: QuicSentPacket) -> None: + ... # pragma: no cover + + @abc.abstractmethod + def on_packet_sent(self, *, packet: QuicSentPacket) -> None: + ... # pragma: no cover + + @abc.abstractmethod + def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: + ... # pragma: no cover + + @abc.abstractmethod + def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None: + ... # pragma: no cover + + @abc.abstractmethod + def on_rtt_measurement(self, *, now: float, rtt: float) -> None: + ... # pragma: no cover + + +class QuicCongestionControlFactory(Protocol): + def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: + ... # pragma: no cover + + +class QuicRttMonitor: + """ + Roundtrip time monitor for HyStart. + """ + + def __init__(self) -> None: + self._increases = 0 + self._last_time = None + self._ready = False + self._size = 5 + + self._filtered_min: Optional[float] = None + + self._sample_idx = 0 + self._sample_max: Optional[float] = None + self._sample_min: Optional[float] = None + self._sample_time = 0.0 + self._samples = [0.0 for i in range(self._size)] + + def add_rtt(self, *, rtt: float) -> None: + self._samples[self._sample_idx] = rtt + self._sample_idx += 1 + + if self._sample_idx >= self._size: + self._sample_idx = 0 + self._ready = True + + if self._ready: + self._sample_max = self._samples[0] + self._sample_min = self._samples[0] + for sample in self._samples[1:]: + if sample < self._sample_min: + self._sample_min = sample + elif sample > self._sample_max: + self._sample_max = sample + + def is_rtt_increasing(self, *, now: float, rtt: float) -> bool: + if now > self._sample_time + K_GRANULARITY: + self.add_rtt(rtt=rtt) + self._sample_time = now + + if self._ready: + if self._filtered_min is None or self._filtered_min > self._sample_max: + self._filtered_min = self._sample_max + + delta = self._sample_min - self._filtered_min + if delta * 4 >= self._filtered_min: + self._increases += 1 + if self._increases >= self._size: + return True + elif delta > 0: + self._increases = 0 + return False + + +_factories: Dict[str, QuicCongestionControlFactory] = {} + + +def create_congestion_control( + name: str, *, max_datagram_size: int +) -> QuicCongestionControl: + """ + Create an instance of the `name` congestion control algorithm. + """ + try: + factory = _factories[name] + except KeyError: + raise Exception(f"Unknown congestion control algorithm: {name}") + return factory(max_datagram_size=max_datagram_size) + + +def register_congestion_control( + name: str, factory: QuicCongestionControlFactory +) -> None: + """ + Register a congestion control algorithm named `name`. + """ + _factories[name] = factory diff --git a/src/aioquic/quic/congestion/reno.py b/src/aioquic/quic/congestion/reno.py new file mode 100644 index 000000000..3bd0ee1c0 --- /dev/null +++ b/src/aioquic/quic/congestion/reno.py @@ -0,0 +1,77 @@ +from typing import Iterable + +from ..packet_builder import QuicSentPacket +from .base import ( + K_MINIMUM_WINDOW, + QuicCongestionControl, + QuicRttMonitor, + register_congestion_control, +) + +K_LOSS_REDUCTION_FACTOR = 0.5 + + +class RenoCongestionControl(QuicCongestionControl): + """ + New Reno congestion control. + """ + + def __init__(self, *, max_datagram_size: int) -> None: + super().__init__(max_datagram_size=max_datagram_size) + self._max_datagram_size = max_datagram_size + self._congestion_recovery_start_time = 0.0 + self._congestion_stash = 0 + self._rtt_monitor = QuicRttMonitor() + + def on_packet_acked(self, *, packet: QuicSentPacket) -> None: + self.bytes_in_flight -= packet.sent_bytes + + # don't increase window in congestion recovery + if packet.sent_time <= self._congestion_recovery_start_time: + return + + if self.ssthresh is None or self.congestion_window < self.ssthresh: + # slow start + self.congestion_window += packet.sent_bytes + else: + # congestion avoidance + self._congestion_stash += packet.sent_bytes + count = self._congestion_stash // self.congestion_window + if count: + self._congestion_stash -= count * self.congestion_window + self.congestion_window += count * self._max_datagram_size + + def on_packet_sent(self, *, packet: QuicSentPacket) -> None: + self.bytes_in_flight += packet.sent_bytes + + def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + + def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None: + lost_largest_time = 0.0 + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + lost_largest_time = packet.sent_time + + # start a new congestion event if packet was sent after the + # start of the previous congestion recovery period. + if lost_largest_time > self._congestion_recovery_start_time: + self._congestion_recovery_start_time = now + self.congestion_window = max( + int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), + K_MINIMUM_WINDOW * self._max_datagram_size, + ) + self.ssthresh = self.congestion_window + + # TODO : collapse congestion window if persistent congestion + + def on_rtt_measurement(self, *, now: float, rtt: float) -> None: + # check whether we should exit slow start + if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( + now=now, rtt=rtt + ): + self.ssthresh = self.congestion_window + + +register_congestion_control("reno", RenoCongestionControl) diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index a879fcd1d..9cf14c236 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -28,6 +28,7 @@ ) from . import events from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration +from .congestion.base import K_GRANULARITY from .crypto import CryptoError, CryptoPair, KeyUnavailableError from .logger import QuicLoggerTrace from .packet import ( @@ -61,7 +62,7 @@ QuicPacketBuilder, QuicPacketBuilderStop, ) -from .recovery import K_GRANULARITY, QuicPacketRecovery, QuicPacketSpace +from .recovery import QuicPacketRecovery, QuicPacketSpace from .stream import FinalSizeError, QuicStream, StreamFinishedError logger = logging.getLogger("quic") @@ -383,6 +384,7 @@ def __init__( # loss recovery self._loss = QuicPacketRecovery( + congestion_control_algorithm=configuration.congestion_control_algorithm, initial_rtt=configuration.initial_rtt, max_datagram_size=self._max_datagram_size, peer_completed_address_validation=not self._is_client, @@ -1508,10 +1510,10 @@ def _handle_ack_frame( self._loss.peer_completed_address_validation = True self._loss.on_ack_received( - space=self._spaces[context.epoch], ack_rangeset=ack_rangeset, ack_delay=ack_delay, now=context.time, + space=self._spaces[context.epoch], ) def _handle_connection_close_frame( diff --git a/src/aioquic/quic/recovery.py b/src/aioquic/quic/recovery.py index 2a95c8590..3c165a7ec 100644 --- a/src/aioquic/quic/recovery.py +++ b/src/aioquic/quic/recovery.py @@ -2,22 +2,18 @@ import math from typing import Any, Callable, Dict, Iterable, List, Optional +from .congestion import reno # noqa +from .congestion.base import K_GRANULARITY, create_congestion_control from .logger import QuicLoggerTrace from .packet_builder import QuicDeliveryState, QuicSentPacket from .rangeset import RangeSet # loss detection K_PACKET_THRESHOLD = 3 -K_GRANULARITY = 0.001 # seconds K_TIME_THRESHOLD = 9 / 8 K_MICRO_SECOND = 0.000001 K_SECOND = 1.0 -# congestion control -K_INITIAL_WINDOW = 10 -K_MINIMUM_WINDOW = 2 -K_LOSS_REDUCTION_FACTOR = 0.5 - class QuicPacketSpace: def __init__(self) -> None: @@ -82,71 +78,6 @@ def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: self.bucket_time = self.bucket_max -class QuicCongestionControl: - """ - New Reno congestion control. - """ - - def __init__(self, *, max_datagram_size: int) -> None: - self._max_datagram_size = max_datagram_size - self.bytes_in_flight = 0 - self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size - self._congestion_recovery_start_time = 0.0 - self._congestion_stash = 0 - self._rtt_monitor = QuicRttMonitor() - self.ssthresh: Optional[int] = None - - def on_packet_acked(self, packet: QuicSentPacket) -> None: - self.bytes_in_flight -= packet.sent_bytes - - # don't increase window in congestion recovery - if packet.sent_time <= self._congestion_recovery_start_time: - return - - if self.ssthresh is None or self.congestion_window < self.ssthresh: - # slow start - self.congestion_window += packet.sent_bytes - else: - # congestion avoidance - self._congestion_stash += packet.sent_bytes - count = self._congestion_stash // self.congestion_window - if count: - self._congestion_stash -= count * self.congestion_window - self.congestion_window += count * self._max_datagram_size - - def on_packet_sent(self, packet: QuicSentPacket) -> None: - self.bytes_in_flight += packet.sent_bytes - - def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: - for packet in packets: - self.bytes_in_flight -= packet.sent_bytes - - def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: - lost_largest_time = 0.0 - for packet in packets: - self.bytes_in_flight -= packet.sent_bytes - lost_largest_time = packet.sent_time - - # start a new congestion event if packet was sent after the - # start of the previous congestion recovery period. - if lost_largest_time > self._congestion_recovery_start_time: - self._congestion_recovery_start_time = now - self.congestion_window = max( - int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), - K_MINIMUM_WINDOW * self._max_datagram_size, - ) - self.ssthresh = self.congestion_window - - # TODO : collapse congestion window if persistent congestion - - def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: - # check whether we should exit slow start - if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( - latest_rtt, now - ): - self.ssthresh = self.congestion_window - - class QuicPacketRecovery: """ Packet loss and congestion controller. @@ -154,6 +85,8 @@ class QuicPacketRecovery: def __init__( self, + *, + congestion_control_algorithm: str, initial_rtt: float, max_datagram_size: int, peer_completed_address_validation: bool, @@ -181,7 +114,9 @@ def __init__( self._time_of_last_sent_ack_eliciting_packet = 0.0 # congestion control - self._cc = QuicCongestionControl(max_datagram_size=max_datagram_size) + self._cc = create_congestion_control( + congestion_control_algorithm, max_datagram_size=max_datagram_size + ) self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size) @property @@ -196,7 +131,7 @@ def discard_space(self, space: QuicPacketSpace) -> None: assert space in self.spaces self._cc.on_packets_expired( - filter(lambda x: x.in_flight, space.sent_packets.values()) + packets=filter(lambda x: x.in_flight, space.sent_packets.values()) ) space.sent_packets.clear() @@ -237,10 +172,11 @@ def get_probe_timeout(self) -> float: def on_ack_received( self, - space: QuicPacketSpace, + *, ack_rangeset: RangeSet, ack_delay: float, now: float, + space: QuicPacketSpace, ) -> None: """ Update metrics as the result of an ACK being received. @@ -263,7 +199,7 @@ def on_ack_received( is_ack_eliciting = True space.ack_eliciting_in_flight -= 1 if packet.in_flight: - self._cc.on_packet_acked(packet) + self._cc.on_packet_acked(packet=packet) largest_newly_acked = packet_number largest_sent_time = packet.sent_time @@ -302,7 +238,7 @@ def on_ack_received( ) # inform congestion controller - self._cc.on_rtt_measurement(latest_rtt, now=now) + self._cc.on_rtt_measurement(now=now, rtt=latest_rtt) self._pacer.update_rate( congestion_window=self._cc.congestion_window, smoothed_rtt=self._rtt_smoothed, @@ -311,7 +247,7 @@ def on_ack_received( else: log_rtt = False - self._detect_loss(space, now=now) + self._detect_loss(now=now, space=space) # reset PTO count self._pto_count = 0 @@ -319,15 +255,15 @@ def on_ack_received( if self._quic_logger is not None: self._log_metrics_updated(log_rtt=log_rtt) - def on_loss_detection_timeout(self, now: float) -> None: + def on_loss_detection_timeout(self, *, now: float) -> None: loss_space = self._get_loss_space() if loss_space is not None: - self._detect_loss(loss_space, now=now) + self._detect_loss(now=now, space=loss_space) else: self._pto_count += 1 self.reschedule_data(now=now) - def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None: + def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None: space.sent_packets[packet.packet_number] = packet if packet.is_ack_eliciting: @@ -337,12 +273,12 @@ def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None self._time_of_last_sent_ack_eliciting_packet = packet.sent_time # add packet to bytes in flight - self._cc.on_packet_sent(packet) + self._cc.on_packet_sent(packet=packet) if self._quic_logger is not None: self._log_metrics_updated() - def reschedule_data(self, now: float) -> None: + def reschedule_data(self, *, now: float) -> None: """ Schedule some data for retransmission. """ @@ -353,7 +289,7 @@ def reschedule_data(self, now: float) -> None: filter(lambda i: i.is_crypto_packet, space.sent_packets.values()) ) if packets: - self._on_packets_lost(packets, space=space, now=now) + self._on_packets_lost(now=now, packets=packets, space=space) crypto_scheduled = True if crypto_scheduled and self._logger is not None: self._logger.debug("Scheduled CRYPTO data for retransmission") @@ -361,7 +297,7 @@ def reschedule_data(self, now: float) -> None: # ensure an ACK-elliciting packet is sent self._send_probe() - def _detect_loss(self, space: QuicPacketSpace, now: float) -> None: + def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None: """ Check whether any packets should be declared lost. """ @@ -386,7 +322,7 @@ def _detect_loss(self, space: QuicPacketSpace, now: float) -> None: if space.loss_time is None or space.loss_time > packet_loss_time: space.loss_time = packet_loss_time - self._on_packets_lost(lost_packets, space=space, now=now) + self._on_packets_lost(now=now, packets=lost_packets, space=space) def _get_loss_space(self) -> Optional[QuicPacketSpace]: loss_space = None @@ -420,7 +356,7 @@ def _log_metrics_updated(self, log_rtt=False) -> None: ) def _on_packets_lost( - self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float + self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace ) -> None: lost_packets_cc = [] for packet in packets: @@ -449,65 +385,10 @@ def _on_packets_lost( # inform congestion controller if lost_packets_cc: - self._cc.on_packets_lost(lost_packets_cc, now=now) + self._cc.on_packets_lost(now=now, packets=lost_packets_cc) self._pacer.update_rate( congestion_window=self._cc.congestion_window, smoothed_rtt=self._rtt_smoothed, ) if self._quic_logger is not None: self._log_metrics_updated() - - -class QuicRttMonitor: - """ - Roundtrip time monitor for HyStart. - """ - - def __init__(self) -> None: - self._increases = 0 - self._last_time = None - self._ready = False - self._size = 5 - - self._filtered_min: Optional[float] = None - - self._sample_idx = 0 - self._sample_max: Optional[float] = None - self._sample_min: Optional[float] = None - self._sample_time = 0.0 - self._samples = [0.0 for i in range(self._size)] - - def add_rtt(self, rtt: float) -> None: - self._samples[self._sample_idx] = rtt - self._sample_idx += 1 - - if self._sample_idx >= self._size: - self._sample_idx = 0 - self._ready = True - - if self._ready: - self._sample_max = self._samples[0] - self._sample_min = self._samples[0] - for sample in self._samples[1:]: - if sample < self._sample_min: - self._sample_min = sample - elif sample > self._sample_max: - self._sample_max = sample - - def is_rtt_increasing(self, rtt: float, now: float) -> bool: - if now > self._sample_time + K_GRANULARITY: - self.add_rtt(rtt) - self._sample_time = now - - if self._ready: - if self._filtered_min is None or self._filtered_min > self._sample_max: - self._filtered_min = self._sample_max - - delta = self._sample_min - self._filtered_min - if delta * 4 >= self._filtered_min: - self._increases += 1 - if self._increases >= self._size: - return True - elif delta > 0: - self._increases = 0 - return False diff --git a/tests/test_recovery.py b/tests/test_recovery.py index 086016e59..389ec3615 100644 --- a/tests/test_recovery.py +++ b/tests/test_recovery.py @@ -1,20 +1,16 @@ -import math from unittest import TestCase -from aioquic import tls -from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT -from aioquic.quic.packet_builder import QuicSentPacket -from aioquic.quic.rangeset import RangeSet -from aioquic.quic.recovery import ( - QuicPacketPacer, - QuicPacketRecovery, - QuicPacketSpace, - QuicRttMonitor, -) +from aioquic.quic.congestion.base import QuicRttMonitor, create_congestion_control +from aioquic.quic.recovery import QuicPacketPacer -def send_probe(): - pass +class QuicCongestionControlTest(TestCase): + def test_create_unknown_congestion_control(self): + with self.assertRaises(Exception) as cm: + create_congestion_control("bogus", max_datagram_size=1280) + self.assertEqual( + str(cm.exception), "Unknown congestion control algorithm: bogus" + ) class QuicPacketPacerTest(TestCase): @@ -61,117 +57,6 @@ def test_with_measurement(self): self.assertAlmostEqual(self.pacer.next_send_time(now=1.00015), 1.0002) -class QuicPacketRecoveryTest(TestCase): - def setUp(self): - self.INITIAL_SPACE = QuicPacketSpace() - self.HANDSHAKE_SPACE = QuicPacketSpace() - self.ONE_RTT_SPACE = QuicPacketSpace() - - self.recovery = QuicPacketRecovery( - initial_rtt=0.1, - max_datagram_size=1280, - peer_completed_address_validation=True, - send_probe=send_probe, - ) - self.recovery.spaces = [ - self.INITIAL_SPACE, - self.HANDSHAKE_SPACE, - self.ONE_RTT_SPACE, - ] - - def test_discard_space(self): - self.recovery.discard_space(self.INITIAL_SPACE) - - def test_on_ack_received_ack_eliciting(self): - packet = QuicSentPacket( - epoch=tls.Epoch.ONE_RTT, - in_flight=True, - is_ack_eliciting=True, - is_crypto_packet=False, - packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, - sent_bytes=1280, - sent_time=0.0, - ) - space = self.ONE_RTT_SPACE - - #  packet sent - self.recovery.on_packet_sent(packet, space) - self.assertEqual(self.recovery.bytes_in_flight, 1280) - self.assertEqual(space.ack_eliciting_in_flight, 1) - self.assertEqual(len(space.sent_packets), 1) - - # packet ack'd - self.recovery.on_ack_received( - space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 - ) - self.assertEqual(self.recovery.bytes_in_flight, 0) - self.assertEqual(space.ack_eliciting_in_flight, 0) - self.assertEqual(len(space.sent_packets), 0) - - # check RTT - self.assertTrue(self.recovery._rtt_initialized) - self.assertEqual(self.recovery._rtt_latest, 10.0) - self.assertEqual(self.recovery._rtt_min, 10.0) - self.assertEqual(self.recovery._rtt_smoothed, 10.0) - - def test_on_ack_received_non_ack_eliciting(self): - packet = QuicSentPacket( - epoch=tls.Epoch.ONE_RTT, - in_flight=True, - is_ack_eliciting=False, - is_crypto_packet=False, - packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, - sent_bytes=1280, - sent_time=123.45, - ) - space = self.ONE_RTT_SPACE - - #  packet sent - self.recovery.on_packet_sent(packet, space) - self.assertEqual(self.recovery.bytes_in_flight, 1280) - self.assertEqual(space.ack_eliciting_in_flight, 0) - self.assertEqual(len(space.sent_packets), 1) - - # packet ack'd - self.recovery.on_ack_received( - space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 - ) - self.assertEqual(self.recovery.bytes_in_flight, 0) - self.assertEqual(space.ack_eliciting_in_flight, 0) - self.assertEqual(len(space.sent_packets), 0) - - # check RTT - self.assertFalse(self.recovery._rtt_initialized) - self.assertEqual(self.recovery._rtt_latest, 0.0) - self.assertEqual(self.recovery._rtt_min, math.inf) - self.assertEqual(self.recovery._rtt_smoothed, 0.0) - - def test_on_packet_lost_crypto(self): - packet = QuicSentPacket( - epoch=tls.Epoch.INITIAL, - in_flight=True, - is_ack_eliciting=True, - is_crypto_packet=True, - packet_number=0, - packet_type=PACKET_TYPE_INITIAL, - sent_bytes=1280, - sent_time=0.0, - ) - space = self.INITIAL_SPACE - - self.recovery.on_packet_sent(packet, space) - self.assertEqual(self.recovery.bytes_in_flight, 1280) - self.assertEqual(space.ack_eliciting_in_flight, 1) - self.assertEqual(len(space.sent_packets), 1) - - self.recovery._detect_loss(space, now=1.0) - self.assertEqual(self.recovery.bytes_in_flight, 0) - self.assertEqual(space.ack_eliciting_in_flight, 0) - self.assertEqual(len(space.sent_packets), 0) - - class QuicRttMonitorTest(TestCase): def test_monitor(self): monitor = QuicRttMonitor() diff --git a/tests/test_recovery_reno.py b/tests/test_recovery_reno.py new file mode 100644 index 000000000..87822a74a --- /dev/null +++ b/tests/test_recovery_reno.py @@ -0,0 +1,130 @@ +import math +from unittest import TestCase + +from aioquic import tls +from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from aioquic.quic.packet_builder import QuicSentPacket +from aioquic.quic.rangeset import RangeSet +from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace + + +def send_probe(): + pass + + +class QuicPacketRecoveryRenoTest(TestCase): + def setUp(self): + self.INITIAL_SPACE = QuicPacketSpace() + self.HANDSHAKE_SPACE = QuicPacketSpace() + self.ONE_RTT_SPACE = QuicPacketSpace() + + self.recovery = QuicPacketRecovery( + congestion_control_algorithm="reno", + initial_rtt=0.1, + max_datagram_size=1280, + peer_completed_address_validation=True, + send_probe=send_probe, + ) + self.recovery.spaces = [ + self.INITIAL_SPACE, + self.HANDSHAKE_SPACE, + self.ONE_RTT_SPACE, + ] + + def test_discard_space(self): + self.recovery.discard_space(self.INITIAL_SPACE) + + def test_on_ack_received_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + ack_rangeset=RangeSet([range(0, 1)]), + ack_delay=0.0, + now=10.0, + space=space, + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertTrue(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 10.0) + self.assertEqual(self.recovery._rtt_min, 10.0) + self.assertEqual(self.recovery._rtt_smoothed, 10.0) + + def test_on_ack_received_non_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=False, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=123.45, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + ack_rangeset=RangeSet([range(0, 1)]), + ack_delay=0.0, + now=10.0, + space=space, + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertFalse(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 0.0) + self.assertEqual(self.recovery._rtt_min, math.inf) + self.assertEqual(self.recovery._rtt_smoothed, 0.0) + + def test_on_packet_lost_crypto(self): + packet = QuicSentPacket( + epoch=tls.Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.INITIAL_SPACE + + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + self.recovery._detect_loss(space=space, now=1.0) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0)