From c94dad7f16dc3ba3851d2f70deff423739325719 Mon Sep 17 00:00:00 2001 From: Martin Guibert Date: Wed, 5 Oct 2022 16:03:06 +0200 Subject: [PATCH] fix: inject account id to enumerator instead of repo --- enumeration/remote/aws/init.go | 4 +-- .../repository/mock_S3ControlRepository.go | 28 +++++-------------- .../aws/repository/s3control_repository.go | 14 +++------- .../repository/s3control_repository_test.go | 7 +++-- ..._account_public_access_block_enumerator.go | 23 ++++++++------- enumeration/remote/aws_s3_scanner_test.go | 10 +++---- enumeration/resource/resource_types.go | 3 +- 7 files changed, 35 insertions(+), 54 deletions(-) diff --git a/enumeration/remote/aws/init.go b/enumeration/remote/aws/init.go index f7f94cc0d..f6998430f 100644 --- a/enumeration/remote/aws/init.go +++ b/enumeration/remote/aws/init.go @@ -35,7 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter repositoryCache := cache.New(100) s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache) - s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache) + s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), repositoryCache) ec2repository := repository.NewEC2Repository(provider.session, repositoryCache) elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache) route53repository := repository.NewRoute53Repository(provider.session, repositoryCache) @@ -72,7 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter)) remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter)) + remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.accountId, alerter)) remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory)) remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer)) diff --git a/enumeration/remote/aws/repository/mock_S3ControlRepository.go b/enumeration/remote/aws/repository/mock_S3ControlRepository.go index 941e03a31..4b220bbbf 100644 --- a/enumeration/remote/aws/repository/mock_S3ControlRepository.go +++ b/enumeration/remote/aws/repository/mock_S3ControlRepository.go @@ -12,13 +12,13 @@ type MockS3ControlRepository struct { mock.Mock } -// DescribeAccountPublicAccessBlock provides a mock function with given fields: -func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) { - ret := _m.Called() +// DescribeAccountPublicAccessBlock provides a mock function with given fields: accountID +func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) { + ret := _m.Called(accountID) var r0 *s3control.PublicAccessBlockConfiguration - if rf, ok := ret.Get(0).(func() *s3control.PublicAccessBlockConfiguration); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(string) *s3control.PublicAccessBlockConfiguration); ok { + r0 = rf(accountID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*s3control.PublicAccessBlockConfiguration) @@ -26,8 +26,8 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(accountID) } else { r1 = ret.Error(1) } @@ -35,20 +35,6 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro return r0, r1 } -// GetAccountID provides a mock function with given fields: -func (_m *MockS3ControlRepository) GetAccountID() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - type mockConstructorTestingTNewMockS3ControlRepository interface { mock.TestingT Cleanup(func()) diff --git a/enumeration/remote/aws/repository/s3control_repository.go b/enumeration/remote/aws/repository/s3control_repository.go index 09a2feb39..74c1deff3 100644 --- a/enumeration/remote/aws/repository/s3control_repository.go +++ b/enumeration/remote/aws/repository/s3control_repository.go @@ -8,34 +8,28 @@ import ( ) type S3ControlRepository interface { - DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) - GetAccountID() string + DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) } type s3ControlRepository struct { clientFactory client.AwsClientFactoryInterface - accountId string cache cache.Cache } -func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository { +func NewS3ControlRepository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3ControlRepository { return &s3ControlRepository{ clientFactory: factory, - accountId: accountId, cache: c, } } -func (s *s3ControlRepository) GetAccountID() string { - return s.accountId -} -func (s *s3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) { +func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) { cacheKey := "S3DescribeAccountPublicAccessBlock" if v := s.cache.Get(cacheKey); v != nil { return v.(*s3control.PublicAccessBlockConfiguration), nil } out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{ - AccountId: aws.String(s.accountId), + AccountId: aws.String(accountID), }) if err != nil { diff --git a/enumeration/remote/aws/repository/s3control_repository_test.go b/enumeration/remote/aws/repository/s3control_repository_test.go index 7c5d12b8d..122c674b8 100644 --- a/enumeration/remote/aws/repository/s3control_repository_test.go +++ b/enumeration/remote/aws/repository/s3control_repository_test.go @@ -17,6 +17,7 @@ import ( ) func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) { + accountID := "123456" tests := []struct { name string @@ -65,14 +66,14 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) { tt.mocks(mockedClient) factory := client.MockAwsClientFactoryInterface{} factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once() - r := NewS3ControlRepository(&factory, "", store) - got, err := r.DescribeAccountPublicAccessBlock() + r := NewS3ControlRepository(&factory, store) + got, err := r.DescribeAccountPublicAccessBlock(accountID) factory.AssertExpectations(t) assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached - cachedData, err := r.DescribeAccountPublicAccessBlock() + cachedData, err := r.DescribeAccountPublicAccessBlock(accountID) assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock")) diff --git a/enumeration/remote/aws/s3_account_public_access_block_enumerator.go b/enumeration/remote/aws/s3_account_public_access_block_enumerator.go index 8109c33f1..b8751bc02 100644 --- a/enumeration/remote/aws/s3_account_public_access_block_enumerator.go +++ b/enumeration/remote/aws/s3_account_public_access_block_enumerator.go @@ -5,24 +5,23 @@ import ( "github.com/snyk/driftctl/enumeration/alerter" "github.com/snyk/driftctl/enumeration/remote/aws/repository" remoteerror "github.com/snyk/driftctl/enumeration/remote/error" - tf "github.com/snyk/driftctl/enumeration/remote/terraform" "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/enumeration/resource/aws" ) type S3AccountPublicAccessBlockEnumerator struct { - repository repository.S3ControlRepository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface + repository repository.S3ControlRepository + factory resource.ResourceFactory + accountID string + alerter alerter.AlerterInterface } -func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator { +func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, accountId string, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator { return &S3AccountPublicAccessBlockEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, + repository: repo, + factory: factory, + accountID: accountId, + alerter: alerter, } } @@ -31,7 +30,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) SupportedType() resource.Resource } func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) { - accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock() + accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock(e.accountID) if err != nil { return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) } @@ -42,7 +41,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource results, e.factory.CreateAbstractResource( string(e.SupportedType()), - e.repository.GetAccountID(), + e.accountID, map[string]interface{}{ "block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls), "block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy), diff --git a/enumeration/remote/aws_s3_scanner_test.go b/enumeration/remote/aws_s3_scanner_test.go index 84f459e03..cf28d3250 100644 --- a/enumeration/remote/aws_s3_scanner_test.go +++ b/enumeration/remote/aws_s3_scanner_test.go @@ -1071,6 +1071,7 @@ func TestS3BucketAnalytic(t *testing.T) { func TestS3AccountPublicAccessBlock(t *testing.T) { dummyError := errors.New("this is an error") + accountID := "123456" tests := []struct { test string mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface) @@ -1080,8 +1081,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) { { test: "existing access block", mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) { - repository.On("GetAccountID").Return("123456") - repository.On("DescribeAccountPublicAccessBlock").Return(&s3control.PublicAccessBlockConfiguration{ + repository.On("DescribeAccountPublicAccessBlock", accountID).Return(&s3control.PublicAccessBlockConfiguration{ BlockPublicAcls: awssdk.Bool(false), BlockPublicPolicy: awssdk.Bool(true), IgnorePublicAcls: awssdk.Bool(false), @@ -1090,7 +1090,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) { }, assertExpected: func(t *testing.T, got []*resource.Resource) { assert.Len(t, got, 1) - assert.Equal(t, got[0].ResourceId(), "123456") + assert.Equal(t, got[0].ResourceId(), accountID) assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock) assert.Equal(t, got[0].Attributes(), &resource.Attributes{ "block_public_acls": false, @@ -1103,7 +1103,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) { { test: "cannot list access block", mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) { - repository.On("DescribeAccountPublicAccessBlock").Return(nil, dummyError) + repository.On("DescribeAccountPublicAccessBlock", accountID).Return(nil, dummyError) }, wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock), }, @@ -1125,7 +1125,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) { remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator( repo, factory, - tf.TerraformProviderConfig{DefaultAlias: "us-east-1"}, + accountID, alerter, )) diff --git a/enumeration/resource/resource_types.go b/enumeration/resource/resource_types.go index e65e15fc1..be2a250a3 100644 --- a/enumeration/resource/resource_types.go +++ b/enumeration/resource/resource_types.go @@ -99,7 +99,8 @@ var supportedTypes = map[string]ResourceTypeMeta{ "aws_s3_bucket_metric": {}, "aws_s3_bucket_notification": {}, "aws_s3_bucket_policy": {}, - "aws_s3_bucket_public_access_block": {}, "aws_security_group": {children: []ResourceType{ + "aws_s3_bucket_public_access_block": {}, + "aws_security_group": {children: []ResourceType{ "aws_security_group_rule", }}, "aws_s3_account_public_access_block": {},