From 18cc0786ed22e9030a83ad52df8cb9594c4c4a02 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 29 Sep 2023 21:43:29 -0700 Subject: [PATCH] feat: OnCongestionEventEx for BBR --- congestion/interface.go | 23 ++++++++++++++ internal/ackhandler/cc_adapter.go | 7 +++++ internal/ackhandler/sent_packet_handler.go | 35 +++++++++++++++++++--- internal/congestion/interface.go | 6 ++++ 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/congestion/interface.go b/congestion/interface.go index 6b29bc47137..0f4ea169244 100644 --- a/congestion/interface.go +++ b/congestion/interface.go @@ -11,6 +11,28 @@ type ( PacketNumber protocol.PacketNumber ) +// Expose some constants from protocol that congestion control algorithms may need. +const ( + InitialPacketSizeIPv4 = protocol.InitialPacketSizeIPv4 + InitialPacketSizeIPv6 = protocol.InitialPacketSizeIPv6 + MinPacingDelay = protocol.MinPacingDelay + MaxPacketBufferSize = protocol.MaxPacketBufferSize + MinInitialPacketSize = protocol.MinInitialPacketSize + MaxCongestionWindowPackets = protocol.MaxCongestionWindowPackets + PacketsPerConnectionID = protocol.PacketsPerConnectionID +) + +type AckedPacketInfo struct { + PacketNumber PacketNumber + BytesAcked ByteCount + ReceivedTime time.Time +} + +type LostPacketInfo struct { + PacketNumber PacketNumber + BytesLost ByteCount +} + type CongestionControl interface { SetRTTStatsProvider(provider RTTStatsProvider) TimeUntilSend(bytesInFlight ByteCount) time.Time @@ -20,6 +42,7 @@ type CongestionControl interface { MaybeExitSlowStart() OnPacketAcked(number PacketNumber, ackedBytes ByteCount, priorInFlight ByteCount, eventTime time.Time) OnCongestionEvent(number PacketNumber, lostBytes ByteCount, priorInFlight ByteCount) + OnCongestionEventEx(priorInFlight ByteCount, eventTime time.Time, ackedPackets []AckedPacketInfo, lostPackets []LostPacketInfo) OnRetransmissionTimeout(packetsRetransmitted bool) SetMaxDatagramSize(size ByteCount) InSlowStart() bool diff --git a/internal/ackhandler/cc_adapter.go b/internal/ackhandler/cc_adapter.go index 79af97070d9..fa0cebf6ccd 100644 --- a/internal/ackhandler/cc_adapter.go +++ b/internal/ackhandler/cc_adapter.go @@ -4,9 +4,12 @@ import ( "time" "github.com/quic-go/quic-go/congestion" + cgInternal "github.com/quic-go/quic-go/internal/congestion" "github.com/quic-go/quic-go/internal/protocol" ) +var _ cgInternal.SendAlgorithmEx = &ccAdapter{} + type ccAdapter struct { CC congestion.CongestionControl } @@ -39,6 +42,10 @@ func (a *ccAdapter) OnCongestionEvent(number protocol.PacketNumber, lostBytes pr a.CC.OnCongestionEvent(congestion.PacketNumber(number), congestion.ByteCount(lostBytes), congestion.ByteCount(priorInFlight)) } +func (a *ccAdapter) OnCongestionEventEx(priorInFlight protocol.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + a.CC.OnCongestionEventEx(congestion.ByteCount(priorInFlight), eventTime, ackedPackets, lostPackets) +} + func (a *ccAdapter) OnRetransmissionTimeout(packetsRetransmitted bool) { a.CC.OnRetransmissionTimeout(packetsRetransmitted) } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index d9d3dbbd893..db33e9b4c39 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -354,6 +354,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } + ackedPacketsInfo := []congestionExt.AckedPacketInfo{} + lostPacketsInfo := []congestionExt.LostPacketInfo{} + // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) @@ -364,13 +367,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - if err := h.detectLostPackets(rcvTime, encLevel); err != nil { + if lostPacketsInfo, err = h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } var acked1RTTPacket bool for _, p := range ackedPackets { if p.includedInBytesInFlight && !p.declaredLost { cc.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + ackedPacketsInfo = append(ackedPacketsInfo, congestionExt.AckedPacketInfo{ + PacketNumber: congestionExt.PacketNumber(p.PacketNumber), + BytesAcked: congestionExt.ByteCount(p.Length), + }) } if p.EncryptionLevel == protocol.Encryption1RTT { acked1RTTPacket = true @@ -378,6 +385,12 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.removeFromBytesInFlight(p) putPacket(p) } + + if cex, ok := h.congestion.(congestion.SendAlgorithmEx); ok && + (len(ackedPacketsInfo) != 0 || len(lostPacketsInfo) != 0) { + cex.OnCongestionEventEx(priorInFlight, rcvTime, ackedPacketsInfo, lostPacketsInfo) + } + // After this point, we must not use ackedPackets any longer! // We've already returned the buffers. ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. @@ -611,7 +624,8 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } } -func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]congestionExt.LostPacketInfo, error) { + lostPackets := []congestionExt.LostPacketInfo{} pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -627,7 +641,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E cc := h.getCongestionControl() priorInFlight := h.bytesInFlight - return pnSpace.history.Iterate(func(p *packet) (bool, error) { + err := pnSpace.history.Iterate(func(p *packet) (bool, error) { if p.PacketNumber > pnSpace.largestAcked { return false, nil } @@ -670,6 +684,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if !p.IsPathMTUProbePacket { cc.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) } + lostPackets = append(lostPackets, congestionExt.LostPacketInfo{ + PacketNumber: congestionExt.PacketNumber(p.PacketNumber), + BytesLost: congestionExt.ByteCount(p.Length), + }) if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.LostPacket(p.PacketNumber) } @@ -677,10 +695,12 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } return true, nil }) + return lostPackets, err } func (h *sentPacketHandler) OnLossDetectionTimeout() error { defer h.setLossDetectionTimer() + priorInFlight := h.bytesInFlight earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -690,7 +710,14 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel) + lostPacketsInfo, err := h.detectLostPackets(time.Now(), encLevel) + + if cex, ok := h.congestion.(congestion.SendAlgorithmEx); ok && + len(lostPacketsInfo) != 0 { + cex.OnCongestionEventEx(priorInFlight, time.Now(), nil, lostPacketsInfo) + } + + return err } // PTO diff --git a/internal/congestion/interface.go b/internal/congestion/interface.go index 881f453b69a..da052159e66 100644 --- a/internal/congestion/interface.go +++ b/internal/congestion/interface.go @@ -3,6 +3,7 @@ package congestion import ( "time" + "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/protocol" ) @@ -26,3 +27,8 @@ type SendAlgorithmWithDebugInfos interface { InRecovery() bool GetCongestionWindow() protocol.ByteCount } + +type SendAlgorithmEx interface { + SendAlgorithmWithDebugInfos + OnCongestionEventEx(priorInFlight protocol.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) +}