diff --git a/examples/http3_client.py b/examples/http3_client.py index e0000230d..1132db55b 100644 --- a/examples/http3_client.py +++ b/examples/http3_client.py @@ -556,8 +556,7 @@ async def main( elif args.congestion_control == "reno": configuration.congestion_control = RenoCongestionControl() else: - print("Invalid congestion control algorithm") - exit(127) + raise Exception("Invalid congestion control algorithm") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: diff --git a/examples/http3_server.py b/examples/http3_server.py index 09a106c6b..a244fb5f4 100644 --- a/examples/http3_server.py +++ b/examples/http3_server.py @@ -603,8 +603,7 @@ async def main( elif args.congestion_control == "reno": configuration.congestion_control = RenoCongestionControl() else: - print("Invalid congestion control algorithm") - exit(127) + raise Exception("Invalid congestion control algorithm") if uvloop is not None: uvloop.install() diff --git a/src/aioquic/quic/congestion/__init__.py b/src/aioquic/quic/congestion/__init__.py index bcf219f31..f7a5964b1 100644 --- a/src/aioquic/quic/congestion/__init__.py +++ b/src/aioquic/quic/congestion/__init__.py @@ -1 +1,147 @@ -from .congestion import * \ No newline at end of file +from ..packet_builder import QuicSentPacket +from typing import Iterable, Optional, Dict, Any +from datetime import datetime +from enum import Enum + +K_GRANULARITY = 0.001 # seconds + +# congestion control +K_INITIAL_WINDOW = 10 +K_MINIMUM_WINDOW = 2 +K_LOSS_REDUCTION_FACTOR = 0.5 + +class CongestionEvent(Enum): + ACK=0 + PACKET_SENT=1 + PACKET_EXPIRED=2 + PACKET_LOST=3 + RTT_MEASURED=4 + +class QuicRttMonitor: + """ + Roundtrip time monitor for HyStart. + """ + + def __init__(self) -> None: + self._increases = 0 + self._last_time = None + self._ready = False + self._size = 5 + + self._filtered_min: Optional[float] = None + + self._sample_idx = 0 + self._sample_max: Optional[float] = None + self._sample_min: Optional[float] = None + self._sample_time = 0.0 + self._samples = [0.0 for i in range(self._size)] + + def add_rtt(self, rtt: float) -> None: + self._samples[self._sample_idx] = rtt + self._sample_idx += 1 + + if self._sample_idx >= self._size: + self._sample_idx = 0 + self._ready = True + + if self._ready: + self._sample_max = self._samples[0] + self._sample_min = self._samples[0] + for sample in self._samples[1:]: + if sample < self._sample_min: + self._sample_min = sample + elif sample > self._sample_max: + self._sample_max = sample + + def is_rtt_increasing(self, rtt: float, now: float) -> bool: + if now > self._sample_time + K_GRANULARITY: + self.add_rtt(rtt) + self._sample_time = now + + if self._ready: + if self._filtered_min is None or self._filtered_min > self._sample_max: + self._filtered_min = self._sample_max + + delta = self._sample_min - self._filtered_min + if delta * 4 >= self._filtered_min: + self._increases += 1 + if self._increases >= self._size: + return True + elif delta > 0: + self._increases = 0 + return False + +class QuicCongestionControl: + + def __init__(self, *, max_datagram_size : int, callback=None, fixed_cwnd = 10*1024*1024) -> None: + self.callback = callback # a callback argument that is called when an event occurs + # 10 GB window or custom fixed size window (shouldn't be used in real network !, use a real CCA instead) + self._max_datagram_size = max_datagram_size + self.cwnd = fixed_cwnd + self.data_in_flight = 0 + + def set_recovery(self, recovery): + # recovery is a QuicPacketRecovery instance + self.recovery = recovery + + def on_packet_acked(self, packet: QuicSentPacket, now : float): + if self.callback: + self.callback(CongestionEvent.ACK, self) + if type(self) == QuicCongestionControl: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_SENT, self) + if type(self) == QuicCongestionControl: + # don't call this if it is a superclass that runs + self.data_in_flight += packet.sent_bytes + + def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_EXPIRED, self) + if type(self) == QuicCongestionControl: + for packet in packets: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_LOST, self) + if type(self) == QuicCongestionControl: + for packet in packets: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: + if self.callback: + self.callback(CongestionEvent.RTT_MEASURED, self) + + def get_congestion_window(self) -> int: + # return the cwnd in number of bytes + return self.cwnd + + def _set_congestion_window(self, value): + self.cwnd = value + + def get_ssthresh(self) -> Optional[int]: + pass + + def get_bytes_in_flight(self) -> int: + return self.data_in_flight + + def log_callback(self) -> Dict[str, Any]: + # a callback called when a recovery happens + # The data object will be saved in the log file, so feel free to add + # any attribute you want to track + data: Dict[str, Any] = { + "bytes_in_flight": self.get_bytes_in_flight(), + "cwnd": self.get_congestion_window(), + } + if self.get_ssthresh() is not None: + data["ssthresh"] = self.get_ssthresh() + + return data + + diff --git a/src/aioquic/quic/congestion/congestion.py b/src/aioquic/quic/congestion/congestion.py deleted file mode 100644 index f7a5964b1..000000000 --- a/src/aioquic/quic/congestion/congestion.py +++ /dev/null @@ -1,147 +0,0 @@ -from ..packet_builder import QuicSentPacket -from typing import Iterable, Optional, Dict, Any -from datetime import datetime -from enum import Enum - -K_GRANULARITY = 0.001 # seconds - -# congestion control -K_INITIAL_WINDOW = 10 -K_MINIMUM_WINDOW = 2 -K_LOSS_REDUCTION_FACTOR = 0.5 - -class CongestionEvent(Enum): - ACK=0 - PACKET_SENT=1 - PACKET_EXPIRED=2 - PACKET_LOST=3 - RTT_MEASURED=4 - -class QuicRttMonitor: - """ - Roundtrip time monitor for HyStart. - """ - - def __init__(self) -> None: - self._increases = 0 - self._last_time = None - self._ready = False - self._size = 5 - - self._filtered_min: Optional[float] = None - - self._sample_idx = 0 - self._sample_max: Optional[float] = None - self._sample_min: Optional[float] = None - self._sample_time = 0.0 - self._samples = [0.0 for i in range(self._size)] - - def add_rtt(self, rtt: float) -> None: - self._samples[self._sample_idx] = rtt - self._sample_idx += 1 - - if self._sample_idx >= self._size: - self._sample_idx = 0 - self._ready = True - - if self._ready: - self._sample_max = self._samples[0] - self._sample_min = self._samples[0] - for sample in self._samples[1:]: - if sample < self._sample_min: - self._sample_min = sample - elif sample > self._sample_max: - self._sample_max = sample - - def is_rtt_increasing(self, rtt: float, now: float) -> bool: - if now > self._sample_time + K_GRANULARITY: - self.add_rtt(rtt) - self._sample_time = now - - if self._ready: - if self._filtered_min is None or self._filtered_min > self._sample_max: - self._filtered_min = self._sample_max - - delta = self._sample_min - self._filtered_min - if delta * 4 >= self._filtered_min: - self._increases += 1 - if self._increases >= self._size: - return True - elif delta > 0: - self._increases = 0 - return False - -class QuicCongestionControl: - - def __init__(self, *, max_datagram_size : int, callback=None, fixed_cwnd = 10*1024*1024) -> None: - self.callback = callback # a callback argument that is called when an event occurs - # 10 GB window or custom fixed size window (shouldn't be used in real network !, use a real CCA instead) - self._max_datagram_size = max_datagram_size - self.cwnd = fixed_cwnd - self.data_in_flight = 0 - - def set_recovery(self, recovery): - # recovery is a QuicPacketRecovery instance - self.recovery = recovery - - def on_packet_acked(self, packet: QuicSentPacket, now : float): - if self.callback: - self.callback(CongestionEvent.ACK, self) - if type(self) == QuicCongestionControl: - # don't call this if it is a superclass that runs - self.data_in_flight -= packet.sent_bytes - - def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None: - if self.callback: - self.callback(CongestionEvent.PACKET_SENT, self) - if type(self) == QuicCongestionControl: - # don't call this if it is a superclass that runs - self.data_in_flight += packet.sent_bytes - - def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: - if self.callback: - self.callback(CongestionEvent.PACKET_EXPIRED, self) - if type(self) == QuicCongestionControl: - for packet in packets: - # don't call this if it is a superclass that runs - self.data_in_flight -= packet.sent_bytes - - def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: - if self.callback: - self.callback(CongestionEvent.PACKET_LOST, self) - if type(self) == QuicCongestionControl: - for packet in packets: - # don't call this if it is a superclass that runs - self.data_in_flight -= packet.sent_bytes - - def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: - if self.callback: - self.callback(CongestionEvent.RTT_MEASURED, self) - - def get_congestion_window(self) -> int: - # return the cwnd in number of bytes - return self.cwnd - - def _set_congestion_window(self, value): - self.cwnd = value - - def get_ssthresh(self) -> Optional[int]: - pass - - def get_bytes_in_flight(self) -> int: - return self.data_in_flight - - def log_callback(self) -> Dict[str, Any]: - # a callback called when a recovery happens - # The data object will be saved in the log file, so feel free to add - # any attribute you want to track - data: Dict[str, Any] = { - "bytes_in_flight": self.get_bytes_in_flight(), - "cwnd": self.get_congestion_window(), - } - if self.get_ssthresh() is not None: - data["ssthresh"] = self.get_ssthresh() - - return data - -