From 410e3df05135ec118b08cf7e7cbc26745fa4386b Mon Sep 17 00:00:00 2001 From: Elie Date: Tue, 6 Jul 2021 17:08:47 +0200 Subject: [PATCH] Fix kms_key crash --- pkg/remote/aws/repository/kms_repository.go | 62 ++++++++-- .../aws/repository/kms_repository_test.go | 111 +++++++++++++++--- pkg/remote/common/details_fetcher.go | 8 ++ 3 files changed, 157 insertions(+), 24 deletions(-) diff --git a/pkg/remote/aws/repository/kms_repository.go b/pkg/remote/aws/repository/kms_repository.go index 67eaa4f3f..a59b40d95 100644 --- a/pkg/remote/aws/repository/kms_repository.go +++ b/pkg/remote/aws/repository/kms_repository.go @@ -1,12 +1,16 @@ package repository import ( + "fmt" "strings" + "sync" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/kms/kmsiface" "github.com/cloudskiff/driftctl/pkg/remote/cache" + "github.com/sirupsen/logrus" ) type KMSRepository interface { @@ -15,14 +19,16 @@ type KMSRepository interface { } type kmsRepository struct { - client kmsiface.KMSAPI - cache cache.Cache + client kmsiface.KMSAPI + cache cache.Cache + describeKeyLock *sync.Mutex } func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository { return &kmsRepository{ kms.New(session), c, + &sync.Mutex{}, } } @@ -68,20 +74,49 @@ func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) { return nil, err } - result := r.filterAliases(aliases) + result, err := r.filterAliases(aliases) + if err != nil { + return nil, err + } r.cache.Put("kmsListAllAliases", result) return result, nil } +func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) { + var results interface{} + // Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache + r.describeKeyLock.Lock() + defer r.describeKeyLock.Unlock() + cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId) + results = r.cache.Get(cacheKey) + if results == nil { + var err error + results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId}) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, results) + } + describeKey := results.(*kms.DescribeKeyOutput) + if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion { + return nil, nil + } + return describeKey, nil +} + func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) { var customerKeys []*kms.KeyListEntry for _, key := range keys { - k, err := r.client.DescribeKey(&kms.DescribeKeyInput{ - KeyId: key.KeyId, - }) + k, err := r.describeKey(key.KeyId) if err != nil { return nil, err } + if k == nil { + logrus.WithFields(logrus.Fields{ + "id": *key.KeyId, + }).Debug("Ignored kms key from listing since it is pending from deletion") + continue + } if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" { customerKeys = append(customerKeys, key) } @@ -89,12 +124,23 @@ func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntr return customerKeys, nil } -func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) []*kms.AliasListEntry { +func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) { var customerAliases []*kms.AliasListEntry for _, alias := range aliases { if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") { + k, err := r.describeKey(alias.TargetKeyId) + if err != nil { + return nil, err + } + if k == nil { + logrus.WithFields(logrus.Fields{ + "id": *alias.TargetKeyId, + "alias": *alias.AliasName, + }).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key") + continue + } customerAliases = append(customerAliases, alias) } } - return customerAliases + return customerAliases, nil } diff --git a/pkg/remote/aws/repository/kms_repository_test.go b/pkg/remote/aws/repository/kms_repository_test.go index 9850abfdf..fe9294290 100644 --- a/pkg/remote/aws/repository/kms_repository_test.go +++ b/pkg/remote/aws/repository/kms_repository_test.go @@ -2,6 +2,7 @@ package repository import ( "strings" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -21,6 +22,45 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { want []*kms.KeyListEntry wantErr error }{ + { + name: "List only enabled keys", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListKeysPages", + &kms.ListKeysInput{}, + mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool { + callback(&kms.ListKeysOutput{ + Keys: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + {KeyId: aws.String("2")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("1"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("1"), + KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("2"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("2"), + KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStatePendingDeletion), + }, + }, nil).Once() + }, + want: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + }, + }, { name: "List only customer keys", mocks: func(client *awstest.MockFakeKMS) { @@ -43,6 +83,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyMetadata: &kms.KeyMetadata{ KeyId: aws.String("1"), KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStateEnabled), }, }, nil).Once() client.On("DescribeKey", @@ -52,6 +93,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyMetadata: &kms.KeyMetadata{ KeyId: aws.String("2"), KeyManager: aws.String("AWS"), + KeyState: aws.String(kms.KeyStateEnabled), }, }, nil).Once() client.On("DescribeKey", @@ -61,6 +103,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyMetadata: &kms.KeyMetadata{ KeyId: aws.String("3"), KeyManager: aws.String("AWS"), + KeyState: aws.String(kms.KeyStateEnabled), }, }, nil).Once() }, @@ -75,8 +118,9 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { client := awstest.MockFakeKMS{} tt.mocks(&client) r := &kmsRepository{ - client: &client, - cache: store, + client: &client, + cache: store, + describeKeyLock: &sync.Mutex{}, } got, err := r.ListAllKeys() assert.Equal(t, tt.wantErr, err) @@ -108,6 +152,35 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) { want []*kms.AliasListEntry wantErr error }{ + { + name: "List only aliases for enabled keys", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListAliasesPages", + &kms.ListAliasesInput{}, + mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { + callback(&kms.ListAliasesOutput{ + Aliases: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-1")}).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStatePendingDeletion), + }, + }, nil) + client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-2")}).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil) + }, + want: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, + }, + }, { name: "List only customer aliases", mocks: func(client *awstest.MockFakeKMS) { @@ -116,24 +189,29 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) { mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { callback(&kms.ListAliasesOutput{ Aliases: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/1")}, - {AliasName: aws.String("alias/foo/2")}, - {AliasName: aws.String("alias/aw/3")}, - {AliasName: aws.String("alias/aws/4")}, - {AliasName: aws.String("alias/aws/5")}, - {AliasName: aws.String("alias/awss/6")}, - {AliasName: aws.String("alias/aws7")}, + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, + {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, + {AliasName: aws.String("alias/aws/4"), TargetKeyId: aws.String("key-id-4")}, + {AliasName: aws.String("alias/aws/5"), TargetKeyId: aws.String("key-id-5")}, + {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, + {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, }, }, true) return true })).Return(nil).Once() + client.On("DescribeKey", mock.Anything).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil) }, want: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/1")}, - {AliasName: aws.String("alias/foo/2")}, - {AliasName: aws.String("alias/aw/3")}, - {AliasName: aws.String("alias/awss/6")}, - {AliasName: aws.String("alias/aws7")}, + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, + {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, + {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, + {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, }, }, } @@ -143,8 +221,9 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) { client := awstest.MockFakeKMS{} tt.mocks(&client) r := &kmsRepository{ - client: &client, - cache: store, + client: &client, + cache: store, + describeKeyLock: &sync.Mutex{}, } got, err := r.ListAllAliases() assert.Equal(t, tt.wantErr, err) diff --git a/pkg/remote/common/details_fetcher.go b/pkg/remote/common/details_fetcher.go index 64a9b42e3..ed9e2478e 100644 --- a/pkg/remote/common/details_fetcher.go +++ b/pkg/remote/common/details_fetcher.go @@ -3,6 +3,7 @@ package common import ( "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/terraform" + "github.com/sirupsen/logrus" ) type DetailsFetcher interface { @@ -31,6 +32,13 @@ func (f *GenericDetailsFetcher) ReadDetails(res resource.Resource) (resource.Res if err != nil { return nil, err } + if ctyVal.IsNull() { + logrus.WithFields(logrus.Fields{ + "type": f.resType, + "id": res.TerraformId(), + }).Debug("Got null while reading resource details") + return nil, nil + } deserializedRes, err := f.deserializer.DeserializeOne(string(f.resType), *ctyVal) if err != nil { return nil, err