diff --git a/firewall.go b/firewall.go index 3a615de8a..f863566a1 100644 --- a/firewall.go +++ b/firewall.go @@ -284,7 +284,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -631,7 +631,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC if ft.UDP.match(p, incoming, c, caPool) { return true } - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } diff --git a/firewall/packet.go b/firewall/packet.go index 8954f4c47..b3cf1fb98 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -9,10 +9,11 @@ import ( type m map[string]interface{} const ( - ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - ProtoTCP = 6 - ProtoUDP = 17 - ProtoICMP = 1 + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` diff --git a/outside.go b/outside.go index f504bb406..91af76712 100644 --- a/outside.go +++ b/outside.go @@ -334,8 +334,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { return fmt.Errorf("ipv6 packet was too small") } fp.Protocol = uint8(proto) - fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } fp.Fragment = false return nil @@ -344,8 +349,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { return fmt.Errorf("ipv6 packet was too small") } fp.Protocol = uint8(proto) - fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } fp.Fragment = false return nil diff --git a/outside_test.go b/outside_test.go index aa5581f03..05537a474 100644 --- a/outside_test.go +++ b/outside_test.go @@ -5,6 +5,9 @@ import ( "net/netip" "testing" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" @@ -87,3 +90,55 @@ func Test_newPacket(t *testing.T) { assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } + +func Test_newPacket_v6(t *testing.T) { + p := &firewall.Packet{} + + ip := layers.IPv6{ + Version: 6, + NextHeader: firewall.ProtoUDP, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + err := udp.SetNetworkLayerForChecksum(&ip) + if err != nil { + panic(err) + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b := buffer.Bytes() + + //test incoming + err = newPacket(b, true, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.RemotePort, uint16(36123)) + assert.Equal(t, p.LocalPort, uint16(22)) + + //test outgoing + err = newPacket(b, false, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.LocalPort, uint16(36123)) + assert.Equal(t, p.RemotePort, uint16(22)) +}