diff --git a/hack/update-codegen-dockerized.sh b/hack/update-codegen-dockerized.sh index 204a5201325..34a0467155e 100755 --- a/hack/update-codegen-dockerized.sh +++ b/hack/update-codegen-dockerized.sh @@ -59,6 +59,7 @@ MOCKGEN_TARGETS=( "pkg/agent/util/iptables Interface testing mock_iptables_linux.go" # Must specify linux.go suffix, otherwise compilation would fail on windows platform as source file has linux build tag. "pkg/agent/util/netlink Interface testing mock_netlink_linux.go" "pkg/agent/wireguard Interface testing mock_wireguard.go" + "pkg/agent/util/winnet Interface testing mock_net_windows.go" "pkg/antctl AntctlClient ." "pkg/controller/networkpolicy EndpointQuerier,PolicyRuleQuerier testing" "pkg/controller/querier ControllerQuerier testing" diff --git a/pkg/agent/agent_windows.go b/pkg/agent/agent_windows.go index 3bac5b53003..1429b95c6da 100644 --- a/pkg/agent/agent_windows.go +++ b/pkg/agent/agent_windows.go @@ -30,6 +30,7 @@ import ( "antrea.io/antrea/pkg/agent/interfacestore" "antrea.io/antrea/pkg/agent/util" antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" + "antrea.io/antrea/pkg/agent/util/winnet" "antrea.io/antrea/pkg/apis/crd/v1alpha1" "antrea.io/antrea/pkg/ovs/ovsconfig" "antrea.io/antrea/pkg/ovs/ovsctl" @@ -37,8 +38,9 @@ import ( ) var ( + winnetUtil winnet.Interface = &winnet.Handle{} // setInterfaceMTU is meant to be overridden for testing - setInterfaceMTU = util.SetInterfaceMTU + setInterfaceMTU = winnetUtil.SetNetAdapterMTU // setInterfaceARPAnnounce is meant to be overridden for testing. setInterfaceARPAnnounce = func(ifaceName string, value int) error { return nil } @@ -57,11 +59,11 @@ func (i *Initializer) prepareHNSNetworkAndOVSExtension() error { hnsNetwork, err := hcsshim.GetHNSNetworkByName(util.LocalHNSNetwork) if err == nil { // Enable OVS Extension on the HNS Network. - if err = util.EnableHNSNetworkExtension(hnsNetwork.Id, util.OVSExtensionID); err != nil { + if err = util.EnableHNSNetworkExtension(hnsNetwork.Id, winnet.OVSExtensionID); err != nil { return err } // Enable RSC for existing vSwitch. - if err = util.EnableRSCOnVSwitch(util.LocalHNSNetwork); err != nil { + if err = winnetUtil.EnableRSCOnVSwitch(util.LocalHNSNetwork); err != nil { return err } // Save the uplink adapter name to check if the OVS uplink port has been created in prepareOVSBridge stage. @@ -87,7 +89,7 @@ func (i *Initializer) prepareHNSNetworkAndOVSExtension() error { // adapter was not found" if no adapter is provided and no physical adapter is available on the host. // If the discovered adapter is virtual, it likely means the physical adapter is already attached to another // HNSNetwork. For example, docker may create HNSNetworks which attach to the physical adapter. - isVirtual, err := util.IsVirtualAdapter(adapter.Name) + isVirtual, err := winnetUtil.IsVirtualNetAdapter(adapter.Name) if err != nil { return err } @@ -107,7 +109,7 @@ func (i *Initializer) prepareHNSNetworkAndOVSExtension() error { klog.InfoS("No default gateway found on interface", "interface", adapter.Name) } i.nodeConfig.UplinkNetConfig.Gateway = defaultGW - dnsServers, err := util.GetDNServersByInterfaceIndex(adapter.Index) + dnsServers, err := winnetUtil.GetDNServersByNetAdapterIndex(adapter.Index) if err != nil { return err } @@ -128,12 +130,12 @@ func (i *Initializer) prepareHNSNetworkAndOVSExtension() error { func (i *Initializer) prepareVMNetworkAndOVSExtension() error { klog.V(2).Info("Setting up VM network") // Check whether VM Switch is created - exists, err := util.VMSwitchExists() + exists, err := winnetUtil.VMSwitchExists(util.LocalVMSwitch) if err != nil { return err } if exists { - vmSwitchIFName, err := util.GetVMSwitchInterfaceName() + vmSwitchIFName, err := winnetUtil.GetVMSwitchNetAdapterName(util.LocalVMSwitch) if err != nil { return err } @@ -168,20 +170,29 @@ func (i *Initializer) prepareVMNetworkAndOVSExtension() error { }() klog.V(2).InfoS("Creating VM switch", "uplinkIFName", uplinkIFName) - if err = util.CreateVMSwitch(uplinkIFName); err != nil { + if err = winnetUtil.AddVMSwitch(uplinkIFName, util.LocalVMSwitch); err != nil { return fmt.Errorf("failed to create VM switch for interface %s: %v", uplinkIFName, err) } + enabled, err := winnetUtil.IsVMSwitchOVSExtensionEnabled(util.LocalVMSwitch) + if err != nil { + return err + } + if !enabled { + if err = winnetUtil.EnableVMSwitchOVSExtension(util.LocalVMSwitch); err != nil { + return err + } + } defer func() { if !success { - if err = util.RemoveVMSwitch(); err != nil { + if err = winnetUtil.RemoveVMSwitch(util.LocalVMSwitch); err != nil { klog.ErrorS(err, "Failed to remove VMSwitch") } } }() uplinkMACStr := strings.Replace(uplinkIface.HardwareAddr.String(), ":", "", -1) - if err = util.RenameVMNetworkAdapter(util.LocalVMSwitch, uplinkMACStr, hostIFName, true); err != nil { + if err = winnetUtil.RenameVMNetworkAdapter(util.LocalVMSwitch, uplinkMACStr, hostIFName, true); err != nil { return fmt.Errorf("failed to rename VMNetworkAdapter as %s: %v", hostIFName, err) } @@ -298,7 +309,7 @@ func (i *Initializer) prepareOVSBridgeOnHNSNetwork() error { // connection are routed to the selected backend Pod via the bridge interface; if we do not enable IP forwarding on // the bridge interface, the packet will be discarded on the bridge interface as the destination of the packet // is not the Node. - if err = util.EnableIPForwarding(brName); err != nil { + if err = winnetUtil.EnableIPForwarding(brName); err != nil { return err } // Set the uplink with "no-flood" config, so that the IP of local Pods and "antrea-gw0" will not be leaked to the @@ -395,11 +406,11 @@ func (i *Initializer) saveHostRoutes() error { // IPv6 is not supported on Windows currently. Please refer to https://github.com/antrea-io/antrea/issues/5162 // for more information. family := antreasyscall.AF_INET - filter := &util.Route{ + filter := &winnet.Route{ LinkIndex: i.nodeConfig.UplinkNetConfig.Index, GatewayAddress: net.ParseIP(i.nodeConfig.UplinkNetConfig.Gateway), } - routes, err := util.RouteListFiltered(family, filter, util.RT_FILTER_IF|util.RT_FILTER_GW) + routes, err := winnetUtil.RouteListFiltered(family, filter, winnet.RT_FILTER_IF|winnet.RT_FILTER_GW) if err != nil { return err } diff --git a/pkg/agent/cniserver/interface_configuration_windows.go b/pkg/agent/cniserver/interface_configuration_windows.go index 58547cc0097..0a94f2ced40 100644 --- a/pkg/agent/cniserver/interface_configuration_windows.go +++ b/pkg/agent/cniserver/interface_configuration_windows.go @@ -34,6 +34,7 @@ import ( "k8s.io/klog/v2" "antrea.io/antrea/pkg/agent/util" + "antrea.io/antrea/pkg/agent/util/winnet" cnipb "antrea.io/antrea/pkg/apis/cni/v1beta1" "antrea.io/antrea/pkg/ovs/ovsconfig" ) @@ -45,7 +46,6 @@ const ( var ( getHnsNetworkByNameFunc = hcsshim.GetHNSNetworkByName listHnsEndpointFunc = hcsshim.HNSListEndpointRequest - setInterfaceMTUFunc = util.SetInterfaceMTU hostInterfaceExistsFunc = util.HostInterfaceExists getNetInterfaceAddrsFunc = getNetInterfaceAddrs createHnsEndpointFunc = createHnsEndpoint @@ -60,6 +60,7 @@ var ( type ifConfigurator struct { hnsNetwork *hcsshim.HNSNetwork epCache *sync.Map + winnet winnet.Interface } // disableTXChecksumOffload is ignored on Windows. @@ -80,6 +81,7 @@ func newInterfaceConfigurator(ovsDatapathType ovsconfig.OVSDatapathType, isOvsHa return &ifConfigurator{ hnsNetwork: hnsNetwork, epCache: epCache, + winnet: &winnet.Handle{}, }, nil } @@ -177,8 +179,8 @@ func (ic *ifConfigurator) configureContainerLink( // CmdAdd request is returned; 2) for Docker runtime, the interface is created after hcsshim.HotAttachEndpoint, // and the hcsshim call is not synchronized from the observation. return ic.addPostInterfaceCreateHook(infraContainerID, epName, containerAccess, func() error { - ifaceName := util.VirtualAdapterName(epName) - if err := setInterfaceMTUFunc(ifaceName, mtu); err != nil { + ifaceName := winnet.VirtualAdapterName(epName) + if err := ic.winnet.SetNetAdapterMTU(ifaceName, mtu); err != nil { return fmt.Errorf("failed to configure MTU on container interface '%s': %v", ifaceName, err) } return nil @@ -342,7 +344,7 @@ func (ic *ifConfigurator) checkContainerInterface( containerIface.Sandbox, sandboxID) } hnsEP := strings.Split(containerIface.Name, "_")[0] - containerIfaceName := util.VirtualAdapterName(hnsEP) + containerIfaceName := winnet.VirtualAdapterName(hnsEP) intf, err := getNetInterfaceByNameFunc(containerIfaceName) if err != nil { klog.Errorf("Failed to get container %s interface: %v", containerID, err) diff --git a/pkg/agent/cniserver/server_windows_test.go b/pkg/agent/cniserver/server_windows_test.go index 4a149d3ecff..70487b083e8 100644 --- a/pkg/agent/cniserver/server_windows_test.go +++ b/pkg/agent/cniserver/server_windows_test.go @@ -41,6 +41,7 @@ import ( routetest "antrea.io/antrea/pkg/agent/route/testing" agenttypes "antrea.io/antrea/pkg/agent/types" "antrea.io/antrea/pkg/agent/util" + winnettest "antrea.io/antrea/pkg/agent/util/winnet/testing" cnipb "antrea.io/antrea/pkg/apis/cni/v1beta1" ovsconfigtest "antrea.io/antrea/pkg/ovs/ovsconfig/testing" "antrea.io/antrea/pkg/util/channel" @@ -49,6 +50,8 @@ import ( var ( containerMACStr = "23:34:56:23:22:45" dnsSearches = []string{"a.b.c.d"} + + mockWinnet *winnettest.MockInterface ) func TestUpdateResultDNSConfig(t *testing.T) { @@ -258,6 +261,7 @@ func newMockCNIServer(t *testing.T, controller *gomock.Controller, podUpdateNoti mockOVSBridgeClient = ovsconfigtest.NewMockOVSBridgeClient(controller) mockOFClient = openflowtest.NewMockClient(controller) mockRoute = routetest.NewMockInterface(controller) + mockWinnet = winnettest.NewMockInterface(controller) ifaceStore = interfacestore.NewInterfaceStore() cniServer := newCNIServer(t) cniServer.routeClient = mockRoute @@ -266,6 +270,7 @@ func newMockCNIServer(t *testing.T, controller *gomock.Controller, podUpdateNoti gateway := &config.GatewayConfig{Name: "", IPv4: gwIPv4, MAC: gwMAC} cniServer.nodeConfig = &config.NodeConfig{Name: "node1", PodIPv4CIDR: nodePodCIDRv4, GatewayConfig: gateway} cniServer.podConfigurator, _ = newPodConfigurator(mockOVSBridgeClient, mockOFClient, mockRoute, ifaceStore, gwMAC, "system", false, false, podUpdateNotifier) + cniServer.podConfigurator.ifConfigurator.(*ifConfigurator).winnet = mockWinnet return cniServer } @@ -289,10 +294,6 @@ func prepareSetup(t *testing.T, ipamType string, name string, containerID, infra } func TestCmdAdd(t *testing.T) { - controller := gomock.NewController(t) - ipamType := "windows-test" - ipamMock := ipamtest.NewMockIPAMDriver(controller) - ipam.ResetIPAMDriver(ipamType, ipamMock) oriIPAMResult := &ipam.IPAMResult{Result: *ipamResult} ctx := context.TODO() @@ -300,7 +301,6 @@ func TestCmdAdd(t *testing.T) { defer mockHostInterfaceExists()() defer mockGetHnsNetworkByName()() - defer mockSetInterfaceMTU(nil)() for _, tc := range []struct { name string @@ -360,6 +360,11 @@ func TestCmdAdd(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + ipamType := "windows-test" + ipamMock := ipamtest.NewMockIPAMDriver(controller) + ipam.ResetIPAMDriver(ipamType, ipamMock) + isDocker := isDockerContainer(tc.netns) testUtil := newHnsTestUtil(generateUUID(), tc.existingHnsEndpoints, isDocker, tc.isAttached, tc.hnsEndpointCreateErr, tc.endpointAttachErr) testUtil.setFunctions() @@ -379,6 +384,9 @@ func TestCmdAdd(t *testing.T) { if tc.ipamDel { ipamMock.EXPECT().Del(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).Times(1) } + if tc.endpointAttachErr == nil { + mockWinnet.EXPECT().SetNetAdapterMTU(gomock.Any(), gomock.Any()).Times(1) + } ovsPortID := generateUUID() if tc.connectOVS { mockOVSBridgeClient.EXPECT().CreatePort(ovsPortName, ovsPortName, gomock.Any()).Return(ovsPortID, nil).Times(1) @@ -431,10 +439,6 @@ func TestCmdAdd(t *testing.T) { } func TestCmdDel(t *testing.T) { - controller := gomock.NewController(t) - ipamType := "windows-test" - ipamMock := ipamtest.NewMockIPAMDriver(controller) - ipam.ResetIPAMDriver(ipamType, ipamMock) ctx := context.TODO() containerID := "261a1970-5b6c-11ed-8caf-000c294e5d03" @@ -442,7 +446,6 @@ func TestCmdDel(t *testing.T) { defer mockHostInterfaceExists()() defer mockGetHnsNetworkByName()() - defer mockSetInterfaceMTU(nil)() for _, tc := range []struct { name string @@ -476,6 +479,11 @@ func TestCmdDel(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + ipamType := "windows-test" + ipamMock := ipamtest.NewMockIPAMDriver(controller) + ipam.ResetIPAMDriver(ipamType, ipamMock) + isDocker := isDockerContainer(tc.netns) requestMsg, ovsPortName := prepareSetup(t, ipamType, testPodNameA, containerID, containerID, tc.netns, nil) hnsEndpoint := getHnsEndpoint(generateUUID(), ovsPortName) @@ -530,10 +538,6 @@ func TestCmdDel(t *testing.T) { } func TestCmdCheck(t *testing.T) { - controller := gomock.NewController(t) - ipamType := "windows-test" - ipamMock := ipamtest.NewMockIPAMDriver(controller) - ipam.ResetIPAMDriver(ipamType, ipamMock) ctx := context.TODO() containerNetns := generateUUID() @@ -544,7 +548,6 @@ func TestCmdCheck(t *testing.T) { defer mockHostInterfaceExists()() defer mockGetHnsNetworkByName()() - defer mockSetInterfaceMTU(nil)() defer mockListHnsEndpoint(nil, nil)() defer mockGetNetInterfaceAddrs(containerIPNet, nil)() defer mockGetHnsEndpointByName(generateUUID(), mac)() @@ -661,6 +664,11 @@ func TestCmdCheck(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + ipamType := "windows-test" + ipamMock := ipamtest.NewMockIPAMDriver(controller) + ipam.ResetIPAMDriver(ipamType, ipamMock) + defer mockGetNetInterfaceByName(tc.netInterface)() cniserver := newMockCNIServer(t, controller, channel.NewSubscribableChannel("podUpdate", 100)) requestMsg, _ := prepareSetup(t, ipamType, tc.podName, tc.containerID, tc.containerID, tc.netns, tc.prevResult) @@ -864,16 +872,3 @@ func mockListHnsEndpoint(endpoints []hcsshim.HNSEndpoint, listError error) func( listHnsEndpointFunc = originalListHnsEndpoint } } - -func mockSetInterfaceMTU(setMTUError error) func() { - originalSetInterfaceMTU := setInterfaceMTUFunc - setInterfaceMTUFunc = func(ifaceName string, mtu int) error { - if setMTUError == nil { - hostIfaces.Store(ifaceName, true) - } - return setMTUError - } - return func() { - setInterfaceMTUFunc = originalSetInterfaceMTU - } -} diff --git a/pkg/agent/externalnode/external_node_controller_windows.go b/pkg/agent/externalnode/external_node_controller_windows.go index 39ee3d840e1..be156b19872 100644 --- a/pkg/agent/externalnode/external_node_controller_windows.go +++ b/pkg/agent/externalnode/external_node_controller_windows.go @@ -21,9 +21,12 @@ import ( "antrea.io/antrea/pkg/agent/config" "antrea.io/antrea/pkg/agent/util" + "antrea.io/antrea/pkg/agent/util/winnet" "antrea.io/antrea/pkg/signals" ) +var winnetUtil winnet.Interface = &winnet.Handle{} + // moveIFConfigurations returns nil for single interface case, as it relies // on Windows New-VMSwitch command to create a host network adapter and copy // the uplink adapter configurations to host adapter. @@ -45,7 +48,7 @@ func (c *ExternalNodeController) removeExternalNodeConfig() error { klog.ErrorS(ovsErr, "Failed to delete OVS bridge") } - if err := util.RemoveVMSwitch(); err != nil { + if err := winnetUtil.RemoveVMSwitch(util.LocalVMSwitch); err != nil { return fmt.Errorf("failed to delete VM Switch, err: %v", err) } // Antrea Agent initializer creates a VM Switch corresponding to an diff --git a/pkg/agent/nodeportlocal/rules/netnat_rule.go b/pkg/agent/nodeportlocal/rules/netnat_rule.go index 6f7ee202f7b..02a15de4c38 100644 --- a/pkg/agent/nodeportlocal/rules/netnat_rule.go +++ b/pkg/agent/nodeportlocal/rules/netnat_rule.go @@ -25,6 +25,7 @@ import ( "antrea.io/antrea/pkg/agent/route" "antrea.io/antrea/pkg/agent/util" + "antrea.io/antrea/pkg/agent/util/winnet" binding "antrea.io/antrea/pkg/ovs/openflow" ) @@ -39,13 +40,15 @@ func InitRules() PodPortRules { } type netnatRules struct { - name string + name string + winnet winnet.Interface } -// NewNetNatRules retruns a new instance of netnatRules. +// NewNetNatRules returns a new instance of netnatRules. func NewNetNatRules() *netnatRules { nnRule := netnatRules{ - name: antreaNatNPL, + name: antreaNatNPL, + winnet: &winnet.Handle{}, } return &nnRule } @@ -61,7 +64,7 @@ func (nn *netnatRules) Init() error { // initRules creates or reuses NetNat table as NPL rule instance on Windows. func (nn *netnatRules) initRules() error { nn.DeleteAllRules() - if err := util.NewNetNat(antreaNatNPL, route.PodCIDRIPv4); err != nil { + if err := nn.winnet.AddNetNat(antreaNatNPL, route.PodCIDRIPv4); err != nil { return err } klog.InfoS("Successfully created NetNat rule", "name", antreaNatNPL, "CIDR", route.PodCIDRIPv4) @@ -70,7 +73,7 @@ func (nn *netnatRules) initRules() error { // AddRule appends a NetNatStaticMapping rule. func (nn *netnatRules) AddRule(nodePort int, podIP string, podPort int, protocol string) error { - netNatStaticMapping := &util.NetNatStaticMapping{ + netNatStaticMapping := &winnet.NetNatStaticMapping{ Name: antreaNatNPL, ExternalIP: net.ParseIP("0.0.0.0"), ExternalPort: util.PortToUint16(nodePort), @@ -78,7 +81,7 @@ func (nn *netnatRules) AddRule(nodePort int, podIP string, podPort int, protocol InternalPort: util.PortToUint16(podPort), Protocol: binding.Protocol(protocol), } - if err := util.ReplaceNetNatStaticMapping(netNatStaticMapping); err != nil { + if err := nn.winnet.ReplaceNetNatStaticMapping(netNatStaticMapping); err != nil { return err } klog.InfoS("Successfully added NetNatStaticMapping", "NetNatStaticMapping", netNatStaticMapping) @@ -97,7 +100,7 @@ func (nn *netnatRules) AddAllRules(nplList []PodNodePort) error { // DeleteRule deletes a specific NPL rule from NetNatStaticMapping table func (nn *netnatRules) DeleteRule(nodePort int, podIP string, podPort int, protocol string) error { - netNatStaticMapping := &util.NetNatStaticMapping{ + netNatStaticMapping := &winnet.NetNatStaticMapping{ Name: antreaNatNPL, ExternalIP: net.ParseIP("0.0.0.0"), ExternalPort: util.PortToUint16(nodePort), @@ -105,7 +108,7 @@ func (nn *netnatRules) DeleteRule(nodePort int, podIP string, podPort int, proto InternalPort: util.PortToUint16(podPort), Protocol: binding.Protocol(protocol), } - if err := util.RemoveNetNatStaticMappingByNPLTuples(netNatStaticMapping); err != nil { + if err := nn.winnet.RemoveNetNatStaticMapping(netNatStaticMapping); err != nil { return err } klog.InfoS("Successfully deleted NetNatStaticMapping", "NetNatStaticMapping", netNatStaticMapping) @@ -114,7 +117,7 @@ func (nn *netnatRules) DeleteRule(nodePort int, podIP string, podPort int, proto // DeleteAllRules deletes the NetNatStaticMapping table in the node func (nn *netnatRules) DeleteAllRules() error { - if err := util.RemoveNetNatStaticMappingByNAME(antreaNatNPL); err != nil { + if err := nn.winnet.RemoveNetNatStaticMappingsByNetNat(antreaNatNPL); err != nil { return err } klog.InfoS("Successfully deleted all NPL NetNatStaticMapping rules", "NatName", antreaNatNPL) diff --git a/pkg/agent/route/route_linux_test.go b/pkg/agent/route/route_linux_test.go index ab266186509..631129e426f 100644 --- a/pkg/agent/route/route_linux_test.go +++ b/pkg/agent/route/route_linux_test.go @@ -1031,8 +1031,6 @@ func TestAddRoutes(t *testing.T) { func TestDeleteRoutes(t *testing.T) { tests := []struct { name string - networkConfig *config.NetworkConfig - nodeConfig *config.NodeConfig podCIDR *net.IPNet existingNodeRoutes map[string][]*netlink.Route existingNodeNeighbors map[string]*netlink.Neigh @@ -1077,8 +1075,6 @@ func TestDeleteRoutes(t *testing.T) { mockIPSet := ipsettest.NewMockInterface(ctrl) c := &Client{netlink: mockNetlink, ipset: mockIPSet, - networkConfig: tt.networkConfig, - nodeConfig: tt.nodeConfig, nodeRoutes: sync.Map{}, nodeNeighbors: sync.Map{}, } @@ -1567,15 +1563,10 @@ func TestAddExternalIPRoute(t *testing.T) { tests := []struct { name string externalIPs []string - serviceRoutes map[string]*netlink.Route expectedCalls func(mockNetlink *netlinktest.MockInterfaceMockRecorder) }{ { - name: "IPv4", - serviceRoutes: map[string]*netlink.Route{ - externalIPv4Addr1: ipv4Route1, - externalIPv4Addr2: ipv4Route2, - }, + name: "IPv4", externalIPs: []string{externalIPv4Addr1, externalIPv4Addr2}, expectedCalls: func(mockNetlink *netlinktest.MockInterfaceMockRecorder) { mockNetlink.RouteReplace(ipv4Route1) @@ -1583,11 +1574,7 @@ func TestAddExternalIPRoute(t *testing.T) { }, }, { - name: "IPv6", - serviceRoutes: map[string]*netlink.Route{ - externalIPv6Addr1: ipv6Route1, - externalIPv6Addr2: ipv6Route2, - }, + name: "IPv6", externalIPs: []string{externalIPv6Addr1, externalIPv6Addr2}, expectedCalls: func(mockNetlink *netlinktest.MockInterfaceMockRecorder) { mockNetlink.RouteReplace(ipv6Route1) diff --git a/pkg/agent/route/route_windows.go b/pkg/agent/route/route_windows.go index 9c279a9c65c..8fa585265a5 100644 --- a/pkg/agent/route/route_windows.go +++ b/pkg/agent/route/route_windows.go @@ -33,6 +33,7 @@ import ( "antrea.io/antrea/pkg/agent/util" antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" "antrea.io/antrea/pkg/agent/util/winfirewall" + "antrea.io/antrea/pkg/agent/util/winnet" binding "antrea.io/antrea/pkg/ovs/openflow" iputil "antrea.io/antrea/pkg/util/ip" ) @@ -56,12 +57,13 @@ var ( type Client struct { nodeConfig *config.NodeConfig networkConfig *config.NetworkConfig + winnet winnet.Interface // nodeRoutes caches ip routes to remote Pods. It's a map of podCIDR to routes. - nodeRoutes *sync.Map + nodeRoutes sync.Map // serviceRoutes caches ip routes about Services. - serviceRoutes *sync.Map + serviceRoutes sync.Map // netNatStaticMappings caches Windows NetNat for NodePort. - netNatStaticMappings *sync.Map + netNatStaticMappings sync.Map fwClient *winfirewall.Client bridgeInfIndex int noSNAT bool @@ -79,14 +81,12 @@ func NewClient(networkConfig *config.NetworkConfig, multicastEnabled bool, serviceCIDRProvider servicecidr.Interface) (*Client, error) { return &Client{ - networkConfig: networkConfig, - nodeRoutes: &sync.Map{}, - serviceRoutes: &sync.Map{}, - netNatStaticMappings: &sync.Map{}, - fwClient: winfirewall.NewClient(), - noSNAT: noSNAT, - proxyAll: proxyAll, - serviceCIDRProvider: serviceCIDRProvider, + networkConfig: networkConfig, + winnet: &winnet.Handle{}, + fwClient: winfirewall.NewClient(), + noSNAT: noSNAT, + proxyAll: proxyAll, + serviceCIDRProvider: serviceCIDRProvider, }, nil } @@ -109,11 +109,11 @@ func (c *Client) Initialize(nodeConfig *config.NodeConfig, done func()) error { // to the host network stack from the host gateway interface, and its dst MAC could be resolved to the right one. // At last, the packet is sent back to OVS from the bridge Interface, and the OpenFlow entries will output it to // the uplink interface directly. - if err := util.EnableIPForwarding(nodeConfig.GatewayConfig.Name); err != nil { + if err := c.winnet.EnableIPForwarding(nodeConfig.GatewayConfig.Name); err != nil { return err } if !c.noSNAT { - err := util.NewNetNat(antreaNat, nodeConfig.PodIPv4CIDR) + err := c.winnet.AddNetNat(antreaNat, nodeConfig.PodIPv4CIDR) if err != nil { return err } @@ -124,7 +124,7 @@ func (c *Client) Initialize(nodeConfig *config.NodeConfig, done func()) error { return fmt.Errorf("failed to initialize Service IP routes: %v", err) } // For NodePort Service, a NetNatStaticMapping is needed. - if err := util.NewNetNat(antreaNatNodePort, virtualNodePortDNATIPv4Net); err != nil { + if err := c.winnet.AddNetNat(antreaNatNodePort, virtualNodePortDNATIPv4Net); err != nil { return err } } @@ -193,7 +193,7 @@ func (c *Client) Reconcile(podCIDRs []string) error { if c.proxyAll && c.isServiceRoute(&routes[i]) { continue } - err = util.RemoveNetRoute(&routes[i]) + err = c.winnet.RemoveNetRoute(&routes[i]) if err != nil { return err } @@ -205,9 +205,9 @@ func (c *Client) Reconcile(podCIDRs []string) error { // It overrides the routes if they already exist, without error. func (c *Client) AddRoutes(podCIDR *net.IPNet, nodeName string, peerNodeIP, peerGwIP net.IP) error { obj, found := c.nodeRoutes.Load(podCIDR.String()) - route := &util.Route{ + route := &winnet.Route{ DestinationSubnet: podCIDR, - RouteMetric: util.MetricDefault, + RouteMetric: winnet.MetricDefault, } if c.networkConfig.NeedsTunnelToPeer(peerNodeIP, c.nodeConfig.NodeTransportIPv4Addr) { route.LinkIndex = c.nodeConfig.GatewayConfig.LinkIndex @@ -222,13 +222,13 @@ func (c *Client) AddRoutes(podCIDR *net.IPNet, nodeName string, peerNodeIP, peer // Use host default route inside the Node. if found { - existingRoute := obj.(*util.Route) + existingRoute := obj.(*winnet.Route) if existingRoute.GatewayAddress.Equal(route.GatewayAddress) { klog.V(4).Infof("Route with destination %s already exists on %s (%s)", podCIDR.String(), nodeName, peerNodeIP) return nil } // Remove the existing route entry if the gateway address is not as expected. - if err := util.RemoveNetRoute(existingRoute); err != nil { + if err := c.winnet.RemoveNetRoute(existingRoute); err != nil { klog.Errorf("Failed to delete existing route entry with destination %s gateway %s on %s (%s)", podCIDR.String(), peerGwIP.String(), nodeName, peerNodeIP) return err } @@ -238,7 +238,7 @@ func (c *Client) AddRoutes(podCIDR *net.IPNet, nodeName string, peerNodeIP, peer return nil } - if err := util.ReplaceNetRoute(route); err != nil { + if err := c.winnet.ReplaceNetRoute(route); err != nil { return err } @@ -256,8 +256,8 @@ func (c *Client) DeleteRoutes(podCIDR *net.IPNet) error { return nil } - rt := obj.(*util.Route) - if err := util.RemoveNetRoute(rt); err != nil { + rt := obj.(*winnet.Route) + if err := c.winnet.RemoveNetRoute(rt); err != nil { return err } c.nodeRoutes.Delete(podCIDR.String()) @@ -272,13 +272,13 @@ func (c *Client) addVirtualServiceIPRoute(isIPv6 bool) error { svcIP := config.VirtualServiceIPv4 neigh := generateNeigh(svcIP, linkIndex) - if err := util.ReplaceNetNeighbor(neigh); err != nil { + if err := c.winnet.ReplaceNetNeighbor(neigh); err != nil { return fmt.Errorf("failed to add new IP neighbour for %s: %w", svcIP, err) } klog.InfoS("Added virtual Service IP neighbor", "neighbor", neigh) - route := generateRoute(virtualServiceIPv4Net, net.IPv4zero, linkIndex, util.MetricHigh) - if err := util.ReplaceNetRoute(route); err != nil { + route := generateRoute(virtualServiceIPv4Net, net.IPv4zero, linkIndex, winnet.MetricHigh) + if err := c.winnet.ReplaceNetRoute(route); err != nil { return fmt.Errorf("failed to install route for virtual Service IP %s: %w", svcIP.String(), err) } c.serviceRoutes.Store(svcIP.String(), route) @@ -290,12 +290,12 @@ func (c *Client) addVirtualServiceIPRoute(isIPv6 bool) error { func (c *Client) addServiceCIDRRoute(serviceCIDR *net.IPNet) error { linkIndex := c.nodeConfig.GatewayConfig.LinkIndex gw := config.VirtualServiceIPv4 - metric := util.MetricHigh + metric := winnet.MetricHigh oldServiceCIDRRoute, serviceCIDRRouteExists := c.serviceRoutes.Load(serviceIPv4CIDRKey) // Generate a route with the new ClusterIP CIDR and install it. route := generateRoute(serviceCIDR, gw, linkIndex, metric) - if err := util.ReplaceNetRoute(route); err != nil { + if err := c.winnet.ReplaceNetRoute(route); err != nil { return fmt.Errorf("failed to install a new Service CIDR route: %w", err) } @@ -304,11 +304,11 @@ func (c *Client) addServiceCIDRRoute(serviceCIDR *net.IPNet) error { c.serviceRoutes.Store(serviceIPv4CIDRKey, route) // Collect stale routes. - var staleRoutes []*util.Route + var staleRoutes []*winnet.Route // If current destination CIDR is not nil, the route with current destination CIDR should be uninstalled since // a new route with a newly calculated destination CIDR has been installed. if serviceCIDRRouteExists { - staleRoutes = append(staleRoutes, oldServiceCIDRRoute.(*util.Route)) + staleRoutes = append(staleRoutes, oldServiceCIDRRoute.(*winnet.Route)) } else { routes, err := c.listIPRoutesOnGW() if err != nil { @@ -335,7 +335,7 @@ func (c *Client) addServiceCIDRRoute(serviceCIDR *net.IPNet) error { // Remove stale routes. for _, rt := range staleRoutes { - if err := util.RemoveNetRoute(rt); err != nil { + if err := c.winnet.RemoveNetRoute(rt); err != nil { return fmt.Errorf("failed to delete stale Service CIDR route %s: %w", rt.String(), err) } else { klog.V(4).InfoS("Deleted stale Service CIDR route successfully", "route", rt) @@ -351,8 +351,8 @@ func (c *Client) addVirtualNodePortDNATIPRoute(isIPv6 bool) error { vIP := config.VirtualNodePortDNATIPv4 gw := config.VirtualServiceIPv4 - route := generateRoute(virtualNodePortDNATIPv4Net, gw, linkIndex, util.MetricHigh) - if err := util.ReplaceNetRoute(route); err != nil { + route := generateRoute(virtualNodePortDNATIPv4Net, gw, linkIndex, winnet.MetricHigh) + if err := c.winnet.ReplaceNetRoute(route); err != nil { return fmt.Errorf("failed to install route for NodePort DNAT IP %s: %w", vIP.String(), err) } c.serviceRoutes.Store(vIP.String(), route) @@ -392,30 +392,30 @@ func (c *Client) syncIPInfra() { } func (c *Client) syncRoute() error { - restoreRoute := func(route *util.Route) bool { - if err := util.ReplaceNetRoute(route); err != nil { + restoreRoute := func(route *winnet.Route) bool { + if err := c.winnet.ReplaceNetRoute(route); err != nil { klog.ErrorS(err, "Failed to sync route", "Route", route) return false } return true } c.nodeRoutes.Range(func(_, v interface{}) bool { - route := v.(*util.Route) + route := v.(*winnet.Route) return restoreRoute(route) }) if c.proxyAll { c.serviceRoutes.Range(func(_, v interface{}) bool { - route := v.(*util.Route) + route := v.(*winnet.Route) return restoreRoute(route) }) } // The route is installed automatically by the kernel when the address is configured on the interface. If the route // is deleted manually by mistake, we restore it. - gwAutoconfRoute := &util.Route{ + gwAutoconfRoute := &winnet.Route{ LinkIndex: c.nodeConfig.GatewayConfig.LinkIndex, DestinationSubnet: c.nodeConfig.PodIPv4CIDR, GatewayAddress: net.IPv4zero, - RouteMetric: util.MetricDefault, + RouteMetric: winnet.MetricDefault, } restoreRoute(gwAutoconfRoute) @@ -423,13 +423,13 @@ func (c *Client) syncRoute() error { } func (c *Client) syncNetNatStaticMapping() error { - if err := util.NewNetNat(antreaNatNodePort, virtualNodePortDNATIPv4Net); err != nil { + if err := c.winnet.AddNetNat(antreaNatNodePort, virtualNodePortDNATIPv4Net); err != nil { return err } c.netNatStaticMappings.Range(func(_, v interface{}) bool { - mapping := v.(*util.NetNatStaticMapping) - if err := util.ReplaceNetNatStaticMapping(mapping); err != nil { + mapping := v.(*winnet.NetNatStaticMapping) + if err := c.winnet.ReplaceNetNatStaticMapping(mapping); err != nil { klog.ErrorS(err, "Failed to add netNatStaticMapping", "netNatStaticMapping", mapping) return false } @@ -439,7 +439,7 @@ func (c *Client) syncNetNatStaticMapping() error { return nil } -func (c *Client) isServiceRoute(route *util.Route) bool { +func (c *Client) isServiceRoute(route *winnet.Route) bool { // If the gateway IP or the destination IP is the virtual Service IP, then it is a route added by AntreaProxy. if route.DestinationSubnet != nil && route.DestinationSubnet.IP.Equal(config.VirtualServiceIPv4) || route.GatewayAddress != nil && route.GatewayAddress.Equal(config.VirtualServiceIPv4) { @@ -448,10 +448,10 @@ func (c *Client) isServiceRoute(route *util.Route) bool { return false } -func (c *Client) listIPRoutesOnGW() ([]util.Route, error) { +func (c *Client) listIPRoutesOnGW() ([]winnet.Route, error) { family := antreasyscall.AF_INET - filter := &util.Route{LinkIndex: c.nodeConfig.GatewayConfig.LinkIndex} - return util.RouteListFiltered(family, filter, util.RT_FILTER_IF) + filter := &winnet.Route{LinkIndex: c.nodeConfig.GatewayConfig.LinkIndex} + return c.winnet.RouteListFiltered(family, filter, winnet.RT_FILTER_IF) } // initFwRules adds Windows Firewall rules to accept the traffic that is sent to or from local Pods. @@ -477,7 +477,7 @@ func (c *Client) DeleteSNATRule(mark uint32) error { // TODO: nodePortAddresses is not supported currently. func (c *Client) AddNodePort(nodePortAddresses []net.IP, port uint16, protocol binding.Protocol) error { - netNatStaticMapping := &util.NetNatStaticMapping{ + netNatStaticMapping := &winnet.NetNatStaticMapping{ Name: antreaNatNodePort, ExternalIP: net.ParseIP("0.0.0.0"), ExternalPort: port, @@ -485,7 +485,7 @@ func (c *Client) AddNodePort(nodePortAddresses []net.IP, port uint16, protocol b InternalPort: port, Protocol: protocol, } - if err := util.ReplaceNetNatStaticMapping(netNatStaticMapping); err != nil { + if err := c.winnet.ReplaceNetNatStaticMapping(netNatStaticMapping); err != nil { return err } c.netNatStaticMappings.Store(fmt.Sprintf("%d-%s", port, protocol), netNatStaticMapping) @@ -500,8 +500,8 @@ func (c *Client) DeleteNodePort(nodePortAddresses []net.IP, port uint16, protoco klog.V(2).InfoS("Didn't find corresponding NetNatStaticMapping for NodePort", "port", port, "protocol", protocol) return nil } - netNatStaticMapping := obj.(*util.NetNatStaticMapping) - if err := util.RemoveNetNatStaticMapping(netNatStaticMapping); err != nil { + netNatStaticMapping := obj.(*winnet.NetNatStaticMapping) + if err := c.winnet.RemoveNetNatStaticMapping(netNatStaticMapping); err != nil { return err } c.netNatStaticMappings.Delete(key) @@ -514,11 +514,11 @@ func (c *Client) AddExternalIPRoute(externalIP net.IP) error { externalIPStr := externalIP.String() linkIndex := c.nodeConfig.GatewayConfig.LinkIndex gw := config.VirtualServiceIPv4 - metric := util.MetricHigh + metric := winnet.MetricHigh svcIPNet := util.NewIPNet(externalIP) route := generateRoute(svcIPNet, gw, linkIndex, metric) - if err := util.ReplaceNetRoute(route); err != nil { + if err := c.winnet.ReplaceNetRoute(route); err != nil { return fmt.Errorf("failed to install route for external IP %s: %w", externalIPStr, err) } c.serviceRoutes.Store(externalIPStr, route) @@ -534,7 +534,7 @@ func (c *Client) DeleteExternalIPRoute(externalIP net.IP) error { klog.V(2).InfoS("Didn't find route for external IP", "IP", externalIPStr) return nil } - if err := util.RemoveNetRoute(route.(*util.Route)); err != nil { + if err := c.winnet.RemoveNetRoute(route.(*winnet.Route)); err != nil { return fmt.Errorf("failed to delete route for external IP %s: %w", externalIPStr, err) } c.serviceRoutes.Delete(externalIPStr) @@ -550,8 +550,8 @@ func (c *Client) DeleteLocalAntreaFlexibleIPAMPodRule(podAddresses []net.IP) err return nil } -func generateRoute(ipNet *net.IPNet, gw net.IP, linkIndex int, metric int) *util.Route { - return &util.Route{ +func generateRoute(ipNet *net.IPNet, gw net.IP, linkIndex int, metric int) *winnet.Route { + return &winnet.Route{ DestinationSubnet: ipNet, GatewayAddress: gw, RouteMetric: metric, @@ -559,8 +559,8 @@ func generateRoute(ipNet *net.IPNet, gw net.IP, linkIndex int, metric int) *util } } -func generateNeigh(ip net.IP, linkIndex int) *util.Neighbor { - return &util.Neighbor{ +func generateNeigh(ip net.IP, linkIndex int) *winnet.Neighbor { + return &winnet.Neighbor{ LinkIndex: linkIndex, IPAddress: ip, LinkLayerAddress: openflow.GlobalVirtualMAC, diff --git a/pkg/agent/route/route_windows_test.go b/pkg/agent/route/route_windows_test.go index 516e019838b..0f1648c1002 100644 --- a/pkg/agent/route/route_windows_test.go +++ b/pkg/agent/route/route_windows_test.go @@ -18,101 +18,424 @@ package route import ( + "fmt" "net" "sync" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "k8s.io/klog/v2" + "go.uber.org/mock/gomock" "antrea.io/antrea/pkg/agent/config" - "antrea.io/antrea/pkg/agent/util" + servicecidrtesting "antrea.io/antrea/pkg/agent/servicecidr/testing" antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" + "antrea.io/antrea/pkg/agent/util/winnet" + winnettesting "antrea.io/antrea/pkg/agent/util/winnet/testing" + "antrea.io/antrea/pkg/ovs/openflow" + "antrea.io/antrea/pkg/util/ip" ) var ( - // Leverage loopback interface for testing. - hostGateway = "Loopback Pseudo-Interface 1" - gwLink = getNetLinkIndex("Loopback Pseudo-Interface 1") - nodeConfig = &config.NodeConfig{ - OVSBridge: "Loopback Pseudo-Interface 1", - GatewayConfig: &config.GatewayConfig{ - Name: hostGateway, - LinkIndex: gwLink, - }, + externalIPv4Addr1 = "1.1.1.1" + externalIPv4Addr2 = "1.1.1.2" + externalIPv4Addr1WithPrefix = externalIPv4Addr1 + "/32" + externalIPv4Addr2WithPrefix = externalIPv4Addr2 + "/32" + + ipv4Route1 = generateRoute(ip.MustParseCIDR(externalIPv4Addr1WithPrefix), config.VirtualServiceIPv4, 10, winnet.MetricHigh) + ipv4Route2 = generateRoute(ip.MustParseCIDR(externalIPv4Addr2WithPrefix), config.VirtualServiceIPv4, 10, winnet.MetricHigh) + + nodePort = uint16(30000) + protocol = openflow.ProtocolTCP + nodePortNetNatStaticMapping = &winnet.NetNatStaticMapping{ + Name: antreaNatNodePort, + ExternalIP: net.ParseIP("0.0.0.0"), + ExternalPort: nodePort, + InternalIP: config.VirtualNodePortDNATIPv4, + InternalPort: nodePort, + Protocol: protocol, } ) -func getNetLinkIndex(dev string) int { - link, err := net.InterfaceByName(dev) - if err != nil { - klog.Fatalf("cannot find dev %s: %v", dev, err) +func TestSyncRoutes(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + + nodeRoute1 := &winnet.Route{DestinationSubnet: ip.MustParseCIDR("192.168.1.0/24"), GatewayAddress: net.ParseIP("1.1.1.1")} + nodeRoute2 := &winnet.Route{DestinationSubnet: ip.MustParseCIDR("192.168.2.0/24"), GatewayAddress: net.ParseIP("1.1.1.2")} + serviceRoute1 := &winnet.Route{DestinationSubnet: ip.MustParseCIDR("169.254.0.253/32"), LinkIndex: 10} + serviceRoute2 := &winnet.Route{DestinationSubnet: ip.MustParseCIDR("169.254.0.252/32"), GatewayAddress: net.ParseIP("169.254.0.253")} + mockWinnet.EXPECT().ReplaceNetRoute(nodeRoute1) + mockWinnet.EXPECT().ReplaceNetRoute(nodeRoute2) + mockWinnet.EXPECT().ReplaceNetRoute(serviceRoute1) + mockWinnet.EXPECT().ReplaceNetRoute(serviceRoute2) + mockWinnet.EXPECT().ReplaceNetRoute(&winnet.Route{ + LinkIndex: 10, + DestinationSubnet: ip.MustParseCIDR("192.168.0.0/24"), + GatewayAddress: net.IPv4zero, + RouteMetric: winnet.MetricDefault, + }) + + c := &Client{ + winnet: mockWinnet, + proxyAll: true, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{LinkIndex: 10, IPv4: net.ParseIP("192.168.0.1")}, + PodIPv4CIDR: ip.MustParseCIDR("192.168.0.0/24"), + }, + } + c.nodeRoutes.Store("192.168.1.0/24", nodeRoute1) + c.nodeRoutes.Store("192.168.2.0/24", nodeRoute2) + c.serviceRoutes.Store("169.254.0.253/32", serviceRoute1) + c.serviceRoutes.Store("169.254.0.252/32", serviceRoute2) + + assert.NoError(t, c.syncRoute()) +} + +func TestInitServiceIPRoutes(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + mockServiceCIDRProvider := servicecidrtesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + networkConfig: &config.NetworkConfig{ + TrafficEncapMode: config.TrafficEncapModeEncap, + IPv4Enabled: true, + }, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{Name: "antrea-gw0", LinkIndex: 10}, + }, + serviceCIDRProvider: mockServiceCIDRProvider, + } + mockWinnet.EXPECT().ReplaceNetRoute(generateRoute(virtualServiceIPv4Net, net.IPv4zero, 10, winnet.MetricHigh)) + mockWinnet.EXPECT().ReplaceNetRoute(generateRoute(virtualNodePortDNATIPv4Net, config.VirtualServiceIPv4, 10, winnet.MetricHigh)) + mockWinnet.EXPECT().ReplaceNetNeighbor(generateNeigh(config.VirtualServiceIPv4, 10)) + mockServiceCIDRProvider.EXPECT().AddEventHandler(gomock.Any()) + assert.NoError(t, c.initServiceIPRoutes()) +} + +func TestReconcile(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + proxyAll: true, + networkConfig: &config.NetworkConfig{}, + nodeConfig: &config.NodeConfig{ + PodIPv4CIDR: ip.MustParseCIDR("192.168.10.0/24"), + GatewayConfig: &config.GatewayConfig{LinkIndex: 10}, + }, + } + + mockWinnet.EXPECT().RouteListFiltered(antreasyscall.AF_INET, &winnet.Route{LinkIndex: 10}, winnet.RT_FILTER_IF).Return([]winnet.Route{ + {DestinationSubnet: ip.MustParseCIDR("192.168.10.0/24"), LinkIndex: 10}, // local podCIDR, should not be deleted. + {DestinationSubnet: ip.MustParseCIDR("192.168.1.0/24"), LinkIndex: 10}, // existing podCIDR, should not be deleted. + {DestinationSubnet: ip.MustParseCIDR("169.254.0.253/32"), LinkIndex: 10}, // service route, should not be deleted. + {DestinationSubnet: ip.MustParseCIDR("192.168.11.0/24"), LinkIndex: 10}, // non-existing podCIDR, should be deleted. + }, nil) + + podCIDRs := []string{"192.168.0.0/24", "192.168.1.0/24"} + mockWinnet.EXPECT().RemoveNetRoute(&winnet.Route{DestinationSubnet: ip.MustParseCIDR("192.168.11.0/24"), LinkIndex: 10}) + assert.NoError(t, c.Reconcile(podCIDRs)) +} + +func TestAddRoutes(t *testing.T) { + ipv4, nodeTransPortIPv4Addr, _ := net.ParseCIDR("172.16.10.2/24") + nodeTransPortIPv4Addr.IP = ipv4 + + tests := []struct { + name string + networkConfig *config.NetworkConfig + nodeConfig *config.NodeConfig + podCIDR *net.IPNet + nodeName string + nodeIP net.IP + nodeGwIP net.IP + expectedNetUtilCalls func(mockWinnet *winnettesting.MockInterfaceMockRecorder) + }{ + { + name: "encap IPv4", + networkConfig: &config.NetworkConfig{ + TrafficEncapMode: config.TrafficEncapModeEncap, + IPv4Enabled: true, + }, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{ + Name: "antrea-gw0", + IPv4: net.ParseIP("1.1.1.1"), + LinkIndex: 10, + }, + NodeTransportIPv4Addr: nodeTransPortIPv4Addr, + }, + podCIDR: ip.MustParseCIDR("192.168.10.0/24"), + nodeName: "node0", + nodeIP: net.ParseIP("1.1.1.10"), + nodeGwIP: net.ParseIP("192.168.10.1"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) { + mockWinnet.ReplaceNetRoute(&winnet.Route{ + GatewayAddress: net.ParseIP("192.168.10.1"), + DestinationSubnet: ip.MustParseCIDR("192.168.10.0/24"), + LinkIndex: 10, + RouteMetric: winnet.MetricDefault, + }) + }, + }, + { + name: "noencap IPv4, direct routing", + networkConfig: &config.NetworkConfig{ + TrafficEncapMode: config.TrafficEncapModeNoEncap, + IPv4Enabled: true, + }, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{ + Name: "antrea-gw0", + IPv4: net.ParseIP("192.168.1.1"), + LinkIndex: 10, + }, + NodeTransportIPv4Addr: nodeTransPortIPv4Addr, + }, + podCIDR: ip.MustParseCIDR("192.168.10.0/24"), + nodeName: "node0", + nodeIP: net.ParseIP("172.16.10.3"), // In the same subnet as local Node IP. + nodeGwIP: net.ParseIP("192.168.10.1"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) { + mockWinnet.ReplaceNetRoute(&winnet.Route{ + GatewayAddress: net.ParseIP("172.16.10.3"), + DestinationSubnet: ip.MustParseCIDR("192.168.10.0/24"), + RouteMetric: winnet.MetricDefault, + }) + }, + }, + { + name: "noencap IPv4, no direct routing", + networkConfig: &config.NetworkConfig{ + TrafficEncapMode: config.TrafficEncapModeNoEncap, + IPv4Enabled: true, + }, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{ + Name: "antrea-gw0", + IPv4: net.ParseIP("192.168.1.1"), + LinkIndex: 10, + }, + NodeTransportIPv4Addr: nodeTransPortIPv4Addr, + }, + podCIDR: ip.MustParseCIDR("192.168.10.0/24"), + nodeName: "node0", + nodeIP: net.ParseIP("172.16.11.3"), // In different subnet from local Node IP. + nodeGwIP: net.ParseIP("192.168.10.1"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) {}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + netutil := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: netutil, + networkConfig: tt.networkConfig, + nodeConfig: tt.nodeConfig, + } + tt.expectedNetUtilCalls(netutil.EXPECT()) + assert.NoError(t, c.AddRoutes(tt.podCIDR, tt.nodeName, tt.nodeIP, tt.nodeGwIP)) + }) + } +} + +func TestDeleteRoutes(t *testing.T) { + existingNodeRoutes := map[string]*winnet.Route{ + "192.168.10.0/24": {GatewayAddress: net.ParseIP("172.16.10.3"), DestinationSubnet: ip.MustParseCIDR("192.168.10.0/24")}, + "192.168.11.0/24": {GatewayAddress: net.ParseIP("172.16.10.4"), DestinationSubnet: ip.MustParseCIDR("192.168.11.0/24")}, + } + podCIDR := ip.MustParseCIDR("192.168.10.0/24") + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + nodeRoutes: sync.Map{}, + } + for podCIDRStr, nodeRoute := range existingNodeRoutes { + c.nodeRoutes.Store(podCIDRStr, nodeRoute) + } + mockWinnet.EXPECT().RemoveNetRoute(&winnet.Route{GatewayAddress: net.ParseIP("172.16.10.3"), DestinationSubnet: ip.MustParseCIDR("192.168.10.0/24")}) + assert.NoError(t, c.DeleteRoutes(podCIDR)) +} + +func TestAddNodePort(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + } + mockWinnet.EXPECT().ReplaceNetNatStaticMapping(nodePortNetNatStaticMapping) + assert.NoError(t, c.AddNodePort(nil, nodePort, protocol)) +} + +func TestDeleteNodePort(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + } + c.netNatStaticMappings.Store(fmt.Sprintf("%d-%s", nodePort, protocol), nodePortNetNatStaticMapping) + mockWinnet.EXPECT().RemoveNetNatStaticMapping(nodePortNetNatStaticMapping) + assert.NoError(t, c.DeleteNodePort(nil, nodePort, protocol)) +} + +func TestAddServiceCIDRRoute(t *testing.T) { + nodeConfig := &config.NodeConfig{GatewayConfig: &config.GatewayConfig{LinkIndex: 10}} + tests := []struct { + name string + curServiceIPv4CIDR *net.IPNet + newServiceIPv4CIDR *net.IPNet + expectedNetUtilCalls func(mockWinnet *winnettesting.MockInterfaceMockRecorder) + }{ + { + name: "Add route for Service IPv4 CIDR", + curServiceIPv4CIDR: nil, + newServiceIPv4CIDR: ip.MustParseCIDR("10.96.0.1/32"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) { + mockWinnet.ReplaceNetRoute(&winnet.Route{ + DestinationSubnet: &net.IPNet{IP: net.ParseIP("10.96.0.1").To4(), Mask: net.CIDRMask(32, 32)}, + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + mockWinnet.RouteListFiltered(antreasyscall.AF_INET, &winnet.Route{LinkIndex: 10}, winnet.RT_FILTER_IF).Return([]winnet.Route{ + { + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/24"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }, + }, nil) + mockWinnet.RemoveNetRoute(&winnet.Route{ + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/24"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + }, + }, + { + name: "Add route for Service IPv4 CIDR and clean up stale routes", + curServiceIPv4CIDR: nil, + newServiceIPv4CIDR: ip.MustParseCIDR("10.96.0.0/28"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) { + mockWinnet.ReplaceNetRoute(&winnet.Route{ + DestinationSubnet: &net.IPNet{IP: net.ParseIP("10.96.0.0").To4(), Mask: net.CIDRMask(28, 32)}, + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + mockWinnet.RouteListFiltered(antreasyscall.AF_INET, &winnet.Route{LinkIndex: 10}, winnet.RT_FILTER_IF).Return([]winnet.Route{ + { + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/24"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }, + { + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/30"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }, + }, nil) + mockWinnet.RemoveNetRoute(&winnet.Route{ + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/24"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + mockWinnet.RemoveNetRoute(&winnet.Route{ + DestinationSubnet: ip.MustParseCIDR("10.96.0.0/30"), + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + }, + }, + { + name: "Update route for Service IPv4 CIDR", + curServiceIPv4CIDR: ip.MustParseCIDR("10.96.0.1/32"), + newServiceIPv4CIDR: ip.MustParseCIDR("10.96.0.0/28"), + expectedNetUtilCalls: func(mockWinnet *winnettesting.MockInterfaceMockRecorder) { + mockWinnet.ReplaceNetRoute(&winnet.Route{ + DestinationSubnet: &net.IPNet{IP: net.ParseIP("10.96.0.0").To4(), Mask: net.CIDRMask(28, 32)}, + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + mockWinnet.RemoveNetRoute(&winnet.Route{ + DestinationSubnet: &net.IPNet{IP: net.ParseIP("10.96.0.1").To4(), Mask: net.CIDRMask(32, 32)}, + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + nodeConfig: nodeConfig, + } + tt.expectedNetUtilCalls(mockWinnet.EXPECT()) + + if tt.curServiceIPv4CIDR != nil { + c.serviceRoutes.Store(serviceIPv4CIDRKey, &winnet.Route{ + DestinationSubnet: &net.IPNet{IP: net.ParseIP("10.96.0.1").To4(), Mask: net.CIDRMask(32, 32)}, + GatewayAddress: config.VirtualServiceIPv4, + RouteMetric: winnet.MetricHigh, + LinkIndex: 10, + }) + } + + assert.NoError(t, c.addServiceCIDRRoute(tt.newServiceIPv4CIDR)) + }) } - return link.Index } -func TestRouteOperation(t *testing.T) { - peerNodeIP1 := net.ParseIP("10.0.0.2") - peerNodeIP2 := net.ParseIP("10.0.0.3") - gwIP1 := net.ParseIP("192.168.2.1") - _, destCIDR1, _ := net.ParseCIDR("192.168.2.0/24") - dest2 := "192.168.3.0/24" - gwIP2 := net.ParseIP("192.168.3.1") - _, destCIDR2, _ := net.ParseCIDR(dest2) - - client, err := NewClient(&config.NetworkConfig{}, true, false, false, false, false, nil) - - require.Nil(t, err) - called := false - err = client.Initialize(nodeConfig, func() { called = true }) - require.Nil(t, err) - require.True(t, called) - - // Add initial routes. - err = client.AddRoutes(destCIDR1, "node1", peerNodeIP1, gwIP1) - require.Nil(t, err) - routes1, err := util.RouteListFiltered(antreasyscall.AF_INET, &util.Route{LinkIndex: gwLink, DestinationSubnet: destCIDR1}, util.RT_FILTER_IF|util.RT_FILTER_DST) - require.Nil(t, err) - assert.Equal(t, 1, len(routes1)) - - err = client.AddRoutes(destCIDR2, "node2", peerNodeIP2, gwIP2) - require.Nil(t, err) - routes2, err := util.RouteListFiltered(antreasyscall.AF_INET, &util.Route{LinkIndex: gwLink, DestinationSubnet: destCIDR2}, util.RT_FILTER_IF|util.RT_FILTER_DST) - require.Nil(t, err) - assert.Equal(t, 1, len(routes2)) - - err = client.Reconcile([]string{dest2}) - require.Nil(t, err) - - err = client.DeleteRoutes(destCIDR2) - require.Nil(t, err) - routes7, err := util.RouteListFiltered(antreasyscall.AF_INET, &util.Route{LinkIndex: gwLink, DestinationSubnet: destCIDR2}, util.RT_FILTER_IF|util.RT_FILTER_DST) - require.Nil(t, err) - assert.Equal(t, 0, len(routes7)) +func TestAddExternalIPRoute(t *testing.T) { + externalIPs := []string{externalIPv4Addr1, externalIPv4Addr2} + + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) + c := &Client{ + winnet: mockWinnet, + nodeConfig: &config.NodeConfig{ + GatewayConfig: &config.GatewayConfig{ + LinkIndex: 10, + }, + }, + } + mockWinnet.EXPECT().ReplaceNetRoute(ipv4Route1) + mockWinnet.EXPECT().ReplaceNetRoute(ipv4Route2) + + for _, externalIP := range externalIPs { + assert.NoError(t, c.AddExternalIPRoute(net.ParseIP(externalIP))) + } } -func TestAddAndDeleteExternalIPRoute(t *testing.T) { +func TestDeleteExternalIPRoute(t *testing.T) { + externalIPs := []string{externalIPv4Addr1, externalIPv4Addr2} + + ctrl := gomock.NewController(t) + mockWinnet := winnettesting.NewMockInterface(ctrl) c := &Client{ - nodeConfig: nodeConfig, - serviceRoutes: &sync.Map{}, - } - externalIP := net.ParseIP("1.1.1.1") - - assert.NoError(t, c.AddExternalIPRoute(externalIP)) - externalIPNet := util.NewIPNet(externalIP) - routes, err := util.RouteListFiltered(antreasyscall.AF_INET, &util.Route{LinkIndex: gwLink, DestinationSubnet: externalIPNet}, util.RT_FILTER_IF|util.RT_FILTER_DST) - require.Nil(t, err) - assert.Equal(t, 1, len(routes)) - - route, ok := c.serviceRoutes.Load(externalIP.String()) - assert.True(t, ok) - assert.EqualValues(t, routes[0], *route.(*util.Route)) - - assert.NoError(t, c.DeleteExternalIPRoute(externalIP)) - routes, err = util.RouteListFiltered(antreasyscall.AF_INET, &util.Route{LinkIndex: gwLink, DestinationSubnet: externalIPNet}, util.RT_FILTER_IF|util.RT_FILTER_DST) - require.Nil(t, err) - assert.Equal(t, 0, len(routes)) - _, ok = c.serviceRoutes.Load(externalIP.String()) - assert.False(t, ok) + winnet: mockWinnet, + } + for ipStr, route := range map[string]*winnet.Route{externalIPv4Addr1: ipv4Route1, externalIPv4Addr2: ipv4Route2} { + c.serviceRoutes.Store(ipStr, route) + } + + mockWinnet.EXPECT().RemoveNetRoute(ipv4Route1) + mockWinnet.EXPECT().RemoveNetRoute(ipv4Route2) + + for _, externalIP := range externalIPs { + assert.NoError(t, c.DeleteExternalIPRoute(net.ParseIP(externalIP))) + } } diff --git a/pkg/agent/util/net_windows.go b/pkg/agent/util/net_windows.go index 4930e072a4b..8e9c0ca534d 100644 --- a/pkg/agent/util/net_windows.go +++ b/pkg/agent/util/net_windows.go @@ -18,272 +18,46 @@ package util import ( - "bufio" "bytes" "context" "encoding/json" - "errors" "fmt" "net" - "os" - "runtime" - "strconv" "strings" - "syscall" "time" - "unsafe" "github.com/Microsoft/go-winio" "github.com/Microsoft/hcsshim" "github.com/containernetworking/plugins/pkg/ip" - "golang.org/x/sys/windows" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" - ps "antrea.io/antrea/pkg/agent/util/powershell" antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" - binding "antrea.io/antrea/pkg/ovs/openflow" - iputil "antrea.io/antrea/pkg/util/ip" + "antrea.io/antrea/pkg/agent/util/winnet" ) const ( - ContainerVNICPrefix = "vEthernet" - HNSNetworkType = "Transparent" - LocalHNSNetwork = "antrea-hnsnetwork" - OVSExtensionID = "583CC151-73EC-4A6A-8B47-578297AD7623" - ovsExtensionName = "Open vSwitch Extension" - namedPipePrefix = `\\.\pipe\` - commandRetryTimeout = 5 * time.Second - commandRetryInterval = time.Second - - MetricDefault = 256 - MetricHigh = 50 + LocalHNSNetwork = "antrea-hnsnetwork" + HNSNetworkType = "Transparent" + namedPipePrefix = `\\.\pipe\` AntreaNatName = "antrea-nat" LocalVMSwitch = "antrea-switch" - - // Filter masks are used to indicate the attributes used for route filtering. - RT_FILTER_IF uint64 = 1 << (1 + iota) - RT_FILTER_METRIC - RT_FILTER_DST - RT_FILTER_GW - - // IP_ADAPTER_DHCP_ENABLED is defined in the Win32 API document. - // https://learn.microsoft.com/en-us/windows/win32/api/iptypes/ns-iptypes-ip_adapter_addresses_lh - IP_ADAPTER_DHCP_ENABLED = 0x00000004 ) var ( - // Declared variables which are meant to be overridden for testing. - antreaNetIO = antreasyscall.NewNetIO() - getAdaptersAddresses = windows.GetAdaptersAddresses - runCommand = ps.RunCommand - getHNSNetworkByName = hcsshim.GetHNSNetworkByName - hnsNetworkRequest = hcsshim.HNSNetworkRequest - hnsNetworkCreate = (*hcsshim.HNSNetwork).Create - hnsNetworkDelete = (*hcsshim.HNSNetwork).Delete -) + winnetUtil winnet.Interface = &winnet.Handle{} -type Route struct { - LinkIndex int - DestinationSubnet *net.IPNet - GatewayAddress net.IP - RouteMetric int -} - -func (r *Route) String() string { - return fmt.Sprintf("LinkIndex: %d, DestinationSubnet: %s, GatewayAddress: %s, RouteMetric: %d", - r.LinkIndex, r.DestinationSubnet, r.GatewayAddress, r.RouteMetric) -} - -func (r *Route) Equal(x Route) bool { - return x.LinkIndex == r.LinkIndex && - x.DestinationSubnet != nil && - r.DestinationSubnet != nil && - iputil.IPNetEqual(x.DestinationSubnet, r.DestinationSubnet) && - x.GatewayAddress.Equal(r.GatewayAddress) -} - -func (r *Route) toMibIPForwardRow() *antreasyscall.MibIPForwardRow { - row := antreasyscall.NewIPForwardRow() - row.DestinationPrefix = *antreasyscall.NewAddressPrefixFromIPNet(r.DestinationSubnet) - row.NextHop = *antreasyscall.NewRawSockAddrInetFromIP(r.GatewayAddress) - row.Metric = uint32(r.RouteMetric) - row.Index = uint32(r.LinkIndex) - return row -} - -func routeFromIPForwardRow(row *antreasyscall.MibIPForwardRow) *Route { - destination := row.DestinationPrefix.IPNet() - gatewayAddr := row.NextHop.IP() - return &Route{ - DestinationSubnet: destination, - GatewayAddress: gatewayAddr, - LinkIndex: int(row.Index), - RouteMetric: int(row.Metric), - } -} - -type Neighbor struct { - LinkIndex int - IPAddress net.IP - LinkLayerAddress net.HardwareAddr - State string -} - -func (n Neighbor) String() string { - return fmt.Sprintf("LinkIndex: %d, IPAddress: %s, LinkLayerAddress: %s", n.LinkIndex, n.IPAddress, n.LinkLayerAddress) -} - -type NetNatStaticMapping struct { - Name string - ExternalIP net.IP - ExternalPort uint16 - InternalIP net.IP - InternalPort uint16 - Protocol binding.Protocol -} - -func (n NetNatStaticMapping) String() string { - return fmt.Sprintf("Name: %s, ExternalIP %s, ExternalPort: %d, InternalIP: %s, InternalPort: %d, Protocol: %s", n.Name, n.ExternalIP, n.ExternalPort, n.InternalIP, n.InternalPort, n.Protocol) -} + getHNSNetworkByName = hcsshim.GetHNSNetworkByName + hnsNetworkRequest = hcsshim.HNSNetworkRequest + hnsNetworkCreate = (*hcsshim.HNSNetwork).Create + hnsNetworkDelete = (*hcsshim.HNSNetwork).Delete +) func GetNSPath(containerNetNS string) (string, error) { return containerNetNS, nil } -// IsVirtualAdapter checks if the provided adapter is virtual. -func IsVirtualAdapter(name string) (bool, error) { - cmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Virtual | Format-Table -HideTableHeaders`, name) - out, err := runCommand(cmd) - if err != nil { - return false, err - } - isVirtual, err := strconv.ParseBool(strings.TrimSpace(out)) - if err != nil { - return false, err - } - return isVirtual, nil -} - -func GetHostInterfaceStatus(ifaceName string) (string, error) { - cmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Status | Format-Table -HideTableHeaders`, ifaceName) - out, err := runCommand(cmd) - if err != nil { - return "", err - } - return strings.TrimSpace(out), nil -} - -// EnableHostInterface sets the specified interface status as UP. -func EnableHostInterface(ifaceName string) error { - cmd := fmt.Sprintf(`Enable-NetAdapter -InterfaceAlias "%s"`, ifaceName) - // Enable-NetAdapter is not a blocking operation based on our testing. - // It returns immediately no matter whether the interface has been enabled or not. - // So we need to check the interface status to ensure it is up before returning. - if err := wait.PollUntilContextTimeout(context.TODO(), commandRetryInterval, commandRetryTimeout, true, func(ctx context.Context) (done bool, err error) { - if _, err := runCommand(cmd); err != nil { - klog.Errorf("Failed to run command %s: %v", cmd, err) - return false, nil - } - status, err := GetHostInterfaceStatus(ifaceName) - if err != nil { - klog.Errorf("Failed to run command %s: %v", cmd, err) - return false, nil - } - if !strings.EqualFold(status, "Up") { - klog.Infof("Waiting for host interface %s to be up", ifaceName) - return false, nil - } - return true, nil - }); err != nil { - return fmt.Errorf("failed to enable interface %s: %v", ifaceName, err) - } - return nil -} - -// ConfigureInterfaceAddress adds IPAddress on the specified interface. -func ConfigureInterfaceAddress(ifaceName string, ipConfig *net.IPNet) error { - ipStr := strings.Split(ipConfig.String(), "/") - cmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s`, ifaceName, ipStr[0], ipStr[1]) - _, err := runCommand(cmd) - // If the address already exists, ignore the error. - if err != nil && !strings.Contains(err.Error(), "already exists") { - return err - } - return nil -} - -// RemoveInterfaceAddress removes IPAddress from the specified interface. -func RemoveInterfaceAddress(ifaceName string, ipAddr net.IP) error { - cmd := fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -Confirm:$false`, ifaceName, ipAddr.String()) - _, err := runCommand(cmd) - // If the address does not exist, ignore the error. - if err != nil && !strings.Contains(err.Error(), "No matching") { - return err - } - return nil -} - -// ConfigureInterfaceAddressWithDefaultGateway adds IPAddress on the specified interface and sets the default gateway -// for the host. -func ConfigureInterfaceAddressWithDefaultGateway(ifaceName string, ipConfig *net.IPNet, gateway string) error { - ipStr := strings.Split(ipConfig.String(), "/") - cmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s`, ifaceName, ipStr[0], ipStr[1]) - if gateway != "" { - cmd = fmt.Sprintf("%s -DefaultGateway %s", cmd, gateway) - } - _, err := runCommand(cmd) - // If the address already exists, ignore the error. - if err != nil && !strings.Contains(err.Error(), "already exists") { - return err - } - return nil -} - -// EnableIPForwarding enables the IP interface to forward packets that arrive at this interface to other interfaces. -func EnableIPForwarding(ifaceName string) error { - adapter, err := getAdapterInAllCompartmentsByName(ifaceName) - if err != nil { - return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", ifaceName, err) - } - return adapter.setForwarding(true, antreasyscall.AF_INET) -} - -func RenameVMNetworkAdapter(networkName string, macStr, newName string, renameNetAdapter bool) error { - cmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -SwitchName "%s" | ? MacAddress -EQ "%s" | Select-Object -Property Name | Format-Table -HideTableHeaders`, networkName, macStr) - stdout, err := runCommand(cmd) - if err != nil { - return err - } - stdout = strings.TrimSpace(stdout) - if len(stdout) == 0 { - return fmt.Errorf("unable to find vmnetwork adapter configured with uplink MAC address %s", macStr) - } - vmNetworkAdapterName := stdout - cmd = fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -Name "%s" | Rename-VMNetworkAdapter -NewName "%s"`, vmNetworkAdapterName, newName) - if _, err := runCommand(cmd); err != nil { - return err - } - if renameNetAdapter { - oriNetAdapterName := VirtualAdapterName(newName) - cmd = fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, oriNetAdapterName, newName) - if _, err := runCommand(cmd); err != nil { - return err - } - } - return nil -} - -// SetAdapterMACAddress sets specified MAC address on interface. -func SetAdapterMACAddress(adapterName string, macConfig *net.HardwareAddr) error { - macAddr := strings.Replace(macConfig.String(), ":", "", -1) - cmd := fmt.Sprintf(`Set-NetAdapterAdvancedProperty -Name "%s" -RegistryKeyword NetworkAddress -RegistryValue "%s"`, - adapterName, macAddr) - _, err := runCommand(cmd) - return err -} - // CreateHNSNetwork creates a new HNS Network, whose type is "Transparent". The NetworkAdapter is using the host // interface which is configured with Node IP. HNS Network properties "ManagementIP" and "SourceMac" are used to record // the original IP and MAC addresses on the network adapter. @@ -352,18 +126,28 @@ func EnableHNSNetworkExtension(hnsNetID string, vSwitchExtension string) error { } func SetLinkUp(name string) (net.HardwareAddr, int, error) { - // Set host gateway interface up. - if err := EnableHostInterface(name); err != nil { - klog.Errorf("Failed to set host link for %s up: %v", name, err) - if strings.Contains(err.Error(), "ObjectNotFound") { - return nil, 0, newLinkNotFoundError(name) + if err := wait.PollUntilContextTimeout(context.TODO(), time.Second, 5*time.Second, true, func(ctx context.Context) (done bool, err error) { + if err := winnetUtil.EnableNetAdapter(name); err != nil { + klog.Errorf("Failed to enable network adapter %s: %v", name, err) + return false, nil } - return nil, 0, err + enabled, err := winnetUtil.IsNetAdapterStatusUp(name) + if err != nil { + klog.Errorf("Failed to get network adapter status %s: %v", name, err) + return false, nil + } + if !enabled { + klog.Infof("Waiting for network adapter %s to be up", name) + return false, nil + } + return true, nil + }); err != nil { + return nil, 0, fmt.Errorf("failed to enable network adapter %s: %v", name, err) } iface, err := netInterfaceByName(name) if err != nil { - if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "no such network interface" { + if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "no such network adapter" { return nil, 0, newLinkNotFoundError(name) } return nil, 0, err @@ -432,7 +216,7 @@ func ConfigureLinkAddresses(idx int, ipNets []*net.IPNet) error { for _, addr := range addrsToRemove { klog.V(2).Infof("Removing address %v from interface %s", addr, ifaceName) - if err := RemoveInterfaceAddress(ifaceName, addr.IP); err != nil { + if err := winnetUtil.RemoveNetAdapterIPAddress(ifaceName, addr.IP); err != nil { return fmt.Errorf("failed to remove address %v from interface %s: %v", addr, ifaceName, err) } } @@ -443,7 +227,7 @@ func ConfigureLinkAddresses(idx int, ipNets []*net.IPNet) error { klog.Warningf("Windows only supports IPv4 addresses, skipping this address %v", addr) return nil } - if err := ConfigureInterfaceAddress(ifaceName, addr); err != nil { + if err := winnetUtil.AddNetAdapterIPAddress(ifaceName, addr, ""); err != nil { return fmt.Errorf("failed to add address %v to interface %s: %v", addr, ifaceName, err) } } @@ -471,7 +255,7 @@ func PrepareHNSNetwork(subnetCIDR *net.IPNet, nodeIPNet *net.IPNet, uplinkAdapte // Therefore, we set the timeout limit to triple of that value, allowing a maximum wait of 6 seconds here. err = wait.PollUntilContextTimeout(context.TODO(), 1*time.Second, 6*time.Second, true, func(ctx context.Context) (bool, error) { var checkErr error - adapter, ipFound, checkErr = adapterIPExists(nodeIPNet.IP, uplinkAdapter.HardwareAddr, ContainerVNICPrefix) + adapter, ipFound, checkErr = adapterIPExists(nodeIPNet.IP, uplinkAdapter.HardwareAddr, winnet.ContainerVNICPrefix) if checkErr != nil { return false, checkErr } @@ -479,7 +263,7 @@ func PrepareHNSNetwork(subnetCIDR *net.IPNet, nodeIPNet *net.IPNet, uplinkAdapte }) if err != nil { if wait.Interrupted(err) { - dhcpStatus, err := InterfaceIPv4DhcpEnabled(uplinkAdapter.Name) + dhcpStatus, err := winnetUtil.IsNetAdapterIPv4DHCPEnabled(uplinkAdapter.Name) if err != nil { klog.ErrorS(err, "Failed to get IPv4 DHCP status on the network adapter", "adapter", uplinkAdapter.Name) } else { @@ -495,25 +279,25 @@ func PrepareHNSNetwork(subnetCIDR *net.IPNet, nodeIPNet *net.IPNet, uplinkAdapte // Server fails to allocate IP to new virtual network. if !ipFound { klog.InfoS("Moving uplink configuration to the management virtual network adapter", "adapter", vNicName) - if err := ConfigureInterfaceAddressWithDefaultGateway(vNicName, nodeIPNet, nodeGateway); err != nil { + if err := winnetUtil.AddNetAdapterIPAddress(vNicName, nodeIPNet, nodeGateway); err != nil { klog.ErrorS(err, "Failed to configure IP and gateway on the management virtual network adapter", "adapter", vNicName, "ip", nodeIPNet.String()) return err } if dnsServers != "" { - if err := SetAdapterDNSServers(vNicName, dnsServers); err != nil { + if err := winnetUtil.SetNetAdapterDNSServers(vNicName, dnsServers); err != nil { klog.ErrorS(err, "Failed to configure DNS servers on the management virtual network adapter", "adapter", vNicName, "dnsServers", dnsServers) return err } } for _, route := range routes { - rt := route.(Route) - newRt := Route{ + rt := route.(winnet.Route) + newRt := winnet.Route{ LinkIndex: index, DestinationSubnet: rt.DestinationSubnet, GatewayAddress: rt.GatewayAddress, RouteMetric: rt.RouteMetric, } - if err := NewNetRoute(&newRt); err != nil { + if err := winnetUtil.AddNetRoute(&newRt); err != nil { return err } } @@ -524,7 +308,7 @@ func PrepareHNSNetwork(subnetCIDR *net.IPNet, nodeIPNet *net.IPNet, uplinkAdapte uplinkMACStr := strings.Replace(uplinkAdapter.HardwareAddr.String(), ":", "", -1) // Rename NetAdapter in the meanwhile, then the network adapter can be treated as a host network adapter other than // a vm network adapter. - if err = RenameVMNetworkAdapter(LocalHNSNetwork, uplinkMACStr, newName, true); err != nil { + if err = winnetUtil.RenameVMNetworkAdapter(LocalHNSNetwork, uplinkMACStr, newName, true); err != nil { return err } } @@ -532,11 +316,11 @@ func PrepareHNSNetwork(subnetCIDR *net.IPNet, nodeIPNet *net.IPNet, uplinkAdapte // Enable OVS Extension on the HNS Network. If an error occurs, delete the HNS Network and return the error. // While the hnsshim API allows for enabling the OVS extension when creating an HNS network, it can cause the adapter being unable // to obtain a valid DHCP IP in case of network interruption. Therefore, we have to enable the OVS extension after running adapterIPExists. - if err = EnableHNSNetworkExtension(hnsNet.Id, OVSExtensionID); err != nil { + if err = EnableHNSNetworkExtension(hnsNet.Id, winnet.OVSExtensionID); err != nil { return err } - if err = EnableRSCOnVSwitch(LocalHNSNetwork); err != nil { + if err = winnetUtil.EnableRSCOnVSwitch(LocalHNSNetwork); err != nil { return err } @@ -576,45 +360,15 @@ func adapterIPExists(ip net.IP, mac net.HardwareAddr, namePrefix string) (*net.I return nil, false, fmt.Errorf("unable to find a network adapter with MAC %s, IP %s, and name prefix %s", mac.String(), ip.String(), namePrefix) } -// EnableRSCOnVSwitch enables RSC in the vSwitch to reduce host CPU utilization and increase throughput for virtual -// workloads by coalescing multiple TCP segments into fewer, but larger segments. -func EnableRSCOnVSwitch(vSwitch string) error { - cmd := fmt.Sprintf("Get-VMSwitch -ComputerName $(hostname) -Name %s | Select-Object -Property SoftwareRscEnabled | Format-Table -HideTableHeaders", vSwitch) - stdout, err := runCommand(cmd) - if err != nil { - return err - } - stdout = strings.TrimSpace(stdout) - // RSC doc says it applies to Windows Server 2019, which is the only Windows operating system supported so far, so - // this should not happen. However, this is only an optimization, no need to crash the process even if it's not - // supported. - // https://docs.microsoft.com/en-us/windows-server/networking/technologies/hpn/rsc-in-the-vswitch - if len(stdout) == 0 { - klog.Warning("Receive Segment Coalescing (RSC) is not supported by this Windows Server version") - return nil - } - if strings.EqualFold(stdout, "True") { - klog.Infof("Receive Segment Coalescing (RSC) for vSwitch %s is already enabled", vSwitch) - return nil - } - cmd = fmt.Sprintf("Set-VMSwitch -ComputerName $(hostname) -Name %s -EnableSoftwareRsc $True", vSwitch) - _, err = runCommand(cmd) - if err != nil { - return err - } - klog.Infof("Enabled Receive Segment Coalescing (RSC) for vSwitch %s", vSwitch) - return nil -} - // GetDefaultGatewayByInterfaceIndex returns the default gateway configured on the specified interface. func GetDefaultGatewayByInterfaceIndex(ifIndex int) (string, error) { ip, defaultDestination, _ := net.ParseCIDR("0.0.0.0/0") - family := addressFamilyByIP(ip) - filter := &Route{ + family := winnet.AddressFamilyByIP(ip) + filter := &winnet.Route{ LinkIndex: ifIndex, DestinationSubnet: defaultDestination, } - routes, err := RouteListFiltered(family, filter, RT_FILTER_IF|RT_FILTER_DST) + routes, err := winnetUtil.RouteListFiltered(family, filter, winnet.RT_FILTER_IF|winnet.RT_FILTER_DST) if err != nil { return "", err } @@ -624,29 +378,8 @@ func GetDefaultGatewayByInterfaceIndex(ifIndex int) (string, error) { return routes[0].GatewayAddress.String(), nil } -// GetDNServersByInterfaceIndex returns the DNS servers configured on the specified interface. -func GetDNServersByInterfaceIndex(ifIndex int) (string, error) { - cmd := fmt.Sprintf("$(Get-DnsClientServerAddress -InterfaceIndex %d -AddressFamily IPv4).ServerAddresses", ifIndex) - dnsServers, err := runCommand(cmd) - if err != nil { - return "", err - } - dnsServers = strings.ReplaceAll(dnsServers, "\r\n", ",") - dnsServers = strings.TrimRight(dnsServers, ",") - return dnsServers, nil -} - -// SetAdapterDNSServers configures DNSServers on network adapter. -func SetAdapterDNSServers(adapterName, dnsServers string) error { - cmd := fmt.Sprintf(`Set-DnsClientServerAddress -InterfaceAlias "%s" -ServerAddresses "%s"`, adapterName, dnsServers) - if _, err := runCommand(cmd); err != nil { - return err - } - return nil -} - // ListenLocalSocket creates a listener on a Unix domain socket or a Windows named pipe. -// - If the specified address starts with "\\.\pipe\", create a listener on the a Windows named pipe path. +// - If the specified address starts with "\\.\pipe\", create a listener on a Windows named pipe path. // - Else create a listener on a local Unix domain socket. func ListenLocalSocket(address string) (net.Listener, error) { if strings.HasPrefix(address, namedPipePrefix) { @@ -656,338 +389,7 @@ func ListenLocalSocket(address string) (net.Listener, error) { } func HostInterfaceExists(ifaceName string) bool { - _, err := getAdapterInAllCompartmentsByName(ifaceName) - if err != nil { - return false - } - return true -} - -// InterfaceIPv4DhcpEnabled returns the IPv4 DHCP status on the specified interface. -func InterfaceIPv4DhcpEnabled(ifaceName string) (bool, error) { - adapter, err := getAdapterInAllCompartmentsByName(ifaceName) - if err != nil { - return false, err - } - ipv4Dhcp := (adapter.flags&IP_ADAPTER_DHCP_ENABLED != 0) - return ipv4Dhcp, nil -} - -// SetInterfaceMTU configures interface MTU on host for Pods. MTU change cannot be realized with HNSEndpoint because -// there's no MTU field in HNSEndpoint: -// https://github.com/Microsoft/hcsshim/blob/4a468a6f7ae547974bc32911395c51fb1862b7df/internal/hns/hnsendpoint.go#L12 -func SetInterfaceMTU(ifaceName string, mtu int) error { - adapter, err := getAdapterInAllCompartmentsByName(ifaceName) - if err != nil { - return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", ifaceName, err) - } - return adapter.setMTU(mtu, antreasyscall.AF_INET) -} - -func NewNetRoute(route *Route) error { - if route == nil { - return nil - } - row := route.toMibIPForwardRow() - if err := antreaNetIO.CreateIPForwardEntry(row); err != nil { - return fmt.Errorf("failed to create new IPForward row: %v", err) - } - return nil -} - -func RemoveNetRoute(route *Route) error { - if route == nil || route.DestinationSubnet == nil { - return nil - } - family := addressFamilyByIP(route.DestinationSubnet.IP) - rows, err := antreaNetIO.ListIPForwardRows(family) - if err != nil { - return fmt.Errorf("unable to list Windows IPForward rows: %v", err) - } - for i := range rows { - row := rows[i] - if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) && row.NextHop.IP().Equal(route.GatewayAddress) { - if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { - return fmt.Errorf("failed to delete existing route %s: %v", route.String(), err) - } - } - } - return nil -} - -func ReplaceNetRoute(route *Route) error { - if route == nil || route.DestinationSubnet == nil { - return nil - } - family := addressFamilyByIP(route.DestinationSubnet.IP) - rows, err := antreaNetIO.ListIPForwardRows(family) - if err != nil { - return fmt.Errorf("unable to list Windows IPForward rows: %v", err) - } - for i := range rows { - row := rows[i] - if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) { - if row.NextHop.IP().Equal(route.GatewayAddress) { - return nil - } else { - if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { - return fmt.Errorf("failed to delete existing route with nextHop %s: %v", route.GatewayAddress, err) - } - } - } - } - return NewNetRoute(route) -} - -func RouteListFiltered(family uint16, filter *Route, filterMask uint64) ([]Route, error) { - rows, err := antreaNetIO.ListIPForwardRows(family) - if err != nil { - return nil, fmt.Errorf("unable to list Windows IPForward rows: %v", err) - } - rts := make([]Route, 0, len(rows)) - for i := range rows { - route := routeFromIPForwardRow(&rows[i]) - if filter != nil { - if filterMask&RT_FILTER_IF != 0 && filter.LinkIndex != route.LinkIndex { - continue - } - if filterMask&RT_FILTER_DST != 0 && !iputil.IPNetEqual(filter.DestinationSubnet, route.DestinationSubnet) { - continue - } - if filterMask&RT_FILTER_GW != 0 && !filter.GatewayAddress.Equal(route.GatewayAddress) { - continue - } - if filterMask&RT_FILTER_METRIC != 0 && filter.RouteMetric != route.RouteMetric { - continue - } - } - rts = append(rts, *route) - } - return rts, nil -} - -func addressFamilyByIP(ip net.IP) uint16 { - if ip.To4() != nil { - return antreasyscall.AF_INET - } - return antreasyscall.AF_INET6 -} - -func parseGetNetCmdResult(result string, itemNum int) [][]string { - scanner := bufio.NewScanner(strings.NewReader(result)) - parsed := [][]string{} - for scanner.Scan() { - items := strings.Fields(scanner.Text()) - if len(items) < itemNum { - // Skip if an empty line or something similar - continue - } - parsed = append(parsed, items) - } - return parsed -} - -func NewNetNat(netNatName string, subnetCIDR *net.IPNet) error { - cmd := fmt.Sprintf(`Get-NetNat -Name %s | Select-Object InternalIPInterfaceAddressPrefix | Format-Table -HideTableHeaders`, netNatName) - if internalNet, err := runCommand(cmd); err != nil { - if !strings.Contains(err.Error(), "No MSFT_NetNat objects found") { - klog.ErrorS(err, "Failed to check the existing netnat", "name", netNatName) - return err - } - } else { - if strings.Contains(internalNet, subnetCIDR.String()) { - klog.V(4).InfoS("The existing netnat matched the subnet CIDR", "name", internalNet, "subnetCIDR", subnetCIDR.String()) - return nil - } - klog.InfoS("Removing the existing NetNat", "name", netNatName, "internalIPInterfaceAddressPrefix", internalNet) - cmd = fmt.Sprintf("Remove-NetNat -Name %s -Confirm:$false", netNatName) - if _, err := runCommand(cmd); err != nil { - klog.ErrorS(err, "Failed to remove the existing netnat", "name", netNatName, "internalIPInterfaceAddressPrefix", internalNet) - return err - } - } - cmd = fmt.Sprintf(`New-NetNat -Name %s -InternalIPInterfaceAddressPrefix %s`, netNatName, subnetCIDR.String()) - _, err := runCommand(cmd) - if err != nil { - klog.ErrorS(err, "Failed to add netnat", "name", netNatName, "internalIPInterfaceAddressPrefix", subnetCIDR.String()) - return err - } - return nil -} - -func ReplaceNetNatStaticMapping(mapping *NetNatStaticMapping) error { - staticMappingStr, err := GetNetNatStaticMapping(mapping) - if err != nil { - return err - } - parsed := parseGetNetCmdResult(staticMappingStr, 6) - if len(parsed) > 0 { - items := parsed[0] - if items[4] == mapping.InternalIP.String() && items[5] == strconv.Itoa(int(mapping.InternalPort)) { - return nil - } - firstCol := strings.Split(items[0], ";") - id, err := strconv.Atoi(firstCol[1]) - if err != nil { - return err - } - if err := RemoveNetNatStaticMappingByID(mapping.Name, id); err != nil { - return err - } - } - return AddNetNatStaticMapping(mapping) -} - -// GetNetNatStaticMapping checks if a NetNatStaticMapping exists. -func GetNetNatStaticMapping(mapping *NetNatStaticMapping) (string, error) { - cmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", mapping.Name) + - fmt.Sprintf("|? ExternalIPAddress -EQ %s", mapping.ExternalIP) + - fmt.Sprintf("|? ExternalPort -EQ %d", mapping.ExternalPort) + - fmt.Sprintf("|? Protocol -EQ %s", mapping.Protocol) + - "| Format-Table -HideTableHeaders" - staticMappingStr, err := runCommand(cmd) - if err != nil && !strings.Contains(err.Error(), "No MSFT_NetNatStaticMapping objects found") { - return "", err - } - return staticMappingStr, nil -} - -// AddNetNatStaticMapping adds a static mapping to a NAT instance. -func AddNetNatStaticMapping(mapping *NetNatStaticMapping) error { - cmd := fmt.Sprintf("Add-NetNatStaticMapping -NatName %s -ExternalIPAddress %s -ExternalPort %d -InternalIPAddress %s -InternalPort %d -Protocol %s", - mapping.Name, mapping.ExternalIP, mapping.ExternalPort, mapping.InternalIP, mapping.InternalPort, mapping.Protocol) - _, err := runCommand(cmd) - return err -} - -func RemoveNetNatStaticMapping(mapping *NetNatStaticMapping) error { - staticMappingStr, err := GetNetNatStaticMapping(mapping) - if err != nil { - return err - } - parsed := parseGetNetCmdResult(staticMappingStr, 6) - if len(parsed) == 0 { - return nil - } - - firstCol := strings.Split(parsed[0][0], ";") - id, err := strconv.Atoi(firstCol[1]) - if err != nil { - return err - } - return RemoveNetNatStaticMappingByID(mapping.Name, id) -} - -func RemoveNetNatStaticMappingByNPLTuples(mapping *NetNatStaticMapping) error { - staticMappingStr, err := GetNetNatStaticMapping(mapping) - if err != nil { - return err - } - parsed := parseGetNetCmdResult(staticMappingStr, 6) - if len(parsed) > 0 { - items := parsed[0] - if items[4] == mapping.InternalIP.String() && items[5] == strconv.Itoa(int(mapping.InternalPort)) { - firstCol := strings.Split(items[0], ";") - id, err := strconv.Atoi(firstCol[1]) - if err != nil { - return err - } - if err := RemoveNetNatStaticMappingByID(mapping.Name, id); err != nil { - return err - } - return nil - } - } - return nil -} - -func RemoveNetNatStaticMappingByID(netNatName string, id int) error { - cmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", netNatName, id) - _, err := runCommand(cmd) - return err -} - -func RemoveNetNatStaticMappingByNAME(netNatName string) error { - cmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -Confirm:$false", netNatName) - _, err := runCommand(cmd) - return err -} - -// GetNetNeighbor gets neighbor cache entries with Get-NetNeighbor. -func GetNetNeighbor(neighbor *Neighbor) ([]Neighbor, error) { - cmd := fmt.Sprintf("Get-NetNeighbor -InterfaceIndex %d -IPAddress %s | Format-Table -HideTableHeaders", neighbor.LinkIndex, neighbor.IPAddress.String()) - neighborsStr, err := runCommand(cmd) - if err != nil && !strings.Contains(err.Error(), "No matching MSFT_NetNeighbor objects") { - return nil, err - } - - parsed := parseGetNetCmdResult(neighborsStr, 5) - var neighbors []Neighbor - for _, items := range parsed { - idx, err := strconv.Atoi(items[0]) - if err != nil { - return nil, fmt.Errorf("failed to parse the LinkIndex '%s': %v", items[0], err) - } - dstIP := net.ParseIP(items[1]) - if err != nil { - return nil, fmt.Errorf("failed to parse the DestinationIP '%s': %v", items[1], err) - } - // Get-NetNeighbor returns LinkLayerAddress like "AA-BB-CC-DD-EE-FF". - mac, err := net.ParseMAC(strings.ReplaceAll(items[2], "-", ":")) - if err != nil { - return nil, fmt.Errorf("failed to parse the Gateway MAC '%s': %v", items[2], err) - } - neighbor := Neighbor{ - LinkIndex: idx, - IPAddress: dstIP, - LinkLayerAddress: mac, - State: items[3], - } - neighbors = append(neighbors, neighbor) - } - return neighbors, nil -} - -// NewNetNeighbor creates a new neighbor cache entry with New-NetNeighbor. -func NewNetNeighbor(neighbor *Neighbor) error { - cmd := fmt.Sprintf("New-NetNeighbor -InterfaceIndex %d -IPAddress %s -LinkLayerAddress %s -State Permanent", - neighbor.LinkIndex, neighbor.IPAddress, neighbor.LinkLayerAddress) - _, err := runCommand(cmd) - return err -} - -func RemoveNetNeighbor(neighbor *Neighbor) error { - cmd := fmt.Sprintf("Remove-NetNeighbor -InterfaceIndex %d -IPAddress %s -Confirm:$false", - neighbor.LinkIndex, neighbor.IPAddress) - _, err := runCommand(cmd) - return err -} - -func ReplaceNetNeighbor(neighbor *Neighbor) error { - neighbors, err := GetNetNeighbor(neighbor) - if err != nil { - return err - } - - if len(neighbors) == 0 { - if err := NewNetNeighbor(neighbor); err != nil { - return err - } - return nil - } - for _, n := range neighbors { - if n.LinkLayerAddress.String() == neighbor.LinkLayerAddress.String() && n.State == neighbor.State { - return nil - } - } - if err := RemoveNetNeighbor(neighbor); err != nil { - return err - } - return NewNetNeighbor(neighbor) -} - -func VirtualAdapterName(name string) string { - return fmt.Sprintf("%s (%s)", ContainerVNICPrefix, name) + return winnetUtil.NetAdapterExists(ifaceName) } func GetInterfaceConfig(ifName string) (*net.Interface, []*net.IPNet, []interface{}, error) { @@ -999,7 +401,7 @@ func GetInterfaceConfig(ifName string) (*net.Interface, []*net.IPNet, []interfac if err != nil { return nil, nil, nil, fmt.Errorf("failed to get address for interface %s: %v", iface.Name, err) } - rts, err := RouteListFiltered(antreasyscall.AF_UNSPEC, &Route{LinkIndex: iface.Index}, RT_FILTER_IF) + rts, err := winnetUtil.RouteListFiltered(antreasyscall.AF_UNSPEC, &winnet.Route{LinkIndex: iface.Index}, winnet.RT_FILTER_IF) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get routes for interface index %d: %v", iface.Index, err) } @@ -1017,7 +419,7 @@ func GetInterfaceConfig(ifName string) (*net.Interface, []*net.IPNet, []interfac func RenameInterface(from, to string) error { var renameErr error pollErr := wait.PollUntilContextTimeout(context.TODO(), time.Millisecond*100, time.Second, false, func(ctx context.Context) (done bool, err error) { - renameErr = renameHostInterface(from, to) + renameErr = winnetUtil.RenameNetAdapter(from, to) if renameErr != nil { klog.ErrorS(renameErr, "Failed to rename adapter, retrying") return false, nil @@ -1030,298 +432,6 @@ func RenameInterface(from, to string) error { return nil } -func GetVMSwitchInterfaceName() (string, error) { - cmd := fmt.Sprintf(`Get-VMSwitchTeam -Name "%s" | select NetAdapterInterfaceDescription | Format-Table -HideTableHeaders`, LocalVMSwitch) - out, err := runCommand(cmd) - if err != nil { - return "", err - } - out = strings.TrimSpace(out) - // Remove the leading and trailing {} brackets - out = out[1 : len(out)-1] - cmd = fmt.Sprintf(`Get-NetAdapter -InterfaceDescription "%s" | select Name | Format-Table -HideTableHeaders`, out) - out, err = runCommand(cmd) - if err != nil { - return "", err - } - out = strings.TrimSpace(out) - return out, err -} - -func VMSwitchExists() (bool, error) { - cmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)`, LocalVMSwitch) - _, err := runCommand(cmd) - if err == nil { - return true, nil - } - if strings.Contains(err.Error(), fmt.Sprintf(`unable to find a virtual switch with name "%s"`, LocalVMSwitch)) { - return false, nil - } - return false, err -} - -// CreateVMSwitch creates a virtual switch and enables openvswitch extension. -// If switch exists and extension is enabled, then it will return no error. -// Otherwise, it will throw an error. -// TODO: Handle for multiple interfaces -func CreateVMSwitch(ifName string) error { - exists, err := VMSwitchExists() - if err != nil { - return err - } - if !exists { - if err = createVMSwitchWithTeaming(LocalVMSwitch, ifName); err != nil { - return err - } - } - - enabled, err := isOVSExtensionEnabled() - if err != nil { - return err - } - if !enabled { - if err = enableOVSExtension(); err != nil { - return err - } - } - return nil -} - -func RemoveVMSwitch() error { - exists, err := VMSwitchExists() - if err != nil { - return err - } - if exists { - cmd := fmt.Sprintf(`Remove-VMSwitch -Name "%s" -ComputerName $(hostname) -Force`, LocalVMSwitch) - _, err = runCommand(cmd) - if err != nil { - return err - } - } - return nil -} - func GenHostInterfaceName(upLinkIfName string) string { return strings.TrimSuffix(upLinkIfName, bridgedUplinkSuffix) } - -type updateIPInterfaceFunc func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow - -type adapter struct { - net.Interface - compartmentID uint32 - flags uint32 -} - -func (a *adapter) setMTU(mtu int, family uint16) error { - if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { - newEntry := *entry - newEntry.NlMtu = uint32(mtu) - return &newEntry - }); err != nil { - return fmt.Errorf("unable to set IPInterface with MTU %d: %v", mtu, err) - } - return nil -} - -func (a *adapter) setForwarding(enabledForwarding bool, family uint16) error { - if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { - newEntry := *entry - newEntry.ForwardingEnabled = enabledForwarding - return &newEntry - }); err != nil { - return fmt.Errorf("unable to enable IPForwarding on net adapter: %v", err) - } - return nil -} - -func (a *adapter) setIPInterfaceEntry(family uint16, updateFunc updateIPInterfaceFunc) error { - if a.compartmentID > 1 { - runtime.LockOSThread() - defer func() { - hcsshim.SetCurrentThreadCompartmentId(0) - runtime.UnlockOSThread() - }() - if err := hcsshim.SetCurrentThreadCompartmentId(a.compartmentID); err != nil { - klog.ErrorS(err, "Failed to change current thread's compartment", "compartment", a.compartmentID) - return err - } - } - ipInterfaceRow := &antreasyscall.MibIPInterfaceRow{Family: family, Index: uint32(a.Index)} - if err := antreaNetIO.GetIPInterfaceEntry(ipInterfaceRow); err != nil { - return fmt.Errorf("unable to get IPInterface entry with Index %d: %v", a.Index, err) - } - updatedRow := updateFunc(ipInterfaceRow) - updatedRow.SitePrefixLength = 0 - return antreaNetIO.SetIPInterfaceEntry(updatedRow) -} - -var ( - errInvalidInterfaceName = errors.New("invalid network interface name") - errNoSuchInterface = errors.New("no such network interface") -) - -func getAdapterInAllCompartmentsByName(name string) (*adapter, error) { - if name == "" { - return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName} - } - adapters, err := getAdaptersByName(name) - if err != nil { - return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} - } - if len(adapters) == 0 { - return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface} - } - return &adapters[0], nil -} - -// createVMSwitchWithTeaming creates VMSwitch and enables OVS extension. -// Connection to VM is lost for few seconds -func createVMSwitchWithTeaming(switchName, ifName string) error { - cmd := fmt.Sprintf(`New-VMSwitch -Name "%s" -NetAdapterName "%s" -EnableEmbeddedTeaming $true -AllowManagementOS $true -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, switchName, ifName, ovsExtensionName) - _, err := runCommand(cmd) - if err != nil { - return err - } - return nil -} - -func enableOVSExtension() error { - cmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, LocalVMSwitch, ovsExtensionName) - _, err := runCommand(cmd) - if err != nil { - return err - } - return nil -} - -// parseOVSExtensionOutput parses the VM extension output -// and returns the value of Enabled field -func parseOVSExtensionOutput(s string) bool { - scanner := bufio.NewScanner(strings.NewReader(s)) - for scanner.Scan() { - temp := strings.Fields(scanner.Text()) - line := strings.Join(temp, "") - if strings.Contains(line, "Enabled") { - if strings.Contains(line, "True") { - return true - } - return false - } - } - return false -} - -func isOVSExtensionEnabled() (bool, error) { - cmd := fmt.Sprintf(`Get-VMSwitchExtension -VMSwitchName "%s" -ComputerName $(hostname) | ? Id -EQ "%s"`, LocalVMSwitch, OVSExtensionID) - out, err := runCommand(cmd) - if err != nil { - return false, err - } - if !strings.Contains(out, ovsExtensionName) { - return false, fmt.Errorf("open vswitch extension driver is not installed") - } - return parseOVSExtensionOutput(out), nil -} - -func renameHostInterface(oriName string, newName string) error { - cmd := fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, oriName, newName) - _, err := runCommand(cmd) - return err -} - -func getAdaptersByName(name string) ([]adapter, error) { - aas, err := adapterAddresses() - if err != nil { - return nil, err - } - var adapters []adapter - for _, aa := range aas { - ifName := windows.UTF16PtrToString(aa.FriendlyName) - if ifName != name { - continue - } - index := aa.IfIndex - if index == 0 { // ipv6IfIndex is a substitute for ifIndex - index = aa.Ipv6IfIndex - } - ifi := net.Interface{ - Index: int(index), - Name: ifName, - } - if aa.OperStatus == windows.IfOperStatusUp { - ifi.Flags |= net.FlagUp - } - // For now we need to infer link-layer service capabilities from media types. - // TODO: use MIB_IF_ROW2.AccessType now that we no longer support Windows XP. - switch aa.IfType { - case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394: - ifi.Flags |= net.FlagBroadcast | net.FlagMulticast - case windows.IF_TYPE_PPP, windows.IF_TYPE_TUNNEL: - ifi.Flags |= net.FlagPointToPoint | net.FlagMulticast - case windows.IF_TYPE_SOFTWARE_LOOPBACK: - ifi.Flags |= net.FlagLoopback | net.FlagMulticast - case windows.IF_TYPE_ATM: - ifi.Flags |= net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast // assume all services available; LANE, point-to-point and point-to-multipoint - } - if aa.Mtu == 0xffffffff { - ifi.MTU = -1 - } else { - ifi.MTU = int(aa.Mtu) - } - if aa.PhysicalAddressLength > 0 { - ifi.HardwareAddr = make(net.HardwareAddr, aa.PhysicalAddressLength) - copy(ifi.HardwareAddr, aa.PhysicalAddress[:]) - } - adapter := adapter{ - Interface: ifi, - compartmentID: aa.CompartmentId, - flags: aa.Flags, - } - adapters = append(adapters, adapter) - } - return adapters, nil -} - -// GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is used in windows.GetAdapterAddresses parameter -// flags to return addresses in all routing compartments. -const GAA_FLAG_INCLUDE_ALL_COMPARTMENTS = 0x00000200 - -// GAA_FLAG_INCLUDE_ALL_INTERFACES is used in windows.GetAdapterAddresses parameter -// flags to return addresses for all NDIS interfaces. -const GAA_FLAG_INCLUDE_ALL_INTERFACES = 0x00000100 - -// adapterAddresses returns a list of IpAdapterAddresses structures. The structure -// contains an IP adapter and flattened multiple IP addresses including unicast, anycast -// and multicast addresses. -// This function is copied from go/src/net/interface_windows.go, with a change that flag -// GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is introduced to query interfaces in all compartments, -// and GAA_FLAG_INCLUDE_ALL_INTERFACES is introduced to query all NDIS interfaces even they -// are not configured with any IP addresses, e.g., uplink. -func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { - flags := uint32(windows.GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_ALL_COMPARTMENTS | GAA_FLAG_INCLUDE_ALL_INTERFACES) - var b []byte - l := uint32(15000) // recommended initial size - for { - b = make([]byte, l) - err := getAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) - if err == nil { - if l == 0 { - return nil, nil - } - break - } - if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - if l <= uint32(len(b)) { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - } - var aas []*windows.IpAdapterAddresses - for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { - aas = append(aas, aa) - } - return aas, nil -} diff --git a/pkg/agent/util/net_windows_test.go b/pkg/agent/util/net_windows_test.go index 4b7e2fcc934..06c1dd06852 100644 --- a/pkg/agent/util/net_windows_test.go +++ b/pkg/agent/util/net_windows_test.go @@ -20,57 +20,19 @@ package util import ( "fmt" "net" - "os" "strings" "testing" - antreasyscalltest "antrea.io/antrea/pkg/agent/util/syscall/testing" - "github.com/Microsoft/hcsshim" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sys/windows" + "go.uber.org/mock/gomock" antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" - "antrea.io/antrea/pkg/ovs/openflow" + "antrea.io/antrea/pkg/agent/util/winnet" + winnettesting "antrea.io/antrea/pkg/agent/util/winnet/testing" ) -func TestRouteString(t *testing.T) { - gw, subnet, _ := net.ParseCIDR("192.168.2.0/24") - testRoute := Route{ - LinkIndex: 1, - DestinationSubnet: subnet, - GatewayAddress: gw, - RouteMetric: MetricDefault, - } - gotRoute := testRoute.String() - assert.Equal(t, "LinkIndex: 1, DestinationSubnet: 192.168.2.0/24, GatewayAddress: 192.168.2.0, RouteMetric: 256", gotRoute) -} - -func TestRouteTranslation(t *testing.T) { - _, subnet, _ := net.ParseCIDR("1.1.1.0/28") - oriRoute := &Route{ - LinkIndex: 27, - RouteMetric: 35, - DestinationSubnet: subnet, - GatewayAddress: net.ParseIP("1.1.1.254"), - } - row := oriRoute.toMibIPForwardRow() - newRoute := routeFromIPForwardRow(row) - assert.Equal(t, oriRoute, newRoute) -} - -func TestNeighborString(t *testing.T) { - testNeighbor := Neighbor{ - LinkIndex: 1, - IPAddress: net.ParseIP("169.254.0.253"), - LinkLayerAddress: testMACAddr, - State: "Permanent", - } - gotNeighbor := testNeighbor.String() - assert.Equal(t, "LinkIndex: 1, IPAddress: 169.254.0.253, LinkLayerAddress: aa:bb:cc:dd:ee:ff", gotNeighbor) -} - func TestGetNSPath(t *testing.T) { testNSPath := "/dev/null" gotNSPath, err := GetNSPath(testNSPath) @@ -78,81 +40,51 @@ func TestGetNSPath(t *testing.T) { assert.Equal(t, testNSPath, gotNSPath) } -func TestIsVirtualAdapter(t *testing.T) { - adapter := "test-adapter" - tests := []struct { - name string - commandOut string - commandErr error - adapter string - wantIsVirtual bool - }{ - { - name: "Virtual adapter", - commandOut: " true ", - wantIsVirtual: true, - }, - { - name: "Virtual adapter Err", - commandErr: testInvalidErr, - wantIsVirtual: false, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, []string{ - fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Virtual | Format-Table -HideTableHeaders`, adapter), - }, tc.commandOut, tc.commandErr, true)() - gotIsVirtual, err := IsVirtualAdapter(adapter) - assert.Equal(t, tc.wantIsVirtual, gotIsVirtual) - assert.Equal(t, tc.commandErr, err) - }) - } -} - func TestSetLinkUp(t *testing.T) { - testName := "test-en0" - enableCmd := fmt.Sprintf(`Enable-NetAdapter -InterfaceAlias "%s"`, testName) - getCmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Status | Format-Table -HideTableHeaders`, testName) + testName := "test-link" tests := []struct { name string - commandOut string - commandErr error gwInterface *net.Interface gwInterfaceErr error - wantCmds []string + expectedError error + expectedCalls func(mockNetUtil *winnettesting.MockInterfaceMockRecorder) }{ { - name: "Set Link Up Normal", - commandOut: " UP ", + name: "Set Link Up Normal", gwInterface: &net.Interface{ Index: 1, Name: testName, HardwareAddr: testMACAddr, }, - wantCmds: []string{enableCmd, getCmd}, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.EnableNetAdapter(testName).Return(nil).MinTimes(1) + mockUtil.IsNetAdapterStatusUp(testName).Return(true, nil).Times(1) + }, }, { name: "Enable Interface Err", - commandErr: fmt.Errorf("fail"), gwInterface: &net.Interface{Index: 0}, - gwInterfaceErr: fmt.Errorf("failed to enable interface %s", testName), - wantCmds: []string{enableCmd}, + gwInterfaceErr: fmt.Errorf("failed to enable network adapter %s", testName), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.EnableNetAdapter(testName).Return(fmt.Errorf("failed to enable interface %s: failed reason", testName)).MinTimes(1) + }, }, { name: "Get Interface Err", - commandOut: " Up ", gwInterface: &net.Interface{Index: 0}, gwInterfaceErr: testInvalidErr, - wantCmds: []string{enableCmd, getCmd}, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.EnableNetAdapter(testName).Return(nil).MinTimes(1) + mockUtil.IsNetAdapterStatusUp(testName).Return(true, nil).Times(1) + }, }, } - for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, false)() + ctrl := gomock.NewController(t) + defer mockUtilWinnet(ctrl, tc.expectedCalls)() defer mockNetInterfaceByName(tc.gwInterface, tc.gwInterfaceErr)() + gotMac, gotIndex, err := SetLinkUp(testName) assert.Equal(t, tc.gwInterface.HardwareAddr, gotMac) assert.Equal(t, tc.gwInterface.Index, gotIndex) @@ -167,24 +99,21 @@ func TestSetLinkUp(t *testing.T) { func TestConfigureLinkAddresses(t *testing.T) { testNetInterface := generateNetInterface("0") - ipStr := strings.Split(ipv4ZeroIPNet.String(), "/") - removeCmd := fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -Confirm:$false`, "0", ipv4Public.String()) - newCmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s`, "0", ipStr[0], ipStr[1]) tests := []struct { name string ipNets []*net.IPNet - commandOut string - commandErr error testNetInterfaceErr error testNetAddrsErr error - wantCmds []string + expectedCalls func(mockNetUtil *winnettesting.MockInterfaceMockRecorder) wantErr error }{ { - name: "Configure Link Addr", - ipNets: []*net.IPNet{&ipv4ZeroIPNet}, - commandOut: "success", - wantCmds: []string{removeCmd, newCmd}, + name: "Configure Link Addr", + ipNets: []*net.IPNet{&ipv4ZeroIPNet}, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RemoveNetAdapterIPAddress("0", ipv4Public).Return(nil).Times(1) + mockUtil.AddNetAdapterIPAddress("0", &ipv4ZeroIPNet, "").Return(nil).Times(1) + }, }, { name: "Net Interface Err", @@ -199,23 +128,25 @@ func TestConfigureLinkAddresses(t *testing.T) { wantErr: fmt.Errorf("failed to query IPv4 address list for interface 0: invalid"), }, { - name: "Link Addr No Change", - ipNets: []*net.IPNet{&ipv4PublicIPNet}, - commandOut: "success", + name: "Link Addr No Change", + ipNets: []*net.IPNet{&ipv4PublicIPNet}, }, { - name: "Link Addr Configure Err", - ipNets: []*net.IPNet{&ipv4ZeroIPNet}, - commandErr: fmt.Errorf("interface No matching"), - wantCmds: []string{removeCmd, newCmd}, - wantErr: fmt.Errorf("failed to add address 0.0.0.0/32 to interface 0: interface No matching"), + name: "Link Addr Configure Err", + ipNets: []*net.IPNet{&ipv4ZeroIPNet}, + wantErr: fmt.Errorf("failed to add address 0.0.0.0/32 to interface 0: interface No matching"), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RemoveNetAdapterIPAddress("0", ipv4Public).Return(nil).Times(1) + mockUtil.AddNetAdapterIPAddress("0", &ipv4ZeroIPNet, "").Return(fmt.Errorf("interface No matching")).Times(1) + }, }, { - name: "Link Addr Remove Err", - ipNets: []*net.IPNet{&ipv4ZeroIPNet}, - commandErr: fmt.Errorf("interface already exists"), - wantCmds: []string{removeCmd}, - wantErr: fmt.Errorf("failed to remove address 8.8.8.8/32 from interface 0: interface already exists"), + name: "Link Addr Remove Err", + ipNets: []*net.IPNet{&ipv4ZeroIPNet}, + wantErr: fmt.Errorf("failed to remove address 8.8.8.8/32 from interface 0: interface already exists"), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RemoveNetAdapterIPAddress("0", ipv4Public).Return(fmt.Errorf("interface already exists")).Times(1) + }, }, { name: "Link Addr IPv6 Not Supported", @@ -225,13 +156,16 @@ func TestConfigureLinkAddresses(t *testing.T) { Mask: net.CIDRMask(128, 128), }, }, - wantCmds: []string{removeCmd}, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RemoveNetAdapterIPAddress("0", ipv4Public).Return(nil).Times(1) + }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() + ctrl := gomock.NewController(t) + defer mockUtilWinnet(ctrl, tc.expectedCalls)() defer mockNetInterfaceByIndex(&testNetInterface, tc.testNetInterfaceErr)() defer mockNetInterfaceAddrs(testNetInterface, tc.testNetAddrsErr)() gotErr := ConfigureLinkAddresses(0, tc.ipNets) @@ -240,45 +174,14 @@ func TestConfigureLinkAddresses(t *testing.T) { } } -func TestSetAdapterMACAddress(t *testing.T) { - tests := []struct { - name string - commandOut string - commandErr error - wantErr error - }{ - { - name: "Set adapter MAC", - commandOut: "success", - }, - { - name: "Set Err", - commandErr: testInvalidErr, - wantErr: testInvalidErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, []string{ - fmt.Sprintf(`Set-NetAdapterAdvancedProperty -Name "%s" -RegistryKeyword NetworkAddress -RegistryValue "%s"`, - "test-adapter", strings.Replace(testMACAddr.String(), ":", "", -1)), - }, tc.commandOut, tc.commandErr, true)() - gotErr := SetAdapterMACAddress("test-adapter", &testMACAddr) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - func TestPrepareHNSNetwork(t *testing.T) { gw, subnet, _ := net.ParseCIDR("8.8.8.8/32") alreadyExistsErr := fmt.Errorf("already exists") - nodeZeroIPNetStr := strings.Split(ipv4ZeroIPNet.String(), "/") - routes := []Route{{ + routes := []winnet.Route{{ LinkIndex: 0, DestinationSubnet: subnet, GatewayAddress: gw, - RouteMetric: MetricDefault, + RouteMetric: winnet.MetricDefault, }} testRoutes := convertTestRoutes(routes) testSubnetCIDR := &net.IPNet{ @@ -296,72 +199,61 @@ func TestPrepareHNSNetwork(t *testing.T) { testDNSServer := "192.168.1.21" testNetInterfaces := generateNetInterfaces() for i, itf := range testNetInterfaces { - testNetInterfaces[i].Name = VirtualAdapterName(itf.Name) + testNetInterfaces[i].Name = winnet.VirtualAdapterName(itf.Name) } - newIPCmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s -DefaultGateway %s`, VirtualAdapterName("0"), nodeZeroIPNetStr[0], nodeZeroIPNetStr[1], "testGateway") - setServerCmd := fmt.Sprintf(`Set-DnsClientServerAddress -InterfaceAlias "%s" -ServerAddresses "%s"`, VirtualAdapterName("0"), testDNSServer) - getAdapterCmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -SwitchName "%s" | ? MacAddress -EQ "%s" | Select-Object -Property Name | Format-Table -HideTableHeaders`, LocalHNSNetwork, testUplinkMACStr) - renameAdapterCmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -Name "%s" | Rename-VMNetworkAdapter -NewName "%s"`, testAdapterName, testNewName) - renameNetCmd := fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, VirtualAdapterName(testNewName), testNewName) - getVMCmd := fmt.Sprintf("Get-VMSwitch -ComputerName $(hostname) -Name %s | Select-Object -Property SoftwareRscEnabled | Format-Table -HideTableHeaders", LocalHNSNetwork) - setVMCmd := fmt.Sprintf("Set-VMSwitch -ComputerName $(hostname) -Name %s -EnableSoftwareRsc $True", LocalHNSNetwork) tests := []struct { name string - testAdapterAddresses *windows.IpAdapterAddresses nodeIPNet *net.IPNet dnsServers string newName string ipFound bool hnsNetworkCreateErr error - commandErr error hnsNetworkRequestError error testNetInterfaceErr error createRowErr error - wantCmds []string + expectedCalls func(mockNetUtil *winnettesting.MockInterfaceMockRecorder) wantErr error }{ { - name: "Prepare Success", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4PublicIPNet, - dnsServers: testDNSServer, - newName: testNewName, - ipFound: true, - wantCmds: []string{getAdapterCmd, renameAdapterCmd, renameNetCmd, - getVMCmd, setVMCmd}, + name: "Prepare Success", + nodeIPNet: &ipv4PublicIPNet, + dnsServers: testDNSServer, + newName: testNewName, + ipFound: true, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RenameVMNetworkAdapter(LocalHNSNetwork, testUplinkMACStr, testNewName, true).Times(1) + mockUtil.EnableRSCOnVSwitch(LocalHNSNetwork).Times(1) + }, }, { - name: "Create Error", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4PublicIPNet, - dnsServers: testDNSServer, - ipFound: true, - hnsNetworkCreateErr: testInvalidErr, - wantErr: fmt.Errorf("error creating HNSNetwork: invalid"), + name: "Create Error", + nodeIPNet: &ipv4PublicIPNet, + dnsServers: testDNSServer, + ipFound: true, + hnsNetworkCreateErr: testInvalidErr, + wantErr: fmt.Errorf("error creating HNSNetwork: invalid"), }, { - name: "adapter Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4PublicIPNet, - dnsServers: testDNSServer, - ipFound: true, - testNetInterfaceErr: testInvalidErr, - wantErr: testInvalidErr, + name: "adapter Err", + nodeIPNet: &ipv4PublicIPNet, + dnsServers: testDNSServer, + ipFound: true, + testNetInterfaceErr: testInvalidErr, + wantErr: testInvalidErr, }, { - name: "Rename Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4PublicIPNet, - dnsServers: testDNSServer, - newName: testNewName, - ipFound: true, - commandErr: testInvalidErr, - wantCmds: []string{getAdapterCmd}, - wantErr: testInvalidErr, + name: "Rename Err", + nodeIPNet: &ipv4PublicIPNet, + dnsServers: testDNSServer, + newName: testNewName, + ipFound: true, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RenameVMNetworkAdapter(LocalHNSNetwork, testUplinkMACStr, testNewName, true).Return(testInvalidErr).Times(1) + }, + wantErr: testInvalidErr, }, { name: "Enable HNS Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), nodeIPNet: &ipv4PublicIPNet, dnsServers: testDNSServer, ipFound: true, @@ -369,64 +261,62 @@ func TestPrepareHNSNetwork(t *testing.T) { wantErr: testInvalidErr, }, { - name: "Enable RSC Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4PublicIPNet, - dnsServers: testDNSServer, - ipFound: true, - commandErr: testInvalidErr, - wantCmds: []string{getVMCmd}, - wantErr: testInvalidErr, - }, - { - name: "IP Not Found", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4ZeroIPNet, - dnsServers: testDNSServer, - ipFound: false, - wantCmds: []string{newIPCmd, setServerCmd, getVMCmd, setVMCmd}, - }, - { - name: "IP Not Found Configure Default Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4ZeroIPNet, - dnsServers: testDNSServer, - ipFound: false, - commandErr: testInvalidErr, - wantCmds: []string{newIPCmd}, - wantErr: testInvalidErr, + name: "Enable RSC Err", + nodeIPNet: &ipv4PublicIPNet, + dnsServers: testDNSServer, + ipFound: true, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.EnableRSCOnVSwitch(LocalHNSNetwork).Return(testInvalidErr).Times(1) + }, + wantErr: testInvalidErr, + }, + { + name: "IP Not Found", + nodeIPNet: &ipv4ZeroIPNet, + dnsServers: testDNSServer, + ipFound: false, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.IsNetAdapterIPv4DHCPEnabled(testAdapterName).Times(1) + mockUtil.AddNetAdapterIPAddress(winnet.VirtualAdapterName("0"), &ipv4ZeroIPNet, "testGateway").Times(1) + mockUtil.SetNetAdapterDNSServers(winnet.VirtualAdapterName("0"), testDNSServer).Times(1) + mockUtil.AddNetRoute(gomock.Any()).Times(1) + mockUtil.EnableRSCOnVSwitch(LocalHNSNetwork).Times(1) + }, }, { - name: "IP Not Found Set adapter Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4ZeroIPNet, - dnsServers: testDNSServer, - ipFound: false, - commandErr: alreadyExistsErr, - wantCmds: []string{newIPCmd, setServerCmd}, - wantErr: alreadyExistsErr, + name: "IP Not Found Configure adapter Err", + nodeIPNet: &ipv4ZeroIPNet, + dnsServers: testDNSServer, + ipFound: false, + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.IsNetAdapterIPv4DHCPEnabled(testAdapterName).Times(1) + mockUtil.AddNetAdapterIPAddress(winnet.VirtualAdapterName("0"), &ipv4ZeroIPNet, "testGateway").Return(alreadyExistsErr).Times(1) + }, + wantErr: alreadyExistsErr, }, { - name: "IP Not Found New Net Route Err", - testAdapterAddresses: createTestAdapterAddresses(testAdapterName), - nodeIPNet: &ipv4ZeroIPNet, - ipFound: false, - createRowErr: fmt.Errorf("ip route not found"), - wantCmds: []string{newIPCmd}, - wantErr: fmt.Errorf("failed to create new IPForward row: ip route not found"), + name: "IP Not Found New Net Route Err", + nodeIPNet: &ipv4ZeroIPNet, + ipFound: false, + createRowErr: fmt.Errorf("ip route not found"), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.IsNetAdapterIPv4DHCPEnabled(testAdapterName).Times(1) + mockUtil.AddNetAdapterIPAddress(winnet.VirtualAdapterName("0"), &ipv4ZeroIPNet, "testGateway").Times(1) + mockUtil.AddNetRoute(gomock.Any()).Return(fmt.Errorf("failed to create new IPForward row: ip route not found")).Times(1) + }, + wantErr: fmt.Errorf("failed to create new IPForward row: ip route not found"), }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, testAdapterName, tc.commandErr, true)() + ctrl := gomock.NewController(t) + defer mockUtilWinnet(ctrl, tc.expectedCalls)() defer mockNetInterfaceGet(testNetInterfaces, tc.testNetInterfaceErr)() defer mockNetInterfaceAddrsMultiple(testNetInterfaces, tc.ipFound, nil)() defer mockHNSNetworkRequest(nil, tc.hnsNetworkRequestError)() defer mockHNSNetworkCreate(tc.hnsNetworkCreateErr)() defer mockHNSNetworkDelete(nil)() - defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{CreateIPForwardEntryErr: tc.createRowErr})() - defer mockGetAdaptersAddresses(tc.testAdapterAddresses, nil)() gotErr := PrepareHNSNetwork(testSubnetCIDR, tc.nodeIPNet, testUplinkAdapter, "testGateway", tc.dnsServers, testRoutes, tc.newName) assert.Equal(t, tc.wantErr, gotErr) }) @@ -434,463 +324,103 @@ func TestPrepareHNSNetwork(t *testing.T) { } func TestGetDefaultGatewayByInterfaceIndex(t *testing.T) { - _, subnet, _ := net.ParseCIDR("0.0.0.0/0") - testIndex := uint32(27) - testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.254")) - listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: ip route not found") - tests := []struct { - name string - listRows []antreasyscall.MibIPForwardRow - listRowsErr error - wantGateway string - wantErr error - }{ - { - name: "Index Success", - listRows: []antreasyscall.MibIPForwardRow{testIPForwardRow}, - wantGateway: "1.1.1.254", - }, - { - name: "Index Error", - listRowsErr: fmt.Errorf("ip route not found"), - wantErr: listIPForwardRowsErr, - }, + _, subnet, _ := net.ParseCIDR("1.1.1.0/28") + testGateway := net.ParseIP("1.1.1.254") + testIndex := 27 + testRoutes := []winnet.Route{ { - name: "Routes not found", - listRows: []antreasyscall.MibIPForwardRow{}, + LinkIndex: testIndex, + DestinationSubnet: subnet, + GatewayAddress: testGateway, + RouteMetric: winnet.MetricDefault, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{ListIPForwardRowsErr: tc.listRowsErr, IPForwardRows: tc.listRows})() - gotGateway, err := GetDefaultGatewayByInterfaceIndex((int)(testIndex)) - assert.Equal(t, tc.wantGateway, gotGateway) - assert.Equal(t, tc.wantErr, err) - }) + ip, defaultDestination, _ := net.ParseCIDR("0.0.0.0/0") + family := winnet.AddressFamilyByIP(ip) + filter := &winnet.Route{ + LinkIndex: testIndex, + DestinationSubnet: defaultDestination, } -} + filterMask := winnet.RT_FILTER_IF | winnet.RT_FILTER_DST + listRouteErr := fmt.Errorf("unable to list Windows IPForward rows: ip route not found") -func TestGetDNServersByInterfaceIndex(t *testing.T) { - testIndex := 1 tests := []struct { name string - commandOut string - commandErr error - wantDNSServer string + expectedCalls func(mockNetUtil *winnettesting.MockInterfaceMockRecorder) + wantGateway string + wantErr error }{ { - name: "Index Success", - commandOut: "hello\r\nworld\r\n\r\n", - wantDNSServer: "hello,world", - }, - { - name: "Index Error", - commandOut: "fail", - commandErr: testInvalidErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, []string{ - fmt.Sprintf("$(Get-DnsClientServerAddress -InterfaceIndex %d -AddressFamily IPv4).ServerAddresses", testIndex), - }, tc.commandOut, tc.commandErr, true)() - gotDNSServer, err := GetDNServersByInterfaceIndex(testIndex) - assert.Equal(t, tc.wantDNSServer, gotDNSServer) - assert.Equal(t, tc.commandErr, err) - }) - } -} - -func TestHostInterfaceExists(t *testing.T) { - tests := []struct { - name string - testNetInterfaceName string - testAdapterAddresses *windows.IpAdapterAddresses - }{ - { - name: "Normal Exist", - testNetInterfaceName: "host", - testAdapterAddresses: createTestAdapterAddresses("host"), - }, - { - name: "Interface not exist", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockGetAdaptersAddresses(tc.testAdapterAddresses, nil)() - gotExists := HostInterfaceExists(tc.testNetInterfaceName) - assert.Equal(t, tc.testNetInterfaceName != "", gotExists) - }) - } -} - -func TestSetInterfaceMTU(t *testing.T) { - testName := "host" - testAdapterAddresses := createTestAdapterAddresses(testName) - testMTU := 2 - tests := []struct { - name string - testNetInterfaceName string - testAdapterAddresses *windows.IpAdapterAddresses - getIPInterfaceErr error - setIPInterfaceErr error - wantErr error - }{ - { - name: "Set Success", - testNetInterfaceName: testName, - testAdapterAddresses: testAdapterAddresses, - }, - { - name: "Interface name invalid", - wantErr: fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", "", - &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}), + name: "Index Success", + wantGateway: testGateway.String(), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RouteListFiltered(family, filter, filterMask).Return(testRoutes, nil).Times(1) + }, }, { - name: "Get Interface Err", - testNetInterfaceName: testName, - testAdapterAddresses: testAdapterAddresses, - getIPInterfaceErr: fmt.Errorf("IP interface not found"), - wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: %v", testMTU, - fmt.Errorf("unable to get IPInterface entry with Index %d: IP interface not found", (int)(testAdapterAddresses.IfIndex))), + name: "Index Error", + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RouteListFiltered(family, filter, filterMask).Return(nil, listRouteErr).Times(1) + }, + wantErr: listRouteErr, }, { - name: "Set Interface Err", - testNetInterfaceName: testName, - testAdapterAddresses: testAdapterAddresses, - setIPInterfaceErr: fmt.Errorf("IP interface set error"), - wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: IP interface set error", testMTU), + name: "Routes not found", + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RouteListFiltered(family, filter, filterMask).Return(nil, nil).Times(1) + }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockGetAdaptersAddresses(tc.testAdapterAddresses, nil)() - defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{GetIPInterfaceEntryErr: tc.getIPInterfaceErr, SetIPInterfaceEntryErr: tc.setIPInterfaceErr})() - gotErr := SetInterfaceMTU(tc.testNetInterfaceName, testMTU) - assert.Equal(t, tc.wantErr, gotErr) + ctrl := gomock.NewController(t) + defer mockUtilWinnet(ctrl, tc.expectedCalls)() + gotGateway, err := GetDefaultGatewayByInterfaceIndex((int)(testIndex)) + assert.Equal(t, tc.wantGateway, gotGateway) + assert.Equal(t, tc.wantErr, err) }) } } -func TestReplaceNetRoute(t *testing.T) { +func TestGetInterfaceConfig(t *testing.T) { _, subnet, _ := net.ParseCIDR("1.1.1.0/28") - testIP := net.ParseIP("1.1.1.254") - testIndex := uint32(27) - testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, testIP) - testRoute := Route{ - LinkIndex: (int)(testIPForwardRow.Index), - DestinationSubnet: subnet, - GatewayAddress: net.ParseIP("1.1.1.254"), - RouteMetric: MetricDefault, - } - listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: unable to list IP forward entry") - deleteIPForwardEntryErr := fmt.Errorf("failed to delete existing route with nextHop %s: unable to delete IP forward entry", testRoute.GatewayAddress) - createIPForwardEntryErr := fmt.Errorf("failed to create new IPForward row: unable to create IP forward entry") - tests := []struct { - name string - listRows []antreasyscall.MibIPForwardRow - listRowsErr error - createIPForwardErr error - deleteIPForwardErr error - wantErr error - }{ - { - name: "Replace Success", - listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, - }, + testGateway := net.ParseIP("1.1.1.254") + testIndex := 0 + testRoutes := []winnet.Route{ { - name: "Same GatewayAddress", - listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, testIP)}, - }, - { - name: "List Rows Err", - listRowsErr: fmt.Errorf("unable to list IP forward entry"), - wantErr: listIPForwardRowsErr, - }, - { - name: "Delete Ip Forward Entry Err", - listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, - deleteIPForwardErr: fmt.Errorf("unable to delete IP forward entry"), - wantErr: deleteIPForwardEntryErr, - }, - { - name: "Add Route Err", - listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, - createIPForwardErr: fmt.Errorf("unable to create IP forward entry"), - wantErr: createIPForwardEntryErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{CreateIPForwardEntryErr: tc.createIPForwardErr, DeleteIPForwardEntryErr: tc.deleteIPForwardErr, ListIPForwardRowsErr: tc.listRowsErr, IPForwardRows: tc.listRows})() - gotErr := ReplaceNetRoute(&testRoute) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestNewNetNat(t *testing.T) { - notFoundErr := fmt.Errorf("received error No MSFT_NetNat objects found") - testNetNat := "test-nat" - testSubnetCIDR := &net.IPNet{ - IP: net.ParseIP("192.168.1.21"), - Mask: net.CIDRMask(32, 32), - } - getCmd := fmt.Sprintf(`Get-NetNat -Name %s | Select-Object InternalIPInterfaceAddressPrefix | Format-Table -HideTableHeaders`, testNetNat) - removeCmd := fmt.Sprintf("Remove-NetNat -Name %s -Confirm:$false", testNetNat) - newCmd := fmt.Sprintf(`New-NetNat -Name %s -InternalIPInterfaceAddressPrefix %s`, testNetNat, testSubnetCIDR.String()) - tests := []struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - }{ - { - name: "New Net Nat", - commandOut: "0.0.0.0/32", - wantCmds: []string{getCmd, removeCmd, newCmd}, - }, - { - name: "Net Nat Not Found", - commandErr: testInvalidErr, - wantCmds: []string{getCmd}, - wantErr: testInvalidErr, - }, - { - name: "Net Nat Exist", - commandOut: "192.168.1.21/32", - wantCmds: []string{getCmd}, - }, - { - name: "Net Nat Add Fail", - commandErr: notFoundErr, - wantCmds: []string{getCmd, newCmd}, - wantErr: notFoundErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := NewNetNat(testNetNat, testSubnetCIDR) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestReplaceNetNatStaticMapping(t *testing.T) { - notFoundErr := fmt.Errorf("received error No MSFT_NetNatStaticMapping objects found") - testNetNatName := "test-nat" - testExternalPort, testInternalPort := (uint16)(80), (uint16)(8080) - testExternalIPAddr, testInternalIPAddr := "10.10.0.1", "192.0.2.179" - testProto := openflow.ProtocolTCP - testNetNat := &NetNatStaticMapping{ - Name: testNetNatName, - ExternalIP: net.ParseIP(testExternalIPAddr), - ExternalPort: testExternalPort, - InternalIP: net.ParseIP(testInternalIPAddr), - InternalPort: testInternalPort, - Protocol: testProto, - } - - getCmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", testNetNatName) + - fmt.Sprintf("|? ExternalIPAddress -EQ %s", testExternalIPAddr) + - fmt.Sprintf("|? ExternalPort -EQ %d", testExternalPort) + - fmt.Sprintf("|? Protocol -EQ %s", testProto) + - "| Format-Table -HideTableHeaders" - removeCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", testNetNatName, 1) - addCmd := fmt.Sprintf("Add-NetNatStaticMapping -NatName %s -ExternalIPAddress %s -ExternalPort %d -InternalIPAddress %s -InternalPort %d -Protocol %s", - testNetNatName, testExternalIPAddr, testExternalPort, testInternalIPAddr, testInternalPort, testProto) - type testFormat struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - } - tests := []testFormat{ - { - name: "Replace Net Nat", - commandOut: "0;1 nil nil nil 192.168.1.21 80", - wantCmds: []string{getCmd, removeCmd, addCmd}, - }, - { - name: "Get Net Nat Err", - commandErr: testInvalidErr, - wantCmds: []string{getCmd}, - wantErr: testInvalidErr, - }, - { - name: "Remove Net Nat Err", - commandOut: "0;1 nil nil nil 192.168.1.21 80", - commandErr: notFoundErr, - wantCmds: []string{getCmd, removeCmd}, - wantErr: notFoundErr, - }, - { - name: "Add Net Nat Err", - commandOut: "empty", - commandErr: notFoundErr, - wantCmds: []string{getCmd, addCmd}, - wantErr: notFoundErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := ReplaceNetNatStaticMapping(testNetNat) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestRemoveNetNatStaticMapping(t *testing.T) { - testNetNatName := "test-nat" - testExternalPort, testInternalPort := (uint16)(80), (uint16)(8080) - testExternalIPAddr, testInternalIPAddr := "10.10.0.1", "192.0.2.179" - testProto := openflow.ProtocolTCP - testNetNat := &NetNatStaticMapping{ - Name: testNetNatName, - ExternalIP: net.ParseIP(testExternalIPAddr), - ExternalPort: testExternalPort, - InternalIP: net.ParseIP(testInternalIPAddr), - InternalPort: testInternalPort, - Protocol: testProto, - } - getCmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", testNetNatName) + - fmt.Sprintf("|? ExternalIPAddress -EQ %s", testExternalIPAddr) + - fmt.Sprintf("|? ExternalPort -EQ %d", testExternalPort) + - fmt.Sprintf("|? Protocol -EQ %s", testProto) + - "| Format-Table -HideTableHeaders" - removeIDCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", testNetNatName, 1) - removeCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -Confirm:$false", testNetNatName) - tests := []struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - }{ - { - name: "Remove Net Nat Static Mapping", - commandOut: "0;1 nil nil nil 192.0.02.179 8080", - wantCmds: []string{getCmd, removeIDCmd, removeCmd}, - }, - { - name: "Remove Err", - commandErr: testInvalidErr, - wantCmds: []string{getCmd, removeCmd}, - wantErr: testInvalidErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, false)() - gotErr := RemoveNetNatStaticMapping(testNetNat) - assert.Equal(t, tc.wantErr, gotErr) - gotErr = RemoveNetNatStaticMappingByNPLTuples(testNetNat) - assert.Equal(t, tc.wantErr, gotErr) - gotErr = RemoveNetNatStaticMappingByNAME(testNetNat.Name) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestReplaceNetNeighbor(t *testing.T) { - netNeighborNotFoundErr := fmt.Errorf("received error No matching MSFT_NetNeighbor objects") - testNeighbor := &Neighbor{ - LinkIndex: 1, - IPAddress: net.ParseIP("169.254.0.253"), - LinkLayerAddress: testMACAddr, - State: "Permanent", - } - getCmd := fmt.Sprintf("Get-NetNeighbor -InterfaceIndex %d -IPAddress %s | Format-Table -HideTableHeaders", testNeighbor.LinkIndex, testNeighbor.IPAddress.String()) - newCmd := fmt.Sprintf("New-NetNeighbor -InterfaceIndex %d -IPAddress %s -LinkLayerAddress %s -State Permanent", - testNeighbor.LinkIndex, testNeighbor.IPAddress, testNeighbor.LinkLayerAddress) - removeCmd := fmt.Sprintf("Remove-NetNeighbor -InterfaceIndex %d -IPAddress %s -Confirm:$false", - testNeighbor.LinkIndex, testNeighbor.IPAddress) - type testFormat struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - } - tests := []testFormat{ - { - name: "Replace Neighbor", - commandOut: "1 169.254.1.253 aa:bb:cc:dd:ff:ff Permanent nil", - wantCmds: []string{getCmd, removeCmd, newCmd}, - }, - { - name: "Get Net Neighbor Err", - commandErr: testInvalidErr, - wantCmds: []string{getCmd}, - wantErr: testInvalidErr, - }, - { - name: "Remove Net Neighbor Err", - commandOut: "1 169.254.1.253 aa:bb:cc:dd:ff:ff Permanent nil", - commandErr: netNeighborNotFoundErr, - wantCmds: []string{getCmd, removeCmd}, - wantErr: netNeighborNotFoundErr, - }, - { - name: "New Net Neighbor Err", - commandErr: netNeighborNotFoundErr, - wantCmds: []string{getCmd, newCmd}, - wantErr: netNeighborNotFoundErr, - }, - { - name: "Duplicate Neighbor", - commandOut: "1 169.254.0.253 aa:bb:cc:dd:ee:ff Permanent nil", - wantCmds: []string{getCmd}, + LinkIndex: testIndex, + DestinationSubnet: subnet, + GatewayAddress: testGateway, + RouteMetric: winnet.MetricDefault, }, } + testNetInterface := generateNetInterface("0") - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := ReplaceNetNeighbor(testNeighbor) - assert.Equal(t, tc.wantErr, gotErr) - }) + family := antreasyscall.AF_UNSPEC + filter := &winnet.Route{ + LinkIndex: testIndex, } -} + filterMask := winnet.RT_FILTER_IF -func TestVirtualAdapterName(t *testing.T) { - gotName := VirtualAdapterName("0") - assert.Equal(t, "vEthernet (0)", gotName) -} + listRouteErr := fmt.Errorf("unable to list Windows IPForward rows: unable to list IP forward rows") -func TestGetInterfaceConfig(t *testing.T) { - gw, subnet, _ := net.ParseCIDR("192.168.2.0/24") - testRow := createTestMibIPForwardRow(0, subnet, gw) - routes := []Route{*routeFromIPForwardRow(&testRow)} - testRoutes := convertTestRoutes(routes) - testNetInterface := generateNetInterface("0") tests := []struct { name string testNetInterfaceErr error - listRows []antreasyscall.MibIPForwardRow - listRowsErr error + expectedCalls func(mockNetUtil *winnettesting.MockInterfaceMockRecorder) wantAddrs []*net.IPNet wantRoutes []interface{} wantErr error }{ { name: "Get Interface Config Success", - listRows: []antreasyscall.MibIPForwardRow{testRow}, wantAddrs: []*net.IPNet{&ipv4PublicIPNet}, - wantRoutes: testRoutes, + wantRoutes: convertTestRoutes(testRoutes), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RouteListFiltered(family, filter, filterMask).Return(testRoutes, nil).Times(1) + }, }, { name: "Interface Err", @@ -898,18 +428,20 @@ func TestGetInterfaceConfig(t *testing.T) { wantErr: fmt.Errorf("failed to get interface %s: %v", "0", testInvalidErr), }, { - name: "Route Err", - listRows: []antreasyscall.MibIPForwardRow{testRow}, - listRowsErr: fmt.Errorf("unable to list IP forward rows"), - wantErr: fmt.Errorf("failed to get routes for interface index %d: %v", testNetInterface.Index, - fmt.Errorf("unable to list Windows IPForward rows: unable to list IP forward rows")), + name: "Route Err", + wantErr: fmt.Errorf("failed to get routes for interface index %d: %v", testNetInterface.Index, listRouteErr), + expectedCalls: func(mockUtil *winnettesting.MockInterfaceMockRecorder) { + mockUtil.RouteListFiltered(family, filter, filterMask).Return(nil, listRouteErr).Times(1) + }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer mockUtilWinnet(ctrl, tc.expectedCalls)() defer mockNetInterfaceByName(&testNetInterface, tc.testNetInterfaceErr)() defer mockNetInterfaceAddrs(testNetInterface, nil)() - defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{IPForwardRows: tc.listRows, ListIPForwardRowsErr: tc.listRowsErr})() + gotInterface, gotAddrs, gotRoutes, gotErr := GetInterfaceConfig("0") if tc.wantErr == nil { assert.EqualValues(t, testNetInterface, *gotInterface) @@ -921,217 +453,7 @@ func TestGetInterfaceConfig(t *testing.T) { } } -func TestRenameInterface(t *testing.T) { - tests := []struct { - name string - commandOut string - commandErr error - wantErr error - }{ - { - name: "Rename Interface", - commandOut: "success", - }, - { - name: "Rename Err", - commandErr: testInvalidErr, - wantErr: fmt.Errorf("failed to rename host interface name test1 to test2"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, []string{ - fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, "test1", "test2"), - }, tc.commandOut, tc.commandErr, false)() - gotErr := RenameInterface("test1", "test2") - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestCreateVMSwitch(t *testing.T) { - notfoundErr := fmt.Errorf("unable to find a virtual switch with name \"antrea-switch\"") - testSwitchName := "test-switch" - getVMCmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)`, LocalVMSwitch) - getExtensionCmd := fmt.Sprintf(`Get-VMSwitchExtension -VMSwitchName "%s" -ComputerName $(hostname) | ? Id -EQ "%s"`, LocalVMSwitch, OVSExtensionID) - newVMCmd := fmt.Sprintf(`New-VMSwitch -Name "%s" -NetAdapterName "%s" -EnableEmbeddedTeaming $true -AllowManagementOS $true -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, LocalVMSwitch, testSwitchName, ovsExtensionName) - enableCmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, LocalVMSwitch, ovsExtensionName) - type testFormat struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - } - tests := []testFormat{ - { - name: "VM Exists Enabled", - commandOut: "Open vSwitch Extension Enabled True", - wantCmds: []string{getVMCmd, getExtensionCmd}, - }, - { - name: "Create Err", - commandErr: notfoundErr, - wantCmds: []string{getVMCmd, newVMCmd}, - wantErr: notfoundErr, - }, - { - name: "VM Exists Err", - commandErr: testInvalidErr, - wantCmds: []string{getVMCmd}, - wantErr: testInvalidErr, - }, - { - name: "VM Not Enabled", - commandOut: "Open vSwitch Extension False", - wantCmds: []string{getVMCmd, getExtensionCmd, enableCmd}, - }, - { - name: "Extension Err", - commandOut: "Extension False", - wantCmds: []string{getVMCmd, getExtensionCmd}, - wantErr: fmt.Errorf("open vswitch extension driver is not installed"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := CreateVMSwitch(testSwitchName) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestGetVMSwitchInterfaceName(t *testing.T) { - getVMCmd := fmt.Sprintf(`Get-VMSwitchTeam -Name "%s" | select NetAdapterInterfaceDescription | Format-Table -HideTableHeaders`, LocalVMSwitch) - getAdapterCmd := fmt.Sprintf(`Get-NetAdapter -InterfaceDescription "%s" | select Name | Format-Table -HideTableHeaders`, "test") - tests := []struct { - name string - commandOut string - commandErr error - wantCmds []string - wantName string - wantErr error - }{ - { - name: "Get Interface Name", - commandOut: " {test} ", - wantCmds: []string{getVMCmd, getAdapterCmd}, - wantName: "{test}", - }, - { - name: "Get Err", - commandErr: testInvalidErr, - wantCmds: []string{getVMCmd}, - wantErr: testInvalidErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotName, gotErr := GetVMSwitchInterfaceName() - assert.Equal(t, tc.wantName, gotName) - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestRemoveVMSwitch(t *testing.T) { - getCmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)`, LocalVMSwitch) - removeCmd := fmt.Sprintf(`Remove-VMSwitch -Name "%s" -ComputerName $(hostname) -Force`, LocalVMSwitch) - tests := []struct { - name string - commandOut string - commandErr error - wantCmds []string - wantErr error - }{ - { - name: "Remove VMSwitch", - commandOut: "true", - wantCmds: []string{getCmd, removeCmd}, - }, - { - name: "Get Err", - commandErr: testInvalidErr, - wantCmds: []string{getCmd}, - wantErr: testInvalidErr, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := RemoveVMSwitch() - assert.Equal(t, tc.wantErr, gotErr) - }) - } -} - -func TestGenHostInterfaceName(t *testing.T) { - hostInterface := GenHostInterfaceName("host~") - assert.Equal(t, "host", hostInterface) -} - -func TestGetAdapterInAllCompartmentsByName(t *testing.T) { - testName := "host" - testFlags := net.FlagUp | net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast - testAdapter := adapter{ - Interface: net.Interface{ - Index: 1, - Name: testName, - Flags: testFlags, - MTU: 1, - HardwareAddr: testMACAddr, - }, - compartmentID: 1, - flags: IP_ADAPTER_DHCP_ENABLED, - } - tests := []struct { - name string - testName string - testAdapters *windows.IpAdapterAddresses - testAdaptersErr error - wantAdapters *adapter - wantErr error - }{ - { - name: "Normal", - testName: testName, - testAdapters: createTestAdapterAddresses(testName), - wantAdapters: &testAdapter, - }, - { - name: "Invalid name", - wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}, - }, - { - name: "adapter Err", - testName: testName, - testAdaptersErr: windows.ERROR_FILE_NOT_FOUND, - wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: os.NewSyscallError("getadaptersaddresses", windows.ERROR_FILE_NOT_FOUND)}, - }, - { - name: "adapter not found", - testName: testName, - wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - defer mockGetAdaptersAddresses(tc.testAdapters, tc.testAdaptersErr)() - gotAdapters, gotErr := getAdapterInAllCompartmentsByName(tc.testName) - assert.EqualValues(t, tc.wantAdapters, gotAdapters) - assert.EqualValues(t, tc.wantErr, gotErr) - }) - } -} - -func convertTestRoutes(routes []Route) []interface{} { +func convertTestRoutes(routes []winnet.Route) []interface{} { testRoutes := make([]interface{}, len(routes)) for i, v := range routes { testRoutes[i] = v @@ -1139,82 +461,20 @@ func convertTestRoutes(routes []Route) []interface{} { return testRoutes } -func createTestAdapterAddresses(name string) *windows.IpAdapterAddresses { - testPhysicalAddress := [8]byte{} - copy(testPhysicalAddress[:6], testMACAddr) - testName, _ := windows.UTF16FromString(name) - return &windows.IpAdapterAddresses{ - FriendlyName: &testName[0], - IfIndex: 1, - OperStatus: windows.IfOperStatusUp, - IfType: windows.IF_TYPE_ATM, - Mtu: 1, - PhysicalAddressLength: 6, - PhysicalAddress: testPhysicalAddress, - CompartmentId: 1, - Flags: IP_ADAPTER_DHCP_ENABLED, - } -} - -func createTestMibIPForwardRow(index uint32, subnet *net.IPNet, ip net.IP) antreasyscall.MibIPForwardRow { - return antreasyscall.MibIPForwardRow{ - Index: index, - Metric: MetricDefault, - DestinationPrefix: *antreasyscall.NewAddressPrefixFromIPNet(subnet), - NextHop: *antreasyscall.NewRawSockAddrInetFromIP(ip), - } -} - -func mockAntreaNetIO(mockNetIO *antreasyscalltest.MockNetIO) func() { - originalNetIO := antreaNetIO - antreaNetIO = mockNetIO - return func() { - antreaNetIO = originalNetIO - } -} - -func mockGetAdaptersAddresses(testAdaptersAddresses *windows.IpAdapterAddresses, err error) func() { - originalGetAdaptersAddresses := getAdaptersAddresses - getAdaptersAddresses = func(family uint32, flags uint32, reserved uintptr, adapterAddresses *windows.IpAdapterAddresses, sizePointer *uint32) (errcode error) { - if adapterAddresses != nil && testAdaptersAddresses != nil { - adapterAddresses.IfIndex = testAdaptersAddresses.IfIndex - adapterAddresses.FriendlyName = testAdaptersAddresses.FriendlyName - adapterAddresses.OperStatus = testAdaptersAddresses.OperStatus - adapterAddresses.IfType = testAdaptersAddresses.IfType - adapterAddresses.Mtu = testAdaptersAddresses.Mtu - adapterAddresses.PhysicalAddressLength = testAdaptersAddresses.PhysicalAddressLength - adapterAddresses.PhysicalAddress = testAdaptersAddresses.PhysicalAddress - adapterAddresses.CompartmentId = testAdaptersAddresses.CompartmentId - adapterAddresses.Flags = testAdaptersAddresses.Flags - } - return err - } - return func() { - getAdaptersAddresses = originalGetAdaptersAddresses - } +func TestGenHostInterfaceName(t *testing.T) { + hostInterface := GenHostInterfaceName("host~") + assert.Equal(t, "host", hostInterface) } -// mockRunCommand mocks runCommand with a custom command output and error message. -// If exactMatch is enabled, this function asserts that the executed commands are -// exactly the same with wantCmds in terms of order and value. Otherwise, for tests -// with retry functions, the commands will be executed multiple times. This function -// asserts that wantCmds is strictly a subset of these executed commands. -func mockRunCommand(t *testing.T, wantCmds []string, commandOut string, err error, exactMatch bool) func() { - originalRunCommand := runCommand - actCmds := make([]string, 0) - runCommand = func(cmd string) (string, error) { - actCmds = append(actCmds, cmd) - return commandOut, err +func mockUtilWinnet(ctrl *gomock.Controller, expectedCalls func(mockWinnet *winnettesting.MockInterfaceMockRecorder)) func() { + originalWinnetInterface := winnetUtil + testWinnetInterface := winnettesting.NewMockInterface(ctrl) + winnetUtil = testWinnetInterface + if expectedCalls != nil { + expectedCalls(testWinnetInterface.EXPECT()) } return func() { - runCommand = originalRunCommand - if wantCmds == nil { - assert.Empty(t, actCmds) - } else if exactMatch { - assert.Equal(t, wantCmds, actCmds) - } else { - assert.Subset(t, actCmds, wantCmds) - } + winnetUtil = originalWinnetInterface } } diff --git a/pkg/agent/util/winnet/interface.go b/pkg/agent/util/winnet/interface.go new file mode 100644 index 00000000000..22fa857e6ee --- /dev/null +++ b/pkg/agent/util/winnet/interface.go @@ -0,0 +1,82 @@ +// Copyright 2024 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package winnet + +import ( + "net" +) + +// Interface is an interface for configuring Windows networking. +type Interface interface { + AddNetRoute(route *Route) error + + RemoveNetRoute(route *Route) error + + ReplaceNetRoute(route *Route) error + + RouteListFiltered(family uint16, filter *Route, filterMask uint64) ([]Route, error) + + ReplaceNetNeighbor(neighbor *Neighbor) error + + AddNetNat(netNatName string, subnetCIDR *net.IPNet) error + + AddNetNatStaticMapping(mapping *NetNatStaticMapping) error + + ReplaceNetNatStaticMapping(mapping *NetNatStaticMapping) error + + RemoveNetNatStaticMapping(mapping *NetNatStaticMapping) error + + RemoveNetNatStaticMappingsByNetNat(netNatName string) error + + EnableNetAdapter(adapterName string) error + + IsNetAdapterStatusUp(adapterName string) (bool, error) + + NetAdapterExists(adapterName string) bool + + AddNetAdapterIPAddress(adapterName string, ipConfig *net.IPNet, gateway string) error + + RemoveNetAdapterIPAddress(adapterName string, ipAddr net.IP) error + + RenameNetAdapter(oriName string, newName string) error + + SetNetAdapterMTU(adapterName string, mtu int) error + + SetNetAdapterDNSServers(adapterName, dnsServers string) error + + IsNetAdapterIPv4DHCPEnabled(adapterName string) (bool, error) + + IsVirtualNetAdapter(adapterName string) (bool, error) + + EnableIPForwarding(adapterName string) error + + EnableRSCOnVSwitch(vSwitch string) error + + GetDNServersByNetAdapterIndex(adapterIndex int) (string, error) + + AddVMSwitch(adapterName string, vmSwitch string) error + + VMSwitchExists(vmSwitch string) (bool, error) + + IsVMSwitchOVSExtensionEnabled(vmSwitch string) (bool, error) + + EnableVMSwitchOVSExtension(vmSwitch string) error + + RemoveVMSwitch(vmSwitch string) error + + GetVMSwitchNetAdapterName(vmSwitch string) (string, error) + + RenameVMNetworkAdapter(networkName, macStr, newName string, renameNetAdapter bool) error +} diff --git a/pkg/agent/util/winnet/net_windows.go b/pkg/agent/util/winnet/net_windows.go new file mode 100644 index 00000000000..b9798d4a95f --- /dev/null +++ b/pkg/agent/util/winnet/net_windows.go @@ -0,0 +1,826 @@ +//go:build windows +// +build windows + +// Copyright 2024 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package winnet + +import ( + "bufio" + "errors" + "fmt" + "net" + "os" + "runtime" + "strconv" + "strings" + "syscall" + "unsafe" + + "github.com/Microsoft/hcsshim" + "golang.org/x/sys/windows" + "k8s.io/klog/v2" + + ps "antrea.io/antrea/pkg/agent/util/powershell" + antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" + iputil "antrea.io/antrea/pkg/util/ip" +) + +const ( + ContainerVNICPrefix = "vEthernet" + OVSExtensionID = "583CC151-73EC-4A6A-8B47-578297AD7623" + ovsExtensionName = "Open vSwitch Extension" + + MetricDefault = 256 + MetricHigh = 50 + + // Filter masks are used to indicate the attributes used for route filtering. + RT_FILTER_IF uint64 = 1 << (1 + iota) + RT_FILTER_METRIC + RT_FILTER_DST + RT_FILTER_GW + + // IP_ADAPTER_DHCP_ENABLED is defined in the Win32 API document. + // https://learn.microsoft.com/en-us/windows/win32/api/iptypes/ns-iptypes-ip_adapter_addresses_lh + IP_ADAPTER_DHCP_ENABLED = 0x00000004 + + // GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is used in windows.GetAdapterAddresses parameter + // flags to return addresses in all routing compartments. + GAA_FLAG_INCLUDE_ALL_COMPARTMENTS = 0x00000200 + + // GAA_FLAG_INCLUDE_ALL_INTERFACES is used in windows.GetAdapterAddresses parameter + // flags to return addresses for all NDIS interfaces. + GAA_FLAG_INCLUDE_ALL_INTERFACES = 0x00000100 +) + +type Handle struct{} + +var ( + // Declared variables which are meant to be overridden for testing. + antreaNetIO = antreasyscall.NewNetIO() + getAdaptersAddresses = windows.GetAdaptersAddresses + runCommand = ps.RunCommand +) + +func routeFromIPForwardRow(row *antreasyscall.MibIPForwardRow) *Route { + destination := row.DestinationPrefix.IPNet() + gatewayAddr := row.NextHop.IP() + return &Route{ + DestinationSubnet: destination, + GatewayAddress: gatewayAddr, + LinkIndex: int(row.Index), + RouteMetric: int(row.Metric), + } +} + +// IsVirtualNetAdapter checks if the provided network adapter is virtual. +func (h *Handle) IsVirtualNetAdapter(adapterName string) (bool, error) { + cmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Virtual | Format-Table -HideTableHeaders`, adapterName) + out, err := runCommand(cmd) + if err != nil { + return false, err + } + isVirtual, err := strconv.ParseBool(strings.TrimSpace(out)) + if err != nil { + return false, err + } + return isVirtual, nil +} + +// IsNetAdapterStatusUp checks if the status of the provided network adapter is UP. +func (h *Handle) IsNetAdapterStatusUp(adapterName string) (bool, error) { + cmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Status | Format-Table -HideTableHeaders`, adapterName) + out, err := runCommand(cmd) + if err != nil { + return false, err + } + status := strings.TrimSpace(out) + if !strings.EqualFold(status, "Up") { + return false, nil + } + return true, nil +} + +// EnableNetAdapter sets the specified network adapter status as UP. +func (h *Handle) EnableNetAdapter(adapterName string) error { + cmd := fmt.Sprintf(`Enable-NetAdapter -InterfaceAlias "%s"`, adapterName) + if _, err := runCommand(cmd); err != nil { + return err + } + return nil +} + +// AddNetAdapterIPAddress adds the specified IP address on the specified network adapter. +func (h *Handle) AddNetAdapterIPAddress(adapterName string, ipConfig *net.IPNet, gateway string) error { + ipStr := strings.Split(ipConfig.String(), "/") + cmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s`, adapterName, ipStr[0], ipStr[1]) + if gateway != "" { + cmd = fmt.Sprintf("%s -DefaultGateway %s", cmd, gateway) + } + _, err := runCommand(cmd) + // If the address already exists, ignore the error. + if err != nil && !strings.Contains(err.Error(), "already exists") { + return err + } + return nil +} + +// RemoveNetAdapterIPAddress removes the specified IP address from the specified network adapter. +func (h *Handle) RemoveNetAdapterIPAddress(adapterName string, ipAddr net.IP) error { + cmd := fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -Confirm:$false`, adapterName, ipAddr.String()) + _, err := runCommand(cmd) + // If the address does not exist, ignore the error. + if err != nil && !strings.Contains(err.Error(), "No matching") { + return err + } + return nil +} + +// EnableIPForwarding enables the network adapter to forward IP packets that arrive at this network adapter to other ones. +func (h *Handle) EnableIPForwarding(adapterName string) error { + adapter, err := getAdapterInAllCompartmentsByName(adapterName) + if err != nil { + return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %w", adapterName, err) + } + return adapter.setForwarding(true, antreasyscall.AF_INET) +} + +func (h *Handle) RenameVMNetworkAdapter(networkName, macStr, newName string, renameNetAdapter bool) error { + cmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -SwitchName "%s" | ? MacAddress -EQ "%s" | Select-Object -Property Name | Format-Table -HideTableHeaders`, networkName, macStr) + stdout, err := runCommand(cmd) + if err != nil { + return err + } + stdout = strings.TrimSpace(stdout) + if len(stdout) == 0 { + return fmt.Errorf("unable to find vmnetwork adapter configured with uplink MAC address %s", macStr) + } + vmNetworkAdapterName := stdout + cmd = fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -Name "%s" | Rename-VMNetworkAdapter -NewName "%s"`, vmNetworkAdapterName, newName) + if _, err := runCommand(cmd); err != nil { + return err + } + if renameNetAdapter { + oriNetAdapterName := VirtualAdapterName(newName) + if err := h.RenameNetAdapter(oriNetAdapterName, newName); err != nil { + return err + } + } + return nil +} + +// EnableRSCOnVSwitch enables RSC in the vSwitch to reduce host CPU utilization and increase throughput for virtual +// workloads by coalescing multiple TCP segments into fewer, but larger segments. +func (h *Handle) EnableRSCOnVSwitch(vSwitch string) error { + cmd := fmt.Sprintf("Get-VMSwitch -ComputerName $(hostname) -Name %s | Select-Object -Property SoftwareRscEnabled | Format-Table -HideTableHeaders", vSwitch) + stdout, err := runCommand(cmd) + if err != nil { + return err + } + stdout = strings.TrimSpace(stdout) + // RSC doc says it applies to Windows Server 2019, which is the only Windows operating system supported so far, so + // this should not happen. However, this is only an optimization, no need to crash the process even if it's not + // supported. + // https://docs.microsoft.com/en-us/windows-server/networking/technologies/hpn/rsc-in-the-vswitch + if len(stdout) == 0 { + klog.Warning("Receive Segment Coalescing (RSC) is not supported by this Windows Server version") + return nil + } + if strings.EqualFold(stdout, "True") { + klog.Infof("Receive Segment Coalescing (RSC) for vSwitch %s is already enabled", vSwitch) + return nil + } + cmd = fmt.Sprintf("Set-VMSwitch -ComputerName $(hostname) -Name %s -EnableSoftwareRsc $True", vSwitch) + _, err = runCommand(cmd) + if err != nil { + return err + } + klog.Infof("Enabled Receive Segment Coalescing (RSC) for vSwitch %s", vSwitch) + return nil +} + +// GetDefaultGatewayByNetAdapterIndex returns the default gateway configured on the specified network adapter. +func (h *Handle) GetDefaultGatewayByNetAdapterIndex(adapterIndex int) (string, error) { + ip, defaultDestination, _ := net.ParseCIDR("0.0.0.0/0") + family := AddressFamilyByIP(ip) + filter := &Route{ + LinkIndex: adapterIndex, + DestinationSubnet: defaultDestination, + } + routes, err := h.RouteListFiltered(family, filter, RT_FILTER_IF|RT_FILTER_DST) + if err != nil { + return "", err + } + if len(routes) == 0 { + return "", nil + } + return routes[0].GatewayAddress.String(), nil +} + +// GetDNServersByNetAdapterIndex returns the DNS servers configured on the specified network adapter. +func (h *Handle) GetDNServersByNetAdapterIndex(adapterIndex int) (string, error) { + cmd := fmt.Sprintf("$(Get-DnsClientServerAddress -InterfaceIndex %d -AddressFamily IPv4).ServerAddresses", adapterIndex) + dnsServers, err := runCommand(cmd) + if err != nil { + return "", err + } + dnsServers = strings.ReplaceAll(dnsServers, "\r\n", ",") + dnsServers = strings.TrimRight(dnsServers, ",") + return dnsServers, nil +} + +// SetNetAdapterDNSServers configures DNS servers on network adapter. +func (h *Handle) SetNetAdapterDNSServers(adapterName, dnsServers string) error { + cmd := fmt.Sprintf(`Set-DnsClientServerAddress -InterfaceAlias "%s" -ServerAddresses "%s"`, adapterName, dnsServers) + if _, err := runCommand(cmd); err != nil { + return err + } + return nil +} + +func (h *Handle) NetAdapterExists(adapterName string) bool { + _, err := getAdapterInAllCompartmentsByName(adapterName) + if err != nil { + return false + } + return true +} + +// IsNetAdapterIPv4DHCPEnabled returns the IPv4 DHCP status on the specified network adapter. +func (h *Handle) IsNetAdapterIPv4DHCPEnabled(adapterName string) (bool, error) { + adapter, err := getAdapterInAllCompartmentsByName(adapterName) + if err != nil { + return false, err + } + ipv4DHCP := adapter.flags&IP_ADAPTER_DHCP_ENABLED != 0 + return ipv4DHCP, nil +} + +// SetNetAdapterMTU configures network adapter MTU on host for Pods. MTU change cannot be realized with HNSEndpoint because +// there's no MTU field in HNSEndpoint: +// https://github.com/Microsoft/hcsshim/blob/4a468a6f7ae547974bc32911395c51fb1862b7df/internal/hns/hnsendpoint.go#L12 +func (h *Handle) SetNetAdapterMTU(adapterName string, mtu int) error { + adapter, err := getAdapterInAllCompartmentsByName(adapterName) + if err != nil { + return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %w", adapterName, err) + } + return adapter.setMTU(mtu, antreasyscall.AF_INET) +} + +func AddressFamilyByIP(ip net.IP) uint16 { + if ip.To4() != nil { + return antreasyscall.AF_INET + } + return antreasyscall.AF_INET6 +} + +func VirtualAdapterName(name string) string { + return fmt.Sprintf("%s (%s)", ContainerVNICPrefix, name) +} + +func toMibIPForwardRow(r *Route) *antreasyscall.MibIPForwardRow { + row := antreasyscall.NewIPForwardRow() + row.DestinationPrefix = *antreasyscall.NewAddressPrefixFromIPNet(r.DestinationSubnet) + row.NextHop = *antreasyscall.NewRawSockAddrInetFromIP(r.GatewayAddress) + row.Metric = uint32(r.RouteMetric) + row.Index = uint32(r.LinkIndex) + return row +} + +func (h *Handle) AddNetRoute(route *Route) error { + if route == nil { + return nil + } + row := toMibIPForwardRow(route) + if err := antreaNetIO.CreateIPForwardEntry(row); err != nil { + return fmt.Errorf("failed to create new IPForward row: %w", err) + } + return nil +} + +func (h *Handle) RemoveNetRoute(route *Route) error { + if route == nil || route.DestinationSubnet == nil { + return nil + } + family := AddressFamilyByIP(route.DestinationSubnet.IP) + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return fmt.Errorf("unable to list Windows IPForward rows: %w", err) + } + for i := range rows { + row := rows[i] + if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) && row.NextHop.IP().Equal(route.GatewayAddress) { + if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { + return fmt.Errorf("failed to delete existing route with nextHop %s: %w", route.GatewayAddress, err) + } + } + } + return nil +} + +func (h *Handle) ReplaceNetRoute(route *Route) error { + if route == nil || route.DestinationSubnet == nil { + return nil + } + family := AddressFamilyByIP(route.DestinationSubnet.IP) + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return fmt.Errorf("unable to list Windows IPForward rows: %w", err) + } + for i := range rows { + row := rows[i] + if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) { + if row.NextHop.IP().Equal(route.GatewayAddress) { + return nil + } else { + if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { + return fmt.Errorf("failed to delete existing route with nextHop %s: %w", route.GatewayAddress, err) + } + } + } + } + return h.AddNetRoute(route) +} + +func (h *Handle) RouteListFiltered(family uint16, filter *Route, filterMask uint64) ([]Route, error) { + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return nil, fmt.Errorf("unable to list Windows IPForward rows: %w", err) + } + rts := make([]Route, 0, len(rows)) + for i := range rows { + route := routeFromIPForwardRow(&rows[i]) + if filter != nil { + if filterMask&RT_FILTER_IF != 0 && filter.LinkIndex != route.LinkIndex { + continue + } + if filterMask&RT_FILTER_DST != 0 && !iputil.IPNetEqual(filter.DestinationSubnet, route.DestinationSubnet) { + continue + } + if filterMask&RT_FILTER_GW != 0 && !filter.GatewayAddress.Equal(route.GatewayAddress) { + continue + } + if filterMask&RT_FILTER_METRIC != 0 && filter.RouteMetric != route.RouteMetric { + continue + } + } + rts = append(rts, *route) + } + return rts, nil +} + +func parseCmdResult(result string, columns int) [][]string { + scanner := bufio.NewScanner(strings.NewReader(result)) + parsed := [][]string{} + for scanner.Scan() { + items := strings.Fields(scanner.Text()) + if len(items) < columns { + // Skip if an empty line or something similar + continue + } + parsed = append(parsed, items) + } + return parsed +} + +func (h *Handle) AddNetNat(netNatName string, subnetCIDR *net.IPNet) error { + cmd := fmt.Sprintf("Get-NetNat -Name %s | Select-Object InternalIPInterfaceAddressPrefix | Format-Table -HideTableHeaders", netNatName) + if internalNet, err := runCommand(cmd); err != nil { + if !strings.Contains(err.Error(), "No MSFT_NetNat objects found") { + return fmt.Errorf("failed to check the existing netnat '%s': %w", netNatName, err) + } + } else { + if strings.Contains(internalNet, subnetCIDR.String()) { + klog.V(4).InfoS("The existing netnat matched the subnet CIDR", "name", internalNet, "subnetCIDR", subnetCIDR.String()) + return nil + } + klog.InfoS("Removing the existing NetNat", "name", netNatName, "internalIPInterfaceAddressPrefix", internalNet) + cmd = fmt.Sprintf("Remove-NetNat -Name %s -Confirm:$false", netNatName) + if _, err := runCommand(cmd); err != nil { + return fmt.Errorf("failed to remove the existing netnat '%s' with internalIPInterfaceAddressPrefix '%s': %w", netNatName, internalNet, err) + } + } + cmd = fmt.Sprintf("New-NetNat -Name %s -InternalIPInterfaceAddressPrefix %s", netNatName, subnetCIDR.String()) + _, err := runCommand(cmd) + if err != nil { + return fmt.Errorf("failed to add netnat '%s' with internalIPInterfaceAddressPrefix '%s': %w", netNatName, subnetCIDR.String(), err) + } + return nil +} + +func (h *Handle) ReplaceNetNatStaticMapping(mapping *NetNatStaticMapping) error { + staticMappingStr, err := getNetNatStaticMapping(mapping) + if err != nil { + return err + } + parsed := parseCmdResult(staticMappingStr, 6) + if len(parsed) > 0 { + items := parsed[0] + if items[4] == mapping.InternalIP.String() && items[5] == strconv.Itoa(int(mapping.InternalPort)) { + return nil + } + firstCol := strings.Split(items[0], ";") + id, err := strconv.Atoi(firstCol[1]) + if err != nil { + return err + } + if err := removeNetNatStaticMappingByID(mapping.Name, id); err != nil { + return err + } + } + return h.AddNetNatStaticMapping(mapping) +} + +// getNetNatStaticMapping checks if a NetNatStaticMapping exists. +func getNetNatStaticMapping(mapping *NetNatStaticMapping) (string, error) { + cmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", mapping.Name) + + fmt.Sprintf("|? ExternalIPAddress -EQ %s", mapping.ExternalIP) + + fmt.Sprintf("|? ExternalPort -EQ %d", mapping.ExternalPort) + + fmt.Sprintf("|? Protocol -EQ %s", mapping.Protocol) + + "| Format-Table -HideTableHeaders" + staticMappingStr, err := runCommand(cmd) + if err != nil && !strings.Contains(err.Error(), "No MSFT_NetNatStaticMapping objects found") { + return "", err + } + return staticMappingStr, nil +} + +// AddNetNatStaticMapping adds a static mapping to a NAT instance. +func (h *Handle) AddNetNatStaticMapping(mapping *NetNatStaticMapping) error { + cmd := fmt.Sprintf("Add-NetNatStaticMapping -NatName %s -ExternalIPAddress %s -ExternalPort %d -InternalIPAddress %s -InternalPort %d -Protocol %s", + mapping.Name, mapping.ExternalIP, mapping.ExternalPort, mapping.InternalIP, mapping.InternalPort, mapping.Protocol) + _, err := runCommand(cmd) + return err +} + +// RemoveNetNatStaticMapping removes a static mapping from a NetNat instance. +func (h *Handle) RemoveNetNatStaticMapping(mapping *NetNatStaticMapping) error { + staticMappingStr, err := getNetNatStaticMapping(mapping) + if err != nil { + return err + } + parsed := parseCmdResult(staticMappingStr, 6) + if len(parsed) == 0 { + return nil + } + + firstCol := strings.Split(parsed[0][0], ";") + id, err := strconv.Atoi(firstCol[1]) + if err != nil { + return err + } + return removeNetNatStaticMappingByID(mapping.Name, id) +} + +func removeNetNatStaticMappingByID(netNatName string, id int) error { + cmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", netNatName, id) + _, err := runCommand(cmd) + return err +} + +// RemoveNetNatStaticMappingsByNetNat removes all static mappings from a NetNat instance. +func (h *Handle) RemoveNetNatStaticMappingsByNetNat(netNatName string) error { + cmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -Confirm:$false", netNatName) + _, err := runCommand(cmd) + return err +} + +// getNetNeighbor gets neighbor cache entries with Get-NetNeighbor. +func getNetNeighbor(neighbor *Neighbor) ([]Neighbor, error) { + cmd := fmt.Sprintf("Get-NetNeighbor -InterfaceIndex %d -IPAddress %s | Format-Table -HideTableHeaders", neighbor.LinkIndex, neighbor.IPAddress.String()) + neighborsStr, err := runCommand(cmd) + if err != nil && !strings.Contains(err.Error(), "No matching MSFT_NetNeighbor objects") { + return nil, err + } + + parsed := parseCmdResult(neighborsStr, 5) + var neighbors []Neighbor + for _, items := range parsed { + idx, err := strconv.Atoi(items[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse the LinkIndex '%s': %w", items[0], err) + } + dstIP := net.ParseIP(items[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse the DestinationIP '%s': %w", items[1], err) + } + // Get-NetNeighbor returns LinkLayerAddress like "AA-BB-CC-DD-EE-FF". + mac, err := net.ParseMAC(strings.ReplaceAll(items[2], "-", ":")) + if err != nil { + return nil, fmt.Errorf("failed to parse the Gateway MAC '%s': %w", items[2], err) + } + neighbor := Neighbor{ + LinkIndex: idx, + IPAddress: dstIP, + LinkLayerAddress: mac, + State: items[3], + } + neighbors = append(neighbors, neighbor) + } + return neighbors, nil +} + +// newNetNeighbor creates a new neighbor cache entry with New-NetNeighbor. +func newNetNeighbor(neighbor *Neighbor) error { + cmd := fmt.Sprintf("New-NetNeighbor -InterfaceIndex %d -IPAddress %s -LinkLayerAddress %s -State Permanent", + neighbor.LinkIndex, neighbor.IPAddress, neighbor.LinkLayerAddress) + _, err := runCommand(cmd) + return err +} + +func removeNetNeighbor(neighbor *Neighbor) error { + cmd := fmt.Sprintf("Remove-NetNeighbor -InterfaceIndex %d -IPAddress %s -Confirm:$false", + neighbor.LinkIndex, neighbor.IPAddress) + _, err := runCommand(cmd) + return err +} + +func (h *Handle) ReplaceNetNeighbor(neighbor *Neighbor) error { + neighbors, err := getNetNeighbor(neighbor) + if err != nil { + return err + } + + if len(neighbors) == 0 { + if err := newNetNeighbor(neighbor); err != nil { + return err + } + return nil + } + for _, n := range neighbors { + if n.LinkLayerAddress.String() == neighbor.LinkLayerAddress.String() && n.State == neighbor.State { + return nil + } + } + if err := removeNetNeighbor(neighbor); err != nil { + return err + } + return newNetNeighbor(neighbor) +} + +func (h *Handle) GetVMSwitchNetAdapterName(vmSwitch string) (string, error) { + cmd := fmt.Sprintf(`Get-VMSwitchTeam -Name "%s" | select NetAdapterInterfaceDescription | Format-Table -HideTableHeaders`, vmSwitch) + out, err := runCommand(cmd) + if err != nil { + return "", err + } + out = strings.TrimSpace(out) + // Remove the leading and trailing {} brackets + out = out[1 : len(out)-1] + cmd = fmt.Sprintf(`Get-NetAdapter -InterfaceDescription "%s" | select Name | Format-Table -HideTableHeaders`, out) + out, err = runCommand(cmd) + if err != nil { + return "", err + } + out = strings.TrimSpace(out) + return out, err +} + +func (h *Handle) VMSwitchExists(vmSwitch string) (bool, error) { + cmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)`, vmSwitch) + _, err := runCommand(cmd) + if err == nil { + return true, nil + } + if strings.Contains(err.Error(), fmt.Sprintf(`unable to find a virtual switch with name "%s"`, vmSwitch)) { + return false, nil + } + return false, err +} + +// AddVMSwitch creates a VMSwitch and enables OVS extension. Connection to VMSwitch is lost for few seconds. +// TODO: Handle for multiple interfaces +func (h *Handle) AddVMSwitch(adapterName, vmSwitch string) error { + cmd := fmt.Sprintf(`New-VMSwitch -Name "%s" -NetAdapterName "%s" -EnableEmbeddedTeaming $true -AllowManagementOS $true -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, vmSwitch, adapterName, ovsExtensionName) + _, err := runCommand(cmd) + if err != nil { + return err + } + return nil +} + +func (h *Handle) RemoveVMSwitch(vmSwitch string) error { + exists, err := h.VMSwitchExists(vmSwitch) + if err != nil { + return err + } + if exists { + cmd := fmt.Sprintf(`Remove-VMSwitch -Name "%s" -ComputerName $(hostname) -Force`, vmSwitch) + _, err = runCommand(cmd) + if err != nil { + return err + } + } + return nil +} + +type updateIPInterfaceFunc func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow + +type adapter struct { + net.Interface + compartmentID uint32 + flags uint32 +} + +func (a *adapter) setMTU(mtu int, family uint16) error { + if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { + newEntry := *entry + newEntry.NlMtu = uint32(mtu) + return &newEntry + }); err != nil { + return fmt.Errorf("unable to set IPInterface with MTU %d: %w", mtu, err) + } + return nil +} + +func (a *adapter) setForwarding(enabledForwarding bool, family uint16) error { + if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { + newEntry := *entry + newEntry.ForwardingEnabled = enabledForwarding + return &newEntry + }); err != nil { + return fmt.Errorf("unable to enable IPForwarding on network adapter: %w", err) + } + return nil +} + +func (a *adapter) setIPInterfaceEntry(family uint16, updateFunc updateIPInterfaceFunc) error { + if a.compartmentID > 1 { + runtime.LockOSThread() + defer func() { + hcsshim.SetCurrentThreadCompartmentId(0) + runtime.UnlockOSThread() + }() + if err := hcsshim.SetCurrentThreadCompartmentId(a.compartmentID); err != nil { + return fmt.Errorf("failed to change current thread's compartment '%d': %w", a.compartmentID, err) + } + } + ipInterfaceRow := &antreasyscall.MibIPInterfaceRow{Family: family, Index: uint32(a.Index)} + if err := antreaNetIO.GetIPInterfaceEntry(ipInterfaceRow); err != nil { + return fmt.Errorf("unable to get IPInterface entry with Index %d: %w", a.Index, err) + } + updatedRow := updateFunc(ipInterfaceRow) + updatedRow.SitePrefixLength = 0 + return antreaNetIO.SetIPInterfaceEntry(updatedRow) +} + +var ( + errInvalidInterfaceName = errors.New("invalid network interface name") + errNoSuchInterface = errors.New("no such network interface") +) + +func getAdapterInAllCompartmentsByName(name string) (*adapter, error) { + if name == "" { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName} + } + adapters, err := getAdaptersByName(name) + if err != nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + if len(adapters) == 0 { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface} + } + return &adapters[0], nil +} + +func (h *Handle) EnableVMSwitchOVSExtension(vmSwitch string) error { + cmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, vmSwitch, ovsExtensionName) + _, err := runCommand(cmd) + if err != nil { + return err + } + return nil +} + +// parseOVSExtensionOutput parses the VM extension output and returns the value of Enabled field. +func parseOVSExtensionOutput(s string) bool { + scanner := bufio.NewScanner(strings.NewReader(s)) + for scanner.Scan() { + temp := strings.Fields(scanner.Text()) + line := strings.Join(temp, "") + if strings.Contains(line, "Enabled") { + if strings.Contains(line, "True") { + return true + } + return false + } + } + return false +} + +func (h *Handle) IsVMSwitchOVSExtensionEnabled(vmSwitch string) (bool, error) { + cmd := fmt.Sprintf(`Get-VMSwitchExtension -VMSwitchName "%s" -ComputerName $(hostname) | ? Id -EQ "%s"`, vmSwitch, OVSExtensionID) + out, err := runCommand(cmd) + if err != nil { + return false, err + } + if !strings.Contains(out, ovsExtensionName) { + return false, fmt.Errorf("open vswitch extension driver is not installed") + } + return parseOVSExtensionOutput(out), nil +} + +func (h *Handle) RenameNetAdapter(oriName string, newName string) error { + cmd := fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, oriName, newName) + _, err := runCommand(cmd) + return err +} + +func getAdaptersByName(name string) ([]adapter, error) { + aas, err := adapterAddresses() + if err != nil { + return nil, err + } + var adapters []adapter + for _, aa := range aas { + ifName := windows.UTF16PtrToString(aa.FriendlyName) + if ifName != name { + continue + } + index := aa.IfIndex + if index == 0 { // ipv6IfIndex is a substitute for ifIndex + index = aa.Ipv6IfIndex + } + ifi := net.Interface{ + Index: int(index), + Name: ifName, + } + if aa.OperStatus == windows.IfOperStatusUp { + ifi.Flags |= net.FlagUp + } + // For now we need to infer link-layer service capabilities from media types. + // TODO: use MIB_IF_ROW2.AccessType now that we no longer support Windows XP. + switch aa.IfType { + case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394: + ifi.Flags |= net.FlagBroadcast | net.FlagMulticast + case windows.IF_TYPE_PPP, windows.IF_TYPE_TUNNEL: + ifi.Flags |= net.FlagPointToPoint | net.FlagMulticast + case windows.IF_TYPE_SOFTWARE_LOOPBACK: + ifi.Flags |= net.FlagLoopback | net.FlagMulticast + case windows.IF_TYPE_ATM: + ifi.Flags |= net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast // assume all services available; LANE, point-to-point and point-to-multipoint + } + if aa.Mtu == 0xffffffff { + ifi.MTU = -1 + } else { + ifi.MTU = int(aa.Mtu) + } + if aa.PhysicalAddressLength > 0 { + ifi.HardwareAddr = make(net.HardwareAddr, aa.PhysicalAddressLength) + copy(ifi.HardwareAddr, aa.PhysicalAddress[:]) + } + adapter := adapter{ + Interface: ifi, + compartmentID: aa.CompartmentId, + flags: aa.Flags, + } + adapters = append(adapters, adapter) + } + return adapters, nil +} + +// adapterAddresses returns a list of IpAdapterAddresses structures. The structure +// contains an IP adapter and flattened multiple IP addresses including unicast, anycast +// and multicast addresses. +// This function is copied from go/src/net/interface_windows.go, with a change that flag +// GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is introduced to query interfaces in all compartments, +// and GAA_FLAG_INCLUDE_ALL_INTERFACES is introduced to query all NDIS interfaces even they +// are not configured with any IP addresses, e.g., uplink. +func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { + flags := uint32(windows.GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_ALL_COMPARTMENTS | GAA_FLAG_INCLUDE_ALL_INTERFACES) + var b []byte + l := uint32(15000) // recommended initial size + for { + b = make([]byte, l) + err := getAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) + if err == nil { + if l == 0 { + return nil, nil + } + break + } + if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + if l <= uint32(len(b)) { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + } + var aas []*windows.IpAdapterAddresses + for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { + aas = append(aas, aa) + } + return aas, nil +} diff --git a/pkg/agent/util/winnet/net_windows_test.go b/pkg/agent/util/winnet/net_windows_test.go new file mode 100644 index 00000000000..d69436b2df0 --- /dev/null +++ b/pkg/agent/util/winnet/net_windows_test.go @@ -0,0 +1,1118 @@ +//go:build windows +// +build windows + +// Copyright 2023 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package winnet + +import ( + "fmt" + "net" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" + + antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" + antreasyscalltest "antrea.io/antrea/pkg/agent/util/syscall/testing" + "antrea.io/antrea/pkg/ovs/openflow" + "antrea.io/antrea/pkg/util/ip" +) + +var ( + testMACAddr, _ = net.ParseMAC("aa:bb:cc:dd:ee:ff") + ipv4Public = net.ParseIP("8.8.8.8") + ipv4PublicIPNet = ip.MustParseCIDR("8.8.8.8/32") + + testInvalidErr = fmt.Errorf("invalid") + + h = &Handle{} +) + +const ( + testVMSwitchName = "antrea-switch" + testAdapterName = "test-en0" +) + +func TestNetRouteString(t *testing.T) { + gw, subnet, _ := net.ParseCIDR("192.168.2.0/24") + testRoute := Route{ + LinkIndex: 1, + DestinationSubnet: subnet, + GatewayAddress: gw, + RouteMetric: MetricDefault, + } + gotRoute := testRoute.String() + assert.Equal(t, "LinkIndex: 1, DestinationSubnet: 192.168.2.0/24, GatewayAddress: 192.168.2.0, RouteMetric: 256", gotRoute) +} + +func TestNetRouteTranslation(t *testing.T) { + subnet := ip.MustParseCIDR("1.1.1.0/28") + oriRoute := &Route{ + LinkIndex: 27, + RouteMetric: 35, + DestinationSubnet: subnet, + GatewayAddress: net.ParseIP("1.1.1.254"), + } + row := toMibIPForwardRow(oriRoute) + newRoute := routeFromIPForwardRow(row) + assert.Equal(t, oriRoute, newRoute) +} + +func TestNetNeighborString(t *testing.T) { + testNeighbor := Neighbor{ + LinkIndex: 1, + IPAddress: net.ParseIP("169.254.0.253"), + LinkLayerAddress: testMACAddr, + State: "Permanent", + } + gotNeighbor := testNeighbor.String() + assert.Equal(t, "LinkIndex: 1, IPAddress: 169.254.0.253, LinkLayerAddress: aa:bb:cc:dd:ee:ff", gotNeighbor) +} + +func TestIsVirtualNetAdapter(t *testing.T) { + adapter := "test-adapter" + tests := []struct { + name string + commandOut string + commandErr error + adapter string + wantIsVirtual bool + }{ + { + name: "Virtual adapter", + commandOut: " true ", + wantIsVirtual: true, + }, + { + name: "Virtual adapter Err", + commandErr: testInvalidErr, + wantIsVirtual: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{ + fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s" | Select-Object -Property Virtual | Format-Table -HideTableHeaders`, adapter), + }, tc.commandOut, tc.commandErr, true) + gotIsVirtual, err := h.IsVirtualNetAdapter(adapter) + assert.Equal(t, tc.wantIsVirtual, gotIsVirtual) + assert.Equal(t, tc.commandErr, err) + }) + } +} + +func TestGetDNServersByNetAdapterIndex(t *testing.T) { + testIndex := 1 + tests := []struct { + name string + commandOut string + commandErr error + wantDNSServer string + }{ + { + name: "Index Success", + commandOut: "hello\r\nworld\r\n\r\n", + wantDNSServer: "hello,world", + }, + { + name: "Index Error", + commandOut: "fail", + commandErr: testInvalidErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{ + fmt.Sprintf("$(Get-DnsClientServerAddress -InterfaceIndex %d -AddressFamily IPv4).ServerAddresses", testIndex), + }, tc.commandOut, tc.commandErr, true) + gotDNSServer, err := h.GetDNServersByNetAdapterIndex(testIndex) + assert.Equal(t, tc.wantDNSServer, gotDNSServer) + assert.Equal(t, tc.commandErr, err) + }) + } +} + +func TestNetAdapterExists(t *testing.T) { + tests := []struct { + name string + testNetInterfaceName string + testAdapterAddresses *windows.IpAdapterAddresses + }{ + { + name: "Normal Exist", + testNetInterfaceName: "host", + testAdapterAddresses: createTestAdapterAddresses("host"), + }, + { + name: "Interface not exist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockGetAdaptersAddresses(t, tc.testAdapterAddresses, nil) + gotExists := h.NetAdapterExists(tc.testNetInterfaceName) + assert.Equal(t, tc.testNetInterfaceName != "", gotExists) + }) + } +} + +func TestSetNetAdapterMTU(t *testing.T) { + testName := "host" + testAdapterAddresses := createTestAdapterAddresses(testName) + testMTU := 2 + tests := []struct { + name string + testNetInterfaceName string + testAdapterAddresses *windows.IpAdapterAddresses + getIPInterfaceErr error + setIPInterfaceErr error + wantErr error + }{ + { + name: "Set Success", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, + }, + { + name: "Interface name invalid", + wantErr: fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %w", "", + &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}), + }, + { + name: "Get Interface Err", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, + getIPInterfaceErr: fmt.Errorf("IP interface not found"), + wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: %w", testMTU, + fmt.Errorf("unable to get IPInterface entry with Index %d: %w", (int)(testAdapterAddresses.IfIndex), fmt.Errorf("IP interface not found"))), + }, + { + name: "Set Interface Err", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, + setIPInterfaceErr: fmt.Errorf("IP interface set error"), + wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: %w", testMTU, fmt.Errorf("IP interface set error")), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockGetAdaptersAddresses(t, tc.testAdapterAddresses, nil) + mockAntreaNetIO(t, + &antreasyscalltest.MockNetIO{ + GetIPInterfaceEntryErr: tc.getIPInterfaceErr, + SetIPInterfaceEntryErr: tc.setIPInterfaceErr}) + gotErr := h.SetNetAdapterMTU(tc.testNetInterfaceName, testMTU) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestReplaceNetRoute(t *testing.T) { + subnet := ip.MustParseCIDR("1.1.1.0/28") + testGateway := net.ParseIP("1.1.1.254") + testIndex := uint32(27) + testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, testGateway) + testRoute := Route{ + LinkIndex: (int)(testIPForwardRow.Index), + DestinationSubnet: subnet, + GatewayAddress: net.ParseIP("1.1.1.254"), + RouteMetric: MetricDefault, + } + listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: %w", fmt.Errorf("unable to list IP forward entry")) + deleteIPForwardEntryErr := fmt.Errorf("failed to delete existing route with nextHop %s: %w", testRoute.GatewayAddress, fmt.Errorf("unable to delete IP forward entry")) + createIPForwardEntryErr := fmt.Errorf("failed to create new IPForward row: %w", fmt.Errorf("unable to create IP forward entry")) + tests := []struct { + name string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error + createIPForwardErr error + deleteIPForwardErr error + wantErr error + }{ + { + name: "Replace Success", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, + }, + { + name: "Same GatewayAddress", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, testGateway)}, + }, + { + name: "List Rows Err", + listRowsErr: fmt.Errorf("unable to list IP forward entry"), + wantErr: listIPForwardRowsErr, + }, + { + name: "Delete Ip Forward Entry Err", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, + deleteIPForwardErr: fmt.Errorf("unable to delete IP forward entry"), + wantErr: deleteIPForwardEntryErr, + }, + { + name: "Add Route Err", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, + createIPForwardErr: fmt.Errorf("unable to create IP forward entry"), + wantErr: createIPForwardEntryErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockAntreaNetIO(t, + &antreasyscalltest.MockNetIO{ + CreateIPForwardEntryErr: tc.createIPForwardErr, + DeleteIPForwardEntryErr: tc.deleteIPForwardErr, + ListIPForwardRowsErr: tc.listRowsErr, + IPForwardRows: tc.listRows}) + gotErr := h.ReplaceNetRoute(&testRoute) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestRemoveNetRoute(t *testing.T) { + subnet := ip.MustParseCIDR("1.1.1.0/28") + testGateway := net.ParseIP("1.1.1.254") + testIndex := uint32(27) + testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, testGateway) + testRoute := Route{ + LinkIndex: (int)(testIPForwardRow.Index), + DestinationSubnet: subnet, + GatewayAddress: testGateway, + RouteMetric: MetricDefault, + } + listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: %w", fmt.Errorf("unable to list IP forward entry")) + deleteIPForwardEntryErr := fmt.Errorf("failed to delete existing route with nextHop %s: %w", testRoute.GatewayAddress, fmt.Errorf("unable to delete IP forward entry")) + tests := []struct { + name string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error + deleteIPForwardErr error + wantErr error + }{ + { + name: "Remove Success", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, testGateway)}, + }, + { + name: "List Rows Err", + listRowsErr: fmt.Errorf("unable to list IP forward entry"), + wantErr: listIPForwardRowsErr, + }, + { + name: "Remove Failed", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, testGateway)}, + deleteIPForwardErr: fmt.Errorf("unable to delete IP forward entry"), + wantErr: deleteIPForwardEntryErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockAntreaNetIO(t, + &antreasyscalltest.MockNetIO{ + DeleteIPForwardEntryErr: tc.deleteIPForwardErr, + ListIPForwardRowsErr: tc.listRowsErr, + IPForwardRows: tc.listRows}) + gotErr := h.RemoveNetRoute(&testRoute) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestRouteListFiltered(t *testing.T) { + subnet1 := ip.MustParseCIDR("1.1.1.0/28") + subnet2 := ip.MustParseCIDR("1.1.1.128/28") + testGateway1 := net.ParseIP("1.1.1.254") + testGateway2 := net.ParseIP("1.1.1.254") + testIndex1 := uint32(27) + testIndex2 := uint32(28) + testIPForwardRow1 := createTestMibIPForwardRow(testIndex1, subnet1, testGateway1) + testIPForwardRow2 := createTestMibIPForwardRow(testIndex2, subnet2, testGateway2) + testRoute1 := Route{ + LinkIndex: (int)(testIPForwardRow1.Index), + DestinationSubnet: subnet1, + GatewayAddress: testGateway1, + RouteMetric: MetricDefault, + } + testRoute2 := Route{ + LinkIndex: (int)(testIPForwardRow2.Index), + DestinationSubnet: subnet2, + GatewayAddress: testGateway2, + RouteMetric: MetricDefault, + } + listRows := []antreasyscall.MibIPForwardRow{ + createTestMibIPForwardRow(testIndex1, subnet1, testGateway1), + createTestMibIPForwardRow(testIndex2, subnet2, testGateway2), + } + + listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: %w", fmt.Errorf("unable to list IP forward entry")) + tests := []struct { + name string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error + filterRoute *Route + filterMasks uint64 + wantRoutes []Route + wantErr error + }{ + { + name: "List Rows Err", + listRowsErr: fmt.Errorf("unable to list IP forward entry"), + wantErr: listIPForwardRowsErr, + }, + { + name: "Filter Link Index", + listRows: listRows, + filterRoute: &Route{ + LinkIndex: (int)(testIPForwardRow1.Index), + }, + filterMasks: RT_FILTER_IF, + wantRoutes: []Route{testRoute1}, + }, + { + name: "Filter Destination", + listRows: listRows, + filterRoute: &Route{ + DestinationSubnet: subnet1, + }, + filterMasks: RT_FILTER_DST, + wantRoutes: []Route{testRoute1}, + }, + { + name: "Filter Gateway", + listRows: listRows, + filterRoute: &Route{ + GatewayAddress: testGateway1, + }, + filterMasks: RT_FILTER_GW, + wantRoutes: []Route{testRoute1, testRoute2}, + }, + { + name: "Filter Metric", + listRows: listRows, + filterRoute: &Route{ + RouteMetric: MetricDefault, + }, + filterMasks: RT_FILTER_METRIC, + wantRoutes: []Route{testRoute1, testRoute2}, + }, + { + name: "Multiple Filters", + listRows: listRows, + filterRoute: &Route{ + LinkIndex: (int)(testIPForwardRow1.Index), + DestinationSubnet: subnet1, + GatewayAddress: testGateway1, + RouteMetric: MetricDefault, + }, + filterMasks: RT_FILTER_IF | RT_FILTER_DST | RT_FILTER_GW | RT_FILTER_METRIC, + wantRoutes: []Route{testRoute1}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockAntreaNetIO(t, + &antreasyscalltest.MockNetIO{ + ListIPForwardRowsErr: tc.listRowsErr, + IPForwardRows: tc.listRows}) + routes, gotErr := h.RouteListFiltered(antreasyscall.AF_INET, tc.filterRoute, tc.filterMasks) + assert.Equal(t, tc.wantErr, gotErr) + assert.ElementsMatch(t, tc.wantRoutes, routes) + }) + } +} + +func TestAddNetNat(t *testing.T) { + notFoundErr := fmt.Errorf("received error No MSFT_NetNat objects found") + testNetNat := "test-nat" + testSubnetCIDR := &net.IPNet{ + IP: net.ParseIP("192.168.1.21"), + Mask: net.CIDRMask(32, 32), + } + getCmd := fmt.Sprintf(`Get-NetNat -Name %s | Select-Object InternalIPInterfaceAddressPrefix | Format-Table -HideTableHeaders`, testNetNat) + removeCmd := fmt.Sprintf("Remove-NetNat -Name %s -Confirm:$false", testNetNat) + newCmd := fmt.Sprintf(`New-NetNat -Name %s -InternalIPInterfaceAddressPrefix %s`, testNetNat, testSubnetCIDR.String()) + tests := []struct { + name string + commandOut string + commandErr error + wantCmds []string + wantErr error + }{ + { + name: "New Net Nat", + commandOut: "0.0.0.0/32", + wantCmds: []string{getCmd, removeCmd, newCmd}, + }, + { + name: "Net Nat Not Found", + commandErr: testInvalidErr, + wantCmds: []string{getCmd}, + wantErr: fmt.Errorf("failed to check the existing netnat '%s': %w", testNetNat, testInvalidErr), + }, + { + name: "Net Nat Exist", + commandOut: "192.168.1.21/32", + wantCmds: []string{getCmd}, + }, + { + name: "Net Nat Add Fail", + commandErr: notFoundErr, + wantCmds: []string{getCmd, newCmd}, + wantErr: fmt.Errorf("failed to add netnat '%s' with internalIPInterfaceAddressPrefix '%s': %w", testNetNat, testSubnetCIDR.String(), notFoundErr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.AddNetNat(testNetNat, testSubnetCIDR) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestReplaceNetNatStaticMapping(t *testing.T) { + notFoundErr := fmt.Errorf("received error No MSFT_NetNatStaticMapping objects found") + testNetNatName := "test-nat" + testExternalPort, testInternalPort := (uint16)(80), (uint16)(8080) + testExternalIPAddr, testInternalIPAddr := "10.10.0.1", "192.0.2.179" + testProto := openflow.ProtocolTCP + testNetNat := &NetNatStaticMapping{ + Name: testNetNatName, + ExternalIP: net.ParseIP(testExternalIPAddr), + ExternalPort: testExternalPort, + InternalIP: net.ParseIP(testInternalIPAddr), + InternalPort: testInternalPort, + Protocol: testProto, + } + + getCmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", testNetNatName) + + fmt.Sprintf("|? ExternalIPAddress -EQ %s", testExternalIPAddr) + + fmt.Sprintf("|? ExternalPort -EQ %d", testExternalPort) + + fmt.Sprintf("|? Protocol -EQ %s", testProto) + + "| Format-Table -HideTableHeaders" + removeCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", testNetNatName, 1) + addCmd := fmt.Sprintf("Add-NetNatStaticMapping -NatName %s -ExternalIPAddress %s -ExternalPort %d -InternalIPAddress %s -InternalPort %d -Protocol %s", + testNetNatName, testExternalIPAddr, testExternalPort, testInternalIPAddr, testInternalPort, testProto) + type testFormat struct { + name string + commandOut string + commandErr error + wantCmds []string + wantErr error + } + tests := []testFormat{ + { + name: "Replace Net Nat", + commandOut: "0;1 nil nil nil 192.168.1.21 80", + wantCmds: []string{getCmd, removeCmd, addCmd}, + }, + { + name: "Get Net Nat Err", + commandErr: testInvalidErr, + wantCmds: []string{getCmd}, + wantErr: testInvalidErr, + }, + { + name: "Remove Net Nat Err", + commandOut: "0;1 nil nil nil 192.168.1.21 80", + commandErr: notFoundErr, + wantCmds: []string{getCmd, removeCmd}, + wantErr: notFoundErr, + }, + { + name: "Add Net Nat Err", + commandOut: "empty", + commandErr: notFoundErr, + wantCmds: []string{getCmd, addCmd}, + wantErr: notFoundErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.ReplaceNetNatStaticMapping(testNetNat) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestRemoveNetNatStaticMapping(t *testing.T) { + testNetNatName := "test-nat" + testExternalPort, testInternalPort := (uint16)(80), (uint16)(8080) + testExternalIPAddr, testInternalIPAddr := "10.10.0.1", "192.0.2.179" + testProto := openflow.ProtocolTCP + testNetNat := &NetNatStaticMapping{ + Name: testNetNatName, + ExternalIP: net.ParseIP(testExternalIPAddr), + ExternalPort: testExternalPort, + InternalIP: net.ParseIP(testInternalIPAddr), + InternalPort: testInternalPort, + Protocol: testProto, + } + getCmd := fmt.Sprintf("Get-NetNatStaticMapping -NatName %s", testNetNatName) + + fmt.Sprintf("|? ExternalIPAddress -EQ %s", testExternalIPAddr) + + fmt.Sprintf("|? ExternalPort -EQ %d", testExternalPort) + + fmt.Sprintf("|? Protocol -EQ %s", testProto) + + "| Format-Table -HideTableHeaders" + removeIDCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -StaticMappingID %d -Confirm:$false", testNetNatName, 1) + removeCmd := fmt.Sprintf("Remove-NetNatStaticMapping -NatName %s -Confirm:$false", testNetNatName) + tests := []struct { + name string + commandOut string + commandErr error + wantCmds []string + wantErr error + }{ + { + name: "Remove Net Nat Static Mapping", + commandOut: "0;1 nil nil nil 192.0.02.179 8080", + wantCmds: []string{getCmd, removeIDCmd, removeCmd}, + }, + { + name: "Remove Err", + commandErr: testInvalidErr, + wantCmds: []string{getCmd, removeCmd}, + wantErr: testInvalidErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, false) + gotErr := h.RemoveNetNatStaticMapping(testNetNat) + assert.Equal(t, tc.wantErr, gotErr) + assert.Equal(t, tc.wantErr, gotErr) + gotErr = h.RemoveNetNatStaticMappingsByNetNat(testNetNat.Name) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestReplaceNetNeighbor(t *testing.T) { + netNeighborNotFoundErr := fmt.Errorf("received error No matching MSFT_NetNeighbor objects") + testNeighbor := &Neighbor{ + LinkIndex: 1, + IPAddress: net.ParseIP("169.254.0.253"), + LinkLayerAddress: testMACAddr, + State: "Permanent", + } + getCmd := fmt.Sprintf("Get-NetNeighbor -InterfaceIndex %d -IPAddress %s | Format-Table -HideTableHeaders", testNeighbor.LinkIndex, testNeighbor.IPAddress.String()) + newCmd := fmt.Sprintf("New-NetNeighbor -InterfaceIndex %d -IPAddress %s -LinkLayerAddress %s -State Permanent", + testNeighbor.LinkIndex, testNeighbor.IPAddress, testNeighbor.LinkLayerAddress) + removeCmd := fmt.Sprintf("Remove-NetNeighbor -InterfaceIndex %d -IPAddress %s -Confirm:$false", + testNeighbor.LinkIndex, testNeighbor.IPAddress) + type testFormat struct { + name string + commandOut string + commandErr error + wantCmds []string + wantErr error + } + tests := []testFormat{ + { + name: "Replace Neighbor", + commandOut: "1 169.254.1.253 aa:bb:cc:dd:ff:ff Permanent nil", + wantCmds: []string{getCmd, removeCmd, newCmd}, + }, + { + name: "Get Net Neighbor Err", + commandErr: testInvalidErr, + wantCmds: []string{getCmd}, + wantErr: testInvalidErr, + }, + { + name: "Remove Net Neighbor Err", + commandOut: "1 169.254.1.253 aa:bb:cc:dd:ff:ff Permanent nil", + commandErr: netNeighborNotFoundErr, + wantCmds: []string{getCmd, removeCmd}, + wantErr: netNeighborNotFoundErr, + }, + { + name: "New Net Neighbor Err", + commandErr: netNeighborNotFoundErr, + wantCmds: []string{getCmd, newCmd}, + wantErr: netNeighborNotFoundErr, + }, + { + name: "Duplicate Neighbor", + commandOut: "1 169.254.0.253 aa:bb:cc:dd:ee:ff Permanent nil", + wantCmds: []string{getCmd}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.ReplaceNetNeighbor(testNeighbor) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestVirtualAdapterName(t *testing.T) { + gotName := VirtualAdapterName("0") + assert.Equal(t, "vEthernet (0)", gotName) +} + +func TestRenameNetAdapter(t *testing.T) { + tests := []struct { + name string + commandOut string + commandErr error + wantErr error + }{ + { + name: "Rename Interface", + commandOut: "success", + }, + { + name: "Rename Err", + commandErr: testInvalidErr, + wantErr: fmt.Errorf("invalid"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{ + fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, "test1", "test2"), + }, tc.commandOut, tc.commandErr, false) + gotErr := h.RenameNetAdapter("test1", "test2") + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestAddVMSwitch(t *testing.T) { + testSwitchName := "test-switch" + tests := []struct { + name string + commandErr error + wantErr error + }{ + { + name: "Success", + }, + { + name: "Error", + commandErr: testInvalidErr, + wantErr: fmt.Errorf("invalid"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{fmt.Sprintf(`New-VMSwitch -Name "%s" -NetAdapterName "%s" -EnableEmbeddedTeaming $true -AllowManagementOS $true -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, testVMSwitchName, testSwitchName, ovsExtensionName)}, "", tc.commandErr, false) + gotErr := h.AddVMSwitch(testSwitchName, testVMSwitchName) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestEnableVMSwitchOVSExtension(t *testing.T) { + tests := []struct { + name string + commandErr error + wantErr error + }{ + { + name: "Enable", + }, + { + name: "Error", + commandErr: testInvalidErr, + wantErr: fmt.Errorf("invalid"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)| Enable-VMSwitchExtension "%s"`, testVMSwitchName, ovsExtensionName)}, "", tc.commandErr, false) + gotErr := h.EnableVMSwitchOVSExtension(testVMSwitchName) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestIsVMSwitchOVSExtensionEnabled(t *testing.T) { + tests := []struct { + name string + commandOut string + commandErr error + wantErr error + wantRes bool + }{ + { + name: "Enabled", + commandOut: "Open vSwitch Extension Enabled True", + wantRes: true, + }, + { + name: "Not enabled", + commandOut: "Open vSwitch Extension False", + wantRes: false, + }, + { + name: "Error", + commandErr: testInvalidErr, + wantErr: fmt.Errorf("invalid"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, []string{fmt.Sprintf(`Get-VMSwitchExtension -VMSwitchName "%s" -ComputerName $(hostname) | ? Id -EQ "%s"`, testVMSwitchName, OVSExtensionID)}, tc.commandOut, tc.commandErr, false) + res, gotErr := h.IsVMSwitchOVSExtensionEnabled(testVMSwitchName) + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestGetVMSwitchInterfaceName(t *testing.T) { + getVMCmd := fmt.Sprintf(`Get-VMSwitchTeam -Name "%s" | select NetAdapterInterfaceDescription | Format-Table -HideTableHeaders`, testVMSwitchName) + getAdapterCmd := fmt.Sprintf(`Get-NetAdapter -InterfaceDescription "%s" | select Name | Format-Table -HideTableHeaders`, "test") + tests := []struct { + name string + commandOut string + commandErr error + wantCmds []string + wantName string + wantErr error + }{ + { + name: "Get Interface Name", + commandOut: " {test} ", + wantCmds: []string{getVMCmd, getAdapterCmd}, + wantName: "{test}", + }, + { + name: "Get Err", + commandErr: testInvalidErr, + wantCmds: []string{getVMCmd}, + wantErr: testInvalidErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotName, gotErr := h.GetVMSwitchNetAdapterName(testVMSwitchName) + assert.Equal(t, tc.wantName, gotName) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestRemoveVMSwitch(t *testing.T) { + getCmd := fmt.Sprintf(`Get-VMSwitch -Name "%s" -ComputerName $(hostname)`, testVMSwitchName) + removeCmd := fmt.Sprintf(`Remove-VMSwitch -Name "%s" -ComputerName $(hostname) -Force`, testVMSwitchName) + tests := []struct { + name string + commandOut string + commandErr error + wantCmds []string + wantErr error + }{ + { + name: "Remove VMSwitch", + commandOut: "true", + wantCmds: []string{getCmd, removeCmd}, + }, + { + name: "Get Err", + commandErr: testInvalidErr, + wantCmds: []string{getCmd}, + wantErr: testInvalidErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.RemoveVMSwitch(testVMSwitchName) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestGetAdapterInAllCompartmentsByName(t *testing.T) { + testName := "host" + testFlags := net.FlagUp | net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast + testAdapter := adapter{ + Interface: net.Interface{ + Index: 1, + Name: testName, + Flags: testFlags, + MTU: 1, + HardwareAddr: testMACAddr, + }, + compartmentID: 1, + flags: IP_ADAPTER_DHCP_ENABLED, + } + tests := []struct { + name string + testName string + testAdapters *windows.IpAdapterAddresses + testAdaptersErr error + wantAdapters *adapter + wantErr error + }{ + { + name: "Normal", + testName: testName, + testAdapters: createTestAdapterAddresses(testName), + wantAdapters: &testAdapter, + }, + { + name: "Invalid name", + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}, + }, + { + name: "adapter Err", + testName: testName, + testAdaptersErr: windows.ERROR_FILE_NOT_FOUND, + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: os.NewSyscallError("getadaptersaddresses", windows.ERROR_FILE_NOT_FOUND)}, + }, + { + name: "adapter not found", + testName: testName, + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockGetAdaptersAddresses(t, tc.testAdapters, tc.testAdaptersErr) + gotAdapters, gotErr := getAdapterInAllCompartmentsByName(tc.testName) + assert.EqualValues(t, tc.wantAdapters, gotAdapters) + assert.EqualValues(t, tc.wantErr, gotErr) + }) + } +} + +func TestEnableNetAdapter(t *testing.T) { + enableCmd := fmt.Sprintf(`Enable-NetAdapter -InterfaceAlias "%s"`, testAdapterName) + tests := []struct { + name string + commandErr error + gwInterfaceErr error + wantCmds []string + }{ + { + name: "Set Link Up Normal", + wantCmds: []string{enableCmd}, + }, + { + name: "Enable Interface Err", + commandErr: fmt.Errorf("failed to enable interface test-en0: fail"), + gwInterfaceErr: fmt.Errorf("failed to enable interface %s", testAdapterName), + wantCmds: []string{enableCmd}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, "", tc.commandErr, false) + err := h.EnableNetAdapter(testAdapterName) + if tc.gwInterfaceErr == nil { + require.NoError(t, err) + } else { + assert.ErrorContains(t, err, tc.gwInterfaceErr.Error()) + } + }) + } +} + +func TestRemoveNetAdapterIPAddress(t *testing.T) { + removeCmd := fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -Confirm:$false`, testAdapterName, ipv4Public.String()) + tests := []struct { + name string + ip net.IP + commandOut string + commandErr error + wantCmds []string + wantErr error + }{ + { + name: "Link Addr Remove Success", + ip: ipv4Public, + wantCmds: []string{removeCmd}, + }, + { + name: "Link Addr Remove Failure", + ip: ipv4Public, + commandErr: fmt.Errorf("fail"), + wantCmds: []string{removeCmd}, + wantErr: fmt.Errorf("fail"), + }, + { + name: "Link Addr Remove Failure with Error 'No Matching'", + ip: ipv4Public, + commandErr: fmt.Errorf("No matching"), + wantCmds: []string{removeCmd}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.RemoveNetAdapterIPAddress(testAdapterName, tc.ip) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestAddNetAdapterIPAddress(t *testing.T) { + ipStr := strings.Split(ipv4PublicIPNet.String(), "/") + configCmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s`, testAdapterName, ipStr[0], ipStr[1]) + gateway := "8.8.8.1" + tests := []struct { + name string + ipNet *net.IPNet + gateway string + commandOut string + commandErr error + wantCmds []string + wantErr error + }{ + { + name: "Configure Link IP Address Success", + ipNet: ipv4PublicIPNet, + commandOut: "success", + wantCmds: []string{configCmd}, + }, + { + name: "Configure Link IP Address and Gateway Success", + ipNet: ipv4PublicIPNet, + gateway: gateway, + commandOut: "success", + wantCmds: []string{fmt.Sprintf(`%s -DefaultGateway %s`, configCmd, gateway)}, + }, + { + name: "Configure Link IP Failure", + ipNet: ipv4PublicIPNet, + commandErr: fmt.Errorf("failed"), + wantErr: fmt.Errorf("failed"), + wantCmds: []string{configCmd}, + }, + { + name: "Configure Link IP Failure with Error 'already exists'", + ipNet: ipv4PublicIPNet, + commandErr: fmt.Errorf("already exists"), + wantCmds: []string{configCmd}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true) + gotErr := h.AddNetAdapterIPAddress(testAdapterName, tc.ipNet, tc.gateway) + assert.Equal(t, tc.wantErr, gotErr) + }) + } + +} + +func createTestAdapterAddresses(name string) *windows.IpAdapterAddresses { + testPhysicalAddress := [8]byte{} + copy(testPhysicalAddress[:6], testMACAddr) + testName, _ := windows.UTF16FromString(name) + return &windows.IpAdapterAddresses{ + FriendlyName: &testName[0], + IfIndex: 1, + OperStatus: windows.IfOperStatusUp, + IfType: windows.IF_TYPE_ATM, + Mtu: 1, + PhysicalAddressLength: 6, + PhysicalAddress: testPhysicalAddress, + CompartmentId: 1, + Flags: IP_ADAPTER_DHCP_ENABLED, + } +} + +func createTestMibIPForwardRow(index uint32, subnet *net.IPNet, ip net.IP) antreasyscall.MibIPForwardRow { + return antreasyscall.MibIPForwardRow{ + Index: index, + Metric: MetricDefault, + DestinationPrefix: *antreasyscall.NewAddressPrefixFromIPNet(subnet), + NextHop: *antreasyscall.NewRawSockAddrInetFromIP(ip), + } +} + +func mockAntreaNetIO(t *testing.T, mockNetIO *antreasyscalltest.MockNetIO) { + originalNetIO := antreaNetIO + antreaNetIO = mockNetIO + t.Cleanup(func() { + antreaNetIO = originalNetIO + }) +} + +func mockGetAdaptersAddresses(t *testing.T, testAdaptersAddresses *windows.IpAdapterAddresses, err error) { + originalGetAdaptersAddresses := getAdaptersAddresses + getAdaptersAddresses = func(family uint32, flags uint32, reserved uintptr, adapterAddresses *windows.IpAdapterAddresses, sizePointer *uint32) (errcode error) { + if adapterAddresses != nil && testAdaptersAddresses != nil { + adapterAddresses.IfIndex = testAdaptersAddresses.IfIndex + adapterAddresses.FriendlyName = testAdaptersAddresses.FriendlyName + adapterAddresses.OperStatus = testAdaptersAddresses.OperStatus + adapterAddresses.IfType = testAdaptersAddresses.IfType + adapterAddresses.Mtu = testAdaptersAddresses.Mtu + adapterAddresses.PhysicalAddressLength = testAdaptersAddresses.PhysicalAddressLength + adapterAddresses.PhysicalAddress = testAdaptersAddresses.PhysicalAddress + adapterAddresses.CompartmentId = testAdaptersAddresses.CompartmentId + adapterAddresses.Flags = testAdaptersAddresses.Flags + } + return err + } + t.Cleanup(func() { + getAdaptersAddresses = originalGetAdaptersAddresses + }) +} + +// mockRunCommand mocks runCommand with a custom command output and error message. +// If exactMatch is enabled, this function asserts that the executed commands are +// exactly the same as wantCmds in terms of order and value. Otherwise, for tests +// with retry functions, the commands will be executed multiple times. This function +// asserts that wantCmds is strictly a subset of these executed commands. +func mockRunCommand(t *testing.T, wantCmds []string, commandOut string, err error, exactMatch bool) { + originalRunCommand := runCommand + actCmds := make([]string, 0) + runCommand = func(cmd string) (string, error) { + actCmds = append(actCmds, cmd) + return commandOut, err + } + t.Cleanup(func() { + runCommand = originalRunCommand + if wantCmds == nil { + assert.Empty(t, actCmds) + } else if exactMatch { + assert.Equal(t, wantCmds, actCmds) + } else { + assert.Subset(t, actCmds, wantCmds) + } + }) +} diff --git a/pkg/agent/util/winnet/testing/mock_net_windows.go b/pkg/agent/util/winnet/testing/mock_net_windows.go new file mode 100644 index 00000000000..8a68cf026d9 --- /dev/null +++ b/pkg/agent/util/winnet/testing/mock_net_windows.go @@ -0,0 +1,483 @@ +// Copyright 2024 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: antrea.io/antrea/pkg/agent/util/winnet (interfaces: Interface) +// +// Generated by this command: +// +// mockgen -copyright_file hack/boilerplate/license_header.raw.txt -destination pkg/agent/util/winnet/testing/mock_net_windows.go -package testing antrea.io/antrea/pkg/agent/util/winnet Interface +// +// Package testing is a generated GoMock package. +package testing + +import ( + net "net" + reflect "reflect" + + winnet "antrea.io/antrea/pkg/agent/util/winnet" + gomock "go.uber.org/mock/gomock" +) + +// MockInterface is a mock of Interface interface. +type MockInterface struct { + ctrl *gomock.Controller + recorder *MockInterfaceMockRecorder +} + +// MockInterfaceMockRecorder is the mock recorder for MockInterface. +type MockInterfaceMockRecorder struct { + mock *MockInterface +} + +// NewMockInterface creates a new mock instance. +func NewMockInterface(ctrl *gomock.Controller) *MockInterface { + mock := &MockInterface{ctrl: ctrl} + mock.recorder = &MockInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { + return m.recorder +} + +// AddNetAdapterIPAddress mocks base method. +func (m *MockInterface) AddNetAdapterIPAddress(arg0 string, arg1 *net.IPNet, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNetAdapterIPAddress", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNetAdapterIPAddress indicates an expected call of AddNetAdapterIPAddress. +func (mr *MockInterfaceMockRecorder) AddNetAdapterIPAddress(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNetAdapterIPAddress", reflect.TypeOf((*MockInterface)(nil).AddNetAdapterIPAddress), arg0, arg1, arg2) +} + +// AddNetNat mocks base method. +func (m *MockInterface) AddNetNat(arg0 string, arg1 *net.IPNet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNetNat", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNetNat indicates an expected call of AddNetNat. +func (mr *MockInterfaceMockRecorder) AddNetNat(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNetNat", reflect.TypeOf((*MockInterface)(nil).AddNetNat), arg0, arg1) +} + +// AddNetNatStaticMapping mocks base method. +func (m *MockInterface) AddNetNatStaticMapping(arg0 *winnet.NetNatStaticMapping) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNetNatStaticMapping", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNetNatStaticMapping indicates an expected call of AddNetNatStaticMapping. +func (mr *MockInterfaceMockRecorder) AddNetNatStaticMapping(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNetNatStaticMapping", reflect.TypeOf((*MockInterface)(nil).AddNetNatStaticMapping), arg0) +} + +// AddNetRoute mocks base method. +func (m *MockInterface) AddNetRoute(arg0 *winnet.Route) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNetRoute", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNetRoute indicates an expected call of AddNetRoute. +func (mr *MockInterfaceMockRecorder) AddNetRoute(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNetRoute", reflect.TypeOf((*MockInterface)(nil).AddNetRoute), arg0) +} + +// AddVMSwitch mocks base method. +func (m *MockInterface) AddVMSwitch(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddVMSwitch", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddVMSwitch indicates an expected call of AddVMSwitch. +func (mr *MockInterfaceMockRecorder) AddVMSwitch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddVMSwitch", reflect.TypeOf((*MockInterface)(nil).AddVMSwitch), arg0, arg1) +} + +// EnableIPForwarding mocks base method. +func (m *MockInterface) EnableIPForwarding(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableIPForwarding", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnableIPForwarding indicates an expected call of EnableIPForwarding. +func (mr *MockInterfaceMockRecorder) EnableIPForwarding(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableIPForwarding", reflect.TypeOf((*MockInterface)(nil).EnableIPForwarding), arg0) +} + +// EnableNetAdapter mocks base method. +func (m *MockInterface) EnableNetAdapter(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableNetAdapter", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnableNetAdapter indicates an expected call of EnableNetAdapter. +func (mr *MockInterfaceMockRecorder) EnableNetAdapter(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableNetAdapter", reflect.TypeOf((*MockInterface)(nil).EnableNetAdapter), arg0) +} + +// EnableRSCOnVSwitch mocks base method. +func (m *MockInterface) EnableRSCOnVSwitch(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableRSCOnVSwitch", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnableRSCOnVSwitch indicates an expected call of EnableRSCOnVSwitch. +func (mr *MockInterfaceMockRecorder) EnableRSCOnVSwitch(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableRSCOnVSwitch", reflect.TypeOf((*MockInterface)(nil).EnableRSCOnVSwitch), arg0) +} + +// EnableVMSwitchOVSExtension mocks base method. +func (m *MockInterface) EnableVMSwitchOVSExtension(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableVMSwitchOVSExtension", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnableVMSwitchOVSExtension indicates an expected call of EnableVMSwitchOVSExtension. +func (mr *MockInterfaceMockRecorder) EnableVMSwitchOVSExtension(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableVMSwitchOVSExtension", reflect.TypeOf((*MockInterface)(nil).EnableVMSwitchOVSExtension), arg0) +} + +// GetDNServersByNetAdapterIndex mocks base method. +func (m *MockInterface) GetDNServersByNetAdapterIndex(arg0 int) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNServersByNetAdapterIndex", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDNServersByNetAdapterIndex indicates an expected call of GetDNServersByNetAdapterIndex. +func (mr *MockInterfaceMockRecorder) GetDNServersByNetAdapterIndex(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNServersByNetAdapterIndex", reflect.TypeOf((*MockInterface)(nil).GetDNServersByNetAdapterIndex), arg0) +} + +// GetVMSwitchNetAdapterName mocks base method. +func (m *MockInterface) GetVMSwitchNetAdapterName(arg0 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVMSwitchNetAdapterName", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetVMSwitchNetAdapterName indicates an expected call of GetVMSwitchNetAdapterName. +func (mr *MockInterfaceMockRecorder) GetVMSwitchNetAdapterName(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMSwitchNetAdapterName", reflect.TypeOf((*MockInterface)(nil).GetVMSwitchNetAdapterName), arg0) +} + +// IsNetAdapterIPv4DHCPEnabled mocks base method. +func (m *MockInterface) IsNetAdapterIPv4DHCPEnabled(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNetAdapterIPv4DHCPEnabled", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsNetAdapterIPv4DHCPEnabled indicates an expected call of IsNetAdapterIPv4DHCPEnabled. +func (mr *MockInterfaceMockRecorder) IsNetAdapterIPv4DHCPEnabled(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNetAdapterIPv4DHCPEnabled", reflect.TypeOf((*MockInterface)(nil).IsNetAdapterIPv4DHCPEnabled), arg0) +} + +// IsNetAdapterStatusUp mocks base method. +func (m *MockInterface) IsNetAdapterStatusUp(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsNetAdapterStatusUp", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsNetAdapterStatusUp indicates an expected call of IsNetAdapterStatusUp. +func (mr *MockInterfaceMockRecorder) IsNetAdapterStatusUp(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNetAdapterStatusUp", reflect.TypeOf((*MockInterface)(nil).IsNetAdapterStatusUp), arg0) +} + +// IsVMSwitchOVSExtensionEnabled mocks base method. +func (m *MockInterface) IsVMSwitchOVSExtensionEnabled(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsVMSwitchOVSExtensionEnabled", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsVMSwitchOVSExtensionEnabled indicates an expected call of IsVMSwitchOVSExtensionEnabled. +func (mr *MockInterfaceMockRecorder) IsVMSwitchOVSExtensionEnabled(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsVMSwitchOVSExtensionEnabled", reflect.TypeOf((*MockInterface)(nil).IsVMSwitchOVSExtensionEnabled), arg0) +} + +// IsVirtualNetAdapter mocks base method. +func (m *MockInterface) IsVirtualNetAdapter(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsVirtualNetAdapter", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsVirtualNetAdapter indicates an expected call of IsVirtualNetAdapter. +func (mr *MockInterfaceMockRecorder) IsVirtualNetAdapter(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsVirtualNetAdapter", reflect.TypeOf((*MockInterface)(nil).IsVirtualNetAdapter), arg0) +} + +// NetAdapterExists mocks base method. +func (m *MockInterface) NetAdapterExists(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NetAdapterExists", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// NetAdapterExists indicates an expected call of NetAdapterExists. +func (mr *MockInterfaceMockRecorder) NetAdapterExists(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NetAdapterExists", reflect.TypeOf((*MockInterface)(nil).NetAdapterExists), arg0) +} + +// RemoveNetAdapterIPAddress mocks base method. +func (m *MockInterface) RemoveNetAdapterIPAddress(arg0 string, arg1 net.IP) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveNetAdapterIPAddress", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveNetAdapterIPAddress indicates an expected call of RemoveNetAdapterIPAddress. +func (mr *MockInterfaceMockRecorder) RemoveNetAdapterIPAddress(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNetAdapterIPAddress", reflect.TypeOf((*MockInterface)(nil).RemoveNetAdapterIPAddress), arg0, arg1) +} + +// RemoveNetNatStaticMapping mocks base method. +func (m *MockInterface) RemoveNetNatStaticMapping(arg0 *winnet.NetNatStaticMapping) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveNetNatStaticMapping", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveNetNatStaticMapping indicates an expected call of RemoveNetNatStaticMapping. +func (mr *MockInterfaceMockRecorder) RemoveNetNatStaticMapping(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNetNatStaticMapping", reflect.TypeOf((*MockInterface)(nil).RemoveNetNatStaticMapping), arg0) +} + +// RemoveNetNatStaticMappingsByNetNat mocks base method. +func (m *MockInterface) RemoveNetNatStaticMappingsByNetNat(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveNetNatStaticMappingsByNetNat", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveNetNatStaticMappingsByNetNat indicates an expected call of RemoveNetNatStaticMappingsByNetNat. +func (mr *MockInterfaceMockRecorder) RemoveNetNatStaticMappingsByNetNat(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNetNatStaticMappingsByNetNat", reflect.TypeOf((*MockInterface)(nil).RemoveNetNatStaticMappingsByNetNat), arg0) +} + +// RemoveNetRoute mocks base method. +func (m *MockInterface) RemoveNetRoute(arg0 *winnet.Route) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveNetRoute", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveNetRoute indicates an expected call of RemoveNetRoute. +func (mr *MockInterfaceMockRecorder) RemoveNetRoute(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNetRoute", reflect.TypeOf((*MockInterface)(nil).RemoveNetRoute), arg0) +} + +// RemoveVMSwitch mocks base method. +func (m *MockInterface) RemoveVMSwitch(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveVMSwitch", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveVMSwitch indicates an expected call of RemoveVMSwitch. +func (mr *MockInterfaceMockRecorder) RemoveVMSwitch(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveVMSwitch", reflect.TypeOf((*MockInterface)(nil).RemoveVMSwitch), arg0) +} + +// RenameNetAdapter mocks base method. +func (m *MockInterface) RenameNetAdapter(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenameNetAdapter", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenameNetAdapter indicates an expected call of RenameNetAdapter. +func (mr *MockInterfaceMockRecorder) RenameNetAdapter(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenameNetAdapter", reflect.TypeOf((*MockInterface)(nil).RenameNetAdapter), arg0, arg1) +} + +// RenameVMNetworkAdapter mocks base method. +func (m *MockInterface) RenameVMNetworkAdapter(arg0, arg1, arg2 string, arg3 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenameVMNetworkAdapter", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenameVMNetworkAdapter indicates an expected call of RenameVMNetworkAdapter. +func (mr *MockInterfaceMockRecorder) RenameVMNetworkAdapter(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenameVMNetworkAdapter", reflect.TypeOf((*MockInterface)(nil).RenameVMNetworkAdapter), arg0, arg1, arg2, arg3) +} + +// ReplaceNetNatStaticMapping mocks base method. +func (m *MockInterface) ReplaceNetNatStaticMapping(arg0 *winnet.NetNatStaticMapping) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplaceNetNatStaticMapping", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReplaceNetNatStaticMapping indicates an expected call of ReplaceNetNatStaticMapping. +func (mr *MockInterfaceMockRecorder) ReplaceNetNatStaticMapping(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceNetNatStaticMapping", reflect.TypeOf((*MockInterface)(nil).ReplaceNetNatStaticMapping), arg0) +} + +// ReplaceNetNeighbor mocks base method. +func (m *MockInterface) ReplaceNetNeighbor(arg0 *winnet.Neighbor) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplaceNetNeighbor", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReplaceNetNeighbor indicates an expected call of ReplaceNetNeighbor. +func (mr *MockInterfaceMockRecorder) ReplaceNetNeighbor(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceNetNeighbor", reflect.TypeOf((*MockInterface)(nil).ReplaceNetNeighbor), arg0) +} + +// ReplaceNetRoute mocks base method. +func (m *MockInterface) ReplaceNetRoute(arg0 *winnet.Route) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplaceNetRoute", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReplaceNetRoute indicates an expected call of ReplaceNetRoute. +func (mr *MockInterfaceMockRecorder) ReplaceNetRoute(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceNetRoute", reflect.TypeOf((*MockInterface)(nil).ReplaceNetRoute), arg0) +} + +// RouteListFiltered mocks base method. +func (m *MockInterface) RouteListFiltered(arg0 uint16, arg1 *winnet.Route, arg2 uint64) ([]winnet.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteListFiltered", arg0, arg1, arg2) + ret0, _ := ret[0].([]winnet.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteListFiltered indicates an expected call of RouteListFiltered. +func (mr *MockInterfaceMockRecorder) RouteListFiltered(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteListFiltered", reflect.TypeOf((*MockInterface)(nil).RouteListFiltered), arg0, arg1, arg2) +} + +// SetNetAdapterDNSServers mocks base method. +func (m *MockInterface) SetNetAdapterDNSServers(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetNetAdapterDNSServers", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetNetAdapterDNSServers indicates an expected call of SetNetAdapterDNSServers. +func (mr *MockInterfaceMockRecorder) SetNetAdapterDNSServers(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetAdapterDNSServers", reflect.TypeOf((*MockInterface)(nil).SetNetAdapterDNSServers), arg0, arg1) +} + +// SetNetAdapterMTU mocks base method. +func (m *MockInterface) SetNetAdapterMTU(arg0 string, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetNetAdapterMTU", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetNetAdapterMTU indicates an expected call of SetNetAdapterMTU. +func (mr *MockInterfaceMockRecorder) SetNetAdapterMTU(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetAdapterMTU", reflect.TypeOf((*MockInterface)(nil).SetNetAdapterMTU), arg0, arg1) +} + +// VMSwitchExists mocks base method. +func (m *MockInterface) VMSwitchExists(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VMSwitchExists", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// VMSwitchExists indicates an expected call of VMSwitchExists. +func (mr *MockInterfaceMockRecorder) VMSwitchExists(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VMSwitchExists", reflect.TypeOf((*MockInterface)(nil).VMSwitchExists), arg0) +} diff --git a/pkg/agent/util/winnet/types.go b/pkg/agent/util/winnet/types.go new file mode 100644 index 00000000000..b006f6f95cb --- /dev/null +++ b/pkg/agent/util/winnet/types.go @@ -0,0 +1,67 @@ +// Copyright 2024 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package winnet + +import ( + "fmt" + "net" + + binding "antrea.io/antrea/pkg/ovs/openflow" + iputil "antrea.io/antrea/pkg/util/ip" +) + +type Route struct { + LinkIndex int + DestinationSubnet *net.IPNet + GatewayAddress net.IP + RouteMetric int +} + +type Neighbor struct { + LinkIndex int + IPAddress net.IP + LinkLayerAddress net.HardwareAddr + State string +} + +type NetNatStaticMapping struct { + Name string + ExternalIP net.IP + ExternalPort uint16 + InternalIP net.IP + InternalPort uint16 + Protocol binding.Protocol +} + +func (r *Route) String() string { + return fmt.Sprintf("LinkIndex: %d, DestinationSubnet: %s, GatewayAddress: %s, RouteMetric: %d", + r.LinkIndex, r.DestinationSubnet, r.GatewayAddress, r.RouteMetric) +} + +func (r *Route) Equal(x Route) bool { + return x.LinkIndex == r.LinkIndex && + x.DestinationSubnet != nil && + r.DestinationSubnet != nil && + iputil.IPNetEqual(x.DestinationSubnet, r.DestinationSubnet) && + x.GatewayAddress.Equal(r.GatewayAddress) +} + +func (n *Neighbor) String() string { + return fmt.Sprintf("LinkIndex: %d, IPAddress: %s, LinkLayerAddress: %s", n.LinkIndex, n.IPAddress, n.LinkLayerAddress) +} + +func (n *NetNatStaticMapping) String() string { + return fmt.Sprintf("Name: %s, ExternalIP %s, ExternalPort: %d, InternalIP: %s, InternalPort: %d, Protocol: %s", n.Name, n.ExternalIP, n.ExternalPort, n.InternalIP, n.InternalPort, n.Protocol) +} diff --git a/test/integration/agent/net_windows_test.go b/test/integration/agent/net_windows_test.go index 6b5bd27134c..00bdce0fae7 100644 --- a/test/integration/agent/net_windows_test.go +++ b/test/integration/agent/net_windows_test.go @@ -26,10 +26,11 @@ import ( "antrea.io/antrea/pkg/agent/util" ps "antrea.io/antrea/pkg/agent/util/powershell" + "antrea.io/antrea/pkg/agent/util/winnet" ) func adapterName(name string) string { - return fmt.Sprintf("%s (%s)", util.ContainerVNICPrefix, name) + return fmt.Sprintf("%s (%s)", winnet.ContainerVNICPrefix, name) } // windowsHyperVEnabled checks if the Hyper-V is enabled on the host. @@ -160,6 +161,6 @@ func TestCreateHNSNetwork(t *testing.T) { assert.Equal(t, hnsNet.ManagementIP, nodeIP.String()) t.Logf("Enabling the Open vSwitch Extension for HNSNetwork '%s'", testNet) - err = util.EnableHNSNetworkExtension(hnsNet.Id, util.OVSExtensionID) + err = util.EnableHNSNetworkExtension(hnsNet.Id, winnet.OVSExtensionID) require.Nil(t, err, "No error expected when enabling the Open vSwitch Extension for the HNSNetwork") }