Skip to content

Commit

Permalink
Added support for cubic congestion control (RFC9438)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aperence committed Dec 14, 2023
1 parent bbe4c4d commit ca0fe88
Show file tree
Hide file tree
Showing 6 changed files with 809 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class QuicConfiguration:
"""
The name of the congestion control algorithm to use.
Currently supported algorithms: `"reno"`.
Currently supported algorithms: `"reno", `"cubic"`.
"""

connection_id_length: int = 8
Expand Down
10 changes: 8 additions & 2 deletions src/aioquic/quic/congestion/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Dict, Iterable, Optional, Protocol
from typing import Any, Dict, Iterable, Optional, Protocol

from ..packet_builder import QuicSentPacket

Expand All @@ -21,7 +21,7 @@ def __init__(self, *, max_datagram_size: int) -> None:
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size

@abc.abstractmethod
def on_packet_acked(self, *, packet: QuicSentPacket) -> None:
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
... # pragma: no cover

@abc.abstractmethod
Expand All @@ -40,6 +40,12 @@ def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> N
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
... # pragma: no cover

def get_log_data(self) -> Dict[str, Any]:
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
if self.ssthresh is not None:
data["ssthresh"] = self.ssthresh
return data


class QuicCongestionControlFactory(Protocol):
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl:
Expand Down
212 changes: 212 additions & 0 deletions src/aioquic/quic/congestion/cubic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Any, Dict, Iterable

from ..packet_builder import QuicSentPacket
from .base import (
K_INITIAL_WINDOW,
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)

# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
K_CUBIC_C = 0.4
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity


def better_cube_root(x: float) -> float:
if x < 0:
# avoid precision errors that make the cube root returns an imaginary number
return -((-x) ** (1.0 / 3.0))
else:
return (x) ** (1.0 / 3.0)


class CubicCongestionControl(QuicCongestionControl):
"""
Cubic congestion control implementation for aioquic
"""

def __init__(self, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
# increase by one segment
self.additive_increase_factor: int = max_datagram_size
self._max_datagram_size: int = max_datagram_size
self._congestion_recovery_start_time = 0.0

self._rtt_monitor = QuicRttMonitor()

self.rtt = 0.02 # starting RTT is considered to be 20ms

self.reset()

self.last_ack = 0.0

def W_cubic(self, t) -> int:
W_max_segments = self._W_max / self._max_datagram_size
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
return int(target_segments * self._max_datagram_size)

def is_reno_friendly(self, t) -> bool:
return self.W_cubic(t) < self._W_est

def is_concave(self) -> bool:
return self.congestion_window < self._W_max

def reset(self) -> None:
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
self.ssthresh = None

self._first_slow_start = True
self._starting_congestion_avoidance = False
self.K: float = 0.0
self._W_est = 0
self._cwnd_epoch = 0
self._t_epoch = 0.0
self._W_max = self.congestion_window

def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
self.last_ack = packet.sent_time

if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
if self._first_slow_start and not self._starting_congestion_avoidance:
# exiting slow start without having a loss
self._first_slow_start = False
self._W_max = self.congestion_window
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)

# initialize the variables used at start of congestion avoidance
if self._starting_congestion_avoidance:
self._starting_congestion_avoidance = False
self._first_slow_start = False
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
# calculate K
W_max_segments = self._W_max / self._max_datagram_size
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
self.K = better_cube_root(
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
)

self._W_est = int(
self._W_est
+ self.additive_increase_factor
* (packet.sent_bytes / self.congestion_window)
)

t = now - self._t_epoch

target: int = 0
W_cubic = self.W_cubic(t + self.rtt)
if W_cubic < self.congestion_window:
target = self.congestion_window
elif W_cubic > 1.5 * self.congestion_window:
target = int(self.congestion_window * 1.5)
else:
target = W_cubic

if self.is_reno_friendly(t):
# reno friendly region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
self.congestion_window = self._W_est
elif self.is_concave():
# concave region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)
else:
# convex region of cubic
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
self.congestion_window = int(
self.congestion_window
+ (
(target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
)

def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
if self.last_ack == 0.0:
return
elapsed_idle = packet.sent_time - self.last_ack
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
self.reset()

def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes

def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time

# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now

# Normal congestion handle, can't be used in same time as fast convergence
# self._W_max = self.congestion_window

# fast convergence
if self._W_max is not None and self.congestion_window < self._W_max:
self._W_max = int(
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
)
else:
self._W_max = self.congestion_window

# normal congestion MD
flight_size = self.bytes_in_flight
new_ssthresh = max(
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = new_ssthresh
self.congestion_window = max(
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
)

# restart a new congestion avoidance phase
self._starting_congestion_avoidance = True

def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
self.rtt = rtt
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
rtt=rtt, now=now
):
self.ssthresh = self.congestion_window

def get_log_data(self) -> Dict[str, Any]:
data = super().get_log_data()

data["cubic-wmax"] = int(self._W_max)

return data


register_congestion_control("cubic", CubicCongestionControl)
2 changes: 1 addition & 1 deletion src/aioquic/quic/congestion/reno.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, *, max_datagram_size: int) -> None:
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()

def on_packet_acked(self, *, packet: QuicSentPacket) -> None:
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes

# don't increase window in congestion recovery
Expand Down
11 changes: 3 additions & 8 deletions src/aioquic/quic/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from typing import Any, Callable, Dict, Iterable, List, Optional

from .congestion import reno # noqa
from .congestion import cubic, reno # noqa
from .congestion.base import K_GRANULARITY, create_congestion_control
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
Expand Down Expand Up @@ -199,7 +199,7 @@ def on_ack_received(
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet=packet)
self._cc.on_packet_acked(packet=packet, now=now)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time

Expand Down Expand Up @@ -334,12 +334,7 @@ def _get_loss_space(self) -> Optional[QuicPacketSpace]:
return loss_space

def _log_metrics_updated(self, log_rtt=False) -> None:
data: Dict[str, Any] = {
"bytes_in_flight": self._cc.bytes_in_flight,
"cwnd": self._cc.congestion_window,
}
if self._cc.ssthresh is not None:
data["ssthresh"] = self._cc.ssthresh
data: Dict[str, Any] = self._cc.get_log_data()

if log_rtt:
data.update(
Expand Down
Loading

0 comments on commit ca0fe88

Please sign in to comment.