diff --git a/pkg/agent/controller/networkpolicy/reject_test.go b/pkg/agent/controller/networkpolicy/reject_test.go index 5a36f4af8e0..17cda725790 100644 --- a/pkg/agent/controller/networkpolicy/reject_test.go +++ b/pkg/agent/controller/networkpolicy/reject_test.go @@ -20,13 +20,17 @@ import ( "testing" "antrea.io/libOpenflow/openflow15" + "antrea.io/libOpenflow/protocol" + "antrea.io/libOpenflow/util" "antrea.io/ofnet/ofctrl" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "antrea.io/antrea/pkg/agent/config" "antrea.io/antrea/pkg/agent/interfacestore" + testing3 "antrea.io/antrea/pkg/agent/interfacestore/testing" "antrea.io/antrea/pkg/agent/openflow" + testing2 "antrea.io/antrea/pkg/agent/openflow/testing" binding "antrea.io/antrea/pkg/ovs/openflow" mocks "antrea.io/antrea/pkg/ovs/openflow/testing" ) @@ -390,6 +394,7 @@ func Test_getRejectPacketOutMutateFunc(t *testing.T) { }) conntrackTableID := openflow.ConntrackTable.GetID() l3ForwardingTableID := openflow.L3ForwardingTable.GetID() + l2ForwardingTableID := openflow.L2ForwardingCalcTable.GetID() ctrl := gomock.NewController(t) type args struct { rejectType RejectType @@ -451,7 +456,7 @@ func Test_getRejectPacketOutMutateFunc(t *testing.T) { }, }, { - name: "RejectLocalToRemoteFlexibleIPAMSrc", + name: "RejectLocalToRemoteNoFlexibleIPAM", args: args{ rejectType: RejectPodLocalToRemote, nodeType: config.K8sNode, @@ -465,6 +470,21 @@ func Test_getRejectPacketOutMutateFunc(t *testing.T) { builder.EXPECT().AddResubmitAction(nil, &l3ForwardingTableID).Return(builder) }, }, + { + name: "RejectLocalToRemoteNoFlexibleIPAM,ExternalNode", + args: args{ + rejectType: RejectPodLocalToRemote, + nodeType: config.ExternalNode, + isFlexibleIPAMSrc: false, + isFlexibleIPAMDst: false, + ctZone: 1, + }, + prepareFunc: func(builder *mocks.MockPacketOutBuilder) { + builder.EXPECT().AddLoadRegMark(openflow.GeneratedRejectPacketOutRegMark).Return(builder) + builder.EXPECT().AddLoadRegMark(binding.NewRegMark(openflow.CtZoneField, 1)).Return(builder) + builder.EXPECT().AddResubmitAction(nil, &l2ForwardingTableID).Return(builder) + }, + }, { name: "RejectServiceRemoteToLocalFlexibleIPAMDst", args: args{ @@ -502,3 +522,244 @@ func Test_getRejectPacketOutMutateFunc(t *testing.T) { }) } } + +func Test_handleRejectRequest(t *testing.T) { + fakeMac1 := "aa:aa:aa:aa:aa:aa" + fakeMac2 := "bb:bb:bb:bb:bb:bb" + fakeIPv41 := "1.1.1.1" + fakeIPv42 := "2.2.2.2" + fakeIPv61 := "2001::1" + fakeIPv62 := "2001::2" + fakePort1 := uint16(8080) + fakePort2 := uint16(80) + fakeTCPSeqNum := uint32(1) + fakeOFPort1 := int32(1) + fakeOFPort2 := int32(2) + fakeSIface := &interfacestore.InterfaceConfig{ + OVSPortConfig: &interfacestore.OVSPortConfig{ + OFPort: fakeOFPort2, + }, + MAC: net.HardwareAddr([]byte{0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb}), + } + fakeDIface := &interfacestore.InterfaceConfig{ + OVSPortConfig: &interfacestore.OVSPortConfig{ + OFPort: fakeOFPort1, + }, + MAC: net.HardwareAddr([]byte{0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa}), + } + fakeGWIface := &interfacestore.InterfaceConfig{ + OVSPortConfig: &interfacestore.OVSPortConfig{ + OFPort: fakeOFPort1, + }, + MAC: net.HardwareAddr([]byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}), + } + type args struct { + srcMAC string + dstMAC string + srcIP string + dstIP string + isIPv6 bool + isTCP bool + srcPort uint16 + dstPort uint16 + tcpSeqNum uint32 + matches []openflow15.MatchField + antreaProxyEnable bool + } + // Only test RejectPodLocal and RejectServiceRemoteToExternal two reject types, + // considering other types branches are covered by the UTs above. + tests := []struct { + name string + args args + expectFunc func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) + }{ + { + name: "IPv4,TCP,antreaProxy,podLocal", + args: args{ + srcMAC: fakeMac1, + dstMAC: fakeMac2, + srcIP: fakeIPv41, + dstIP: fakeIPv42, + isIPv6: false, + isTCP: true, + srcPort: fakePort1, + dstPort: fakePort2, + tcpSeqNum: fakeTCPSeqNum, + matches: []openflow15.MatchField{*openflow15.NewInPortField(uint32(fakeOFPort1))}, + antreaProxyEnable: true, + }, + expectFunc: func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) { + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeSIface, true) + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeDIface, true) + mockClient.EXPECT().SendTCPPacketOut(fakeMac2, fakeMac1, fakeIPv42, fakeIPv41, gomock.Any(), gomock.Any(), false, fakePort2, fakePort1, gomock.Any(), fakeTCPSeqNum+1, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + }, + { + name: "IPv4,TCP,antreaProxyDisable,serviceRemoteToExternal", + args: args{ + srcMAC: fakeMac1, + dstMAC: fakeMac2, + srcIP: fakeIPv41, + dstIP: fakeIPv42, + isIPv6: false, + isTCP: true, + srcPort: fakePort1, + dstPort: fakePort2, + tcpSeqNum: fakeTCPSeqNum, + matches: []openflow15.MatchField{{ + Class: openflow15.OXM_CLASS_PACKET_REGS, + Field: 2, + HasMask: false, + Length: 0, + ExperimenterID: 0, + Value: &openflow15.ByteArrayField{ + Data: []byte{0, 2, 0, 0, 0, 0, 0, 0}, + Length: 64, + }, + Mask: nil, + }}, + antreaProxyEnable: true, + }, + expectFunc: func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) { + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(nil, false) + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(nil, false) + mockClient.EXPECT().SendTCPPacketOut(fakeMac2, openflow.GlobalVirtualMAC.String(), fakeIPv42, fakeIPv41, gomock.Any(), gomock.Any(), false, fakePort2, fakePort1, gomock.Any(), fakeTCPSeqNum+1, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + }, + { + name: "IPv4,UDP,antreaProxy,podLocal", + args: args{ + srcMAC: fakeMac1, + dstMAC: fakeMac2, + srcIP: fakeIPv41, + dstIP: fakeIPv42, + isIPv6: false, + isTCP: false, + srcPort: fakePort1, + dstPort: fakePort2, + matches: []openflow15.MatchField{*openflow15.NewInPortField(uint32(fakeOFPort1))}, + antreaProxyEnable: true, + }, + expectFunc: func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) { + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeSIface, true) + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeDIface, true) + mockClient.EXPECT().SendICMPPacketOut(fakeMac2, fakeMac1, fakeIPv42, fakeIPv41, gomock.Any(), gomock.Any(), false, uint8(3), uint8(10), gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + }, + { + name: "IPv4,TCP,antreaProxyDisable,podLocal", + args: args{ + srcMAC: fakeMac1, + dstMAC: fakeMac2, + srcIP: fakeIPv41, + dstIP: fakeIPv42, + isIPv6: false, + isTCP: true, + srcPort: fakePort1, + dstPort: fakePort2, + tcpSeqNum: fakeTCPSeqNum, + antreaProxyEnable: false, + }, + expectFunc: func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) { + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeSIface, true) + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeDIface, true) + mockIStore.EXPECT().GetInterfacesByType(gomock.Any()).Return([]*interfacestore.InterfaceConfig{fakeGWIface}) + mockClient.EXPECT().SendTCPPacketOut(fakeMac2, fakeMac1, fakeIPv42, fakeIPv41, gomock.Any(), gomock.Any(), false, fakePort2, fakePort1, gomock.Any(), fakeTCPSeqNum+1, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + }, + { + name: "IPv6,TCP,antreaProxy,podLocal", + args: args{ + srcMAC: fakeMac1, + dstMAC: fakeMac2, + srcIP: fakeIPv61, + dstIP: fakeIPv62, + isIPv6: true, + isTCP: true, + srcPort: fakePort1, + dstPort: fakePort2, + tcpSeqNum: fakeTCPSeqNum, + antreaProxyEnable: true, + }, + expectFunc: func(mockClient *testing2.MockClient, mockIStore *testing3.MockInterfaceStore) { + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeSIface, true) + mockIStore.EXPECT().GetInterfaceByIP(gomock.Any()).Return(fakeDIface, true) + mockClient.EXPECT().SendTCPPacketOut(fakeMac2, fakeMac1, fakeIPv62, fakeIPv61, gomock.Any(), gomock.Any(), true, fakePort2, fakePort1, gomock.Any(), fakeTCPSeqNum+1, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + }, + } + ctrl := gomock.NewController(t) + controller, _, _ := newTestController() + mockClient := testing2.NewMockClient(ctrl) + mockIStore := testing3.NewMockInterfaceStore(ctrl) + controller.ofClient = mockClient + controller.ifaceStore = mockIStore + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + controller.antreaProxyEnabled = tt.args.antreaProxyEnable + tt.expectFunc(mockClient, mockIStore) + assert.NoError(t, controller.rejectRequest(genPacketIn(tt.args.srcMAC, tt.args.dstMAC, tt.args.srcIP, tt.args.dstIP, tt.args.srcPort, tt.args.dstPort, tt.args.tcpSeqNum, tt.args.isIPv6, tt.args.isTCP, tt.args.matches))) + }) + } +} + +func genPacketIn(srcMac, dstMac, srcIP, dstIP string, srcPort, dstPort uint16, seqNum uint32, isIPv6, isTCP bool, matches []openflow15.MatchField) *ofctrl.PacketIn { + pktIn := openflow15.NewPacketIn() + for i := range matches { + pktIn.Match.AddField(matches[i]) + } + var ipData util.Message + var proto uint8 + if isTCP { + proto = protocol.Type_TCP + tcpPacket := &protocol.TCP{ + PortSrc: srcPort, + PortDst: dstPort, + SeqNum: seqNum, + } + b, _ := tcpPacket.MarshalBinary() + ipData = util.NewBuffer(b) + } else { + proto = protocol.Type_UDP + udpPacket := &protocol.UDP{ + PortSrc: srcPort, + PortDst: dstPort, + } + b, _ := udpPacket.MarshalBinary() + ipData = util.NewBuffer(b) + } + var ethData util.Message + var ethType uint16 + if isIPv6 { + ethType = protocol.IPv6_MSG + ipPacket := &protocol.IPv6{ + NWSrc: net.ParseIP(srcIP), + NWDst: net.ParseIP(dstIP), + NextHeader: proto, + Data: ipData, + } + b, _ := ipPacket.MarshalBinary() + ethData = util.NewBuffer(b) + } else { + ethType = protocol.IPv4_MSG + ipPacket := &protocol.IPv4{ + NWSrc: net.ParseIP(srcIP), + NWDst: net.ParseIP(dstIP), + Protocol: proto, + Data: ipData, + } + b, _ := ipPacket.MarshalBinary() + ethData = util.NewBuffer(b) + } + ethernetPkt := protocol.NewEthernet() + hwSrc, _ := net.ParseMAC(srcMac) + hwDst, _ := net.ParseMAC(dstMac) + ethernetPkt.HWSrc = hwSrc + ethernetPkt.HWDst = hwDst + ethernetPkt.Ethertype = ethType + ethernetPkt.Data = ethData + pktBytes, _ := ethernetPkt.MarshalBinary() + pktIn.Data = util.NewBuffer(pktBytes) + return &ofctrl.PacketIn{PacketIn: pktIn} +}