Skip to content

Commit

Permalink
Merge pull request #809 from cloudskiff/fix_kms_key_crash
Browse files Browse the repository at this point in the history
Fix kms_key crash
  • Loading branch information
sundowndev authored Jul 7, 2021
2 parents 8857bd1 + 410e3df commit 6e04cf2
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 24 deletions.
62 changes: 54 additions & 8 deletions pkg/remote/aws/repository/kms_repository.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package repository

import (
"fmt"
"strings"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
"github.com/sirupsen/logrus"
)

type KMSRepository interface {
Expand All @@ -15,14 +19,16 @@ type KMSRepository interface {
}

type kmsRepository struct {
client kmsiface.KMSAPI
cache cache.Cache
client kmsiface.KMSAPI
cache cache.Cache
describeKeyLock *sync.Mutex
}

func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository {
return &kmsRepository{
kms.New(session),
c,
&sync.Mutex{},
}
}

Expand Down Expand Up @@ -68,33 +74,73 @@ func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) {
return nil, err
}

result := r.filterAliases(aliases)
result, err := r.filterAliases(aliases)
if err != nil {
return nil, err
}
r.cache.Put("kmsListAllAliases", result)
return result, nil
}

func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) {
var results interface{}
// Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache
r.describeKeyLock.Lock()
defer r.describeKeyLock.Unlock()
cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId)
results = r.cache.Get(cacheKey)
if results == nil {
var err error
results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, results)
}
describeKey := results.(*kms.DescribeKeyOutput)
if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion {
return nil, nil
}
return describeKey, nil
}

func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) {
var customerKeys []*kms.KeyListEntry
for _, key := range keys {
k, err := r.client.DescribeKey(&kms.DescribeKeyInput{
KeyId: key.KeyId,
})
k, err := r.describeKey(key.KeyId)
if err != nil {
return nil, err
}
if k == nil {
logrus.WithFields(logrus.Fields{
"id": *key.KeyId,
}).Debug("Ignored kms key from listing since it is pending from deletion")
continue
}
if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" {
customerKeys = append(customerKeys, key)
}
}
return customerKeys, nil
}

func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) []*kms.AliasListEntry {
func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) {
var customerAliases []*kms.AliasListEntry
for _, alias := range aliases {
if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") {
k, err := r.describeKey(alias.TargetKeyId)
if err != nil {
return nil, err
}
if k == nil {
logrus.WithFields(logrus.Fields{
"id": *alias.TargetKeyId,
"alias": *alias.AliasName,
}).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key")
continue
}
customerAliases = append(customerAliases, alias)
}
}
return customerAliases
return customerAliases, nil
}
111 changes: 95 additions & 16 deletions pkg/remote/aws/repository/kms_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package repository

import (
"strings"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -21,6 +22,45 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
want []*kms.KeyListEntry
wantErr error
}{
{
name: "List only enabled keys",
mocks: func(client *awstest.MockFakeKMS) {
client.On("ListKeysPages",
&kms.ListKeysInput{},
mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool {
callback(&kms.ListKeysOutput{
Keys: []*kms.KeyListEntry{
{KeyId: aws.String("1")},
{KeyId: aws.String("2")},
},
}, true)
return true
})).Return(nil).Once()
client.On("DescribeKey",
&kms.DescribeKeyInput{
KeyId: aws.String("1"),
}).Return(&kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
KeyId: aws.String("1"),
KeyManager: aws.String("CUSTOMER"),
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil).Once()
client.On("DescribeKey",
&kms.DescribeKeyInput{
KeyId: aws.String("2"),
}).Return(&kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
KeyId: aws.String("2"),
KeyManager: aws.String("CUSTOMER"),
KeyState: aws.String(kms.KeyStatePendingDeletion),
},
}, nil).Once()
},
want: []*kms.KeyListEntry{
{KeyId: aws.String("1")},
},
},
{
name: "List only customer keys",
mocks: func(client *awstest.MockFakeKMS) {
Expand All @@ -43,6 +83,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
KeyMetadata: &kms.KeyMetadata{
KeyId: aws.String("1"),
KeyManager: aws.String("CUSTOMER"),
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil).Once()
client.On("DescribeKey",
Expand All @@ -52,6 +93,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
KeyMetadata: &kms.KeyMetadata{
KeyId: aws.String("2"),
KeyManager: aws.String("AWS"),
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil).Once()
client.On("DescribeKey",
Expand All @@ -61,6 +103,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
KeyMetadata: &kms.KeyMetadata{
KeyId: aws.String("3"),
KeyManager: aws.String("AWS"),
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil).Once()
},
Expand All @@ -75,8 +118,9 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
client := awstest.MockFakeKMS{}
tt.mocks(&client)
r := &kmsRepository{
client: &client,
cache: store,
client: &client,
cache: store,
describeKeyLock: &sync.Mutex{},
}
got, err := r.ListAllKeys()
assert.Equal(t, tt.wantErr, err)
Expand Down Expand Up @@ -108,6 +152,35 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
want []*kms.AliasListEntry
wantErr error
}{
{
name: "List only aliases for enabled keys",
mocks: func(client *awstest.MockFakeKMS) {
client.On("ListAliasesPages",
&kms.ListAliasesInput{},
mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool {
callback(&kms.ListAliasesOutput{
Aliases: []*kms.AliasListEntry{
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
{AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")},
},
}, true)
return true
})).Return(nil).Once()
client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-1")}).Return(&kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
KeyState: aws.String(kms.KeyStatePendingDeletion),
},
}, nil)
client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-2")}).Return(&kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil)
},
want: []*kms.AliasListEntry{
{AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")},
},
},
{
name: "List only customer aliases",
mocks: func(client *awstest.MockFakeKMS) {
Expand All @@ -116,24 +189,29 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool {
callback(&kms.ListAliasesOutput{
Aliases: []*kms.AliasListEntry{
{AliasName: aws.String("alias/1")},
{AliasName: aws.String("alias/foo/2")},
{AliasName: aws.String("alias/aw/3")},
{AliasName: aws.String("alias/aws/4")},
{AliasName: aws.String("alias/aws/5")},
{AliasName: aws.String("alias/awss/6")},
{AliasName: aws.String("alias/aws7")},
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
{AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")},
{AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")},
{AliasName: aws.String("alias/aws/4"), TargetKeyId: aws.String("key-id-4")},
{AliasName: aws.String("alias/aws/5"), TargetKeyId: aws.String("key-id-5")},
{AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")},
{AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")},
},
}, true)
return true
})).Return(nil).Once()
client.On("DescribeKey", mock.Anything).Return(&kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
KeyState: aws.String(kms.KeyStateEnabled),
},
}, nil)
},
want: []*kms.AliasListEntry{
{AliasName: aws.String("alias/1")},
{AliasName: aws.String("alias/foo/2")},
{AliasName: aws.String("alias/aw/3")},
{AliasName: aws.String("alias/awss/6")},
{AliasName: aws.String("alias/aws7")},
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
{AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")},
{AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")},
{AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")},
{AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")},
},
},
}
Expand All @@ -143,8 +221,9 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
client := awstest.MockFakeKMS{}
tt.mocks(&client)
r := &kmsRepository{
client: &client,
cache: store,
client: &client,
cache: store,
describeKeyLock: &sync.Mutex{},
}
got, err := r.ListAllAliases()
assert.Equal(t, tt.wantErr, err)
Expand Down
8 changes: 8 additions & 0 deletions pkg/remote/common/details_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package common
import (
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/sirupsen/logrus"
)

type DetailsFetcher interface {
Expand Down Expand Up @@ -31,6 +32,13 @@ func (f *GenericDetailsFetcher) ReadDetails(res resource.Resource) (resource.Res
if err != nil {
return nil, err
}
if ctyVal.IsNull() {
logrus.WithFields(logrus.Fields{
"type": f.resType,
"id": res.TerraformId(),
}).Debug("Got null while reading resource details")
return nil, nil
}
deserializedRes, err := f.deserializer.DeserializeOne(string(f.resType), *ctyVal)
if err != nil {
return nil, err
Expand Down

0 comments on commit 6e04cf2

Please sign in to comment.