From ff2850b5b33ec71ecf3ff7a424c1f4bb6c50a428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bindewald=2C=20Andr=C3=A9=20=28UIT=29?= Date: Mon, 13 Jan 2025 16:11:40 +0100 Subject: [PATCH] refactor: Typed data source --- .../compute/compute_skus_data_source.go | 390 ++++++++++++------ internal/services/compute/parse/skus.go | 57 +++ internal/services/compute/parse/skus_test.go | 82 ++++ internal/services/compute/registration.go | 2 +- internal/services/compute/validate/skus_id.go | 26 ++ .../services/compute/validate/skus_id_test.go | 54 +++ 6 files changed, 478 insertions(+), 133 deletions(-) create mode 100644 internal/services/compute/parse/skus.go create mode 100644 internal/services/compute/parse/skus_test.go create mode 100644 internal/services/compute/validate/skus_id.go create mode 100644 internal/services/compute/validate/skus_id_test.go diff --git a/internal/services/compute/compute_skus_data_source.go b/internal/services/compute/compute_skus_data_source.go index 3f7298b0b7da..3c2bdec4c661 100644 --- a/internal/services/compute/compute_skus_data_source.go +++ b/internal/services/compute/compute_skus_data_source.go @@ -4,180 +4,306 @@ package compute import ( - "fmt" + "context" "slices" "strings" "time" - "github.com/google/uuid" + "github.com/hashicorp/go-azure-helpers/lang/pointer" "github.com/hashicorp/go-azure-helpers/resourcemanager/commonids" "github.com/hashicorp/go-azure-helpers/resourcemanager/commonschema" "github.com/hashicorp/go-azure-helpers/resourcemanager/location" "github.com/hashicorp/go-azure-sdk/resource-manager/compute/2021-07-01/skus" - "github.com/hashicorp/terraform-provider-azurerm/internal/clients" + "github.com/hashicorp/terraform-provider-azurerm/internal/sdk" + "github.com/hashicorp/terraform-provider-azurerm/internal/services/compute/parse" "github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk" "github.com/hashicorp/terraform-provider-azurerm/internal/tf/validation" - "github.com/hashicorp/terraform-provider-azurerm/internal/timeouts" ) -func dataSourceComputeSkus() *pluginsdk.Resource { - return &pluginsdk.Resource{ - Read: dataSourceComputeSkusRead, +type ComputeSkusDataSource struct{} - Timeouts: &pluginsdk.ResourceTimeout{ - Read: pluginsdk.DefaultTimeout(5 * time.Minute), +var _ sdk.DataSource = ComputeSkusDataSource{} + +type ComputeSkusDataSourceModel struct { + Name string `tfschema:"name"` + Location string `tfschema:"location"` + IncludeCapabilities bool `tfschema:"include_capabilities"` + Skus []ComputeSkusSkuModel `tfschema:"skus"` +} + +type ComputeSkusSkuModel struct { + Name string `tfschema:"name"` + ResourceType string `tfschema:"resource_type"` + Size string `tfschema:"size"` + Tier string `tfschema:"tier"` + LocationRestrictions []string `tfschema:"location_restrictions"` + ZoneRestrictions []string `tfschema:"zone_restrictions"` + Capabilities map[string]string `tfschema:"capabilities"` + Zones []string `tfschema:"zones"` +} + +func (ds ComputeSkusDataSource) Arguments() map[string]*pluginsdk.Schema { + return map[string]*pluginsdk.Schema{ + "name": { + Type: pluginsdk.TypeString, + Optional: true, + ValidateFunc: validation.StringIsNotEmpty, + }, + "location": commonschema.Location(), + "include_capabilities": { + Type: pluginsdk.TypeBool, + Optional: true, + Default: false, }, + } +} - Schema: map[string]*pluginsdk.Schema{ - "name": { - Type: pluginsdk.TypeString, - Optional: true, - ValidateFunc: validation.StringIsNotEmpty, - }, - "location": commonschema.Location(), - "include_capabilities": { - Type: pluginsdk.TypeBool, - Optional: true, - Default: false, - }, - "skus": { - Type: pluginsdk.TypeList, - Computed: true, - Elem: &pluginsdk.Resource{ - Schema: map[string]*pluginsdk.Schema{ - "name": { - Type: pluginsdk.TypeString, - Computed: true, - }, - "resource_type": { - Type: pluginsdk.TypeString, - Computed: true, - }, - "size": { - Type: pluginsdk.TypeString, - Computed: true, - }, - "tier": { - Type: pluginsdk.TypeString, - Computed: true, - }, - "location_restrictions": { - Type: pluginsdk.TypeList, - Computed: true, - Elem: &pluginsdk.Schema{ - Type: pluginsdk.TypeString, - }, +func (ds ComputeSkusDataSource) Attributes() map[string]*pluginsdk.Schema { + return map[string]*pluginsdk.Schema{ + "skus": { + Type: pluginsdk.TypeList, + Computed: true, + Elem: &pluginsdk.Resource{ + Schema: map[string]*pluginsdk.Schema{ + "name": { + Type: pluginsdk.TypeString, + Computed: true, + }, + "resource_type": { + Type: pluginsdk.TypeString, + Computed: true, + }, + "size": { + Type: pluginsdk.TypeString, + Computed: true, + }, + "tier": { + Type: pluginsdk.TypeString, + Computed: true, + }, + "location_restrictions": { + Type: pluginsdk.TypeList, + Computed: true, + Elem: &pluginsdk.Schema{ + Type: pluginsdk.TypeString, }, - "zone_restrictions": { - Type: pluginsdk.TypeList, - Computed: true, - Elem: &pluginsdk.Schema{ - Type: pluginsdk.TypeString, - }, + }, + "zone_restrictions": { + Type: pluginsdk.TypeList, + Computed: true, + Elem: &pluginsdk.Schema{ + Type: pluginsdk.TypeString, }, - "capabilities": { - Type: pluginsdk.TypeMap, - Optional: true, - Elem: &pluginsdk.Schema{ - Type: pluginsdk.TypeString, - }, + }, + "capabilities": { + Type: pluginsdk.TypeMap, + Computed: true, + Elem: &pluginsdk.Schema{ + Type: pluginsdk.TypeString, }, - "zones": commonschema.ZonesMultipleComputed(), }, + "zones": commonschema.ZonesMultipleComputed(), }, }, }, } } -func dataSourceComputeSkusRead(d *pluginsdk.ResourceData, meta interface{}) error { - client := meta.(*clients.Client).Compute.SkusClient - subscriptionId := meta.(*clients.Client).Account.SubscriptionId +func (ds ComputeSkusDataSource) ModelObject() interface{} { + return &ComputeSkusDataSourceModel{} +} - ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d) - defer cancel() +func (ds ComputeSkusDataSource) ResourceType() string { + return "azurerm_compute_skus" +} - resp, err := client.ResourceSkusList(ctx, commonids.NewSubscriptionID(subscriptionId), skus.DefaultResourceSkusListOperationOptions()) - if err != nil { - return fmt.Errorf("retrieving SKUs: %+v", err) - } +func (ds ComputeSkusDataSource) Read() sdk.ResourceFunc { + return sdk.ResourceFunc{ + Timeout: 5 * time.Minute, + Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error { + var state ComputeSkusDataSourceModel + if err := metadata.Decode(&state); err != nil { + return err + } - name := d.Get("name").(string) - loc := location.Normalize(d.Get("location").(string)) - availableSkus := make([]map[string]interface{}, 0) + subscriptionId := metadata.Client.Account.SubscriptionId + name := state.Name + loc := location.Normalize(state.Location) + availableSkus := make([]ComputeSkusSkuModel, 0) + id := parse.NewSkusID(subscriptionId) - if model := resp.Model; model != nil { - for _, sku := range *model { - // the API does not allow filtering by name - if name != "" { - if !strings.EqualFold(*sku.Name, name) { - continue - } + resp, err := metadata.Client.Compute.SkusClient.ResourceSkusList(ctx, commonids.NewSubscriptionID(subscriptionId), skus.DefaultResourceSkusListOperationOptions()) + if err != nil { + return err } - // while the API accepts OData filters, the location filter is currently - // not working, thus we need to filter the results manually - locationsNormalized := make([]string, len(*sku.Locations)) - for _, v := range *sku.Locations { - locationsNormalized = append(locationsNormalized, location.Normalize(v)) - } - if !slices.Contains(locationsNormalized, loc) { - continue - } + if model := resp.Model; model != nil { + for _, sku := range *model { + // the API does not allow filtering by name + if name != "" { + if !strings.EqualFold(*sku.Name, name) { + continue + } + } - var zones []string - var locationRestrictions []string - var zoneRestrictions []string - capabilities := make(map[string]string) + // while the API accepts OData filters, the location filter is currently + // not working, thus we need to filter the results manually + locationsNormalized := make([]string, len(*sku.Locations)) + for _, v := range *sku.Locations { + locationsNormalized = append(locationsNormalized, location.Normalize(v)) + } + if !slices.Contains(locationsNormalized, loc) { + continue + } - if sku.Restrictions != nil && len(*sku.Restrictions) > 0 { - for _, restriction := range *sku.Restrictions { - restrictionType := *restriction.Type + var zones []string + var locationRestrictions []string + var zoneRestrictions []string + capabilities := make(map[string]string) - switch restrictionType { - case skus.ResourceSkuRestrictionsTypeLocation: - restrictedLocationsNormalized := make([]string, 0) - for _, v := range *restriction.RestrictionInfo.Locations { - restrictedLocationsNormalized = append(restrictedLocationsNormalized, location.Normalize(v)) - } - locationRestrictions = restrictedLocationsNormalized + if sku.Restrictions != nil && len(*sku.Restrictions) > 0 { + for _, restriction := range *sku.Restrictions { + restrictionType := *restriction.Type + + switch restrictionType { + case skus.ResourceSkuRestrictionsTypeLocation: + restrictedLocationsNormalized := make([]string, 0) + for _, v := range *restriction.RestrictionInfo.Locations { + restrictedLocationsNormalized = append(restrictedLocationsNormalized, location.Normalize(v)) + } + locationRestrictions = restrictedLocationsNormalized - case skus.ResourceSkuRestrictionsTypeZone: - zoneRestrictions = *restriction.RestrictionInfo.Zones + case skus.ResourceSkuRestrictionsTypeZone: + zoneRestrictions = *restriction.RestrictionInfo.Zones + } + } } - } - } - if sku.LocationInfo != nil && len(*sku.LocationInfo) > 0 { - for _, locationInfo := range *sku.LocationInfo { - if location.Normalize(*locationInfo.Location) == loc { - zones = *locationInfo.Zones + if sku.LocationInfo != nil && len(*sku.LocationInfo) > 0 { + for _, locationInfo := range *sku.LocationInfo { + if location.Normalize(*locationInfo.Location) == loc { + zones = *locationInfo.Zones + } + } } - } - } - if d.Get("include_capabilities").(bool) { - if sku.Capabilities != nil && len(*sku.Capabilities) > 0 { - for _, capability := range *sku.Capabilities { - capabilities[*capability.Name] = *capability.Value + if state.IncludeCapabilities { + if sku.Capabilities != nil && len(*sku.Capabilities) > 0 { + for _, capability := range *sku.Capabilities { + capabilities[*capability.Name] = *capability.Value + } + } } + + availableSkus = append(availableSkus, ComputeSkusSkuModel{ + Name: pointer.From(sku.Name), + ResourceType: pointer.From(sku.ResourceType), + Size: pointer.From(sku.Size), + Tier: pointer.From(sku.Tier), + LocationRestrictions: locationRestrictions, + ZoneRestrictions: zoneRestrictions, + Zones: zones, + Capabilities: capabilities, + }) } + + state.Skus = availableSkus } - availableSkus = append(availableSkus, map[string]interface{}{ - "name": sku.Name, - "resource_type": sku.ResourceType, - "size": sku.Size, - "tier": sku.Tier, - "location_restrictions": locationRestrictions, - "zone_restrictions": zoneRestrictions, - "zones": zones, - "capabilities": capabilities, - }) - } - d.SetId(uuid.New().String()) - d.Set("skus", availableSkus) + metadata.SetID(id) + return metadata.Encode(&state) + }, } - - return nil } + +// func dataSourceComputeSkusRead(d *pluginsdk.ResourceData, meta interface{}) error { +// client := meta.(*clients.Client).Compute.SkusClient +// subscriptionId := meta.(*clients.Client).Account.SubscriptionId + +// ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d) +// defer cancel() + +// resp, err := client.ResourceSkusList(ctx, commonids.NewSubscriptionID(subscriptionId), skus.DefaultResourceSkusListOperationOptions()) +// if err != nil { +// return fmt.Errorf("retrieving SKUs: %+v", err) +// } + +// name := d.Get("name").(string) +// loc := location.Normalize(d.Get("location").(string)) +// availableSkus := make([]map[string]interface{}, 0) + +// if model := resp.Model; model != nil { +// for _, sku := range *model { +// // the API does not allow filtering by name +// if name != "" { +// if !strings.EqualFold(*sku.Name, name) { +// continue +// } +// } + +// // while the API accepts OData filters, the location filter is currently +// // not working, thus we need to filter the results manually +// locationsNormalized := make([]string, len(*sku.Locations)) +// for _, v := range *sku.Locations { +// locationsNormalized = append(locationsNormalized, location.Normalize(v)) +// } +// if !slices.Contains(locationsNormalized, loc) { +// continue +// } + +// var zones []string +// var locationRestrictions []string +// var zoneRestrictions []string +// capabilities := make(map[string]string) + +// if sku.Restrictions != nil && len(*sku.Restrictions) > 0 { +// for _, restriction := range *sku.Restrictions { +// restrictionType := *restriction.Type + +// switch restrictionType { +// case skus.ResourceSkuRestrictionsTypeLocation: +// restrictedLocationsNormalized := make([]string, 0) +// for _, v := range *restriction.RestrictionInfo.Locations { +// restrictedLocationsNormalized = append(restrictedLocationsNormalized, location.Normalize(v)) +// } +// locationRestrictions = restrictedLocationsNormalized + +// case skus.ResourceSkuRestrictionsTypeZone: +// zoneRestrictions = *restriction.RestrictionInfo.Zones +// } +// } +// } + +// if sku.LocationInfo != nil && len(*sku.LocationInfo) > 0 { +// for _, locationInfo := range *sku.LocationInfo { +// if location.Normalize(*locationInfo.Location) == loc { +// zones = *locationInfo.Zones +// } +// } +// } + +// if d.Get("include_capabilities").(bool) { +// if sku.Capabilities != nil && len(*sku.Capabilities) > 0 { +// for _, capability := range *sku.Capabilities { +// capabilities[*capability.Name] = *capability.Value +// } +// } +// } + +// availableSkus = append(availableSkus, map[string]interface{}{ +// "name": sku.Name, +// "resource_type": sku.ResourceType, +// "size": sku.Size, +// "tier": sku.Tier, +// "location_restrictions": locationRestrictions, +// "zone_restrictions": zoneRestrictions, +// "zones": zones, +// "capabilities": capabilities, +// }) +// } +// d.SetId(uuid.New().String()) +// d.Set("skus", availableSkus) +// } + +// return nil +// } diff --git a/internal/services/compute/parse/skus.go b/internal/services/compute/parse/skus.go new file mode 100644 index 000000000000..1f795ca9549e --- /dev/null +++ b/internal/services/compute/parse/skus.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parse + +// NOTE: this file is generated via 'go:generate' - manual changes will be overwritten + +import ( + "errors" + "fmt" + "strings" + + "github.com/hashicorp/go-azure-helpers/resourcemanager/resourceids" +) + +type SkusId struct { + SubscriptionId string +} + +func NewSkusID(subscriptionId string) SkusId { + return SkusId{ + SubscriptionId: subscriptionId, + } +} + +func (id SkusId) String() string { + segments := []string{} + segmentsStr := strings.Join(segments, " / ") + return fmt.Sprintf("%s: (%s)", "Skus", segmentsStr) +} + +func (id SkusId) ID() string { + fmtString := "/subscriptions/%s" + return fmt.Sprintf(fmtString, id.SubscriptionId) +} + +// SkusID parses a Skus ID into an SkusId struct +func SkusID(input string) (*SkusId, error) { + id, err := resourceids.ParseAzureResourceID(input) + if err != nil { + return nil, fmt.Errorf("parsing %q as an Skus ID: %+v", input, err) + } + + resourceId := SkusId{ + SubscriptionId: id.SubscriptionID, + } + + if resourceId.SubscriptionId == "" { + return nil, errors.New("ID was missing the 'subscriptions' element") + } + + if err := id.ValidateNoEmptySegments(input); err != nil { + return nil, err + } + + return &resourceId, nil +} diff --git a/internal/services/compute/parse/skus_test.go b/internal/services/compute/parse/skus_test.go new file mode 100644 index 000000000000..b7103876b46f --- /dev/null +++ b/internal/services/compute/parse/skus_test.go @@ -0,0 +1,82 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parse + +// NOTE: this file is generated via 'go:generate' - manual changes will be overwritten + +import ( + "testing" + + "github.com/hashicorp/go-azure-helpers/resourcemanager/resourceids" +) + +var _ resourceids.Id = SkusId{} + +func TestSkusIDFormatter(t *testing.T) { + actual := NewSkusID("707acc15-6870-4327-99cb-acf3b7fd1633").ID() + expected := "/subscriptions/707acc15-6870-4327-99cb-acf3b7fd1633" + if actual != expected { + t.Fatalf("Expected %q but got %q", expected, actual) + } +} + +func TestSkusID(t *testing.T) { + testData := []struct { + Input string + Error bool + Expected *SkusId + }{ + { + // empty + Input: "", + Error: true, + }, + + { + // missing SubscriptionId + Input: "/", + Error: true, + }, + + { + // missing value for SubscriptionId + Input: "/subscriptions/", + Error: true, + }, + + { + // valid + Input: "/subscriptions/707acc15-6870-4327-99cb-acf3b7fd1633", + Expected: &SkusId{ + SubscriptionId: "707acc15-6870-4327-99cb-acf3b7fd1633", + }, + }, + + { + // upper-cased + Input: "/SUBSCRIPTIONS/707ACC15-6870-4327-99CB-ACF3B7FD1633", + Error: true, + }, + } + + for _, v := range testData { + t.Logf("[DEBUG] Testing %q", v.Input) + + actual, err := SkusID(v.Input) + if err != nil { + if v.Error { + continue + } + + t.Fatalf("Expect a value but got an error: %s", err) + } + if v.Error { + t.Fatal("Expect an error but didn't get one") + } + + if actual.SubscriptionId != v.Expected.SubscriptionId { + t.Fatalf("Expected %q but got %q for SubscriptionId", v.Expected.SubscriptionId, actual.SubscriptionId) + } + } +} diff --git a/internal/services/compute/registration.go b/internal/services/compute/registration.go index c8783a12ea81..35982f82ca75 100644 --- a/internal/services/compute/registration.go +++ b/internal/services/compute/registration.go @@ -26,7 +26,6 @@ func (r Registration) WebsiteCategories() []string { func (r Registration) SupportedDataSources() map[string]*pluginsdk.Resource { return map[string]*pluginsdk.Resource{ "azurerm_availability_set": dataSourceAvailabilitySet(), - "azurerm_compute_skus": dataSourceComputeSkus(), "azurerm_dedicated_host": dataSourceDedicatedHost(), "azurerm_dedicated_host_group": dataSourceDedicatedHostGroup(), "azurerm_disk_encryption_set": dataSourceDiskEncryptionSet(), @@ -83,6 +82,7 @@ func (r Registration) SupportedResources() map[string]*pluginsdk.Resource { func (r Registration) DataSources() []sdk.DataSource { return []sdk.DataSource{ + ComputeSkusDataSource{}, OrchestratedVirtualMachineScaleSetDataSource{}, } } diff --git a/internal/services/compute/validate/skus_id.go b/internal/services/compute/validate/skus_id.go new file mode 100644 index 000000000000..35d8aff961b0 --- /dev/null +++ b/internal/services/compute/validate/skus_id.go @@ -0,0 +1,26 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package validate + +// NOTE: this file is generated via 'go:generate' - manual changes will be overwritten + +import ( + "fmt" + + "github.com/hashicorp/terraform-provider-azurerm/internal/services/compute/parse" +) + +func SkusID(input interface{}, key string) (warnings []string, errors []error) { + v, ok := input.(string) + if !ok { + errors = append(errors, fmt.Errorf("expected %q to be a string", key)) + return + } + + if _, err := parse.SkusID(v); err != nil { + errors = append(errors, err) + } + + return +} diff --git a/internal/services/compute/validate/skus_id_test.go b/internal/services/compute/validate/skus_id_test.go new file mode 100644 index 000000000000..0f2c7f1fe7bc --- /dev/null +++ b/internal/services/compute/validate/skus_id_test.go @@ -0,0 +1,54 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package validate + +// NOTE: this file is generated via 'go:generate' - manual changes will be overwritten + +import "testing" + +func TestSkusID(t *testing.T) { + cases := []struct { + Input string + Valid bool + }{ + { + // empty + Input: "", + Valid: false, + }, + + { + // missing SubscriptionId + Input: "/", + Valid: false, + }, + + { + // missing value for SubscriptionId + Input: "/subscriptions/", + Valid: false, + }, + + { + // valid + Input: "/subscriptions/707acc15-6870-4327-99cb-acf3b7fd1633", + Valid: true, + }, + + { + // upper-cased + Input: "/SUBSCRIPTIONS/707ACC15-6870-4327-99CB-ACF3B7FD1633", + Valid: false, + }, + } + for _, tc := range cases { + t.Logf("[DEBUG] Testing Value %s", tc.Input) + _, errors := SkusID(tc.Input, "test") + valid := len(errors) == 0 + + if tc.Valid != valid { + t.Fatalf("Expected %t but got %t", tc.Valid, valid) + } + } +}