Skip to content

Commit

Permalink
Move congestion control code to allow other algorithms
Browse files Browse the repository at this point in the history
Use keyword arguments wherever possible to make it explicit what is
being passed around.

Co-authored-by: Aperence <[email protected]>
  • Loading branch information
jlaine and Aperence committed Dec 13, 2023
1 parent 041de71 commit bbe4c4d
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 268 deletions.
7 changes: 7 additions & 0 deletions examples/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ async def main(
"CHACHA20_POLY1305_SHA256`"
),
)
parser.add_argument(
"--congestion-control-algorithm",
type=str,
default="reno",
help="use the specified congestion control algorithm",
)
parser.add_argument(
"-d", "--data", type=str, help="send the specified data in a POST request"
)
Expand Down Expand Up @@ -527,6 +533,7 @@ async def main(
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN,
congestion_control_algorithm=args.congestion_control_algorithm,
max_datagram_size=args.max_datagram_size,
)
if args.ca_certs:
Expand Down
7 changes: 7 additions & 0 deletions examples/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ async def main(
required=True,
help="load the TLS certificate from the specified file",
)
parser.add_argument(
"--congestion-control-algorithm",
type=str,
default="reno",
help="use the specified congestion control algorithm",
)
parser.add_argument(
"--host",
type=str,
Expand Down Expand Up @@ -584,6 +590,7 @@ async def main(

configuration = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
congestion_control_algorithm=args.congestion_control_algorithm,
is_client=False,
max_datagram_frame_size=65536,
max_datagram_size=args.max_datagram_size,
Expand Down
7 changes: 7 additions & 0 deletions src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class QuicConfiguration:
A list of supported ALPN protocols.
"""

congestion_control_algorithm: str = "reno"
"""
The name of the congestion control algorithm to use.
Currently supported algorithms: `"reno"`.
"""

connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
Expand Down
Empty file.
126 changes: 126 additions & 0 deletions src/aioquic/quic/congestion/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import abc
from typing import Dict, Iterable, Optional, Protocol

from ..packet_builder import QuicSentPacket

K_GRANULARITY = 0.001 # seconds
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2


class QuicCongestionControl(abc.ABC):
"""
Base class for congestion control implementations.
"""

bytes_in_flight: int = 0
congestion_window: int = 0
ssthresh: Optional[int] = None

def __init__(self, *, max_datagram_size: int) -> None:
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size

@abc.abstractmethod
def on_packet_acked(self, *, packet: QuicSentPacket) -> None:
... # pragma: no cover

@abc.abstractmethod
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
... # pragma: no cover

@abc.abstractmethod
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
... # pragma: no cover

@abc.abstractmethod
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
... # pragma: no cover

@abc.abstractmethod
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
... # pragma: no cover


class QuicCongestionControlFactory(Protocol):
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl:
... # pragma: no cover


class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""

def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5

self._filtered_min: Optional[float] = None

self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]

def add_rtt(self, *, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1

if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True

if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample

def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt=rtt)
self._sample_time = now

if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max

delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False


_factories: Dict[str, QuicCongestionControlFactory] = {}


def create_congestion_control(
name: str, *, max_datagram_size: int
) -> QuicCongestionControl:
"""
Create an instance of the `name` congestion control algorithm.
"""
try:
factory = _factories[name]
except KeyError:
raise Exception(f"Unknown congestion control algorithm: {name}")
return factory(max_datagram_size=max_datagram_size)


def register_congestion_control(
name: str, factory: QuicCongestionControlFactory
) -> None:
"""
Register a congestion control algorithm named `name`.
"""
_factories[name] = factory
77 changes: 77 additions & 0 deletions src/aioquic/quic/congestion/reno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Iterable

from ..packet_builder import QuicSentPacket
from .base import (
K_MINIMUM_WINDOW,
QuicCongestionControl,
QuicRttMonitor,
register_congestion_control,
)

K_LOSS_REDUCTION_FACTOR = 0.5


class RenoCongestionControl(QuicCongestionControl):
"""
New Reno congestion control.
"""

def __init__(self, *, max_datagram_size: int) -> None:
super().__init__(max_datagram_size=max_datagram_size)
self._max_datagram_size = max_datagram_size
self._congestion_recovery_start_time = 0.0
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()

def on_packet_acked(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes

# don't increase window in congestion recovery
if packet.sent_time <= self._congestion_recovery_start_time:
return

if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
self._congestion_stash += packet.sent_bytes
count = self._congestion_stash // self.congestion_window
if count:
self._congestion_stash -= count * self.congestion_window
self.congestion_window += count * self._max_datagram_size

def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes

def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes

def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time

# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
self.congestion_window = max(
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
K_MINIMUM_WINDOW * self._max_datagram_size,
)
self.ssthresh = self.congestion_window

# TODO : collapse congestion window if persistent congestion

def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
now=now, rtt=rtt
):
self.ssthresh = self.congestion_window


register_congestion_control("reno", RenoCongestionControl)
6 changes: 4 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from . import events
from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from .congestion.base import K_GRANULARITY
from .crypto import CryptoError, CryptoPair, KeyUnavailableError
from .logger import QuicLoggerTrace
from .packet import (
Expand Down Expand Up @@ -61,7 +62,7 @@
QuicPacketBuilder,
QuicPacketBuilderStop,
)
from .recovery import K_GRANULARITY, QuicPacketRecovery, QuicPacketSpace
from .recovery import QuicPacketRecovery, QuicPacketSpace
from .stream import FinalSizeError, QuicStream, StreamFinishedError

logger = logging.getLogger("quic")
Expand Down Expand Up @@ -383,6 +384,7 @@ def __init__(

# loss recovery
self._loss = QuicPacketRecovery(
congestion_control_algorithm=configuration.congestion_control_algorithm,
initial_rtt=configuration.initial_rtt,
max_datagram_size=self._max_datagram_size,
peer_completed_address_validation=not self._is_client,
Expand Down Expand Up @@ -1508,10 +1510,10 @@ def _handle_ack_frame(
self._loss.peer_completed_address_validation = True

self._loss.on_ack_received(
space=self._spaces[context.epoch],
ack_rangeset=ack_rangeset,
ack_delay=ack_delay,
now=context.time,
space=self._spaces[context.epoch],
)

def _handle_connection_close_frame(
Expand Down
Loading

0 comments on commit bbe4c4d

Please sign in to comment.