From ca0fe88a434920c4baf71fd668111e9a88c5d5ed Mon Sep 17 00:00:00 2001 From: Aperence Date: Wed, 13 Dec 2023 16:46:55 +0100 Subject: [PATCH] Added support for cubic congestion control (RFC9438) --- src/aioquic/quic/configuration.py | 2 +- src/aioquic/quic/congestion/base.py | 10 +- src/aioquic/quic/congestion/cubic.py | 212 ++++++++++ src/aioquic/quic/congestion/reno.py | 2 +- src/aioquic/quic/recovery.py | 11 +- tests/test_recovery_cubic.py | 584 +++++++++++++++++++++++++++ 6 files changed, 809 insertions(+), 12 deletions(-) create mode 100644 src/aioquic/quic/congestion/cubic.py create mode 100644 tests/test_recovery_cubic.py diff --git a/src/aioquic/quic/configuration.py b/src/aioquic/quic/configuration.py index 0bf919612..80738ea94 100644 --- a/src/aioquic/quic/configuration.py +++ b/src/aioquic/quic/configuration.py @@ -30,7 +30,7 @@ class QuicConfiguration: """ The name of the congestion control algorithm to use. - Currently supported algorithms: `"reno"`. + Currently supported algorithms: `"reno", `"cubic"`. """ connection_id_length: int = 8 diff --git a/src/aioquic/quic/congestion/base.py b/src/aioquic/quic/congestion/base.py index df6ca0028..0ae25a1c3 100644 --- a/src/aioquic/quic/congestion/base.py +++ b/src/aioquic/quic/congestion/base.py @@ -1,5 +1,5 @@ import abc -from typing import Dict, Iterable, Optional, Protocol +from typing import Any, Dict, Iterable, Optional, Protocol from ..packet_builder import QuicSentPacket @@ -21,7 +21,7 @@ def __init__(self, *, max_datagram_size: int) -> None: self.congestion_window = K_INITIAL_WINDOW * max_datagram_size @abc.abstractmethod - def on_packet_acked(self, *, packet: QuicSentPacket) -> None: + def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ... # pragma: no cover @abc.abstractmethod @@ -40,6 +40,12 @@ def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> N def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ... # pragma: no cover + def get_log_data(self) -> Dict[str, Any]: + data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight} + if self.ssthresh is not None: + data["ssthresh"] = self.ssthresh + return data + class QuicCongestionControlFactory(Protocol): def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: diff --git a/src/aioquic/quic/congestion/cubic.py b/src/aioquic/quic/congestion/cubic.py new file mode 100644 index 000000000..c7a774033 --- /dev/null +++ b/src/aioquic/quic/congestion/cubic.py @@ -0,0 +1,212 @@ +from typing import Any, Dict, Iterable + +from ..packet_builder import QuicSentPacket +from .base import ( + K_INITIAL_WINDOW, + K_MINIMUM_WINDOW, + QuicCongestionControl, + QuicRttMonitor, + register_congestion_control, +) + +# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions) +K_CUBIC_C = 0.4 +K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7 +K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity + + +def better_cube_root(x: float) -> float: + if x < 0: + # avoid precision errors that make the cube root returns an imaginary number + return -((-x) ** (1.0 / 3.0)) + else: + return (x) ** (1.0 / 3.0) + + +class CubicCongestionControl(QuicCongestionControl): + """ + Cubic congestion control implementation for aioquic + """ + + def __init__(self, max_datagram_size: int) -> None: + super().__init__(max_datagram_size=max_datagram_size) + # increase by one segment + self.additive_increase_factor: int = max_datagram_size + self._max_datagram_size: int = max_datagram_size + self._congestion_recovery_start_time = 0.0 + + self._rtt_monitor = QuicRttMonitor() + + self.rtt = 0.02 # starting RTT is considered to be 20ms + + self.reset() + + self.last_ack = 0.0 + + def W_cubic(self, t) -> int: + W_max_segments = self._W_max / self._max_datagram_size + target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments) + return int(target_segments * self._max_datagram_size) + + def is_reno_friendly(self, t) -> bool: + return self.W_cubic(t) < self._W_est + + def is_concave(self) -> bool: + return self.congestion_window < self._W_max + + def reset(self) -> None: + self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size + self.ssthresh = None + + self._first_slow_start = True + self._starting_congestion_avoidance = False + self.K: float = 0.0 + self._W_est = 0 + self._cwnd_epoch = 0 + self._t_epoch = 0.0 + self._W_max = self.congestion_window + + def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: + self.bytes_in_flight -= packet.sent_bytes + self.last_ack = packet.sent_time + + if self.ssthresh is None or self.congestion_window < self.ssthresh: + # slow start + self.congestion_window += packet.sent_bytes + else: + # congestion avoidance + if self._first_slow_start and not self._starting_congestion_avoidance: + # exiting slow start without having a loss + self._first_slow_start = False + self._W_max = self.congestion_window + self._t_epoch = now + self._cwnd_epoch = self.congestion_window + self._W_est = self._cwnd_epoch + # calculate K + W_max_segments = self._W_max / self._max_datagram_size + cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size + self.K = better_cube_root( + (W_max_segments - cwnd_epoch_segments) / K_CUBIC_C + ) + + # initialize the variables used at start of congestion avoidance + if self._starting_congestion_avoidance: + self._starting_congestion_avoidance = False + self._first_slow_start = False + self._t_epoch = now + self._cwnd_epoch = self.congestion_window + self._W_est = self._cwnd_epoch + # calculate K + W_max_segments = self._W_max / self._max_datagram_size + cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size + self.K = better_cube_root( + (W_max_segments - cwnd_epoch_segments) / K_CUBIC_C + ) + + self._W_est = int( + self._W_est + + self.additive_increase_factor + * (packet.sent_bytes / self.congestion_window) + ) + + t = now - self._t_epoch + + target: int = 0 + W_cubic = self.W_cubic(t + self.rtt) + if W_cubic < self.congestion_window: + target = self.congestion_window + elif W_cubic > 1.5 * self.congestion_window: + target = int(self.congestion_window * 1.5) + else: + target = W_cubic + + if self.is_reno_friendly(t): + # reno friendly region of cubic + # (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region) + self.congestion_window = self._W_est + elif self.is_concave(): + # concave region of cubic + # (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region) + self.congestion_window = int( + self.congestion_window + + ( + (target - self.congestion_window) + * (self._max_datagram_size / self.congestion_window) + ) + ) + else: + # convex region of cubic + # (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region) + self.congestion_window = int( + self.congestion_window + + ( + (target - self.congestion_window) + * (self._max_datagram_size / self.congestion_window) + ) + ) + + def on_packet_sent(self, *, packet: QuicSentPacket) -> None: + self.bytes_in_flight += packet.sent_bytes + if self.last_ack == 0.0: + return + elapsed_idle = packet.sent_time - self.last_ack + if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME: + self.reset() + + def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + + def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None: + lost_largest_time = 0.0 + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + lost_largest_time = packet.sent_time + + # start a new congestion event if packet was sent after the + # start of the previous congestion recovery period. + if lost_largest_time > self._congestion_recovery_start_time: + self._congestion_recovery_start_time = now + + # Normal congestion handle, can't be used in same time as fast convergence + # self._W_max = self.congestion_window + + # fast convergence + if self._W_max is not None and self.congestion_window < self._W_max: + self._W_max = int( + self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2 + ) + else: + self._W_max = self.congestion_window + + # normal congestion MD + flight_size = self.bytes_in_flight + new_ssthresh = max( + int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR), + K_MINIMUM_WINDOW * self._max_datagram_size, + ) + self.ssthresh = new_ssthresh + self.congestion_window = max( + self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size + ) + + # restart a new congestion avoidance phase + self._starting_congestion_avoidance = True + + def on_rtt_measurement(self, *, now: float, rtt: float) -> None: + self.rtt = rtt + # check whether we should exit slow start + if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( + rtt=rtt, now=now + ): + self.ssthresh = self.congestion_window + + def get_log_data(self) -> Dict[str, Any]: + data = super().get_log_data() + + data["cubic-wmax"] = int(self._W_max) + + return data + + +register_congestion_control("cubic", CubicCongestionControl) diff --git a/src/aioquic/quic/congestion/reno.py b/src/aioquic/quic/congestion/reno.py index 3bd0ee1c0..0ccf0792a 100644 --- a/src/aioquic/quic/congestion/reno.py +++ b/src/aioquic/quic/congestion/reno.py @@ -23,7 +23,7 @@ def __init__(self, *, max_datagram_size: int) -> None: self._congestion_stash = 0 self._rtt_monitor = QuicRttMonitor() - def on_packet_acked(self, *, packet: QuicSentPacket) -> None: + def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: self.bytes_in_flight -= packet.sent_bytes # don't increase window in congestion recovery diff --git a/src/aioquic/quic/recovery.py b/src/aioquic/quic/recovery.py index 3c165a7ec..6ee6593cf 100644 --- a/src/aioquic/quic/recovery.py +++ b/src/aioquic/quic/recovery.py @@ -2,7 +2,7 @@ import math from typing import Any, Callable, Dict, Iterable, List, Optional -from .congestion import reno # noqa +from .congestion import cubic, reno # noqa from .congestion.base import K_GRANULARITY, create_congestion_control from .logger import QuicLoggerTrace from .packet_builder import QuicDeliveryState, QuicSentPacket @@ -199,7 +199,7 @@ def on_ack_received( is_ack_eliciting = True space.ack_eliciting_in_flight -= 1 if packet.in_flight: - self._cc.on_packet_acked(packet=packet) + self._cc.on_packet_acked(packet=packet, now=now) largest_newly_acked = packet_number largest_sent_time = packet.sent_time @@ -334,12 +334,7 @@ def _get_loss_space(self) -> Optional[QuicPacketSpace]: return loss_space def _log_metrics_updated(self, log_rtt=False) -> None: - data: Dict[str, Any] = { - "bytes_in_flight": self._cc.bytes_in_flight, - "cwnd": self._cc.congestion_window, - } - if self._cc.ssthresh is not None: - data["ssthresh"] = self._cc.ssthresh + data: Dict[str, Any] = self._cc.get_log_data() if log_rtt: data.update( diff --git a/tests/test_recovery_cubic.py b/tests/test_recovery_cubic.py new file mode 100644 index 000000000..c71c7c8a5 --- /dev/null +++ b/tests/test_recovery_cubic.py @@ -0,0 +1,584 @@ +import math +from unittest import TestCase + +from aioquic import tls +from aioquic.quic.congestion.base import K_INITIAL_WINDOW, K_MINIMUM_WINDOW +from aioquic.quic.congestion.cubic import ( + K_CUBIC_C, + K_CUBIC_LOSS_REDUCTION_FACTOR, + CubicCongestionControl, + better_cube_root, +) +from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from aioquic.quic.packet_builder import QuicSentPacket +from aioquic.quic.rangeset import RangeSet +from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace + + +def send_probe(): + pass + + +def W_cubic(t, K, W_max): + return K_CUBIC_C * (t - K) ** 3 + (W_max) + + +class QuicPacketRecoveryCubicTest(TestCase): + def setUp(self): + self.INITIAL_SPACE = QuicPacketSpace() + self.HANDSHAKE_SPACE = QuicPacketSpace() + self.ONE_RTT_SPACE = QuicPacketSpace() + + self.recovery = QuicPacketRecovery( + congestion_control_algorithm="cubic", + initial_rtt=0.1, + max_datagram_size=1280, + peer_completed_address_validation=True, + send_probe=send_probe, + ) + self.recovery.spaces = [ + self.INITIAL_SPACE, + self.HANDSHAKE_SPACE, + self.ONE_RTT_SPACE, + ] + + def test_better_cube_root(self): + self.assertAlmostEqual(better_cube_root(8), 2) + self.assertAlmostEqual(better_cube_root(-8), -2) + self.assertAlmostEqual(better_cube_root(0), 0) + self.assertAlmostEqual(better_cube_root(27), 3) + + def test_discard_space(self): + self.recovery.discard_space(self.INITIAL_SPACE) + + def test_on_ack_received_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + ack_rangeset=RangeSet([range(0, 1)]), + ack_delay=0.0, + now=10.0, + space=space, + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertTrue(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 10.0) + self.assertEqual(self.recovery._rtt_min, 10.0) + self.assertEqual(self.recovery._rtt_smoothed, 10.0) + + def test_on_ack_received_non_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=False, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=123.45, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + ack_rangeset=RangeSet([range(0, 1)]), + ack_delay=0.0, + now=10.0, + space=space, + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertFalse(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 0.0) + self.assertEqual(self.recovery._rtt_min, math.inf) + self.assertEqual(self.recovery._rtt_smoothed, 0.0) + + def test_on_packet_lost_crypto(self): + packet = QuicSentPacket( + epoch=tls.Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.INITIAL_SPACE + + self.recovery.on_packet_sent(packet=packet, space=space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + self.recovery._detect_loss(space=space, now=1.0) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + def test_packet_expired(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + + cubic = CubicCongestionControl(1440) + cubic.on_packet_sent(packet=packet) + + cubic.on_packets_expired(packets=[packet]) + + self.assertEqual(cubic.bytes_in_flight, 0) + + def test_log_data(self): + cubic = CubicCongestionControl(1440) + + self.assertEqual( + cubic.get_log_data(), + { + "cwnd": cubic.congestion_window, + "bytes_in_flight": cubic.bytes_in_flight, + "cubic-wmax": cubic._W_max, + }, + ) + + cubic._W_max = 5000 + cubic.ssthresh = 5000 + + self.assertEqual( + cubic.get_log_data(), + { + "cwnd": cubic.congestion_window, + "ssthresh": cubic.ssthresh, + "bytes_in_flight": cubic.bytes_in_flight, + "cubic-wmax": cubic._W_max, + }, + ) + + def test_congestion_avoidance(self): + """ + Check if the cubic implementation respects the mathematical + formula defined in the rfc 9438 + """ + + max_datagram_size = 1440 + + n = 400 # number of ms to check + + W_max = 5 # starting W_max + K = better_cube_root(W_max * (1 - K_CUBIC_LOSS_REDUCTION_FACTOR) / K_CUBIC_C) + cwnd = W_max * K_CUBIC_LOSS_REDUCTION_FACTOR + + correct = [] + + test_range = range(n) + + for i in test_range: + correct.append(W_cubic(i / 1000, K, W_max) * max_datagram_size) + + cubic = CubicCongestionControl(max_datagram_size) + cubic.rtt = 0 + cubic._W_max = W_max * max_datagram_size + cubic._starting_congestion_avoidance = True + cubic.congestion_window = cwnd * max_datagram_size + cubic.ssthresh = cubic.congestion_window + cubic._W_est = 0 + + results = [] + for i in test_range: + cwnd = cubic.congestion_window // max_datagram_size # number of segments + + # simulate the reception of cwnd packets (a full window of acks) + for _ in range(int(cwnd)): + packet = QuicSentPacket(None, True, True, True, 0, 0) + packet.sent_bytes = 0 # won't affect results + + cubic.on_packet_acked(packet=packet, now=(i / 1000)) + + results.append(cubic.congestion_window) + + for i in test_range: + # check if it is almost equal to the value of W_cubic + self.assertTrue( + correct[i] * 0.99 <= results[i] <= 1.01 * correct[i], + f"Error at {i}ms, Result={results[i]}, Expected={correct[i]}", + ) + + def test_reset_idle(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=10.0, + ) + + max_datagram_size = 1440 + + cubic = CubicCongestionControl(1440) + # set last received at time 1 + cubic.last_ack = 1 + + # receive a packet after 9s of idle time + cubic.on_packet_sent(packet=packet) + + cubic.on_packets_expired(packets=[packet]) + + self.assertEqual(cubic.congestion_window, K_INITIAL_WINDOW * max_datagram_size) + + self.assertIsNone(cubic.ssthresh) + + self.assertTrue(cubic._first_slow_start) + self.assertFalse(cubic._starting_congestion_avoidance) + self.assertEqual(cubic.K, 0.0) + self.assertEqual(cubic._W_est, 0) + self.assertEqual(cubic._cwnd_epoch, 0) + self.assertEqual(cubic._t_epoch, 0.0) + + self.assertEqual(cubic._W_max, K_INITIAL_WINDOW * max_datagram_size) + + def test_reno_friendly_region(self): + cubic = CubicCongestionControl(1440) + cubic._W_max = 5000 # set the target number of bytes to 5000 + cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion + cubic._starting_congestion_avoidance = False + cubic._first_slow_start = False + cubic.ssthresh = 2880 + cubic._t_epoch = 5 + + # set an arbitrarily high W_est, + # meaning that cubic would underperform compared to reno + cubic._W_est = 100000 + + # calculate K + W_max_segments = cubic._W_max / cubic._max_datagram_size + cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size + cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + + previous_cwnd = cubic.congestion_window + + cubic.on_packet_acked(now=10, packet=packet) + + # congestion window should be equal to W_est (Reno estimated window) + self.assertAlmostEqual( + cubic.congestion_window, + 100000 + + cubic.additive_increase_factor * (packet.sent_bytes / previous_cwnd), + ) + + def test_convex_region(self): + cubic = CubicCongestionControl(1440) + cubic._W_max = 5000 # set the target number of bytes to 5000 + cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion + cubic._starting_congestion_avoidance = False + cubic._first_slow_start = False + cubic.ssthresh = 2880 + cubic._t_epoch = 5 + + cubic._W_est = 0 + + # calculate K + W_max_segments = cubic._W_max / cubic._max_datagram_size + cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size + cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + + previous_cwnd = cubic.congestion_window + + cubic.on_packet_acked(now=10, packet=packet) + + # elapsed time + basic rtt + target = int(previous_cwnd * 1.5) + + expected = int( + previous_cwnd + + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) + ) + + # congestion window should be equal to W_est (Reno estimated window) + self.assertAlmostEqual(cubic.congestion_window, expected) + + def test_concave_region(self): + cubic = CubicCongestionControl(1440) + cubic._W_max = 25000 # set the target number of bytes to 25000 + cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion + cubic._starting_conges2ion_avoidance = False + cubic._first_slow_start = False + cubic.ssthresh = 2880 + cubic._t_epoch = 5 + + cubic._W_est = 0 + + # calculate K + W_max_segments = cubic._W_max / cubic._max_datagram_size + cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size + cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + + previous_cwnd = cubic.congestion_window + + cubic.on_packet_acked(now=6, packet=packet) + + # elapsed time + basic rtt + target = cubic.W_cubic(1 + 0.02) + + expected = int( + previous_cwnd + + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) + ) + + self.assertAlmostEqual(cubic.congestion_window, expected) + + def test_increasing_rtt(self): + cubic = CubicCongestionControl(1440) + + # get some low rtt + for i in range(10): + cubic.on_rtt_measurement(now=i + 1, rtt=1) + + # rtt increase (because of congestion for example) + for i in range(10): + cubic.on_rtt_measurement(now=100 + i, rtt=1000) + + self.assertEqual(cubic.ssthresh, cubic.congestion_window) + + def test_increasing_rtt_exiting_slow_start(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=200.0, + ) + + cubic = CubicCongestionControl(1440) + + # get some low rtt + for i in range(10): + cubic.on_rtt_measurement(now=i + 1, rtt=1) + + # rtt increase (because of congestion for example) + for i in range(10): + cubic.on_rtt_measurement(now=100 + i, rtt=1000) + + previous_cwnd = cubic.congestion_window + + self.assertFalse(cubic._starting_congestion_avoidance) + + cubic.on_packet_acked(packet=packet, now=220) + + self.assertFalse(cubic._first_slow_start) + self.assertEqual(cubic._W_max, previous_cwnd) + self.assertEqual(cubic._t_epoch, 220) + self.assertEqual(cubic._cwnd_epoch, previous_cwnd) + self.assertEqual( + cubic._W_est, + previous_cwnd + + cubic.additive_increase_factor * (packet.sent_bytes / previous_cwnd), + ) + + # calculate K + W_max_segments = previous_cwnd / cubic._max_datagram_size + cwnd_epoch_segments = previous_cwnd / cubic._max_datagram_size + K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + self.assertEqual(cubic.K, K) + + def test_packet_lost(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=200.0, + ) + + packet2 = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=240.0, + ) + + cubic = CubicCongestionControl(1440) + + previous_cwnd = cubic.congestion_window + + cubic.on_packets_lost(now=210, packets=[packet]) + + self.assertEqual(cubic._congestion_recovery_start_time, 210) + + self.assertEqual(cubic._W_max, previous_cwnd) + self.assertEqual(cubic.ssthresh, K_MINIMUM_WINDOW * cubic._max_datagram_size) + self.assertEqual( + cubic.congestion_window, K_MINIMUM_WINDOW * cubic._max_datagram_size + ) + self.assertTrue(cubic._starting_congestion_avoidance) + + previous_cwnd = cubic.congestion_window + W_max = cubic._W_max + + cubic.on_packet_acked(now=250, packet=packet) + + self.assertFalse(cubic._starting_congestion_avoidance) + self.assertFalse(cubic._first_slow_start) + self.assertEqual(cubic._t_epoch, 250) + self.assertEqual(cubic._cwnd_epoch, previous_cwnd) + self.assertEqual( + cubic._W_est, + previous_cwnd + + cubic.additive_increase_factor * (packet2.sent_bytes / previous_cwnd), + ) + # calculate K + W_max_segments = W_max / cubic._max_datagram_size + cwnd_epoch_segments = previous_cwnd / cubic._max_datagram_size + K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + self.assertEqual(cubic.K, K) + + def test_lost_with_W_max(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=200.0, + ) + + cubic = CubicCongestionControl(1440) + + cubic._W_max = 100000 + + previous_cwnd = cubic.congestion_window + + cubic.on_packets_lost(now=210, packets=[packet]) + + # test when W_max was much more than cwnd + # and a loss occur + self.assertEqual( + cubic._W_max, previous_cwnd * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2 + ) + + def test_cwnd_target(self): + cubic = CubicCongestionControl(1440) + cubic._W_max = 25000 # set the target number of bytes to 25000 + cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion + cubic._starting_conges2ion_avoidance = False + cubic._first_slow_start = False + cubic.ssthresh = 2880 + cubic._t_epoch = 5 + cubic.congestion_window = 100000 + + cubic._W_est = 0 + + # calculate K + W_max_segments = cubic._W_max / cubic._max_datagram_size + cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size + cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) + + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + + previous_cwnd = cubic.congestion_window + + cubic.on_packet_acked(now=6, packet=packet) + + # elapsed time + basic rtt + target = previous_cwnd + + expected = int( + previous_cwnd + + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) + ) + + self.assertAlmostEqual(cubic.congestion_window, expected)