From 2beb29d49c0865bf7e9cd40d653f68f3b9bf48fe Mon Sep 17 00:00:00 2001 From: K Tamil Vanan Date: Thu, 23 Jan 2025 22:05:13 +0530 Subject: [PATCH] feat: add support for aws ecr authentication Signed-off-by: K Tamil Vanan --- .../config-sync-ecr-credential-helper.json | 39 ++++ go.mod | 2 +- pkg/extensions/config/sync/config.go | 19 +- pkg/extensions/sync/ecr_credential_helper.go | 195 ++++++++++++++++++ pkg/extensions/sync/remote.go | 7 + pkg/extensions/sync/service.go | 106 ++++++++-- pkg/extensions/sync/sync.go | 20 ++ pkg/extensions/sync/sync_internal_test.go | 26 +++ pkg/test/mocks/sync_remote_mock.go | 3 + 9 files changed, 387 insertions(+), 30 deletions(-) create mode 100644 examples/config-sync-ecr-credential-helper.json create mode 100644 pkg/extensions/sync/ecr_credential_helper.go diff --git a/examples/config-sync-ecr-credential-helper.json b/examples/config-sync-ecr-credential-helper.json new file mode 100644 index 000000000..aa606bb3c --- /dev/null +++ b/examples/config-sync-ecr-credential-helper.json @@ -0,0 +1,39 @@ +{ + "distSpecVersion": "1.1.0", + "storage": { + "rootDirectory": "/tmp/zot", + "dedupe": false, + "storageDriver": { + "name": "s3", + "region": "REGION_NAME", + "bucket": "BUGKET_NAME", + "rootdirectory": "/ROOTDIR", + "secure": true, + "skipverify": false + } + }, + "http": { + "address": "0.0.0.0", + "port": "8080" + }, + "log": { + "level": "debug" + }, + "extensions": { + "sync": { + "credentialsFile": "", + "DownloadDir": "/tmp/zot", + "registries": [ + { + "urls": [ + "https://ACCOUNTID.dkr.ecr.REGION.amazonaws.com" + ], + "onDemand": true, + "maxRetries": 5, + "retryDelay": "2m", + "credentialHelper": "ecr" + } + ] + } + } +} \ No newline at end of file diff --git a/go.mod b/go.mod index 433e7c679..a55732f51 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.29.1 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.25 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.5 + github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.13 github.com/aws/aws-secretsmanager-caching-go v1.2.0 github.com/aws/smithy-go v1.22.1 @@ -158,7 +159,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.12 // indirect github.com/aws/aws-sdk-go-v2/service/ebs v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/ec2 v1.193.0 // indirect - github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 // indirect github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.9 // indirect diff --git a/pkg/extensions/config/sync/config.go b/pkg/extensions/config/sync/config.go index ec888a084..180420ee4 100644 --- a/pkg/extensions/config/sync/config.go +++ b/pkg/extensions/config/sync/config.go @@ -23,15 +23,16 @@ type Config struct { } type RegistryConfig struct { - URLs []string - PollInterval time.Duration - Content []Content - TLSVerify *bool - OnDemand bool - CertDir string - MaxRetries *int - RetryDelay *time.Duration - OnlySigned *bool + URLs []string + PollInterval time.Duration + Content []Content + TLSVerify *bool + OnDemand bool + CertDir string + MaxRetries *int + RetryDelay *time.Duration + OnlySigned *bool + CredentialHelper string } type Content struct { diff --git a/pkg/extensions/sync/ecr_credential_helper.go b/pkg/extensions/sync/ecr_credential_helper.go new file mode 100644 index 000000000..486f58985 --- /dev/null +++ b/pkg/extensions/sync/ecr_credential_helper.go @@ -0,0 +1,195 @@ +//go:build sync +// +build sync + +package sync + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecr" + + syncconf "zotregistry.dev/zot/pkg/extensions/config/sync" + "zotregistry.dev/zot/pkg/log" +) + +// ECR tokens are valid for 12 hours. The expiryWindow variable is set to 1 hour, +// meaning if the remaining validity of the token is less than 1 hour, it will be considered expired. +const ( + expiryWindow int = 1 + ecrURLSplitPartsCount int = 6 + mockExpiryDuration int = 12 + usernameTokenParts int = 2 +) + +var ( + errInvalidURLFormat = errors.New("invalid ECR URL is received") + errInvalidTokenFormat = errors.New("invalid token format received from ECR") + errUnableToLoadAWSConfig = errors.New("unable to load AWS config for region") + errUnableToGetECRAuthToken = errors.New("unable to get ECR authorization token for account") + errUnableToDecodeECRToken = errors.New("unable to decode ECR token") + errFailedToGetECRCredentials = errors.New("failed to get ECR credentials") +) + +type ecrCredential struct { + username string + password string + expiry time.Time + account string + region string +} + +type ecrCredentialsHelper struct { + credentials map[string]ecrCredential + log log.Logger + getCredentialsFunc func(string) (ecrCredential, error) +} + +func NewECRCredentialHelper(log log.Logger, getCredentialsFunc func(string) (ecrCredential, error)) CredentialHelper { + return &ecrCredentialsHelper{ + credentials: make(map[string]ecrCredential), + log: log, + getCredentialsFunc: getCredentialsFunc, + } +} + +// extractAccountAndRegion extracts the account ID and region from the given ECR URL. +// Example URL format: account.dkr.ecr.region.amazonaws.com. +func extractAccountAndRegion(url string) (string, string, error) { + parts := strings.Split(url, ".") + if len(parts) < ecrURLSplitPartsCount { + return "", "", fmt.Errorf("%w: %s", errInvalidURLFormat, url) + } + + accountID := parts[0] // First part is the account ID + + region := parts[3] // Fourth part is the region + + return accountID, region, nil +} + +// getMockECRCredentials provides mock credentials for testing purposes. +func getMockECRCredentials(remoteAddress string) (ecrCredential, error) { + // Extract account ID and region from the URL. + accountID, region, err := extractAccountAndRegion(remoteAddress) + if err != nil { + return ecrCredential{}, fmt.Errorf("%w %s: %w", errInvalidTokenFormat, remoteAddress, err) + } + expiry := time.Now().Add(time.Duration(mockExpiryDuration) * time.Hour) + + return ecrCredential{ + username: "mockUsername", + password: "mockPassword", + expiry: expiry, + account: accountID, + region: region, + }, nil +} + +// getECRCredentials retrieves actual ECR credentials using AWS SDK. +func getECRCredentials(remoteAddress string) (ecrCredential, error) { + // Extract account ID and region from the URL. + accountID, region, err := extractAccountAndRegion(remoteAddress) + if err != nil { + return ecrCredential{}, fmt.Errorf("%w %s: %w", errInvalidTokenFormat, remoteAddress, err) + } + + // Load the AWS config for the specific region. + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return ecrCredential{}, fmt.Errorf("%w %s: %w", errUnableToLoadAWSConfig, region, err) + } + + // Create an ECR client + ecrClient := ecr.NewFromConfig(cfg) + + // Fetch the ECR authorization token. + ecrAuth, err := ecrClient.GetAuthorizationToken(context.TODO(), &ecr.GetAuthorizationTokenInput{ + RegistryIds: []string{accountID}, // Filter by the account ID. + }) + if err != nil { + return ecrCredential{}, fmt.Errorf("%w %s: %w", errUnableToGetECRAuthToken, accountID, err) + } + + // Decode the base64-encoded ECR token. + authToken := *ecrAuth.AuthorizationData[0].AuthorizationToken + + decodedToken, err := base64.StdEncoding.DecodeString(authToken) + if err != nil { + return ecrCredential{}, fmt.Errorf("%w: %w", errUnableToDecodeECRToken, err) + } + + // Split the decoded token into username and password (username is "AWS"). + tokenParts := strings.Split(string(decodedToken), ":") + if len(tokenParts) != usernameTokenParts { + return ecrCredential{}, fmt.Errorf("%w", errInvalidTokenFormat) + } + + expiry := *ecrAuth.AuthorizationData[0].ExpiresAt + username := tokenParts[0] + password := tokenParts[1] + + return ecrCredential{username: username, password: password, expiry: expiry, account: accountID, region: region}, nil +} + +// GetECRCredentials retrieves the ECR credentials (username and password) from AWS ECR. +func (credHelper *ecrCredentialsHelper) GetCredentials(urls []string) (syncconf.CredentialsFile, error) { + ecrCredentials := make(syncconf.CredentialsFile) + + for _, url := range urls { + remoteAddress := StripRegistryTransport(url) + + // Use the injected credential retrieval function. + ecrCred, err := credHelper.getCredentialsFunc(remoteAddress) + if err != nil { + return syncconf.CredentialsFile{}, fmt.Errorf("%w %s: %w", errFailedToGetECRCredentials, url, err) + } + // Store the credentials in the map using the base URL as the key. + ecrCredentials[remoteAddress] = syncconf.Credentials{ + Username: ecrCred.username, + Password: ecrCred.password, + } + credHelper.credentials[remoteAddress] = ecrCred + } + + return ecrCredentials, nil +} + +// AreCredentialsValid checks if the credentials for a given remote address are still valid. +func (credHelper *ecrCredentialsHelper) AreCredentialsValid(remoteAddress string) bool { + expiry := credHelper.credentials[remoteAddress].expiry + expiryDuration := time.Duration(expiryWindow) * time.Hour + + if time.Until(expiry) <= expiryDuration { + credHelper.log.Info(). + Str("url", remoteAddress). + Msg("the credentials are close to expiring") + + return false + } + + credHelper.log.Info(). + Str("url", remoteAddress). + Msg("the credentials are valid") + + return true +} + +// RefreshCredentials refreshes the ECR credentials for the given remote address. +func (credHelper *ecrCredentialsHelper) RefreshCredentials( + remoteAddress string, +) (syncconf.Credentials, error) { + credHelper.log.Info().Str("url", remoteAddress).Msg("refreshing the ECR credentials") + + ecrCred, err := credHelper.getCredentialsFunc(remoteAddress) + if err != nil { + return syncconf.Credentials{}, fmt.Errorf("%w %s: %w", errFailedToGetECRCredentials, remoteAddress, err) + } + + return syncconf.Credentials{Username: ecrCred.username, Password: ecrCred.password}, nil +} diff --git a/pkg/extensions/sync/remote.go b/pkg/extensions/sync/remote.go index bf85f62f3..174ae94c2 100644 --- a/pkg/extensions/sync/remote.go +++ b/pkg/extensions/sync/remote.go @@ -44,6 +44,13 @@ func NewRemoteRegistry(client *client.Client, logger log.Logger) Remote { return registry } +func (registry *RemoteRegistry) SetUpstreamAuthConfig(username, password string) { + registry.context.DockerAuthConfig = &types.DockerAuthConfig{ + Username: username, + Password: password, + } +} + func (registry *RemoteRegistry) GetContext() *types.SystemContext { return registry.context } diff --git a/pkg/extensions/sync/service.go b/pkg/extensions/sync/service.go index 4f1fca23d..65c678b4e 100644 --- a/pkg/extensions/sync/service.go +++ b/pkg/extensions/sync/service.go @@ -27,19 +27,20 @@ import ( ) type BaseService struct { - config syncconf.RegistryConfig - credentials syncconf.CredentialsFile - clusterConfig *config.ClusterConfig - remote Remote - destination Destination - retryOptions *retry.RetryOptions - contentManager ContentManager - storeController storage.StoreController - metaDB mTypes.MetaDB - repositories []string - references references.References - client *client.Client - log log.Logger + config syncconf.RegistryConfig + credentials syncconf.CredentialsFile + credentialHelper CredentialHelper + clusterConfig *config.ClusterConfig + remote Remote + destination Destination + retryOptions *retry.RetryOptions + contentManager ContentManager + storeController storage.StoreController + metaDB mTypes.MetaDB + repositories []string + references references.References + client *client.Client + log log.Logger } func New( @@ -60,16 +61,40 @@ func New( var err error var credentialsFile syncconf.CredentialsFile - if credentialsFilepath != "" { + + if service.config.CredentialHelper == "" && credentialsFilepath != "" { + // Only load credentials from file if CredentialHelper is not set + log.Info().Msgf("using file-based credentials because CredentialHelper is not set") + credentialsFile, err = getFileCredentials(credentialsFilepath) if err != nil { - log.Error().Str("errortype", common.TypeOf(err)).Str("path", credentialsFilepath). - Err(err).Msg("couldn't get registry credentials from configured path") + log.Error(). + Str("errortype", common.TypeOf(err)). + Str("path", credentialsFilepath). + Err(err). + Msg("couldn't get registry credentials from configured path") + } + service.credentialHelper = nil + service.credentials = credentialsFile + } else if service.config.CredentialHelper != "" { + log.Info().Msgf("using credentials helper, because CredentialHelper is set to %s", service.config.CredentialHelper) + + switch service.config.CredentialHelper { + case "ecr": + // Logic to fetch credentials for ECR + log.Info().Msg("fetch the credentials using AWS ECR Auth Token.") + service.credentialHelper = NewECRCredentialHelper(log, getECRCredentials) + + creds, err := service.credentialHelper.GetCredentials(service.config.URLs) + if err != nil { + log.Error().Err(err).Msg("failed to retrieve credentials using ECR credentials helper.") + } + service.credentials = creds + default: + log.Warn().Msgf("unsupported CredentialHelper: %s", service.config.CredentialHelper) } } - service.credentials = credentialsFile - // load the cluster config into the object // can be nil if the user did not configure cluster config service.clusterConfig = clusterConfig @@ -102,7 +127,6 @@ func New( service.retryOptions = retryOptions service.storeController = storeController - // try to set next client. if err := service.SetNextAvailableClient(); err != nil { // if it's a ping issue, it will be retried @@ -126,9 +150,51 @@ func New( return service, nil } +// refreshRegistryTemporaryCredentials refreshes the temporary credentials for the registry if necessary. +// It checks whether a CredentialHelper is configured and if the current credentials have expired. +// If the credentials are expired, it attempts to refresh them and updates the service configuration. +func (service *BaseService) refreshRegistryTemporaryCredentials() error { + // Exit early if no CredentialHelper is configured. + if service.config.CredentialHelper == "" { + return nil + } + + // Strip the transport protocol (e.g., https:// or http://) from the remote address. + remoteAddress := StripRegistryTransport(service.client.GetHostname()) + + // Exit early if the credentials are valid. + if service.credentialHelper.AreCredentialsValid(remoteAddress) { + return nil + } + + // Attempt to refresh the credentials using the CredentialHelper. + credentials, err := service.credentialHelper.RefreshCredentials(remoteAddress) + if err != nil { + service.log.Error(). + Err(err). + Str("url", remoteAddress). + Msg("failed to refresh the credentials") + + return err + } + + service.log.Info(). + Str("url", remoteAddress). + Msg("refreshing the upstream remote registry credentials") + + // Update the service's credentials map with the new set of credentials. + service.credentials[remoteAddress] = credentials + + // Set the upstream authentication context using the refreshed credentials. + service.remote.SetUpstreamAuthConfig(credentials.Username, credentials.Password) + + // Return nil to indicate the operation completed successfully. + return nil +} + func (service *BaseService) SetNextAvailableClient() error { if service.client != nil && service.client.Ping() { - return nil + return service.refreshRegistryTemporaryCredentials() } found := false diff --git a/pkg/extensions/sync/sync.go b/pkg/extensions/sync/sync.go index 1afd11172..7f5867c12 100644 --- a/pkg/extensions/sync/sync.go +++ b/pkg/extensions/sync/sync.go @@ -13,6 +13,7 @@ import ( "github.com/containers/image/v5/types" "github.com/opencontainers/go-digest" + syncconf "zotregistry.dev/zot/pkg/extensions/config/sync" "zotregistry.dev/zot/pkg/log" "zotregistry.dev/zot/pkg/scheduler" ) @@ -48,6 +49,22 @@ type Registry interface { GetContext() *types.SystemContext } +// The CredentialHelper interface should be implemented by registries that use temporary tokens. +// This interface defines methods to: +// - Check if the credentials for a registry are still valid. +// - Retrieve credentials for the specified registry URLs. +// - Refresh credentials for a given registry URL. +type CredentialHelper interface { + // Validates whether the credentials for the specified registry URL have expired. + AreCredentialsValid(url string) bool + + // Retrieves credentials for the provided list of registry URLs. + GetCredentials(urls []string) (syncconf.CredentialsFile, error) + + // Refreshes credentials for the specified registry URL. + RefreshCredentials(url string) (syncconf.Credentials, error) +} + /* Temporary oci layout, sync first pulls an image to this oci layout (using oci:// transport) then moves them into ImageStore. @@ -68,6 +85,9 @@ type Remote interface { // In the case of public dockerhub images 'library' namespace is added to the repo names of images // eg: alpine -> library/alpine GetDockerRemoteRepo(repo string) string + // SetUpstreamAuthConfig sets the upstream credentials used when the credential helper is set. + // This method refreshes the authentication configuration with the provided username and password. + SetUpstreamAuthConfig(username, password string) } // Local registry. diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index 3609cb4a1..81c71aa8c 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -676,3 +676,29 @@ func TestConvertDockerLayersToOCI(t *testing.T) { So(dockerLayers[3].MediaType, ShouldEqual, ispec.MediaTypeImageLayerGzip) }) } + +func TestECRCredentialsHelper(t *testing.T) { + Convey("Test ECR Credentials Helper", t, func() { + logs := log.Logger{Logger: zerolog.New(os.Stdout)} + // use getMockECRCredentials for testing purposes + credentialHelper := NewECRCredentialHelper(logs, getMockECRCredentials) + url := "https://mockAccount.dkr.ecr.mockRegion.amazonaws.com" + + Convey("Test Valid Credentials Retrieval", func() { + creds, err := credentialHelper.GetCredentials([]string{url}) + So(err, ShouldBeNil) + So(creds, ShouldNotBeNil) + So(creds[url].Username, ShouldEqual, "mockUsername") + So(creds[url].Password, ShouldEqual, "mockPassword") + }) + + Convey("Test Credentials are valid", func() { + So(credentialHelper.AreCredentialsValid(url), ShouldBeTrue) + }) + + Convey("Test Credentials Refresh", func() { + _, err := credentialHelper.RefreshCredentials(url) + So(err, ShouldBeNil) + }) + }) +} diff --git a/pkg/test/mocks/sync_remote_mock.go b/pkg/test/mocks/sync_remote_mock.go index c22d74fd8..3fb5d01c0 100644 --- a/pkg/test/mocks/sync_remote_mock.go +++ b/pkg/test/mocks/sync_remote_mock.go @@ -75,3 +75,6 @@ func (remote SyncRemote) GetManifestContent(imageReference types.ImageReference) return nil, "", "", nil } + +func (remote SyncRemote) SetUpstreamAuthConfig(username, password string) { +}