diff --git a/.github/workflows/approval-comment.yaml b/.github/workflows/approval-comment.yaml index 20c680785..f78d62b56 100644 --- a/.github/workflows/approval-comment.yaml +++ b/.github/workflows/approval-comment.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true disable-sudo: true @@ -30,7 +30,7 @@ jobs: mkdir -p /tmp/artifacts { echo ${{ github.event.pull_request.number }}; echo ${{ github.event.review.commit_id }}; } >> /tmp/artifacts/metadata.txt cat /tmp/artifacts/metadata.txt - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 + - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: artifacts path: /tmp/artifacts diff --git a/.github/workflows/build-publish-mcr.yml b/.github/workflows/build-publish-mcr.yml index 61d29a087..1aaf5f7ff 100644 --- a/.github/workflows/build-publish-mcr.yml +++ b/.github/workflows/build-publish-mcr.yml @@ -23,7 +23,7 @@ jobs: labels: [self-hosted, "1ES.Pool=${{ vars.RELEASE_1ES_POOL }}"] steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: egress-policy: audit diff --git a/.github/workflows/ci-test.yml b/.github/workflows/ci-test.yml index b63ae1b5f..b9b4336da 100644 --- a/.github/workflows/ci-test.yml +++ b/.github/workflows/ci-test.yml @@ -19,7 +19,7 @@ jobs: K8S_VERSION: ${{ matrix.k8sVersion }} steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true egress-policy: block diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 63326b666..755af2b50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true egress-policy: block diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 98dc425a3..991223df3 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true egress-policy: block @@ -46,8 +46,8 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: ./.github/actions/install-deps - run: make vulncheck - - uses: github/codeql-action/init@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 + - uses: github/codeql-action/init@b6a472f63d85b9c78a3ac5e89422239fc15e9b3c # v3.28.1 with: languages: ${{ matrix.language }} - - uses: github/codeql-action/autobuild@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 - - uses: github/codeql-action/analyze@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 + - uses: github/codeql-action/autobuild@b6a472f63d85b9c78a3ac5e89422239fc15e9b3c # v3.28.1 + - uses: github/codeql-action/analyze@b6a472f63d85b9c78a3ac5e89422239fc15e9b3c # v3.28.1 diff --git a/.github/workflows/deflake.yml b/.github/workflows/deflake.yml index f97928dba..fc2e09084 100644 --- a/.github/workflows/deflake.yml +++ b/.github/workflows/deflake.yml @@ -14,7 +14,7 @@ jobs: statuses: write steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true egress-policy: block diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index e97e8c207..6f49fc692 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true disable-sudo: true diff --git a/.github/workflows/e2e-matrix.yaml b/.github/workflows/e2e-matrix.yaml index a854583e9..fd43bb889 100644 --- a/.github/workflows/e2e-matrix.yaml +++ b/.github/workflows/e2e-matrix.yaml @@ -29,7 +29,7 @@ jobs: E2E_HASH: ${{ steps.generate-e2e-run-hash.outputs.E2E_HASH }} steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true disable-sudo: true diff --git a/.github/workflows/e2e.yaml b/.github/workflows/e2e.yaml index 0cfdac413..bbff09ad8 100644 --- a/.github/workflows/e2e.yaml +++ b/.github/workflows/e2e.yaml @@ -45,7 +45,7 @@ jobs: AZURE_SUBSCRIPTION_ID: ${{ secrets.E2E_SUBSCRIPTION_ID }} steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true egress-policy: block diff --git a/.github/workflows/release-trigger.yaml b/.github/workflows/release-trigger.yaml index 6b227b4f5..1ee240e36 100644 --- a/.github/workflows/release-trigger.yaml +++ b/.github/workflows/release-trigger.yaml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-telemetry: true disable-sudo: true diff --git a/.github/workflows/resolve-args.yaml b/.github/workflows/resolve-args.yaml index 8588f8e32..d992176d6 100644 --- a/.github/workflows/resolve-args.yaml +++ b/.github/workflows/resolve-args.yaml @@ -16,7 +16,7 @@ jobs: steps: # Download the artifact and resolve the GIT_REF - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-sudo: true disable-telemetry: true diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 9439f8681..0171bc9af 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Harden Runner - uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 + uses: step-security/harden-runner@c95a14d0e5bab51a9f56296a4eb0e416910cd350 # v2.10.3 with: disable-sudo: true disable-telemetry: true @@ -82,7 +82,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 + uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: SARIF file path: results.sarif @@ -90,6 +90,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 + uses: github/codeql-action/upload-sarif@b6a472f63d85b9c78a3ac5e89422239fc15e9b3c # v3.28.1 with: sarif_file: results.sarif diff --git a/README.md b/README.md index 68031d77a..5c9fd72d7 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,8 @@ Karpenter provider for AKS can be used in two modes: * **Self-hosted mode**: Karpenter is run as a standalone deployment in the cluster. This mode is useful for advanced users who want to customize or experiment with Karpenter's deployment. The rest of this page describes how to use Karpenter in self-hosted mode. ## Known limitations - -* Only AKS clusters with Azure CNI Overlay + Cilium networking are supported. * Only Linux nodes are supported. +* Kubenet and Calico are not supported ## Installation (self-hosted) diff --git a/charts/karpenter-crd/templates/karpenter.azure.com_aksnodeclasses.yaml b/charts/karpenter-crd/templates/karpenter.azure.com_aksnodeclasses.yaml index db6f5f00a..6311bd8bf 100644 --- a/charts/karpenter-crd/templates/karpenter.azure.com_aksnodeclasses.yaml +++ b/charts/karpenter-crd/templates/karpenter.azure.com_aksnodeclasses.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.0 + controller-gen.kubebuilder.io/version: v0.17.1 name: aksnodeclasses.karpenter.azure.com spec: group: karpenter.azure.com diff --git a/go.mod b/go.mod index b11722a98..5864da261 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/Azure/karpenter-provider-azure -go 1.23.0 +go 1.23.5 require ( github.com/Azure/azure-kusto-go v0.16.1 diff --git a/pkg/apis/crds/karpenter.azure.com_aksnodeclasses.yaml b/pkg/apis/crds/karpenter.azure.com_aksnodeclasses.yaml index db6f5f00a..6311bd8bf 100644 --- a/pkg/apis/crds/karpenter.azure.com_aksnodeclasses.yaml +++ b/pkg/apis/crds/karpenter.azure.com_aksnodeclasses.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.0 + controller-gen.kubebuilder.io/version: v0.17.1 name: aksnodeclasses.karpenter.azure.com spec: group: karpenter.azure.com diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index b9c9764b3..fa5992bd8 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -138,6 +138,7 @@ func (c *CloudProvider) List(ctx context.Context) ([]*karpv1.NodeClaim, error) { if err != nil { return nil, fmt.Errorf("listing instances, %w", err) } + var nodeClaims []*karpv1.NodeClaim for _, instance := range instances { instanceType, err := c.resolveInstanceTypeFromInstance(ctx, instance) @@ -328,18 +329,15 @@ func (c *CloudProvider) instanceToNodeClaim(ctx context.Context, vm *armcompute. nodeClaim.Status.Allocatable = lo.PickBy(instanceType.Allocatable(), func(_ v1.ResourceName, v resource.Quantity) bool { return !resources.IsZero(v) }) } - // TODO: review logic for determining zone (AWS uses Zone from subnet resolved and aviailable from NodeClass conditions ...) - if zoneID, err := instance.GetZoneID(vm); err != nil { + if zone, err := utils.GetZone(vm); err != nil { logging.FromContext(ctx).Warnf("Failed to get zone for VM %s, %v", *vm.Name, err) } else { - zone := makeZone(*vm.Location, zoneID) // aks-node-validating-webhook protects v1.LabelTopologyZone, will be set elsewhere, so we use a different label labels[v1alpha2.AlternativeLabelTopologyZone] = zone } labels[karpv1.CapacityTypeLabelKey] = instance.GetCapacityType(vm) - // TODO: v1beta1 new kes/labels if tag, ok := vm.Tags[instance.NodePoolTagKey]; ok { labels[karpv1.NodePoolLabelKey] = *tag } @@ -369,14 +367,6 @@ func GenerateNodeClaimName(vmName string) string { return strings.TrimLeft("aks-", vmName) } -// makeZone returns the zone value in format of -. -func makeZone(location string, zoneID string) string { - if zoneID == "" { - return "" - } - return fmt.Sprintf("%s-%s", strings.ToLower(location), zoneID) -} - // newTerminatingNodeClassError returns a NotFound error for handling by func newTerminatingNodeClassError(name string) *errors.StatusError { qualifiedResource := schema.GroupResource{Group: apis.Group, Resource: "aksnodeclasses"} diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 0a9c86c0f..146502887 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -144,7 +144,7 @@ var _ = Describe("CloudProvider", func() { nodeClaims, _ := cloudProvider.List(ctx) Expect(azureEnv.AzureResourceGraphAPI.AzureResourceGraphResourcesBehavior.CalledWithInput.Len()).To(Equal(1)) queryRequest := azureEnv.AzureResourceGraphAPI.AzureResourceGraphResourcesBehavior.CalledWithInput.Pop().Query - Expect(*queryRequest.Query).To(Equal(instance.GetListQueryBuilder(azureEnv.AzureResourceGraphAPI.ResourceGroup).String())) + Expect(*queryRequest.Query).To(Equal(instance.GetVMListQueryBuilder(azureEnv.AzureResourceGraphAPI.ResourceGroup).String())) Expect(nodeClaims).To(HaveLen(1)) Expect(nodeClaims[0]).ToNot(BeNil()) resp, _ := azureEnv.VirtualMachinesAPI.Get(ctx, azureEnv.AzureResourceGraphAPI.ResourceGroup, nodeClaims[0].Name, nil) diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index e96be8104..793c1e93f 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -44,7 +44,10 @@ func NewControllers(ctx context.Context, mgr manager.Manager, kubeClient client. nodeclasshash.NewController(kubeClient), nodeclassstatus.NewController(kubeClient), nodeclasstermination.NewController(kubeClient, recorder), - nodeclaimgarbagecollection.NewController(kubeClient, cloudProvider), + + nodeclaimgarbagecollection.NewVirtualMachine(kubeClient, cloudProvider), + nodeclaimgarbagecollection.NewNetworkInterface(kubeClient, instanceProvider), + // TODO: nodeclaim tagging inplaceupdate.NewController(kubeClient, instanceProvider), status.NewController[*v1alpha2.AKSNodeClass](kubeClient, mgr.GetEventRecorderFor("karpenter")), diff --git a/pkg/controllers/nodeclaim/garbagecollection/controller.go b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go similarity index 87% rename from pkg/controllers/nodeclaim/garbagecollection/controller.go rename to pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go index 033dc31f3..f86fc9ada 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/controller.go +++ b/pkg/controllers/nodeclaim/garbagecollection/instance_garbagecollection.go @@ -23,7 +23,6 @@ import ( "github.com/awslabs/operatorpkg/singleton" - // "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" "github.com/samber/lo" "go.uber.org/multierr" v1 "k8s.io/api/core/v1" @@ -41,21 +40,21 @@ import ( corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider" ) -type Controller struct { +type VirtualMachine struct { kubeClient client.Client cloudProvider corecloudprovider.CloudProvider - successfulCount uint64 // keeps track of successful reconciles for more aggressive requeueing near the start of the controller + successfulCount uint64 // keeps track of successful reconciles for more aggressive requeuing near the start of the controller } -func NewController(kubeClient client.Client, cloudProvider corecloudprovider.CloudProvider) *Controller { - return &Controller{ +func NewVirtualMachine(kubeClient client.Client, cloudProvider corecloudprovider.CloudProvider) *VirtualMachine { + return &VirtualMachine{ kubeClient: kubeClient, cloudProvider: cloudProvider, successfulCount: 0, } } -func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { +func (c *VirtualMachine) Reconcile(ctx context.Context) (reconcile.Result, error) { ctx = injection.WithControllerName(ctx, "instance.garbagecollection") // We LIST VMs on the CloudProvider BEFORE we grab NodeClaims/Nodes on the cluster so that we make sure that, if @@ -65,6 +64,7 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { if err != nil { return reconcile.Result{}, fmt.Errorf("listing cloudprovider VMs, %w", err) } + managedRetrieved := lo.Filter(retrieved, func(nc *karpv1.NodeClaim, _ int) bool { return nc.DeletionTimestamp.IsZero() }) @@ -93,7 +93,7 @@ func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { return reconcile.Result{RequeueAfter: lo.Ternary(c.successfulCount <= 20, time.Second*10, time.Minute*2)}, nil } -func (c *Controller) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeClaim, nodeList *v1.NodeList) error { +func (c *VirtualMachine) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeClaim, nodeList *v1.NodeList) error { ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("provider-id", nodeClaim.Status.ProviderID)) if err := c.cloudProvider.Delete(ctx, nodeClaim); err != nil { return corecloudprovider.IgnoreNodeClaimNotFoundError(err) @@ -112,7 +112,7 @@ func (c *Controller) garbageCollect(ctx context.Context, nodeClaim *karpv1.NodeC return nil } -func (c *Controller) Register(_ context.Context, m manager.Manager) error { +func (c *VirtualMachine) Register(_ context.Context, m manager.Manager) error { return controllerruntime.NewControllerManagedBy(m). Named("instance.garbagecollection"). WatchesRawSource(singleton.Source()). diff --git a/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go b/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go new file mode 100644 index 000000000..7571e79d8 --- /dev/null +++ b/pkg/controllers/nodeclaim/garbagecollection/nic_garbagecollection.go @@ -0,0 +1,112 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package garbagecollection + +import ( + "context" + "fmt" + "time" + + "github.com/samber/lo" + "knative.dev/pkg/logging" + + "github.com/awslabs/operatorpkg/singleton" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/util/workqueue" + controllerruntime "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" + "sigs.k8s.io/karpenter/pkg/operator/injection" + + "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" +) + +const ( + NicReservationDuration = time.Second * 180 + // We set this interval at 5 minutes, as thats how often our NRP limits are reset. + // See: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/request-limits-and-throttling#network-throttling + NicGarbageCollectionInterval = time.Minute * 5 +) + +type NetworkInterface struct { + kubeClient client.Client + instanceProvider instance.Provider +} + +func NewNetworkInterface(kubeClient client.Client, instanceProvider instance.Provider) *NetworkInterface { + return &NetworkInterface{ + kubeClient: kubeClient, + instanceProvider: instanceProvider, + } +} + +func (c *NetworkInterface) populateUnremovableInterfaces(ctx context.Context) (sets.Set[string], error) { + unremovableInterfaces := sets.New[string]() + vms, err := c.instanceProvider.List(ctx) + if err != nil { + return unremovableInterfaces, fmt.Errorf("listing VMs: %w", err) + } + for _, vm := range vms { + unremovableInterfaces.Insert(lo.FromPtr(vm.Name)) + } + nodeClaimList := &karpv1.NodeClaimList{} + if err := c.kubeClient.List(ctx, nodeClaimList); err != nil { + return unremovableInterfaces, fmt.Errorf("listing NodeClaims for NIC GC: %w", err) + } + + for _, nodeClaim := range nodeClaimList.Items { + unremovableInterfaces.Insert(instance.GenerateResourceName(nodeClaim.Name)) + } + return unremovableInterfaces, nil +} + +func (c *NetworkInterface) Reconcile(ctx context.Context) (reconcile.Result, error) { + ctx = injection.WithControllerName(ctx, "networkinterface.garbagecollection") + nics, err := c.instanceProvider.ListNics(ctx) + if err != nil { + return reconcile.Result{}, fmt.Errorf("listing NICs: %w", err) + } + + unremovableInterfaces, err := c.populateUnremovableInterfaces(ctx) + if err != nil { + return reconcile.Result{}, fmt.Errorf("error listing resources needed to populate unremovable nics %w", err) + } + workqueue.ParallelizeUntil(ctx, 100, len(nics), func(i int) { + nicName := lo.FromPtr(nics[i].Name) + if !unremovableInterfaces.Has(nicName) { + err := c.instanceProvider.DeleteNic(ctx, nicName) + if err != nil { + logging.FromContext(ctx).Error(err) + return + } + + logging.FromContext(ctx).With("nic", nicName).Infof("garbage collected NIC") + } + }) + return reconcile.Result{ + RequeueAfter: NicGarbageCollectionInterval, + }, nil +} + +func (c *NetworkInterface) Register(_ context.Context, m manager.Manager) error { + return controllerruntime.NewControllerManagedBy(m). + Named("networkinterface.garbagecollection"). + WatchesRawSource(singleton.Source()). + Complete(singleton.AsReconciler(c)) +} diff --git a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go index 5e9a7edee..3197e5c94 100644 --- a/pkg/controllers/nodeclaim/garbagecollection/suite_test.go +++ b/pkg/controllers/nodeclaim/garbagecollection/suite_test.go @@ -32,9 +32,9 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" "github.com/Azure/karpenter-provider-azure/pkg/controllers/nodeclaim/garbagecollection" - "github.com/Azure/karpenter-provider-azure/pkg/fake" "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" + . "github.com/Azure/karpenter-provider-azure/pkg/test/expectations" "github.com/Azure/karpenter-provider-azure/pkg/utils" . "github.com/onsi/ginkgo/v2" @@ -64,7 +64,8 @@ var nodePool *karpv1.NodePool var nodeClass *v1alpha2.AKSNodeClass var cluster *state.Cluster var cloudProvider *cloudprovider.CloudProvider -var garbageCollectionController *garbagecollection.Controller +var virtualMachineGCController *garbagecollection.VirtualMachine +var networkInterfaceGCController *garbagecollection.NetworkInterface var prov *provisioning.Provisioner func TestAPIs(t *testing.T) { @@ -80,7 +81,8 @@ var _ = BeforeSuite(func() { // ctx, stop = context.WithCancel(ctx) azureEnv = test.NewEnvironment(ctx, env) cloudProvider = cloudprovider.New(azureEnv.InstanceTypesProvider, azureEnv.InstanceProvider, events.NewRecorder(&record.FakeRecorder{}), env.Client, azureEnv.ImageProvider) - garbageCollectionController = garbagecollection.NewController(env.Client, cloudProvider) + virtualMachineGCController = garbagecollection.NewVirtualMachine(env.Client, cloudProvider) + networkInterfaceGCController = garbagecollection.NewNetworkInterface(env.Client, azureEnv.InstanceProvider) fakeClock = &clock.FakeClock{} cluster = state.NewCluster(fakeClock, env.Client) prov = provisioning.NewProvisioner(env.Client, events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster, fakeClock) @@ -119,7 +121,7 @@ var _ = AfterEach(func() { // TODO: move before/after each into the tests (see AWS) // review tests themselves (very different from AWS?) // (e.g. AWS has not a single ExpectPRovisioned? why?) -var _ = Describe("GarbageCollection", func() { +var _ = Describe("VirtualMachine Garbage Collection", func() { var vm *armcompute.VirtualMachine var providerID string var err error @@ -147,7 +149,7 @@ var _ = Describe("GarbageCollection", func() { }) azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -164,23 +166,18 @@ var _ = Describe("GarbageCollection", func() { vm, err = azureEnv.InstanceProvider.Get(ctx, vmName) Expect(err).To(BeNil()) providerID = utils.ResourceIDToProviderID(ctx, *vm.ID) - azureEnv.VirtualMachinesAPI.Instances.Store( - *vm.ID, - armcompute.VirtualMachine{ - ID: vm.ID, - Name: vm.Name, - Location: lo.ToPtr(fake.Region), - Properties: &armcompute.VirtualMachineProperties{ - TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), - }, - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - }) + newVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: vmName, + NodepoolName: "default", + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), + }, + }) + azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(newVM.ID), newVM) ids = append(ids, *vm.ID) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -210,19 +207,14 @@ var _ = Describe("GarbageCollection", func() { vm, err = azureEnv.InstanceProvider.Get(ctx, vmName) Expect(err).To(BeNil()) providerID = utils.ResourceIDToProviderID(ctx, *vm.ID) - azureEnv.VirtualMachinesAPI.Instances.Store( - *vm.ID, - armcompute.VirtualMachine{ - ID: vm.ID, - Name: vm.Name, - Location: lo.ToPtr(fake.Region), - Properties: &armcompute.VirtualMachineProperties{ - TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), - }, - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - }) + newVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: vmName, + NodepoolName: "default", + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 10)), + }, + }) + azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(newVM.ID), newVM) nodeClaim := coretest.NodeClaim(karpv1.NodeClaim{ Status: karpv1.NodeClaimStatus{ ProviderID: utils.ResourceIDToProviderID(ctx, *vm.ID), @@ -233,7 +225,7 @@ var _ = Describe("GarbageCollection", func() { nodeClaims = append(nodeClaims, nodeClaim) } } - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) wg := sync.WaitGroup{} for _, id := range ids { @@ -259,7 +251,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).NotTo(HaveOccurred()) }) @@ -280,7 +272,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, nodeClaim, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err := cloudProvider.Get(ctx, providerID) Expect(err).ToNot(HaveOccurred()) ExpectExists(ctx, env.Client, node) @@ -289,16 +281,8 @@ var _ = Describe("GarbageCollection", func() { var _ = Context("Basic", func() { BeforeEach(func() { - id := utils.MkVMID(azureEnv.AzureResourceGraphAPI.ResourceGroup, "vm-a") - vm = &armcompute.VirtualMachine{ - ID: lo.ToPtr(id), - Name: lo.ToPtr("vm-a"), - Location: lo.ToPtr(fake.Region), - Tags: map[string]*string{ - instance.NodePoolTagKey: lo.ToPtr("default"), - }, - } - providerID = utils.ResourceIDToProviderID(ctx, id) + vm = test.VirtualMachine(test.VirtualMachineOptions{Name: "vm-a", NodepoolName: "default"}) + providerID = utils.ResourceIDToProviderID(ctx, lo.FromPtr(vm.ID)) }) It("should delete an instance if there is no NodeClaim owner", func() { // Launch happened 10m ago @@ -307,7 +291,7 @@ var _ = Describe("GarbageCollection", func() { } azureEnv.VirtualMachinesAPI.Instances.Store(lo.FromPtr(vm.ID), *vm) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -323,7 +307,7 @@ var _ = Describe("GarbageCollection", func() { }) ExpectApplied(ctx, env.Client, node) - ExpectSingletonReconciled(ctx, garbageCollectionController) + ExpectSingletonReconciled(ctx, virtualMachineGCController) _, err = cloudProvider.Get(ctx, providerID) Expect(err).To(HaveOccurred()) Expect(corecloudprovider.IsNodeClaimNotFoundError(err)).To(BeTrue()) @@ -332,3 +316,77 @@ var _ = Describe("GarbageCollection", func() { }) }) }) + +var _ = Describe("NetworkInterface Garbage Collection", func() { + It("should not delete a network interface if a nodeclaim exists for it", func() { + // Create and apply a NodeClaim that references this NIC + nodeClaim := coretest.NodeClaim() + ExpectApplied(ctx, env.Client, nodeClaim) + + // Create a managed NIC + nic := test.Interface(test.InterfaceOptions{Name: instance.GenerateResourceName(nodeClaim.Name), NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic.ID), *nic) + + nicsBeforeGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsBeforeGC)).To(Equal(1)) + + // Run garbage collection + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + + // Verify NIC still exists after GC + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + }) + It("should delete a NIC if there is no associated VM", func() { + nic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + nic2 := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic.ID), *nic) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(nic2.ID), *nic2) + nicsBeforeGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsBeforeGC)).To(Equal(2)) + // add a nic to azure env, and call reconcile. It should show up in the list before reconcile + // then it should not showup after + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(0)) + }) + It("should not delete a NIC if there is an associated VM", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + managedVM := test.VirtualMachine(test.VirtualMachineOptions{Name: lo.FromPtr(managedNic.Name), NodepoolName: nodePool.Name}) + azureEnv.VirtualMachinesAPI.VirtualMachinesBehavior.Instances.Store(lo.FromPtr(managedVM.ID), *managedVM) + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + // We should still have a network interface here + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + + }) + It("the vm gc controller should remove the nic if there is an associated vm", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + managedVM := test.VirtualMachine(test.VirtualMachineOptions{ + Name: lo.FromPtr(managedNic.Name), + NodepoolName: nodePool.Name, + Properties: &armcompute.VirtualMachineProperties{ + TimeCreated: lo.ToPtr(time.Now().Add(-time.Minute * 16)), // Needs to be older than the nodeclaim registration ttl + }, + }) + azureEnv.VirtualMachinesAPI.VirtualMachinesBehavior.Instances.Store(lo.FromPtr(managedVM.ID), *managedVM) + ExpectSingletonReconciled(ctx, networkInterfaceGCController) + // We should still have a network interface here + nicsAfterGC, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterGC)).To(Equal(1)) + + ExpectSingletonReconciled(ctx, virtualMachineGCController) + nicsAfterVMReconciliation, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(nicsAfterVMReconciliation)).To(Equal(0)) + + }) +}) diff --git a/pkg/fake/azureresourcegraphapi.go b/pkg/fake/azureresourcegraphapi.go index fe160ae62..9eea36287 100644 --- a/pkg/fake/azureresourcegraphapi.go +++ b/pkg/fake/azureresourcegraphapi.go @@ -23,6 +23,7 @@ import ( "github.com/samber/lo" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" ) @@ -35,6 +36,7 @@ type AzureResourceGraphResourcesInput struct { type AzureResourceGraphBehavior struct { AzureResourceGraphResourcesBehavior MockedFunction[AzureResourceGraphResourcesInput, armresourcegraph.ClientResourcesResponse] VirtualMachinesAPI *VirtualMachinesAPI + NetworkInterfacesAPI *NetworkInterfacesAPI ResourceGroup string } @@ -42,9 +44,23 @@ type AzureResourceGraphBehavior struct { var _ instance.AzureResourceGraphAPI = &AzureResourceGraphAPI{} type AzureResourceGraphAPI struct { + vmListQuery string + nicListQuery string AzureResourceGraphBehavior } +func NewAzureResourceGraphAPI(resourceGroup string, virtualMachinesAPI *VirtualMachinesAPI, networkInterfacesAPI *NetworkInterfacesAPI) *AzureResourceGraphAPI { + return &AzureResourceGraphAPI{ + vmListQuery: instance.GetVMListQueryBuilder(resourceGroup).String(), + nicListQuery: instance.GetNICListQueryBuilder(resourceGroup).String(), + AzureResourceGraphBehavior: AzureResourceGraphBehavior{ + VirtualMachinesAPI: virtualMachinesAPI, + NetworkInterfacesAPI: networkInterfacesAPI, + ResourceGroup: resourceGroup, + }, + } +} + // Reset must be called between tests otherwise tests will pollute each other. func (c *AzureResourceGraphAPI) Reset() {} @@ -66,7 +82,7 @@ func (c *AzureResourceGraphAPI) Resources(_ context.Context, query armresourcegr func (c *AzureResourceGraphAPI) getResourceList(query string) []interface{} { switch query { - case instance.GetListQueryBuilder(c.ResourceGroup).String(): + case c.vmListQuery: vmList := lo.Filter(c.loadVMObjects(), func(vm armcompute.VirtualMachine, _ int) bool { return vm.Tags != nil && vm.Tags[instance.NodePoolTagKey] != nil }) @@ -75,12 +91,20 @@ func (c *AzureResourceGraphAPI) getResourceList(query string) []interface{} { return convertBytesToInterface(b) }) return resourceList + case c.nicListQuery: + nicList := lo.Filter(c.loadNicObjects(), func(nic armnetwork.Interface, _ int) bool { + return nic.Tags != nil && nic.Tags[instance.NodePoolTagKey] != nil + }) + resourceList := lo.Map(nicList, func(nic armnetwork.Interface, _ int) interface{} { + b, _ := json.Marshal(nic) + return convertBytesToInterface(b) + }) + return resourceList } return nil } -func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { - vmList := []armcompute.VirtualMachine{} +func (c *AzureResourceGraphAPI) loadVMObjects() (vmList []armcompute.VirtualMachine) { c.VirtualMachinesAPI.Instances.Range(func(k, v any) bool { vm, _ := c.VirtualMachinesAPI.Instances.Load(k) vmList = append(vmList, vm.(armcompute.VirtualMachine)) @@ -89,6 +113,15 @@ func (c *AzureResourceGraphAPI) loadVMObjects() []armcompute.VirtualMachine { return vmList } +func (c *AzureResourceGraphAPI) loadNicObjects() (nicList []armnetwork.Interface) { + c.NetworkInterfacesAPI.NetworkInterfaces.Range(func(k, v any) bool { + nic, _ := c.NetworkInterfacesAPI.NetworkInterfaces.Load(k) + nicList = append(nicList, nic.(armnetwork.Interface)) + return true + }) + return nicList +} + func convertBytesToInterface(b []byte) interface{} { jsonObj := instance.Resource{} _ = json.Unmarshal(b, &jsonObj) diff --git a/pkg/fake/azureresourcegraphapi_test.go b/pkg/fake/azureresourcegraphapi_test.go index f23bc2796..f5bc7b1ef 100644 --- a/pkg/fake/azureresourcegraphapi_test.go +++ b/pkg/fake/azureresourcegraphapi_test.go @@ -32,7 +32,7 @@ func TestAzureResourceGraphAPI_Resources_VM(t *testing.T) { resourceGroup := "test_managed_cluster_rg" subscriptionID := "test_sub" virtualMachinesAPI := &VirtualMachinesAPI{} - azureResourceGraphAPI := &AzureResourceGraphAPI{AzureResourceGraphBehavior{VirtualMachinesAPI: virtualMachinesAPI, ResourceGroup: resourceGroup}} + azureResourceGraphAPI := NewAzureResourceGraphAPI(resourceGroup, virtualMachinesAPI, nil) cases := []struct { testName string vmNames []string @@ -67,7 +67,7 @@ func TestAzureResourceGraphAPI_Resources_VM(t *testing.T) { return } } - queryRequest := instance.NewQueryRequest(&subscriptionID, instance.GetListQueryBuilder(resourceGroup).String()) + queryRequest := instance.NewQueryRequest(&subscriptionID, instance.GetVMListQueryBuilder(resourceGroup).String()) data, err := instance.GetResourceData(context.Background(), azureResourceGraphAPI, *queryRequest) if err != nil { t.Errorf("Unexpected error %v", err) diff --git a/pkg/fake/networkinterfaceapi.go b/pkg/fake/networkinterfaceapi.go index f1fed7163..96404fdff 100644 --- a/pkg/fake/networkinterfaceapi.go +++ b/pkg/fake/networkinterfaceapi.go @@ -73,6 +73,7 @@ func (c *NetworkInterfacesAPI) BeginCreateOrUpdate(_ context.Context, resourceGr return c.NetworkInterfacesCreateOrUpdateBehavior.Invoke(input, func(input *NetworkInterfaceCreateOrUpdateInput) (*armnetwork.InterfacesClientCreateOrUpdateResponse, error) { iface := input.Interface + iface.Name = to.StringPtr(input.InterfaceName) id := mkNetworkInterfaceID(input.ResourceGroupName, input.InterfaceName) iface.ID = to.StringPtr(id) c.NetworkInterfaces.Store(id, iface) @@ -99,7 +100,7 @@ func (c *NetworkInterfacesAPI) BeginDelete(_ context.Context, resourceGroupName InterfaceName: interfaceName, } return c.NetworkInterfacesDeleteBehavior.Invoke(input, func(input *NetworkInterfaceDeleteInput) (*armnetwork.InterfacesClientDeleteResponse, error) { - id := mkNetworkInterfaceID(resourceGroupName, interfaceName) + id := mkNetworkInterfaceID(input.ResourceGroupName, input.InterfaceName) c.NetworkInterfaces.Delete(id) return &armnetwork.InterfacesClientDeleteResponse{}, nil }) diff --git a/pkg/fake/types.go b/pkg/fake/types.go index a76b22cf2..b2f427ce5 100644 --- a/pkg/fake/types.go +++ b/pkg/fake/types.go @@ -45,13 +45,12 @@ func (m *MockedFunction[I, O]) Reset() { } func (m *MockedFunction[I, O]) Invoke(input *I, defaultTransformer func(*I) (O, error)) (O, error) { + m.CalledWithInput.Add(input) err := m.Error.Get() if err != nil { m.failedCalls.Add(1) return *new(O), err } - m.CalledWithInput.Add(input) - if !m.Output.IsNil() { m.successfulCalls.Add(1) return *m.Output.Clone(), nil @@ -94,6 +93,8 @@ func (m *MockedLRO[I, O]) Reset() { } func (m *MockedLRO[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, error)) (*runtime.Poller[O], error) { + m.CalledWithInput.Add(input) + if err := m.BeginError.Get(); err != nil { m.failedCalls.Add(1) return nil, err @@ -103,8 +104,6 @@ func (m *MockedLRO[I, O]) Invoke(input *I, defaultTransformer func(*I) (*O, erro return newMockPoller[O](nil, err) } - m.CalledWithInput.Add(input) - if !m.Output.IsNil() { m.successfulCalls.Add(1) return newMockPoller(m.Output.Clone(), nil) diff --git a/pkg/providers/instance/azureresourcegraphlist.go b/pkg/providers/instance/azureresourcegraphlist.go new file mode 100644 index 000000000..fc41fd5f0 --- /dev/null +++ b/pkg/providers/instance/azureresourcegraphlist.go @@ -0,0 +1,108 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package instance + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/Azure/azure-kusto-go/kusto/kql" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/samber/lo" +) + +const ( + vmResourceType = "microsoft.compute/virtualmachines" + nicResourceType = "microsoft.network/networkinterfaces" +) + +// getResourceListQueryBuilder returns a KQL query builder for listing resources with nodepool tags +func getResourceListQueryBuilder(rg string, resourceType string) *kql.Builder { + return kql.New(`Resources`). + AddLiteral(` | where type == `).AddString(resourceType). + AddLiteral(` | where resourceGroup == `).AddString(strings.ToLower(rg)). // ARG resources appear to have lowercase RG + AddLiteral(` | where tags has_cs `).AddString(NodePoolTagKey) +} + +// GetVMListQueryBuilder returns a KQL query builder for listing VMs with nodepool tags +func GetVMListQueryBuilder(rg string) *kql.Builder { + return getResourceListQueryBuilder(rg, vmResourceType) +} + +// GetNICListQueryBuilder returns a KQL query builder for listing NICs with nodepool tags +func GetNICListQueryBuilder(rg string) *kql.Builder { + return getResourceListQueryBuilder(rg, nicResourceType) +} + +// createVMFromQueryResponseData converts ARG query response data into a VirtualMachine object +func createVMFromQueryResponseData(data map[string]interface{}) (*armcompute.VirtualMachine, error) { + jsonString, err := json.Marshal(data) + if err != nil { + return nil, err + } + vm := armcompute.VirtualMachine{} + err = json.Unmarshal(jsonString, &vm) + if err != nil { + return nil, err + } + if vm.ID == nil { + return nil, fmt.Errorf("virtual machine is missing id") + } + if vm.Name == nil { + return nil, fmt.Errorf("virtual machine is missing name") + } + if vm.Tags == nil { + return nil, fmt.Errorf("virtual machine is missing tags") + } + // We see inconsistent casing being returned by ARG for the last segment + // of the vm.ID string. This forces it to be lowercase. + parts := strings.Split(lo.FromPtr(vm.ID), "/") + parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) + vm.ID = lo.ToPtr(strings.Join(parts, "/")) + return &vm, nil +} + +// createNICFromQueryResponseData converts ARG query response data into a Network Interface object +func createNICFromQueryResponseData(data map[string]interface{}) (*armnetwork.Interface, error) { + jsonString, err := json.Marshal(data) + if err != nil { + return nil, err + } + + nic := armnetwork.Interface{} + err = json.Unmarshal(jsonString, &nic) + if err != nil { + return nil, err + } + if nic.ID == nil { + return nil, fmt.Errorf("network interface is missing id") + } + if nic.Name == nil { + return nil, fmt.Errorf("network interface is missing name") + } + if nic.Tags == nil { + return nil, fmt.Errorf("network interface is missing tags") + } + // We see inconsistent casing being returned by ARG for the last segment + // of the nic.ID string. This forces it to be lowercase. + parts := strings.Split(lo.FromPtr(nic.ID), "/") + parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) + nic.ID = lo.ToPtr(strings.Join(parts, "/")) + return &nic, nil +} diff --git a/pkg/providers/instance/argutils.go b/pkg/providers/instance/azureresourcegraphutils.go similarity index 100% rename from pkg/providers/instance/argutils.go rename to pkg/providers/instance/azureresourcegraphutils.go diff --git a/pkg/providers/instance/armutils.go b/pkg/providers/instance/azureresourcemanagerutils.go similarity index 100% rename from pkg/providers/instance/armutils.go rename to pkg/providers/instance/azureresourcemanagerutils.go diff --git a/pkg/providers/instance/instance.go b/pkg/providers/instance/instance.go index 10771f0d6..c2a543985 100644 --- a/pkg/providers/instance/instance.go +++ b/pkg/providers/instance/instance.go @@ -18,7 +18,6 @@ package instance import ( "context" - "encoding/json" "errors" "fmt" "math" @@ -32,11 +31,11 @@ import ( "k8s.io/apimachinery/pkg/util/sets" "knative.dev/pkg/logging" - "github.com/Azure/azure-kusto-go/kusto/kql" "github.com/Azure/karpenter-provider-azure/pkg/cache" "github.com/Azure/karpenter-provider-azure/pkg/providers/instancetype" "github.com/Azure/karpenter-provider-azure/pkg/providers/launchtemplate" "github.com/Azure/karpenter-provider-azure/pkg/providers/loadbalancer" + "github.com/Azure/karpenter-provider-azure/pkg/utils" corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/scheduling" @@ -54,9 +53,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" ) +var ( + vmListQuery string + nicListQuery string +) + var ( NodePoolTagKey = strings.ReplaceAll(karpv1.NodePoolLabelKey, "/", "_") - listQuery string CapacityTypeToPriority = map[string]string{ karpv1.CapacityTypeSpot: string(compute.Spot), @@ -86,6 +89,8 @@ type Provider interface { // CreateTags(context.Context, string, map[string]string) error Update(context.Context, string, armcompute.VirtualMachineUpdate) error GetNic(context.Context, string, string) (*armnetwork.Interface, error) + DeleteNic(context.Context, string) error + ListNics(context.Context) ([]*armnetwork.Interface, error) } // assert that DefaultProvider implements Provider interface @@ -114,7 +119,8 @@ func NewDefaultProvider( subscriptionID string, provisionMode string, ) *DefaultProvider { - listQuery = GetListQueryBuilder(resourceGroup).String() + vmListQuery = GetVMListQueryBuilder(resourceGroup).String() + nicListQuery = GetNICListQueryBuilder(resourceGroup).String() return &DefaultProvider{ azClient: azClient, instanceTypeProvider: instanceTypeProvider, @@ -140,7 +146,7 @@ func (p *DefaultProvider) Create(ctx context.Context, nodeClass *v1alpha2.AKSNod } return nil, err } - zone, err := GetZoneID(vm) + zone, err := utils.GetZone(vm) if err != nil { logging.FromContext(ctx).Error(err) } @@ -174,7 +180,7 @@ func (p *DefaultProvider) Get(ctx context.Context, vmName string) (*armcompute.V } func (p *DefaultProvider) List(ctx context.Context) ([]*armcompute.VirtualMachine, error) { - req := NewQueryRequest(&(p.subscriptionID), listQuery) + req := NewQueryRequest(&(p.subscriptionID), vmListQuery) client := p.azClient.azureResourceGraphClient data, err := GetResourceData(ctx, client, *req) if err != nil { @@ -196,6 +202,37 @@ func (p *DefaultProvider) Delete(ctx context.Context, resourceName string) error return p.cleanupAzureResources(ctx, resourceName) } +func (p *DefaultProvider) GetNic(ctx context.Context, rg, nicName string) (*armnetwork.Interface, error) { + nicResponse, err := p.azClient.networkInterfacesClient.Get(ctx, rg, nicName, nil) + if err != nil { + return nil, err + } + return &nicResponse.Interface, nil +} + +// ListNics returns all network interfaces in the resource group that have the nodepool tag +func (p *DefaultProvider) ListNics(ctx context.Context) ([]*armnetwork.Interface, error) { + req := NewQueryRequest(&(p.subscriptionID), nicListQuery) + client := p.azClient.azureResourceGraphClient + data, err := GetResourceData(ctx, client, *req) + if err != nil { + return nil, fmt.Errorf("querying azure resource graph, %w", err) + } + var nicList []*armnetwork.Interface + for i := range data { + nic, err := createNICFromQueryResponseData(data[i]) + if err != nil { + return nil, fmt.Errorf("creating NIC object from query response data, %w", err) + } + nicList = append(nicList, nic) + } + return nicList, nil +} + +func (p *DefaultProvider) DeleteNic(ctx context.Context, nicName string) error { + return deleteNicIfExists(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, nicName) +} + // createAKSIdentifyingExtension attaches a VM extension to identify that this VM participates in an AKS cluster func (p *DefaultProvider) createAKSIdentifyingExtension(ctx context.Context, vmName string) (err error) { vmExt := p.getAKSIdentifyingExtension() @@ -301,14 +338,6 @@ func (p *DefaultProvider) createNetworkInterface(ctx context.Context, opts *crea return *res.ID, nil } -func (p *DefaultProvider) GetNic(ctx context.Context, rg, nicName string) (*armnetwork.Interface, error) { - nicResponse, err := p.azClient.networkInterfacesClient.Get(ctx, rg, nicName, nil) - if err != nil { - return nil, err - } - return &nicResponse.Interface, nil -} - // newVMObject is a helper func that creates a new armcompute.VirtualMachine // from key input. func newVMObject( @@ -375,7 +404,7 @@ func newVMObject( CapacityTypeToPriority[capacityType]), ), }, - Zones: lo.Ternary(len(zone) > 0, []*string{&zone}, []*string{}), + Zones: utils.MakeVMZone(zone), Tags: launchTemplate.Tags, } setVMPropertiesOSDiskType(vm.Properties, launchTemplate.StorageProfile) @@ -628,11 +657,6 @@ func (p *DefaultProvider) pickSkuSizePriorityAndZone(ctx context.Context, nodeCl }) zonesWithPriority := lo.Map(priorityOfferings, func(o corecloudprovider.Offering, _ int) string { return getOfferingZone(o) }) if zone, ok := sets.New(zonesWithPriority...).PopAny(); ok { - if len(zone) > 0 { - // Zones in zonal Offerings have - format; the zone returned from here will be used for VM instantiation, - // which expects just the zone number, without region - zone = string(zone[len(zone)-1]) - } return instanceType, priority, zone } return nil, "", "" @@ -646,11 +670,11 @@ func (p *DefaultProvider) cleanupAzureResources(ctx context.Context, resourceNam // The order here is intentional, if the VM was created successfully, then we attempt to delete the vm, the // nic, disk and all associated resources will be removed. If the VM was not created successfully and a nic was found, // then we attempt to delete the nic. + nicErr := deleteNicIfExists(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, resourceName) if nicErr != nil { - logging.FromContext(ctx).Errorf("networkInterface.Delete for %s failed: %v", resourceName, nicErr) + logging.FromContext(ctx).Errorf("networkinterface.Delete for %s failed: %v", resourceName, nicErr) } - return errors.Join(vmErr, nicErr) } @@ -752,60 +776,6 @@ func (p *DefaultProvider) getCSExtension(cse string, isWindows bool) *armcompute } } -// GetZoneID returns the zone ID for the given virtual machine, or an empty string if there is no zone specified -func GetZoneID(vm *armcompute.VirtualMachine) (string, error) { - if vm == nil { - return "", fmt.Errorf("cannot pass in a nil virtual machine") - } - if vm.Name == nil { - return "", fmt.Errorf("virtual machine is missing name") - } - if vm.Zones == nil { - return "", nil - } - if len(vm.Zones) == 1 { - return *(vm.Zones)[0], nil - } - if len(vm.Zones) > 1 { - return "", fmt.Errorf("virtual machine %v has multiple zones", *vm.Name) - } - return "", nil -} - -func GetListQueryBuilder(rg string) *kql.Builder { - return kql.New(`Resources`). - AddLiteral(` | where type == "microsoft.compute/virtualmachines"`). - AddLiteral(` | where resourceGroup == `).AddString(strings.ToLower(rg)). // ARG VMs appear to have lowercase RG - AddLiteral(` | where tags has_cs `).AddString(NodePoolTagKey) -} - -func createVMFromQueryResponseData(data map[string]interface{}) (*armcompute.VirtualMachine, error) { - jsonString, err := json.Marshal(data) - if err != nil { - return nil, err - } - vm := armcompute.VirtualMachine{} - err = json.Unmarshal(jsonString, &vm) - if err != nil { - return nil, err - } - if vm.ID == nil { - return nil, fmt.Errorf("virtual machine is missing id") - } - if vm.Name == nil { - return nil, fmt.Errorf("virtual machine is missing name") - } - if vm.Tags == nil { - return nil, fmt.Errorf("virtual machine is missing tags") - } - // We see inconsistent casing being returned by ARG for the last segment - // of the vm.ID string. This forces it to be lowercase. - parts := strings.Split(lo.FromPtr(vm.ID), "/") - parts[len(parts)-1] = strings.ToLower(parts[len(parts)-1]) - vm.ID = lo.ToPtr(strings.Join(parts, "/")) - return &vm, nil -} - func ConvertToVirtualMachineIdentity(nodeIdentities []string) *armcompute.VirtualMachineIdentity { var identity *armcompute.VirtualMachineIdentity if len(nodeIdentities) > 0 { diff --git a/pkg/providers/instance/instance_test.go b/pkg/providers/instance/instance_test.go index 67ee915f3..0c08e6578 100644 --- a/pkg/providers/instance/instance_test.go +++ b/pkg/providers/instance/instance_test.go @@ -20,8 +20,8 @@ import ( "context" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/karpenter-provider-azure/pkg/cache" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" @@ -79,7 +79,7 @@ func TestGetPriorityCapacityAndInstanceType(t *testing.T) { }, nodeClaim: &karpv1.NodeClaim{}, expectedInstanceType: "Standard_D2s_v3", - expectedZone: "2", + expectedZone: "westus-2", expectedPriority: karpv1.CapacityTypeOnDemand, }, } @@ -101,54 +101,61 @@ func TestGetPriorityCapacityAndInstanceType(t *testing.T) { } } -func TestGetZone(t *testing.T) { - testVMName := "silly-armcompute" +func TestCreateNICFromQueryResponseData(t *testing.T) { + id := "nic_id" + name := "nic_name" + tag := "tag1" + val := "val1" + tags := map[string]*string{tag: &val} + tc := []struct { testName string - input *armcompute.VirtualMachine - expectedZone string + data map[string]interface{} expectedError string + expectedNIC *armnetwork.Interface }{ { - testName: "missing name", - input: &armcompute.VirtualMachine{ - Name: nil, + testName: "missing id", + data: map[string]interface{}{ + "name": name, }, - expectedError: "virtual machine is missing name", + expectedError: "network interface is missing id", + expectedNIC: nil, }, { - testName: "invalid virtual machine struct", - input: nil, - expectedError: "cannot pass in a nil virtual machine", - }, - { - testName: "invalid zones field in virtual machine struct", - input: &armcompute.VirtualMachine{ - Name: &testVMName, + testName: "missing name", + data: map[string]interface{}{ + "id": id, }, - expectedError: "virtual machine silly-armcompute zones are nil", + expectedError: "network interface is missing name", + expectedNIC: nil, }, { testName: "happy case", - input: &armcompute.VirtualMachine{ - Name: &testVMName, - Zones: []*string{to.Ptr("poland-central")}, + data: map[string]interface{}{ + "id": id, + "name": name, + "tags": map[string]interface{}{tag: val}, }, - expectedZone: "poland-central", - }, - { - testName: "emptyZones", - input: &armcompute.VirtualMachine{ - Name: &testVMName, - Zones: []*string{}, + expectedNIC: &armnetwork.Interface{ + ID: &id, + Name: &name, + Tags: tags, }, - expectedError: "virtual machine silly-armcompute does not have any zones specified", }, } for _, c := range tc { - zone, err := GetZoneID(c.input) - assert.Equal(t, c.expectedZone, zone, c.testName) + nic, err := createNICFromQueryResponseData(c.data) + if nic != nil { + expected := *c.expectedNIC + actual := *nic + assert.Equal(t, *expected.ID, *actual.ID, c.testName) + assert.Equal(t, *expected.Name, *actual.Name, c.testName) + for key := range expected.Tags { + assert.Equal(t, *(expected.Tags[key]), *(actual.Tags[key]), c.testName) + } + } if err != nil { assert.Equal(t, c.expectedError, err.Error(), c.testName) } diff --git a/pkg/providers/instance/suite_test.go b/pkg/providers/instance/suite_test.go index 923f24e17..6ef0dc1ec 100644 --- a/pkg/providers/instance/suite_test.go +++ b/pkg/providers/instance/suite_test.go @@ -30,8 +30,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clock "k8s.io/utils/clock/testing" - "k8s.io/client-go/tools/record" - "github.com/Azure/karpenter-provider-azure/pkg/apis" "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" @@ -39,6 +37,8 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" "github.com/Azure/karpenter-provider-azure/pkg/test" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/karpenter/pkg/controllers/provisioning" "sigs.k8s.io/karpenter/pkg/controllers/state" "sigs.k8s.io/karpenter/pkg/events" @@ -46,6 +46,7 @@ import ( karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1" corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider" + . "github.com/Azure/karpenter-provider-azure/pkg/test/expectations" . "knative.dev/pkg/logging/testing" . "sigs.k8s.io/karpenter/pkg/test/expectations" "sigs.k8s.io/karpenter/pkg/test/v1alpha1" @@ -215,4 +216,24 @@ var _ = Describe("InstanceProvider", func() { return strings.Contains(key, "/") // ARM tags can't contain '/' })).To(HaveLen(0)) }) + It("should list nic from karpenter provisioning request", func() { + ExpectApplied(ctx, env.Client, nodePool, nodeClass) + pod := coretest.UnschedulablePod(coretest.PodOptions{}) + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + ExpectScheduled(ctx, env.Client, pod) + interfaces, err := azureEnv.InstanceProvider.ListNics(ctx) + Expect(err).To(BeNil()) + Expect(len(interfaces)).To(Equal(1)) + }) + It("should only list nics that belong to karpenter", func() { + managedNic := test.Interface(test.InterfaceOptions{NodepoolName: nodePool.Name}) + unmanagedNic := test.Interface(test.InterfaceOptions{Tags: map[string]*string{"kubernetes.io/cluster/test-cluster": lo.ToPtr("random-aks-vm")}}) + + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(managedNic.ID), *managedNic) + azureEnv.NetworkInterfacesAPI.NetworkInterfaces.Store(lo.FromPtr(unmanagedNic.ID), *unmanagedNic) + interfaces, err := azureEnv.InstanceProvider.ListNics(ctx) + ExpectNoError(err) + Expect(len(interfaces)).To(Equal(1)) + Expect(interfaces[0].Name).To(Equal(managedNic.Name)) + }) }) diff --git a/pkg/providers/instancetype/instancetypes.go b/pkg/providers/instancetype/instancetypes.go index 946db59e8..4ed007112 100644 --- a/pkg/providers/instancetype/instancetypes.go +++ b/pkg/providers/instancetype/instancetypes.go @@ -172,7 +172,7 @@ func instanceTypeZones(sku *skewer.SKU, region string) sets.Set[string] { skuZones := lo.Keys(sku.AvailabilityZones(region)) if hasZonalSupport(region) && len(skuZones) > 0 { return sets.New(lo.Map(skuZones, func(zone string, _ int) string { - return fmt.Sprintf("%s-%s", region, zone) + return utils.MakeZone(region, zone) })...) } return sets.New("") // empty string means non-zonal offering diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index ae1f0d515..972be555f 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -76,6 +76,8 @@ var coreProvisioner, coreProvisionerNonZonal *provisioning.Provisioner var cluster, clusterNonZonal *state.Cluster var cloudProvider, cloudProviderNonZonal *cloudprovider.CloudProvider +var fakeZone1 = utils.MakeZone(fake.Region, "1") + func TestAzure(t *testing.T) { ctx = TestContextWithLogger(t) RegisterFailHandler(Fail) @@ -589,8 +591,8 @@ var _ = Describe("InstanceType Provider", func() { Context("Unavailable Offerings", func() { It("should not allocate a vm in a zone marked as unavailable", func() { - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeSpot) - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeOnDemand) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fakeZone1, karpv1.CapacityTypeSpot) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fakeZone1, karpv1.CapacityTypeOnDemand) coretest.ReplaceRequirements(nodePool, karpv1.NodeSelectorRequirementWithMinValues{ NodeSelectorRequirement: v1.NodeSelectorRequirement{ Key: v1.LabelInstanceTypeStable, @@ -599,19 +601,38 @@ var _ = Describe("InstanceType Provider", func() { }}) ExpectApplied(ctx, env.Client, nodePool, nodeClass) - // Try this 100 times to make sure we don't get a node in eastus-1, - // we pick from 3 zones so the likelihood of this test passing by chance is 1/3^100 - for i := 0; i < 100; i++ { - pod := coretest.UnschedulablePod() - ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) - ExpectScheduled(ctx, env.Client, pod) - nodes := &v1.NodeList{} - Expect(env.Client.List(ctx, nodes)).To(Succeed()) - for _, node := range nodes.Items { - Expect(node.Labels["karpenter.kubernetes.azure/zone"]).ToNot(Equal(fmt.Sprintf("%s-1", fake.Region))) - Expect(node.Labels["node.kubernetes.io/instance-type"]).To(Equal("Standard_D2_v2")) - } - } + pod := coretest.UnschedulablePod() + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + node := ExpectScheduled(ctx, env.Client, pod) + Expect(node.Labels[v1alpha2.AlternativeLabelTopologyZone]).ToNot(Equal(fakeZone1)) + Expect(node.Labels[v1.LabelInstanceTypeStable]).To(Equal("Standard_D2_v2")) + }) + It("should handle ZonalAllocationFailed on creating the VM", func() { + azureEnv.VirtualMachinesAPI.VirtualMachinesBehavior.VirtualMachineCreateOrUpdateBehavior.Error.Set( + &azcore.ResponseError{ErrorCode: sdkerrors.ZoneAllocationFailed}, + ) + coretest.ReplaceRequirements(nodePool, karpv1.NodeSelectorRequirementWithMinValues{ + NodeSelectorRequirement: v1.NodeSelectorRequirement{ + Key: v1.LabelInstanceTypeStable, + Operator: v1.NodeSelectorOpIn, + Values: []string{"Standard_D2_v2"}, + }}) + + ExpectApplied(ctx, env.Client, nodePool, nodeClass) + pod := coretest.UnschedulablePod() + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + ExpectNotScheduled(ctx, env.Client, pod) + + By("marking whatever zone was picked as unavailable - for both spot and on-demand") + zone, err := utils.GetZone(&azureEnv.VirtualMachinesAPI.VirtualMachineCreateOrUpdateBehavior.CalledWithInput.Pop().VM) + Expect(err).ToNot(HaveOccurred()) + Expect(azureEnv.UnavailableOfferingsCache.IsUnavailable("Standard_D2_v2", zone, karpv1.CapacityTypeSpot)).To(BeTrue()) + Expect(azureEnv.UnavailableOfferingsCache.IsUnavailable("Standard_D2_v2", zone, karpv1.CapacityTypeOnDemand)).To(BeTrue()) + + By("successfully scheduling in a different zone on retry") + ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) + node := ExpectScheduled(ctx, env.Client, pod) + Expect(node.Labels[v1alpha2.AlternativeLabelTopologyZone]).ToNot(Equal(zone)) }) DescribeTable("Should not return unavailable offerings", func(azEnv *test.Environment) { @@ -641,8 +662,8 @@ var _ = Describe("InstanceType Provider", func() { ) It("should launch instances in a different zone than preferred", func() { - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeOnDemand) - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeSpot) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fakeZone1, karpv1.CapacityTypeOnDemand) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "ZonalAllocationFailure", "Standard_D2_v2", fakeZone1, karpv1.CapacityTypeSpot) ExpectApplied(ctx, env.Client, nodeClass, nodePool) pod := coretest.UnschedulablePod(coretest.PodOptions{ @@ -651,18 +672,18 @@ var _ = Describe("InstanceType Provider", func() { pod.Spec.Affinity = &v1.Affinity{NodeAffinity: &v1.NodeAffinity{PreferredDuringSchedulingIgnoredDuringExecution: []v1.PreferredSchedulingTerm{ { Weight: 1, Preference: v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{ - {Key: v1.LabelTopologyZone, Operator: v1.NodeSelectorOpIn, Values: []string{fmt.Sprintf("%s-1", fake.Region)}}, + {Key: v1.LabelTopologyZone, Operator: v1.NodeSelectorOpIn, Values: []string{fakeZone1}}, }}, }, }}} ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) node := ExpectScheduled(ctx, env.Client, pod) - Expect(node.Labels["karpenter.kubernetes.azure/zone"]).ToNot(Equal(fmt.Sprintf("%s-1", fake.Region))) - Expect(node.Labels["node.kubernetes.io/instance-type"]).To(Equal("Standard_D2_v2")) + Expect(node.Labels[v1alpha2.AlternativeLabelTopologyZone]).ToNot(Equal(fakeZone1)) + Expect(node.Labels[v1.LabelInstanceTypeStable]).To(Equal("Standard_D2_v2")) }) It("should launch smaller instances than optimal if larger instance launch results in Insufficient Capacity Error", func() { - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "SubscriptionQuotaReached", "Standard_F16s_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeOnDemand) - azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "SubscriptionQuotaReached", "Standard_F16s_v2", fmt.Sprintf("%s-1", fake.Region), karpv1.CapacityTypeSpot) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "SubscriptionQuotaReached", "Standard_F16s_v2", fakeZone1, karpv1.CapacityTypeOnDemand) + azureEnv.UnavailableOfferingsCache.MarkUnavailable(ctx, "SubscriptionQuotaReached", "Standard_F16s_v2", fakeZone1, karpv1.CapacityTypeSpot) coretest.ReplaceRequirements(nodePool, karpv1.NodeSelectorRequirementWithMinValues{ NodeSelectorRequirement: v1.NodeSelectorRequirement{ Key: v1.LabelInstanceTypeStable, @@ -676,7 +697,7 @@ var _ = Describe("InstanceType Provider", func() { Requests: v1.ResourceList{v1.ResourceCPU: resource.MustParse("1")}, }, NodeSelector: map[string]string{ - v1.LabelTopologyZone: fmt.Sprintf("%s-1", fake.Region), + v1.LabelTopologyZone: fakeZone1, }, })) } @@ -731,8 +752,8 @@ var _ = Describe("InstanceType Provider", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, coreProvisioner, pod) ExpectNotScheduled(ctx, env.Client, pod) - for _, zone := range []string{"1", "2", "3"} { - ExpectUnavailable(azureEnv, sku, zone, capacityType) + for _, zoneID := range []string{"1", "2", "3"} { + ExpectUnavailable(azureEnv, sku, utils.MakeZone(fake.Region, zoneID), capacityType) } } @@ -793,7 +814,7 @@ var _ = Describe("InstanceType Provider", func() { // Well known v1.LabelTopologyRegion: fake.Region, karpv1.NodePoolLabelKey: nodePool.Name, - v1.LabelTopologyZone: fmt.Sprintf("%s-1", fake.Region), + v1.LabelTopologyZone: fakeZone1, v1.LabelInstanceTypeStable: "Standard_NC24ads_A100_v4", v1.LabelOSStable: "linux", v1.LabelArchStable: "amd64", @@ -814,11 +835,11 @@ var _ = Describe("InstanceType Provider", func() { v1alpha2.LabelSKUAccelerator: "A100", // Deprecated Labels v1.LabelFailureDomainBetaRegion: fake.Region, - v1.LabelFailureDomainBetaZone: fmt.Sprintf("%s-1", fake.Region), + v1.LabelFailureDomainBetaZone: fakeZone1, "beta.kubernetes.io/arch": "amd64", "beta.kubernetes.io/os": "linux", v1.LabelInstanceType: "Standard_NC24ads_A100_v4", - "topology.disk.csi.azure.com/zone": fmt.Sprintf("%s-1", fake.Region), + "topology.disk.csi.azure.com/zone": fakeZone1, v1.LabelWindowsBuild: "window", // Cluster Label v1alpha2.AKSLabelCluster: "test-cluster", diff --git a/pkg/test/environment.go b/pkg/test/environment.go index 11d3faf0e..bcf8ab48f 100644 --- a/pkg/test/environment.go +++ b/pkg/test/environment.go @@ -92,15 +92,16 @@ func NewRegionalEnvironment(ctx context.Context, env *coretest.Environment, regi // API virtualMachinesAPI := &fake.VirtualMachinesAPI{} - azureResourceGraphAPI := &fake.AzureResourceGraphAPI{AzureResourceGraphBehavior: fake.AzureResourceGraphBehavior{VirtualMachinesAPI: virtualMachinesAPI, ResourceGroup: resourceGroup}} - virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} + networkInterfacesAPI := &fake.NetworkInterfacesAPI{} + virtualMachinesExtensionsAPI := &fake.VirtualMachineExtensionsAPI{} pricingAPI := &fake.PricingAPI{} skuClientSingleton := &fake.MockSkuClientSingleton{SKUClient: &fake.ResourceSKUsAPI{Location: region}} communityImageVersionsAPI := &fake.CommunityGalleryImageVersionsAPI{} loadBalancersAPI := &fake.LoadBalancersAPI{} nodeImageVersionsAPI := &fake.NodeImageVersionsAPI{} + azureResourceGraphAPI := fake.NewAzureResourceGraphAPI(resourceGroup, virtualMachinesAPI, networkInterfacesAPI) // Cache kubernetesVersionCache := cache.New(azurecache.KubernetesVersionTTL, azurecache.DefaultCleanupInterval) instanceTypeCache := cache.New(instancetype.InstanceTypesCacheTTL, azurecache.DefaultCleanupInterval) diff --git a/pkg/test/expectations/expectations.go b/pkg/test/expectations/expectations.go index f16b7fcfe..d6f1e5632 100644 --- a/pkg/test/expectations/expectations.go +++ b/pkg/test/expectations/expectations.go @@ -21,16 +21,14 @@ import ( "fmt" "strings" + "github.com/Azure/karpenter-provider-azure/pkg/test" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - - "github.com/Azure/karpenter-provider-azure/pkg/fake" - "github.com/Azure/karpenter-provider-azure/pkg/test" ) func ExpectUnavailable(env *test.Environment, instanceType string, zone string, capacityType string) { GinkgoHelper() - Expect(env.UnavailableOfferingsCache.IsUnavailable(instanceType, fmt.Sprintf("%s-%s", fake.Region, zone), capacityType)).To(BeTrue()) + Expect(env.UnavailableOfferingsCache.IsUnavailable(instanceType, zone, capacityType)).To(BeTrue()) } func ExpectKubeletFlags(env *test.Environment, customData string, expectedFlags map[string]string) { @@ -55,3 +53,8 @@ func ExpectDecodedCustomData(env *test.Environment) string { return decodedString } + +func ExpectNoError(err error) { + GinkgoHelper() + Expect(err).To(BeNil()) +} diff --git a/pkg/test/networkinterfaces.go b/pkg/test/networkinterfaces.go new file mode 100644 index 000000000..c879a7495 --- /dev/null +++ b/pkg/test/networkinterfaces.go @@ -0,0 +1,83 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/Azure/karpenter-provider-azure/pkg/fake" + "github.com/imdario/mergo" + "github.com/samber/lo" +) + +// InterfaceOptions customizes an Azure Network Interface for testing. +type InterfaceOptions struct { + Name string + NodepoolName string + Location string + Properties *armnetwork.InterfacePropertiesFormat + Tags map[string]*string +} + +// Interface creates a test Azure Network Interface with defaults that can be overridden by InterfaceOptions. +// Overrides are applied in order, with last-write-wins semantics. +func Interface(overrides ...InterfaceOptions) *armnetwork.Interface { + options := InterfaceOptions{} + for _, o := range overrides { + if err := mergo.Merge(&options, o, mergo.WithOverride); err != nil { + panic(fmt.Sprintf("Failed to merge Interface options: %s", err)) + } + } + + // Provide default values if none are set + if options.Name == "" { + options.Name = RandomName("aks") + } + if options.NodepoolName == "" { + options.NodepoolName = "default" + } + if options.Location == "" { + options.Location = fake.Region + } + if options.Tags == nil { + options.Tags = ManagedTags(options.NodepoolName) + } + if options.Properties == nil { + options.Properties = &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ + { + Name: lo.ToPtr("ipConfig"), + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: lo.ToPtr(armnetwork.IPAllocationMethodDynamic), + Subnet: &armnetwork.Subnet{ID: lo.ToPtr("/subscriptions/.../resourceGroups/.../providers/Microsoft.Network/virtualNetworks/.../subnets/default")}, + }, + }, + }, + } + } + + nic := &armnetwork.Interface{ + ID: lo.ToPtr(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/test-resourceGroup/providers/Microsoft.Network/networkInterfaces/%s", options.Name)), + Name: &options.Name, + Location: &options.Location, + Properties: options.Properties, + Tags: options.Tags, + } + + return nic +} diff --git a/pkg/test/utils.go b/pkg/test/utils.go new file mode 100644 index 000000000..1c1af00f9 --- /dev/null +++ b/pkg/test/utils.go @@ -0,0 +1,35 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "github.com/samber/lo" + k8srand "k8s.io/apimachinery/pkg/util/rand" +) + +// RandomName returns a pseudo-random resource name with a given prefix. +func RandomName(prefix string) string { + // You could make this more robust by including additional random characters. + return prefix + "-" + k8srand.String(10) +} + +func ManagedTags(nodepoolName string) map[string]*string { + return map[string]*string{ + "karpenter.sh_cluster": lo.ToPtr("test-cluster"), + "karpenter.sh_nodepool": lo.ToPtr(nodepoolName), + } +} diff --git a/pkg/test/virtualmachines.go b/pkg/test/virtualmachines.go new file mode 100644 index 000000000..7ad01321a --- /dev/null +++ b/pkg/test/virtualmachines.go @@ -0,0 +1,78 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/karpenter-provider-azure/pkg/fake" + "github.com/imdario/mergo" + "github.com/samber/lo" +) + +// VirtualMachineOptions customizes an Azure Virtual Machine for testing. +type VirtualMachineOptions struct { + Name string + NodepoolName string + Location string + Properties *armcompute.VirtualMachineProperties + Tags map[string]*string +} + +// VirtualMachine creates a test Azure Virtual Machine with defaults that can be overridden by VirtualMachineOptions. +// Overrides are applied in order, with last-write-wins semantics. +func VirtualMachine(overrides ...VirtualMachineOptions) *armcompute.VirtualMachine { + options := VirtualMachineOptions{} + for _, o := range overrides { + if err := mergo.Merge(&options, o, mergo.WithOverride); err != nil { + panic(fmt.Sprintf("Failed to merge VirtualMachine options: %s", err)) + } + } + + // Provide default values if none are set + if options.Name == "" { + options.Name = RandomName("aks") + } + if options.NodepoolName == "" { + options.NodepoolName = "default" + } + if options.Location == "" { + options.Location = fake.Region + } + if options.Properties == nil { + options.Properties = &armcompute.VirtualMachineProperties{} + } + if options.Tags == nil { + options.Tags = ManagedTags(options.NodepoolName) + } + if options.Properties.TimeCreated == nil { + options.Properties.TimeCreated = lo.ToPtr(time.Now()) + } + + // Construct the basic VM + vm := &armcompute.VirtualMachine{ + ID: lo.ToPtr(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/test-resourceGroup/providers/Microsoft.Compute/virtualMachines/%s", options.Name)), + Name: lo.ToPtr(options.Name), + Location: lo.ToPtr(options.Location), + Properties: options.Properties, + Tags: options.Tags, + } + + return vm +} diff --git a/pkg/utils/zone.go b/pkg/utils/zone.go new file mode 100644 index 000000000..a1efe304e --- /dev/null +++ b/pkg/utils/zone.go @@ -0,0 +1,61 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" +) + +// MakeZone returns the zone value in format of -. +func MakeZone(location string, zoneID string) string { + if zoneID == "" { + return "" + } + return fmt.Sprintf("%s-%s", strings.ToLower(location), zoneID) +} + +// VM Zones field expects just the zone number, without region +func MakeVMZone(zone string) []*string { + if zone == "" { + return []*string{} + } + zoneNum := zone[len(zone)-1:] + return []*string{&zoneNum} +} + +// GetZone returns the zone for the given virtual machine, or an empty string if there is no zone specified +func GetZone(vm *armcompute.VirtualMachine) (string, error) { + if vm == nil { + return "", fmt.Errorf("cannot pass in a nil virtual machine") + } + if vm.Zones == nil { + return "", nil + } + if len(vm.Zones) == 1 { + if vm.Location == nil { + return "", fmt.Errorf("virtual machine is missing location") + } + return MakeZone(*vm.Location, *(vm.Zones)[0]), nil + } + if len(vm.Zones) > 1 { + return "", fmt.Errorf("virtual machine has multiple zones") + } + return "", nil +} diff --git a/pkg/utils/zone_test.go b/pkg/utils/zone_test.go new file mode 100644 index 000000000..ff9d9c353 --- /dev/null +++ b/pkg/utils/zone_test.go @@ -0,0 +1,86 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils_test + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/Azure/karpenter-provider-azure/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func TestGetZone(t *testing.T) { + tc := []struct { + testName string + input *armcompute.VirtualMachine + expectedZone string + expectedError string + }{ + { + testName: "invalid virtual machine struct", + input: nil, + expectedError: "cannot pass in a nil virtual machine", + }, + { + testName: "happy case", + input: &armcompute.VirtualMachine{ + Location: to.Ptr("region"), + Zones: []*string{to.Ptr("1")}, + }, + expectedZone: "region-1", + }, + { + testName: "missing Location", + input: &armcompute.VirtualMachine{ + Zones: []*string{to.Ptr("1")}, + }, + expectedError: "virtual machine is missing location", + }, + { + testName: "multiple zones", + input: &armcompute.VirtualMachine{ + Zones: []*string{to.Ptr("1"), to.Ptr("2")}, + }, + expectedError: "virtual machine has multiple zones", + }, + { + testName: "empty Zones", + input: &armcompute.VirtualMachine{ + Zones: []*string{}, + }, + expectedZone: "", + }, + { + testName: "nil Zones", + input: &armcompute.VirtualMachine{}, + expectedZone: "", + }, + } + + for _, c := range tc { + zone, err := utils.GetZone(c.input) + assert.Equal(t, c.expectedZone, zone, c.testName) + if err == nil && c.expectedError != "" { + assert.Fail(t, "expected error but got nil", c.testName) + } + if err != nil { + assert.Equal(t, c.expectedError, err.Error(), c.testName) + } + } +} diff --git a/test/pkg/environment/common/expectations.go b/test/pkg/environment/common/expectations.go index af3876eee..75274e0e8 100644 --- a/test/pkg/environment/common/expectations.go +++ b/test/pkg/environment/common/expectations.go @@ -366,7 +366,12 @@ func (env *Environment) EventuallyExpectKarpenterRestarted() { GinkgoHelper() By("rolling out the new karpenter deployment") env.EventuallyExpectRollout("karpenter", "kube-system") - env.ExpectKarpenterLeaseOwnerChanged() + + if !lo.ContainsBy(env.ExpectSettings(), func(v corev1.EnvVar) bool { + return v.Name == "DISABLE_LEADER_ELECTION" && v.Value == "true" + }) { + env.ExpectKarpenterLeaseOwnerChanged() + } } func (env *Environment) ExpectKarpenterLeaseOwnerChanged() { diff --git a/test/suites/drift/suite_test.go b/test/suites/drift/suite_test.go index b2d06e2c4..4d7fce780 100644 --- a/test/suites/drift/suite_test.go +++ b/test/suites/drift/suite_test.go @@ -61,8 +61,6 @@ var _ = Describe("Drift", func() { var pod *corev1.Pod BeforeEach(func() { - env.ExpectSettingsOverridden(corev1.EnvVar{Name: "FEATURE_GATES", Value: "Drift=true"}) - coretest.ReplaceRequirements(nodePool, karpv1.NodeSelectorRequirementWithMinValues{ NodeSelectorRequirement: corev1.NodeSelectorRequirement{ Key: corev1.LabelInstanceTypeStable,