Skip to content

Commit

Permalink
Cert V2 firewall fixes (#1242)
Browse files Browse the repository at this point in the history
* combine icmp and icmpv6 handling in the firewall

* correct ipv6 port number interpretation

* add unit test, fix orientation of ports

* gofmt
  • Loading branch information
JackDoanRivian authored Oct 10, 2024
1 parent a96c02a commit c257a75
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 10 deletions.
4 changes: 2 additions & 2 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
fp = ft.TCP
case firewall.ProtoUDP:
fp = ft.UDP
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
fp = ft.ICMP
case firewall.ProtoAny:
fp = ft.AnyProto
Expand Down Expand Up @@ -631,7 +631,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
Expand Down
9 changes: 5 additions & 4 deletions firewall/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
type m map[string]interface{}

const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
ProtoICMPv6 = 58

PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment`
Expand Down
18 changes: 14 additions & 4 deletions outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil

Expand All @@ -344,8 +349,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil

Expand Down
55 changes: 55 additions & 0 deletions outside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"net/netip"
"testing"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"

"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
Expand Down Expand Up @@ -87,3 +90,55 @@ func Test_newPacket(t *testing.T) {
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}

func Test_newPacket_v6(t *testing.T) {
p := &firewall.Packet{}

ip := layers.IPv6{
Version: 6,
NextHeader: firewall.ProtoUDP,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}

udp := layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}

buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
if err != nil {
panic(err)
}
b := buffer.Bytes()

//test incoming
err = newPacket(b, true, p)

assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.RemotePort, uint16(36123))
assert.Equal(t, p.LocalPort, uint16(22))

//test outgoing
err = newPacket(b, false, p)

assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.LocalPort, uint16(36123))
assert.Equal(t, p.RemotePort, uint16(22))
}

0 comments on commit c257a75

Please sign in to comment.