diff --git a/lib/web/integrations_awsoidc_test.go b/lib/web/integrations_awsoidc_test.go index bfa91cc4cbed1..cfd6fd3d53597 100644 --- a/lib/web/integrations_awsoidc_test.go +++ b/lib/web/integrations_awsoidc_test.go @@ -848,8 +848,6 @@ func TestBuildListDatabasesConfigureIAMScript(t *testing.T) { func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { t.Parallel() ctx := context.Background() - env := newWebPack(t, 1) - clt := env.proxies[0].client matchRegion := "us-east-1" matchAccountId := "123456789012" @@ -858,7 +856,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { AccountID: matchAccountId, } - upsertDbSvcFn := func(vpcId string, matcher []*types.DatabaseResourceMatcher) { + dbServiceFor := func(vpcId string, matcher []*types.DatabaseResourceMatcher) *types.DatabaseServiceV1 { if matcher == nil { matcher = []*types.DatabaseResourceMatcher{ { @@ -877,8 +875,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { ResourceMatchers: matcher, }) require.NoError(t, err) - _, err = env.server.Auth().UpsertDatabaseService(ctx, svc) - require.NoError(t, err) + return svc } extractKeysFn := func(resp *ui.AWSOIDCRequiredVPCSResponse) []string { @@ -900,11 +897,9 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { } // Double check we start with 0 db svcs. - s, err := env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{ - ResourceType: types.KindDatabaseService, - }) - require.NoError(t, err) - require.Empty(t, s.Resources) + clt := &mockGetResources{ + databaseServices: &proto.ListResourcesResponse{}, + } // All vpc's required. resp, err := awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss) @@ -912,12 +907,13 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { require.Len(t, resp.VPCMapOfSubnets, 5) require.ElementsMatch(t, vpcs, extractKeysFn(resp)) - // Insert two valid database services. - upsertDbSvcFn("vpc-1", nil) - upsertDbSvcFn("vpc-5", nil) + // Add some database services. + // Two valid database services. + validDBServiceVPC1 := dbServiceFor("vpc-1", nil) + validDBServiceVPC5 := dbServiceFor("vpc-5", nil) - // Insert two invalid database services. - upsertDbSvcFn("vpc-2", []*types.DatabaseResourceMatcher{ + // Two invalid database services. + invalidDBServiceVPC2 := dbServiceFor("vpc-2", []*types.DatabaseResourceMatcher{ { Labels: &types.Labels{ types.DiscoveryLabelAccountID: []string{matchAccountId}, @@ -926,7 +922,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { }, }, }) - upsertDbSvcFn("vpc-2a", []*types.DatabaseResourceMatcher{ + invalidDBServiceVPC2a := dbServiceFor("vpc-2a", []*types.DatabaseResourceMatcher{ { Labels: &types.Labels{ types.DiscoveryLabelAccountID: []string{matchAccountId}, @@ -937,12 +933,12 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { }, }) - // Double check services were created. - s, err = env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{ - ResourceType: types.KindDatabaseService, - }) - require.NoError(t, err) - require.Len(t, s.Resources, 4) + clt.databaseServices.Resources = append(clt.databaseServices.Resources, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC1}}, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC5}}, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: invalidDBServiceVPC2}}, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: invalidDBServiceVPC2a}}, + ) // Test that only 3 vpcs are required. resp, err = awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss) @@ -950,9 +946,14 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { require.ElementsMatch(t, []string{"vpc-2", "vpc-3", "vpc-4"}, extractKeysFn(resp)) // Insert the rest of db services - upsertDbSvcFn("vpc-2", nil) - upsertDbSvcFn("vpc-3", nil) - upsertDbSvcFn("vpc-4", nil) + validDBServiceVPC2 := dbServiceFor("vpc-2", nil) + validDBServiceVPC3 := dbServiceFor("vpc-3", nil) + validDBServiceVPC4 := dbServiceFor("vpc-4", nil) + clt.databaseServices.Resources = append(clt.databaseServices.Resources, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC2}}, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC3}}, + &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC4}}, + ) // Test no required vpcs. resp, err = awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss) @@ -989,9 +990,16 @@ func TestAWSOIDCRequiredVPCSHelper_CombinedSubnetsForAVpcID(t *testing.T) { } type mockGetResources struct { + databaseServices *proto.ListResourcesResponse } func (m *mockGetResources) GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { + switch req.ResourceType { + case types.KindDatabaseService: + if m.databaseServices != nil { + return m.databaseServices, nil + } + } return &proto.ListResourcesResponse{}, nil }