Skip to content

Commit

Permalink
moved congestion.py to __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Aperence committed Dec 12, 2023
1 parent 67b60d8 commit f7bc5af
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 152 deletions.
3 changes: 1 addition & 2 deletions examples/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,7 @@ async def main(
elif args.congestion_control == "reno":
configuration.congestion_control = RenoCongestionControl()
else:
print("Invalid congestion control algorithm")
exit(127)
raise Exception("Invalid congestion control algorithm")
if args.session_ticket:
try:
with open(args.session_ticket, "rb") as fp:
Expand Down
3 changes: 1 addition & 2 deletions examples/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,7 @@ async def main(
elif args.congestion_control == "reno":
configuration.congestion_control = RenoCongestionControl()
else:
print("Invalid congestion control algorithm")
exit(127)
raise Exception("Invalid congestion control algorithm")

if uvloop is not None:
uvloop.install()
Expand Down
148 changes: 147 additions & 1 deletion src/aioquic/quic/congestion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,147 @@
from .congestion import *
from ..packet_builder import QuicSentPacket
from typing import Iterable, Optional, Dict, Any
from datetime import datetime
from enum import Enum

K_GRANULARITY = 0.001 # seconds

# congestion control
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
K_LOSS_REDUCTION_FACTOR = 0.5

class CongestionEvent(Enum):
ACK=0
PACKET_SENT=1
PACKET_EXPIRED=2
PACKET_LOST=3
RTT_MEASURED=4

class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""

def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5

self._filtered_min: Optional[float] = None

self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]

def add_rtt(self, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1

if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True

if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample

def is_rtt_increasing(self, rtt: float, now: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt)
self._sample_time = now

if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max

delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False

class QuicCongestionControl:

def __init__(self, *, max_datagram_size : int, callback=None, fixed_cwnd = 10*1024*1024) -> None:
self.callback = callback # a callback argument that is called when an event occurs
# 10 GB window or custom fixed size window (shouldn't be used in real network !, use a real CCA instead)
self._max_datagram_size = max_datagram_size
self.cwnd = fixed_cwnd
self.data_in_flight = 0

def set_recovery(self, recovery):
# recovery is a QuicPacketRecovery instance
self.recovery = recovery

def on_packet_acked(self, packet: QuicSentPacket, now : float):
if self.callback:
self.callback(CongestionEvent.ACK, self)
if type(self) == QuicCongestionControl:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_SENT, self)
if type(self) == QuicCongestionControl:
# don't call this if it is a superclass that runs
self.data_in_flight += packet.sent_bytes

def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_EXPIRED, self)
if type(self) == QuicCongestionControl:
for packet in packets:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_LOST, self)
if type(self) == QuicCongestionControl:
for packet in packets:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
if self.callback:
self.callback(CongestionEvent.RTT_MEASURED, self)

def get_congestion_window(self) -> int:
# return the cwnd in number of bytes
return self.cwnd

def _set_congestion_window(self, value):
self.cwnd = value

def get_ssthresh(self) -> Optional[int]:
pass

def get_bytes_in_flight(self) -> int:
return self.data_in_flight

def log_callback(self) -> Dict[str, Any]:
# a callback called when a recovery happens
# The data object will be saved in the log file, so feel free to add
# any attribute you want to track
data: Dict[str, Any] = {
"bytes_in_flight": self.get_bytes_in_flight(),
"cwnd": self.get_congestion_window(),
}
if self.get_ssthresh() is not None:
data["ssthresh"] = self.get_ssthresh()

return data


147 changes: 0 additions & 147 deletions src/aioquic/quic/congestion/congestion.py

This file was deleted.

0 comments on commit f7bc5af

Please sign in to comment.