Skip to content

Commit

Permalink
Update awsconfig
Browse files Browse the repository at this point in the history
* Add a Cache for caching credentials, similar to SDK v1 session cache.
* Add a Provider interface that provides aws.Config
* Simplified role chaining options

Unlike our SDK v1 session cache, the SDK v2 implementation in this PR does not include region as a cache key.
There are regional AWS STS endpoints for lower latency calls, but the lowest latency path is to just grab credentials from the cache if we already have them - the region they were originally taken from doesn't matter.
  • Loading branch information
GavinFrazar committed Dec 23, 2024
1 parent 60aaa6d commit 75e8974
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 72 deletions.
165 changes: 99 additions & 66 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package awsconfig
import (
"context"
"log/slog"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -47,16 +48,20 @@ const (
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

type assumeRoleAPIClientFunc func(aws.Config) stscreds.AssumeRoleAPIClient

// assumeRole is an AWS role ARN to assume, optionally with an external ID.
type assumeRole struct {
roleARN string
externalID string
}

// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type options struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
// assumeRoleARN is the AWS IAM Role ARN to assume.
assumeRoleARN string
// assumeRoleExternalID is used to assume an external AWS IAM Role.
assumeRoleExternalID string
// assumeRoles are AWS IAM roles that should be assumed one by one in order,
// as a chain of assumed roles.
assumeRoles []assumeRole
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
Expand All @@ -67,22 +72,40 @@ type options struct {
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// newAssumeRoleAPIClientFn is an internal-only option used in tests to fake
// an STS client for assuming roles.
newAssumeRoleAPIClientFn assumeRoleAPIClientFunc
}

func (a *options) checkAndSetDefaults() error {
switch a.credentialsSource {
func buildOptions(optFns ...OptionsFn) (*options, error) {
var opts options
for _, optFn := range optFns {
optFn(&opts)
}
if err := opts.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &opts, nil
}

func (o *options) checkAndSetDefaults() error {
switch o.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
if o.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
if a.integration == "" {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}

if o.newAssumeRoleAPIClientFn == nil {
o.newAssumeRoleAPIClientFn = newAssumeRoleAPIClient
}

return nil
}

Expand All @@ -93,8 +116,14 @@ type OptionsFn func(*options)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
if roleARN == "" {
// ignore empty role ARN for caller convenience.
return
}
options.assumeRoles = append(options.assumeRoles, assumeRole{
roleARN: roleARN,
externalID: externalID,
})
}
}

Expand Down Expand Up @@ -148,94 +177,98 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O

// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
var options options
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
options.baseConfig = &cfg
func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
opts, err := buildOptions(optFns...)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
if options.assumeRoleARN == "" {
return *options.baseConfig, nil

cfg, err := getBaseConfig(ctx, region, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRole(ctx, region, options)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.newAssumeRoleAPIClientFn)
}

// ambientConfigProvider loads a new config using the environment variables.
func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
// loadDefaultConfig loads a new config.
func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) {
configOpts := buildConfigOptions(region, cred, opts)
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
return cfg, trace.Wrap(err)
}

func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error {
configOpts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
if options.customRetryer != nil {
opts = append(opts, config.WithRetryer(options.customRetryer))
if opts.customRetryer != nil {
configOpts = append(configOpts, config.WithRetryer(opts.customRetryer))
}
if options.maxRetries != nil {
opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
if opts.maxRetries != nil {
configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries))
}
return opts
return configOpts
}

// getConfigForRegion returns AWS config for the specified region.
func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.integrationCredentialsProvider == nil {
if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
}

cfg, err := ambientConfigProvider(region, cred, options)
cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
}

// getConfigForRole returns an AWS config for the specified region and role.
func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []assumeRole, newSTSCltFn assumeRoleAPIClientFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(newSTSCltFn(cfg), r)
}
if len(roles) > 0 {
// no point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, func(cacheOpts *aws.CredentialsCacheOptions) {
// expire early to avoid expiration race.
cacheOpts.ExpiryWindow = 5 * time.Minute
})
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
}
return cfg, nil
}

stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) {
func newAssumeRoleAPIClient(cfg aws.Config) stscreds.AssumeRoleAPIClient {
return newSTSClient(cfg)
}

func newSTSClient(cfg aws.Config) *sts.Client {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
if options.assumeRoleExternalID != "" {
aro.ExternalID = aws.String(options.assumeRoleExternalID)
}

func getAssumeRoleProvider(clt stscreds.AssumeRoleAPIClient, role assumeRole) aws.CredentialsProvider {
return stscreds.NewAssumeRoleProvider(clt, role.roleARN, func(aro *stscreds.AssumeRoleOptions) {
if role.externalID != "" {
aro.ExternalID = aws.String(role.externalID)
}
})
if _, err := cred.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}

opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
}
91 changes: 85 additions & 6 deletions lib/cloud/awsconfig/awsconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ package awsconfig

import (
"context"
"fmt"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
Expand All @@ -29,26 +34,68 @@ type mockCredentialProvider struct {
cred aws.Credentials
}

func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
func (m *mockCredentialProvider) Retrieve(_ context.Context) (aws.Credentials, error) {
return m.cred, nil
}

type mockAssumeRoleAPIClient struct{}

func (m *mockAssumeRoleAPIClient) AssumeRole(_ context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
fakeKeyID := fmt.Sprintf("role: %s, externalID: %s", aws.ToString(params.RoleArn), aws.ToString(params.ExternalId))
return &sts.AssumeRoleOutput{
AssumedRoleUser: &ststypes.AssumedRoleUser{
Arn: params.RoleArn,
AssumedRoleId: aws.String("role-id"),
},
Credentials: &ststypes.Credentials{
AccessKeyId: aws.String(fakeKeyID),
Expiration: aws.Time(time.Time{}),
SecretAccessKey: aws.String("fake-secret-access-key"),
SessionToken: aws.String("fake-session-token"),
},
}, nil
}

func TestGetConfigIntegration(t *testing.T) {
t.Parallel()

cache, err := NewCache()
require.NoError(t, err)
tests := []struct {
desc string
Provider
}{
{
desc: "uncached",
Provider: ProviderFunc(GetConfig),
},
{
desc: "cached",
Provider: cache,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
testGetConfigIntegration(t, test.Provider)
})
}
}

func testGetConfigIntegration(t *testing.T, provider Provider) {
dummyIntegration := "integration-test"
dummyRegion := "test-region-123"

t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) {
ctx := context.Background()
_, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
_, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err)
require.ErrorContains(t, err, "missing aws integration credential provider")
})

t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) {
ctx := context.Background()

cfg, err := GetConfig(ctx, dummyRegion,
cfg, err := provider.GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(dummyIntegration),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
if region == dummyRegion && integration == dummyIntegration {
Expand All @@ -66,10 +113,42 @@ func TestGetConfigIntegration(t *testing.T) {
require.Equal(t, "foo-bar", creds.SessionToken)
})

t.Run("with an integration credential provider assuming a role, must return assumed role credentials", func(t *testing.T) {
ctx := context.Background()

cfg, err := provider.GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(dummyIntegration),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
if region == dummyRegion && integration == dummyIntegration {
return &mockCredentialProvider{
cred: aws.Credentials{
SessionToken: "foo-bar",
},
}, nil
}
return nil, trace.NotFound("no creds in region %q with integration %q", region, integration)
}),
WithAssumeRole("roleA", "abc123"),
func(o *options) {
o.newAssumeRoleAPIClientFn = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
creds, err := cfg.Credentials.Retrieve(context.Background())
require.NoError(t, err)
require.Equal(t, "foo-bar", creds.SessionToken)
return &mockAssumeRoleAPIClient{}
}
},
)
require.NoError(t, err)
creds, err := cfg.Credentials.Retrieve(ctx)
require.NoError(t, err)
require.Equal(t, "role: roleA, externalID: abc123", creds.AccessKeyID)
require.Equal(t, "fake-session-token", creds.SessionToken)
})

t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetConfig(ctx, dummyRegion,
_, err := provider.GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(""),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
Expand All @@ -81,7 +160,7 @@ func TestGetConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetConfig(ctx, dummyRegion,
_, err := provider.GetConfig(ctx, dummyRegion,
WithAmbientCredentials(),
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
Expand All @@ -93,7 +172,7 @@ func TestGetConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but no credential source", func(t *testing.T) {
ctx := context.Background()

_, err := GetConfig(ctx, dummyRegion,
_, err := provider.GetConfig(ctx, dummyRegion,
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
Expand Down
Loading

0 comments on commit 75e8974

Please sign in to comment.