Skip to content

Commit

Permalink
fix: moving filter for gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryce-Soghigian committed Jan 5, 2024
1 parent 50ba327 commit 3d3a65e
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 37 deletions.
5 changes: 5 additions & 0 deletions pkg/apis/v1alpha2/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,8 @@ var (

NodeClaimLinkedAnnotationKey = v1alpha5.MachineLinkedAnnotationKey // still using the one from v1alpha5
)

const (
Ubuntu2204ImageFamily = "Ubuntu2204"
AzureLinuxImageFamily = "AzureLinux"
)
3 changes: 1 addition & 2 deletions pkg/providers/imagefamily/azlinux.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
)

const (
AzureLinuxImageFamily = "AzureLinux"
AzureLinuxGen2CommunityImage = "V2gen2"
AzureLinuxGen1CommunityImage = "V2"
AzureLinuxGen2ArmCommunityImage = "V2gen2arm64"
Expand All @@ -41,7 +40,7 @@ type AzureLinux struct {
}

func (u AzureLinux) Name() string {
return AzureLinuxImageFamily
return v1alpha2.AzureLinuxImageFamily
}

func (u AzureLinux) DefaultImages() []DefaultImageOutput {
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/imagefamily/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func (r Resolver) Resolve(ctx context.Context, nodeClass *v1alpha2.AKSNodeClass,

func getImageFamily(familyName *string, parameters *template.StaticParameters) ImageFamily {
switch lo.FromPtr(familyName) {
case Ubuntu2204ImageFamily:
case v1alpha2.Ubuntu2204ImageFamily:
return &Ubuntu2204{Options: parameters}
case AzureLinuxImageFamily:
case v1alpha2.AzureLinuxImageFamily:
return &AzureLinux{Options: parameters}
default:
return &Ubuntu2204{Options: parameters}
Expand Down
3 changes: 1 addition & 2 deletions pkg/providers/imagefamily/ubuntu_2204.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
)

const (
Ubuntu2204ImageFamily = "Ubuntu2204"
Ubuntu2204Gen2CommunityImage = "2204gen2containerd"
Ubuntu2204Gen1CommunityImage = "2204containerd"
Ubuntu2204Gen2ArmCommunityImage = "2204gen2arm64containerd"
Expand All @@ -41,7 +40,7 @@ type Ubuntu2204 struct {
}

func (u Ubuntu2204) Name() string {
return Ubuntu2204ImageFamily
return v1alpha2.Ubuntu2204ImageFamily
}

func (u Ubuntu2204) DefaultImages() []DefaultImageOutput {
Expand Down
45 changes: 26 additions & 19 deletions pkg/providers/instancetype/instancetypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ import (
const (
InstanceTypesCacheKey = "types"
InstanceTypesCacheTTL = 23 * time.Hour

Ubuntu2204ImageFamily = "Ubuntu2204"
AzureLinuxImageFamily = "AzureLinux"
)

type Provider struct {
Expand Down Expand Up @@ -78,8 +75,7 @@ func (p *Provider) List(
p.Lock()
defer p.Unlock()
// Get SKUs from Azure
imageFamily := lo.FromPtr(nodeClass.Spec.ImageFamily)
skus, err := p.getInstanceTypes(ctx, imageFamily)
skus, err := p.getInstanceTypes(ctx)
if err != nil {
return nil, err
}
Expand All @@ -103,6 +99,10 @@ func (p *Provider) List(
if len(instanceType.Offerings) == 0 {
continue
}

if !p.isInstanceTypeSupportedByImageFamily(sku.GetName(), lo.FromPtr(nodeClass.Spec.ImageFamily)) {
continue
}
result = append(result, instanceType)
}
return result, nil
Expand Down Expand Up @@ -143,8 +143,23 @@ func (p *Provider) createOfferings(sku *skewer.SKU, zones sets.Set[string]) []cl
return offerings
}

func (p *Provider) isInstanceTypeSupportedByImageFamily(skuName, imageFamily string) bool {
// Currently only GPU has conditional support by image family
if !(utils.IsNvidiaEnabledSKU(skuName) || utils.IsMarinerEnabledGPUSKU(skuName)) {
return true
}
switch imageFamily {
case v1alpha2.Ubuntu2204ImageFamily:
return utils.IsNvidiaEnabledSKU(skuName)
case v1alpha2.AzureLinuxImageFamily:
return utils.IsMarinerEnabledGPUSKU(skuName)
default:
return false
}
}

// getInstanceTypes retrieves all instance types from skewer using some opinionated filters
func (p *Provider) getInstanceTypes(ctx context.Context, imageFamily string) (map[string]*skewer.SKU, error) {
func (p *Provider) getInstanceTypes(ctx context.Context) (map[string]*skewer.SKU, error) {
if cached, ok := p.cache.Get(InstanceTypesCacheKey); ok {
return cached.(map[string]*skewer.SKU), nil
}
Expand All @@ -164,7 +179,7 @@ func (p *Provider) getInstanceTypes(ctx context.Context, imageFamily string) (ma
continue
}

if !skus[i].HasLocationRestriction(p.region) && p.isSupported(&skus[i], vmsize, imageFamily) {
if !skus[i].HasLocationRestriction(p.region) && p.isSupported(&skus[i], vmsize) {
instanceTypes[skus[i].GetName()] = &skus[i]
}
}
Expand All @@ -175,11 +190,11 @@ func (p *Provider) getInstanceTypes(ctx context.Context, imageFamily string) (ma
}

// isSupported indicates SKU is supported by AKS, based on SKU properties
func (p *Provider) isSupported(sku *skewer.SKU, vmsize *skewer.VMSizeType, imageFamily string) bool {
func (p *Provider) isSupported(sku *skewer.SKU, vmsize *skewer.VMSizeType) bool {
return p.hasMinimumCPU(sku) &&
p.hasMinimumMemory(sku) &&
!p.isUnsupportedByAKS(sku) &&
!p.isUnsupportedGPU(sku, imageFamily) &&
!p.isUnsupportedGPU(sku) &&
!p.hasConstrainedCPUs(vmsize) &&
!p.isConfidential(sku)
}
Expand All @@ -202,21 +217,13 @@ func (p *Provider) isUnsupportedByAKS(sku *skewer.SKU) bool {
}

// GPU SKUs AKS does not support
func (p *Provider) isUnsupportedGPU(sku *skewer.SKU, imageFamily string) bool {
func (p *Provider) isUnsupportedGPU(sku *skewer.SKU) bool {
name := lo.FromPtr(sku.Name)
gpu, err := sku.GPU()
if err != nil || gpu <= 0 {
return false
}

switch imageFamily {
case Ubuntu2204ImageFamily:
return !utils.IsNvidiaEnabledSKU(name)
case AzureLinuxImageFamily:
return !utils.IsMarinerEnabledGPUSKU(name)
default:
return false
}
return !utils.IsMarinerEnabledGPUSKU(name) && !utils.IsNvidiaEnabledSKU(name)
}

// SKU with constrained CPUs
Expand Down
6 changes: 0 additions & 6 deletions pkg/providers/instancetype/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ var _ = Describe("InstanceType Provider", func() {
os.Setenv("AZURE_SUBNET_NAME", "test-subnet-name")

nodeClass = test.AKSNodeClass()
// Sometimes we use nodeClass without applying it, when simulating the List() call.
// In that case, we need to set the default values for the node class.
nodeClass.Spec.OSDiskSizeGB = lo.ToPtr[int32](128)
nodePool = coretest.NodePool(corev1beta1.NodePool{
Spec: corev1beta1.NodePoolSpec{
Template: corev1beta1.NodeClaimTemplate{
Expand Down Expand Up @@ -268,7 +265,6 @@ var _ = Describe("InstanceType Provider", func() {
It("should use ephemeral disk if supported, and has space of at least 128GB by default", func() {
// Create a Provisioner that selects a sku that supports ephemeral
// SKU Standard_D64s_v3 has 1600GB of CacheDisk space, so we expect we can create an ephemeral disk with size 128GB

np := coretest.NodePool()
np.Spec.Template.Spec.Requirements = append(np.Spec.Template.Spec.Requirements, v1.NodeSelectorRequirement{
Key: "node.kubernetes.io/instance-type",
Expand Down Expand Up @@ -362,7 +358,6 @@ var _ = Describe("InstanceType Provider", func() {
}

It("should support provisioning with kubeletConfig, computeResources & maxPods not specified", func() {

nodePool.Spec.Template.Spec.Kubelet = kubeletConfig
ExpectApplied(ctx, env.Client, nodePool, nodeClass)

Expand Down Expand Up @@ -672,7 +667,6 @@ var _ = Describe("InstanceType Provider", func() {
It("should propagate all values to requirements from skewer", func() {
var gpuNode *corecloudprovider.InstanceType
var normalNode *corecloudprovider.InstanceType

for _, instanceType := range instanceTypes {
if instanceType.Name == "Standard_D2_v2" {
normalNode = instanceType
Expand Down
13 changes: 7 additions & 6 deletions pkg/test/aksnodeclass.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ limitations under the License.
package test

import (
"context"
"fmt"

"github.com/imdario/mergo"

"github.com/aws/karpenter-core/pkg/test"

"github.com/Azure/karpenter/pkg/apis/v1alpha2"
"github.com/aws/karpenter-core/pkg/test"
"github.com/imdario/mergo"
"github.com/samber/lo"
)

func AKSNodeClass(overrides ...v1alpha2.AKSNodeClass) *v1alpha2.AKSNodeClass {
Expand All @@ -40,6 +38,9 @@ func AKSNodeClass(overrides ...v1alpha2.AKSNodeClass) *v1alpha2.AKSNodeClass {
Spec: options.Spec,
Status: options.Status,
}
nc.SetDefaults(context.Background())
// In reality, these default values will be set via the defaulting done by the API server. The reason we provide them here is
// we sometimes reference a test.AKSNodeClass without applying it, and in that case we need to set the default values ourselves
nc.Spec.OSDiskSizeGB = lo.ToPtr[int32](128)
nc.Spec.ImageFamily = lo.ToPtr(v1alpha2.Ubuntu2204ImageFamily)
return nc
}

0 comments on commit 3d3a65e

Please sign in to comment.