diff --git a/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go index ab0cf0ef7..a38f2f8a2 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go +++ b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go @@ -25,7 +25,6 @@ import ( "github.com/awslabs/operatorpkg/singleton" "github.com/patrickmn/go-cache" - // "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" "github.com/samber/lo" "go.uber.org/multierr" v1 "k8s.io/api/core/v1" diff --git a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go index 0ccd80fff..85d6e27e9 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go +++ b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go @@ -26,6 +26,7 @@ import ( "github.com/awslabs/operatorpkg/object" opstatus "github.com/awslabs/operatorpkg/status" + "github.com/patrickmn/go-cache" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" "github.com/Azure/karpenter-provider-azure/pkg/apis" @@ -64,7 +65,7 @@ var nodePool *karpv1.NodePool var nodeClass *v1alpha2.AKSNodeClass var cluster *state.Cluster var cloudProvider *cloudprovider.CloudProvider -var garbageCollectionController *garbagecollection.Controller +var virtualMachineGCController *garbagecollection.VirtualMachineController var prov *provisioning.Provisioner func TestAPIs(t *testing.T) { @@ -80,7 +81,7 @@ var _ = BeforeSuite(func() { // ctx, stop = context.WithCancel(ctx) azureEnv = test.NewEnvironment(ctx, env) cloudProvider = cloudprovider.New(azureEnv.InstanceTypesProvider, azureEnv.InstanceProvider, events.NewRecorder(&record.FakeRecorder{}), env.Client, azureEnv.ImageProvider) - garbageCollectionController = garbagecollection.NewController(env.Client, cloudProvider, azureEnv.InstanceProvider) + virtualMachineGCController = garbagecollection.NewVirtualMachineController(env.Client, cloudProvider, cache.New(time.Minute, time.Second)) fakeClock = &clock.FakeClock{} cluster = state.NewCluster(fakeClock, env.Client) prov = provisioning.NewProvisioner(env.Client, events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster, fakeClock) @@ -119,7 +120,7 @@ var _ = AfterEach(func() { // TODO: move before/after each into the tests (see AWS) // review tests themselves (very different from AWS?) // (e.g. AWS has not a single ExpectPRovisioned? why?) -var _ = Describe("GarbageCollection", func() { +var _ = Describe("VirtualMachine Garbage Collection", func() { var vm *armcompute.VirtualMachine var providerID string var err error @@ -147,7 +148,7 @@ var _ = Describe("GarbageCollection", func() { }) azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -180,7 +181,7 @@ var _ = Describe("GarbageCollection", func() { ids = append(ids, *vm.ID) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -233,7 +234,7 @@ var _ = Describe("GarbageCollection", func() { nodeClaims = append(nodeClaims, nodeClaim) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -259,7 +260,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -280,7 +281,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, nodeClaim, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).ToNot(HaveOccurred()) ExpectExists(ctx, env.Client, node) @@ -307,7 +308,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -323,7 +324,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -332,3 +333,4 @@ var _ = Describe("GarbageCollection", func() { }) }) }) + diff --git a/pkg/fake/azureresourcegraphapi.go b/pkg/fake/azureresourcegraphapi.go index 02e69cefb..e82ee7859 100644 --- a/pkg/fake/azureresourcegraphapi.go +++ b/pkg/fake/azureresourcegraphapi.go @@ -23,6 +23,7 @@ import ( "github.com/samber/lo" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" ) @@ -35,6 +36,7 @@ type AzureResourceGraphResourcesInput struct { type AzureResourceGraphBehavior struct { AzureResourceGraphResourcesBehavior MockedFunction[AzureResourceGraphResourcesInput, armresourcegraph.ClientResourcesResponse] VirtualMachinesAPI *VirtualMachinesAPI + NetworkInterfacesAPI *NetworkInterfacesAPI ResourceGroup string } @@ -75,12 +77,22 @@ func (c *AzureResourceGraphAPI) getResourceList(query string) []interface{} { return convertBytesToInterface(b) }) return resourceList + case instance.GetNICListQueryBuilder(c.ResourceGroup).String(): + nicList := lo.Filter(c.loadNicObjects(), func(nic armnetwork.Interface, _ int) bool { + return nic.Tags != nil && nic.Tags[instance.NodePoolTagKey] != nil + }) + resourceList := lo.Map(nicList, func(nic armnetwork.Interface, _ int) interface{} { + b, _ := json.Marshal(nic) + return convertBytesToInterface(b) + }) + return resourceList } return nil } -func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { - vmList := []armcompute.VirtualMachine{} + + +func (c *AzureResourceGraphAPI) loadVMObjects() (vmList []armcompute.VirtualMachine) { c.VirtualMachinesAPI.Instances.Range(func(k, v any) bool { vm, _ := c.VirtualMachinesAPI.Instances.Load(k) vmList = append(vmList, vm.(armcompute.VirtualMachine)) @@ -89,6 +101,16 @@ func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { return vmList } + +func (c *AzureResourceGraphAPI) loadNicObjects() (nicList []armnetwork.Interface) { + c.NetworkInterfacesAPI.NetworkInterfaces.Range(func(k, v any) bool { + nic, _ := c.NetworkInterfacesAPI.NetworkInterfaces.Load(k) + nicList = append(nicList, nic.(armnetwork.Interface)) + return true + }) + return nicList +} + func convertBytesToInterface(b []byte) interface{} { jsonObj := instance.Resource{} _ = json.Unmarshal(b, &jsonObj) diff --git a/pkg/fake/networkinterfaceapi.go b/pkg/fake/networkinterfaceapi.go index f1fed7163..1b4da8157 100644 --- a/pkg/fake/networkinterfaceapi.go +++ b/pkg/fake/networkinterfaceapi.go @@ -73,6 +73,7 @@ func (c *NetworkInterfacesAPI) BeginCreateOrUpdate(_ context.Context, resourceGr return c.NetworkInterfacesCreateOrUpdateBehavior.Invoke(input, func(input *NetworkInterfaceCreateOrUpdateInput) (*armnetwork.InterfacesClientCreateOrUpdateResponse, error) { iface := input.Interface + iface.Name = to.StringPtr(interfaceName) id := mkNetworkInterfaceID(input.ResourceGroupName, input.InterfaceName) iface.ID = to.StringPtr(id) c.NetworkInterfaces.Store(id, iface) diff --git a/pkg/providers/instance/suite_test.go b/pkg/providers/instance/suite_test.go index 923f24e17..48f95adc6 100644 --- a/pkg/providers/instance/suite_test.go +++ b/pkg/providers/instance/suite_test.go @@ -215,4 +215,13 @@ var _ = Describe("InstanceProvider", func() { return strings.Contains(key, "/") // ARM tags can't contain '/' })).To(HaveLen(0)) }) + It("should list nic from karpenter provisioning request", func(){ + ExpectApplied(ctx, env.Client, nodePool, nodeClass) + pod := coretest.UnschedulablePod(coretest.PodOptions{}) + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + ExpectScheduled(ctx, env.Client, pod) + interfaces, err := azureEnv.InstanceProvider.ListNics(ctx) + Expect(err).To(BeNil()) + Expect(len(interfaces)).To(Equal(1)) + }) }) diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 11d3faf0e..ad610a7f6 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -92,15 +92,21 @@ func NewRegionalEnvironment(ctx context.Context, env *coretest.Environment, regi // API virtualMachinesAPI := &fake.VirtualMachinesAPI{} - azureResourceGraphAPI := &fake.AzureResourceGraphAPI{AzureResourceGraphBehavior: fake.AzureResourceGraphBehavior{VirtualMachinesAPI: virtualMachinesAPI, ResourceGroup: resourceGroup}} - virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} + networkInterfacesAPI := &fake.NetworkInterfacesAPI{} + virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} pricingAPI := &fake.PricingAPI{} skuClientSingleton := &fake.MockSkuClientSingleton{SKUClient: &fake.ResourceSKUsAPI{Location: region}} communityImageVersionsAPI := &fake.CommunityGalleryImageVersionsAPI{} loadBalancersAPI := &fake.LoadBalancersAPI{} nodeImageVersionsAPI := &fake.NodeImageVersionsAPI{} + azureResourceGraphAPI := &fake.AzureResourceGraphAPI{ + AzureResourceGraphBehavior: fake.AzureResourceGraphBehavior{ + VirtualMachinesAPI: virtualMachinesAPI, + NetworkInterfacesAPI: networkInterfacesAPI, + ResourceGroup: resourceGroup, + }} // Cache kubernetesVersionCache := cache.New(azurecache.KubernetesVersionTTL, azurecache.DefaultCleanupInterval) instanceTypeCache := cache.New(instancetype.InstanceTypesCacheTTL, azurecache.DefaultCleanupInterval)