diff --git a/src/main/java/io/kestra/plugin/aws/AbstractConnection.java b/src/main/java/io/kestra/plugin/aws/AbstractConnection.java index 20740201..403f04dd 100644 --- a/src/main/java/io/kestra/plugin/aws/AbstractConnection.java +++ b/src/main/java/io/kestra/plugin/aws/AbstractConnection.java @@ -3,12 +3,28 @@ import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.tasks.Task; import io.kestra.core.runners.RunContext; +import jakarta.annotation.Nullable; +import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.ToString; import lombok.experimental.SuperBuilder; -import software.amazon.awssdk.auth.credentials.*; +import org.apache.commons.lang3.StringUtils; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; + +import java.net.URI; +import java.time.Duration; @SuperBuilder @ToString @@ -16,37 +32,124 @@ @Getter @NoArgsConstructor public abstract class AbstractConnection extends Task implements AbstractConnectionInterface { - protected String accessKeyId; - protected String secretKeyId; + protected String region; + protected String endpointOverride; + private Boolean compatibilityMode; + // Configuration for StaticCredentialsProvider + protected String accessKeyId; + protected String secretKeyId; protected String sessionToken; - protected String region; + // Configuration for AWS STS AssumeRole + protected String stsRoleArn; + protected String stsRoleExternalId; + protected String stsRoleSessionName; + protected String stsEndpointOverride; + @Builder.Default + protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION; - protected String endpointOverride; + protected AwsCredentialsProvider credentials(final RunContext runContext) throws IllegalVariableEvaluationException { - private Boolean compatibilityMode; + final AwsClientConfig awsClientConfig = awsClientConfig(runContext); + + // StsAssumeRoleCredentialsProvider + if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn)) { + return stsAssumeRoleCredentialsProvider(awsClientConfig); + } + + // StaticCredentialsProvider + if (StringUtils.isNotEmpty(awsClientConfig.accessKeyId) && + StringUtils.isNotEmpty(awsClientConfig.secretKeyId)) { + return staticCredentialsProvider(awsClientConfig); + } + + // Otherwise, use DefaultCredentialsProvider + return DefaultCredentialsProvider.builder().build(); + } - protected AwsCredentialsProvider credentials(RunContext runContext) throws IllegalVariableEvaluationException { - String accessKeyId = runContext.render(this.accessKeyId); - String secretKeyId = runContext.render(this.secretKeyId); - String sessionToken = runContext.render(this.sessionToken); - - if (sessionToken != null) { - StaticCredentialsProvider.create(AwsSessionCredentials.create( - accessKeyId, - secretKeyId, - sessionToken - )); - } else if (accessKeyId != null && secretKeyId != null) { - return StaticCredentialsProvider.create(AwsBasicCredentials.create( - accessKeyId, - secretKeyId - )); + private static StaticCredentialsProvider staticCredentialsProvider(final AwsClientConfig awsClientConfig) { + final AwsCredentials credentials; + if (StringUtils.isNotEmpty(awsClientConfig.sessionToken())) { + credentials = AwsSessionCredentials.create( + awsClientConfig.accessKeyId, + awsClientConfig.secretKeyId, + awsClientConfig.sessionToken + ); + } else { + credentials = AwsBasicCredentials.create( + awsClientConfig.accessKeyId, + awsClientConfig.secretKeyId + ); } + return StaticCredentialsProvider.create(credentials); + } + + private static StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider(final AwsClientConfig awsClientConfig) { + + String roleSessionName = awsClientConfig.stsRoleSessionName(); + roleSessionName = roleSessionName != null ? roleSessionName : "kestra-plugin-s3-" + System.currentTimeMillis(); - return DefaultCredentialsProvider.builder() + final AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder() + .roleArn(awsClientConfig.stsRoleArn()) + .roleSessionName(roleSessionName) + .durationSeconds((int) awsClientConfig.stsRoleSessionDuration().toSeconds()) + .externalId(awsClientConfig.stsRoleExternalId()) .build(); + + return StsAssumeRoleCredentialsProvider.builder() + .stsClient(stsClient(awsClientConfig)) + .refreshRequest(assumeRoleRequest) + .build(); + } + + private static StsClient stsClient(final AwsClientConfig awsClientConfig) { + StsClientBuilder builder = StsClient.builder(); + + final String stsEndpointOverride = awsClientConfig.stsEndpointOverride; + if (stsEndpointOverride != null) { + builder.applyMutation(stsClientBuilder -> + stsClientBuilder.endpointOverride(URI.create(stsEndpointOverride))); + } + + final String regionString = awsClientConfig.region; + if (regionString != null) { + builder.applyMutation(stsClientBuilder -> + stsClientBuilder.region(Region.of(regionString))); + } + return builder.build(); + } + + private AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException { + return new AwsClientConfig( + runContext.render(this.accessKeyId), + runContext.render(this.secretKeyId), + runContext.render(this.sessionToken), + runContext.render(this.stsRoleArn), + runContext.render(this.stsRoleExternalId), + runContext.render(this.stsRoleSessionName), + runContext.render(this.stsEndpointOverride), + stsRoleSessionDuration, + runContext.render(this.region), + runContext.render(this.endpointOverride) + ); + } + + /** + * Common AWS Client configuration properties. + */ + private record AwsClientConfig( + @Nullable String accessKeyId, + @Nullable String secretKeyId, + @Nullable String sessionToken, + @Nullable String stsRoleArn, + @Nullable String stsRoleExternalId, + @Nullable String stsRoleSessionName, + @Nullable String stsEndpointOverride, + Duration stsRoleSessionDuration, + @Nullable String region, + @Nullable String endpointOverride + ) { } } diff --git a/src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java b/src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java index 921ff2b2..179bc4a7 100644 --- a/src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java +++ b/src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java @@ -3,7 +3,12 @@ import io.kestra.core.models.annotations.PluginProperty; import io.swagger.v3.oas.annotations.media.Schema; +import java.time.Duration; + public interface AbstractConnectionInterface { + + Duration AWS_MIN_STS_ROLE_SESSION_DURATION = Duration.ofSeconds(900); + @Schema( title = "Access Key Id in order to connect to AWS.", description = "If no connection is defined, we will use the `DefaultCredentialsProvider` to fetch the value." @@ -25,6 +30,40 @@ public interface AbstractConnectionInterface { @PluginProperty(dynamic = true) String getSessionToken(); + @Schema( + title = "AWS STS Role.", + description = "The Amazon Resource Name (ARN) of the role to assume. If set the task will use the `StsAssumeRoleCredentialsProvider`. Otherwise, the `StaticCredentialsProvider` will be used with the provided Access Key Id and Secret Key." + ) + @PluginProperty(dynamic = true) + + String getStsRoleArn(); + + @Schema( + title = "AWS STS External Id.", + description = " A unique identifier that might be required when you assume a role in another account. This property is only used when an `stsRoleArn` is defined." + ) + @PluginProperty(dynamic = true) + + String getStsRoleExternalId(); + @Schema( + title = "AWS STS Session name. This property is only used when an `stsRoleArn` is defined." + ) + @PluginProperty(dynamic = true) + String getStsRoleSessionName(); + + @Schema( + title = "AWS STS Session duration.", + description = "The duration of the role session (default: 15 minutes, i.e., PT15M). This property is only used when an `stsRoleArn` is defined." + ) + @PluginProperty + Duration getStsRoleSessionDuration(); + + @Schema( + title = "The AWS STS endpoint with which the SDKClient should communicate." + ) + @PluginProperty(dynamic = true) + String getStsEndpointOverride(); + @Schema( title = "AWS region with which the SDK should communicate." ) diff --git a/src/main/java/io/kestra/plugin/aws/s3/Downloads.java b/src/main/java/io/kestra/plugin/aws/s3/Downloads.java index a7cce3be..9e8734a2 100644 --- a/src/main/java/io/kestra/plugin/aws/s3/Downloads.java +++ b/src/main/java/io/kestra/plugin/aws/s3/Downloads.java @@ -90,6 +90,11 @@ public List.Output run(RunContext runContext) throws Exception { .expectedBucketOwner(this.expectedBucketOwner) .regexp(this.regexp) .filter(this.filter) + .stsRoleArn(this.stsRoleArn) + .stsRoleSessionName(this.stsRoleSessionName) + .stsRoleExternalId(this.stsRoleExternalId) + .stsRoleSessionDuration(this.stsRoleSessionDuration) + .stsEndpointOverride(this.stsEndpointOverride) .build(); List.Output run = task.run(runContext); diff --git a/src/main/java/io/kestra/plugin/aws/s3/S3Service.java b/src/main/java/io/kestra/plugin/aws/s3/S3Service.java index 7349f343..9c25f5e6 100644 --- a/src/main/java/io/kestra/plugin/aws/s3/S3Service.java +++ b/src/main/java/io/kestra/plugin/aws/s3/S3Service.java @@ -40,7 +40,7 @@ public static Pair download(RunContext runContext, S3Asy .build() ); - GetObjectResponse response =download.completionFuture().get().response(); + GetObjectResponse response = download.completionFuture().get().response(); runContext.metric(Counter.of("file.size", response.contentLength())); @@ -68,6 +68,11 @@ static void performAction( .secretKeyId(abstractConnection.getSecretKeyId()) .key(object.getKey()) .bucket(abstractS3Object.getBucket()) + .stsRoleArn(abstractConnection.getStsRoleArn()) + .stsRoleExternalId(abstractConnection.getStsRoleExternalId()) + .stsRoleSessionName(abstractConnection.getStsRoleSessionName()) + .stsRoleSessionDuration(abstractConnection.getStsRoleSessionDuration()) + .stsEndpointOverride(abstractConnection.getStsEndpointOverride()) .build(); delete.run(runContext); } @@ -80,6 +85,11 @@ static void performAction( .endpointOverride(abstractS3.getEndpointOverride()) .accessKeyId(abstractConnection.getAccessKeyId()) .secretKeyId(abstractConnection.getSecretKeyId()) + .stsRoleArn(abstractConnection.getStsRoleArn()) + .stsRoleExternalId(abstractConnection.getStsRoleExternalId()) + .stsRoleSessionName(abstractConnection.getStsRoleSessionName()) + .stsRoleSessionDuration(abstractConnection.getStsRoleSessionDuration()) + .stsEndpointOverride(abstractConnection.getStsEndpointOverride()) .from(Copy.CopyObjectFrom.builder() .bucket(abstractS3Object.getBucket()) .key(object.getKey()) diff --git a/src/main/java/io/kestra/plugin/aws/s3/Trigger.java b/src/main/java/io/kestra/plugin/aws/s3/Trigger.java index a63af6ab..390b9fa6 100644 --- a/src/main/java/io/kestra/plugin/aws/s3/Trigger.java +++ b/src/main/java/io/kestra/plugin/aws/s3/Trigger.java @@ -14,7 +14,11 @@ import io.kestra.plugin.aws.AbstractConnectionInterface; import io.kestra.plugin.aws.s3.models.S3Object; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; import lombok.experimental.SuperBuilder; import java.time.Duration; @@ -145,6 +149,14 @@ public class Trigger extends AbstractTrigger implements PollingTriggerInterface, private Copy.CopyObject moveTo; + // Configuration for AWS STS AssumeRole + protected String stsRoleArn; + protected String stsRoleExternalId; + protected String stsRoleSessionName; + protected String stsEndpointOverride; + @Builder.Default + protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION; + @Override public Optional evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception { RunContext runContext = conditionContext.getRunContext(); @@ -166,6 +178,11 @@ public Optional evaluate(ConditionContext conditionContext, TriggerCo .expectedBucketOwner(this.expectedBucketOwner) .regexp(this.regexp) .filter(this.filter) + .stsRoleArn(this.stsRoleArn) + .stsRoleSessionName(this.stsRoleSessionName) + .stsRoleExternalId(this.stsRoleExternalId) + .stsRoleSessionDuration(this.stsRoleSessionDuration) + .stsEndpointOverride(this.stsEndpointOverride) .build(); List.Output run = task.run(runContext); diff --git a/src/main/java/io/kestra/plugin/aws/sqs/Trigger.java b/src/main/java/io/kestra/plugin/aws/sqs/Trigger.java index dde7e762..ef428c04 100644 --- a/src/main/java/io/kestra/plugin/aws/sqs/Trigger.java +++ b/src/main/java/io/kestra/plugin/aws/sqs/Trigger.java @@ -12,6 +12,7 @@ import io.kestra.core.models.triggers.TriggerContext; import io.kestra.core.models.triggers.TriggerOutput; import io.kestra.core.runners.RunContext; +import io.kestra.plugin.aws.AbstractConnectionInterface; import io.kestra.plugin.aws.sqs.model.SerdeType; import io.swagger.v3.oas.annotations.media.Schema; import lombok.*; @@ -75,6 +76,14 @@ public class Trigger extends AbstractTrigger implements PollingTriggerInterface, @Schema(title = "The serializer/deserializer to use.") private SerdeType serdeType = SerdeType.STRING; + // Configuration for AWS STS AssumeRole + protected String stsRoleArn; + protected String stsRoleExternalId; + protected String stsRoleSessionName; + protected String stsEndpointOverride; + @Builder.Default + protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION; + @Override public Optional evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception { RunContext runContext = conditionContext.getRunContext(); @@ -90,6 +99,11 @@ public Optional evaluate(ConditionContext conditionContext, TriggerCo .maxRecords(this.maxRecords) .maxDuration(this.maxDuration) .serdeType(this.serdeType) + .stsRoleArn(this.stsRoleArn) + .stsRoleSessionName(this.stsRoleSessionName) + .stsRoleExternalId(this.stsRoleExternalId) + .stsRoleSessionDuration(this.stsRoleSessionDuration) + .stsEndpointOverride(this.stsEndpointOverride) .build(); Consume.Output run = task.run(runContext);