Skip to content

Commit

Permalink
TUN-8822: Prevent concurrent usage of ICMPDecoder
Browse files Browse the repository at this point in the history
## Summary
Some description...

Closes TUN-8822
  • Loading branch information
Gonçalo Garcia committed Dec 19, 2024
1 parent 9bc6cbd commit c690155
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 6 deletions.
24 changes: 18 additions & 6 deletions quic/v3/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ type datagramConn struct {
icmpRouter ingress.ICMPRouter
metrics Metrics
logger *zerolog.Logger

datagrams chan []byte
readErrors chan error
datagrams chan []byte
readErrors chan error

icmpEncoderPool sync.Pool // a pool of *packet.Encoder
icmpDecoder *packet.ICMPDecoder
icmpDecoderPool sync.Pool
}

func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
Expand All @@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
return packet.NewEncoder()
},
},
icmpDecoder: packet.NewICMPDecoder(),
icmpDecoderPool: sync.Pool{
New: func() any {
return packet.NewICMPDecoder()
},
},
}
}

Expand Down Expand Up @@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {

// Decode the provided ICMPDatagram as an ICMP packet
rawPacket := packet.RawPacket{Data: datagram.Payload}
icmp, err := c.icmpDecoder.Decode(rawPacket)
cachedDecoder := c.icmpDecoderPool.Get()
defer c.icmpDecoderPool.Put(cachedDecoder)
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
if !ok {
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
return
}

icmp, err := decoder.Decode(rawPacket)

if err != nil {
c.logger.Err(err).Msgf("unable to marshal icmp packet")
return
Expand Down
89 changes: 89 additions & 0 deletions quic/v3/muxer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"golang.org/x/net/icmp"
Expand Down Expand Up @@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) {
assertContextClosed(t, ctx, done, cancel)
}

// This test exists because decoding multiple packets in parallel with the same decoder
// instances causes inteference resulting in multiple different raw packets being decoded
// as the same decoded packet.
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)

// Setup the muxer
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(errors.New("other error"))
done := make(chan error, 1)
go func() {
done <- conn.Serve(ctx)
}()

packetCount := 100
packets := make([]*packet.ICMP, 100)
ipTemplate := "10.0.0.%d"
for i := 1; i <= packetCount; i++ {
packets[i-1] = &packet.ICMP{
IP: &packet.IP{
Src: netip.MustParseAddr("192.168.1.1"),
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
Protocol: layers.IPProtocolICMPv4,
TTL: 20,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 25821,
Seq: 58129,
Data: []byte("test"),
},
},
}
}

wg := sync.WaitGroup{}
var receivedPackets []*packet.ICMP
go func() {
for ctx.Err() == nil {
select {
case icmpPacket := <-router.recv:
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
}
}
}()

for _, p := range packets {
// We increment here but only decrement when receiving the packet
wg.Add(1)
go func() {
datagram := newICMPDatagram(p)
quic.send <- datagram
}()
}

wg.Wait()

// If there were duplicates then we won't have the same number of IPs
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
for _, p := range receivedPackets {
packetSet[p.Dst] = p
}
assert.Equal(t, len(packetSet), len(packets))

// Sort the slice by last byte of IP address (the one we increment for each destination)
// and then check that we have one match for each packet sent
sort.Slice(receivedPackets, func(i, j int) bool {
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
})
for i, p := range receivedPackets {
assert.Equal(t, p.Dst, packets[i].Dst)
}

// Cancel the muxer Serve context and make sure it closes with the expected error
assertContextClosed(t, ctx, done, cancel)
}

func TestDatagramConnServe_RegisterTwice(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
Expand Down

0 comments on commit c690155

Please sign in to comment.