Skip to content

Commit

Permalink
feat: add support for StsAssumeRoleCredentialsProvider (#348)
Browse files Browse the repository at this point in the history
This commit add support for StsAssumeRoleCredentialsProvider allowing
to assume an IAM role and obtain temporary credentials.

Fix: #348
  • Loading branch information
fhussonnois committed Mar 6, 2024
1 parent b76ab23 commit 8bf8905
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 25 deletions.
149 changes: 126 additions & 23 deletions src/main/java/io/kestra/plugin/aws/AbstractConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,153 @@
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
import jakarta.annotation.Nullable;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.*;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.net.URI;
import java.time.Duration;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
public abstract class AbstractConnection extends Task implements AbstractConnectionInterface {
protected String accessKeyId;

protected String secretKeyId;
protected String region;
protected String endpointOverride;
private Boolean compatibilityMode;

// Configuration for StaticCredentialsProvider
protected String accessKeyId;
protected String secretKeyId;
protected String sessionToken;

protected String region;
// Configuration for AWS STS AssumeRole
protected String stsRoleArn;
protected String stsRoleExternalId;
protected String stsRoleSessionName;
protected String stsEndpointOverride;
@Builder.Default
protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION;

protected String endpointOverride;
protected AwsCredentialsProvider credentials(final RunContext runContext) throws IllegalVariableEvaluationException {

private Boolean compatibilityMode;
final AwsClientConfig awsClientConfig = awsClientConfig(runContext);

// StsAssumeRoleCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn)) {
return stsAssumeRoleCredentialsProvider(awsClientConfig);
}

// StaticCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.accessKeyId) &&
StringUtils.isNotEmpty(awsClientConfig.secretKeyId)) {
return staticCredentialsProvider(awsClientConfig);
}

// Otherwise, use DefaultCredentialsProvider
return DefaultCredentialsProvider.builder().build();
}

protected AwsCredentialsProvider credentials(RunContext runContext) throws IllegalVariableEvaluationException {
String accessKeyId = runContext.render(this.accessKeyId);
String secretKeyId = runContext.render(this.secretKeyId);
String sessionToken = runContext.render(this.sessionToken);

if (sessionToken != null) {
StaticCredentialsProvider.create(AwsSessionCredentials.create(
accessKeyId,
secretKeyId,
sessionToken
));
} else if (accessKeyId != null && secretKeyId != null) {
return StaticCredentialsProvider.create(AwsBasicCredentials.create(
accessKeyId,
secretKeyId
));
private static StaticCredentialsProvider staticCredentialsProvider(final AwsClientConfig awsClientConfig) {
final AwsCredentials credentials;
if (StringUtils.isNotEmpty(awsClientConfig.sessionToken())) {
credentials = AwsSessionCredentials.create(
awsClientConfig.accessKeyId,
awsClientConfig.secretKeyId,
awsClientConfig.sessionToken
);
} else {
credentials = AwsBasicCredentials.create(
awsClientConfig.accessKeyId,
awsClientConfig.secretKeyId
);
}
return StaticCredentialsProvider.create(credentials);
}

private static StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider(final AwsClientConfig awsClientConfig) {

String roleSessionName = awsClientConfig.stsRoleSessionName();
roleSessionName = roleSessionName != null ? roleSessionName : "kestra-plugin-s3-" + System.currentTimeMillis();

return DefaultCredentialsProvider.builder()
final AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder()
.roleArn(awsClientConfig.stsRoleArn())
.roleSessionName(roleSessionName)
.durationSeconds((int) awsClientConfig.stsRoleSessionDuration().toSeconds())
.externalId(awsClientConfig.stsRoleExternalId())
.build();

return StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient(awsClientConfig))
.refreshRequest(assumeRoleRequest)
.build();
}

private static StsClient stsClient(final AwsClientConfig awsClientConfig) {
StsClientBuilder builder = StsClient.builder();

final String stsEndpointOverride = awsClientConfig.stsEndpointOverride;
if (stsEndpointOverride != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.endpointOverride(URI.create(stsEndpointOverride)));
}

final String regionString = awsClientConfig.region;
if (regionString != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.region(Region.of(regionString)));
}
return builder.build();
}

private AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException {
return new AwsClientConfig(
runContext.render(this.accessKeyId),
runContext.render(this.secretKeyId),
runContext.render(this.sessionToken),
runContext.render(this.stsRoleArn),
runContext.render(this.stsRoleExternalId),
runContext.render(this.stsRoleSessionName),
runContext.render(this.stsEndpointOverride),
stsRoleSessionDuration,
runContext.render(this.region),
runContext.render(this.endpointOverride)
);
}

/**
* Common AWS Client configuration properties.
*/
private record AwsClientConfig(
@Nullable String accessKeyId,
@Nullable String secretKeyId,
@Nullable String sessionToken,
@Nullable String stsRoleArn,
@Nullable String stsRoleExternalId,
@Nullable String stsRoleSessionName,
@Nullable String stsEndpointOverride,
Duration stsRoleSessionDuration,
@Nullable String region,
@Nullable String endpointOverride
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import io.kestra.core.models.annotations.PluginProperty;
import io.swagger.v3.oas.annotations.media.Schema;

import java.time.Duration;

public interface AbstractConnectionInterface {

Duration AWS_MIN_STS_ROLE_SESSION_DURATION = Duration.ofSeconds(900);

@Schema(
title = "Access Key Id in order to connect to AWS.",
description = "If no connection is defined, we will use the `DefaultCredentialsProvider` to fetch the value."
Expand All @@ -25,6 +30,40 @@ public interface AbstractConnectionInterface {
@PluginProperty(dynamic = true)
String getSessionToken();

@Schema(
title = "AWS STS Role.",
description = "The Amazon Resource Name (ARN) of the role to assume. If set the task will use the `StsAssumeRoleCredentialsProvider`. Otherwise, the `StaticCredentialsProvider` will be used with the provided Access Key Id and Secret Key."
)
@PluginProperty(dynamic = true)

String getStsRoleArn();

@Schema(
title = "AWS STS External Id.",
description = " A unique identifier that might be required when you assume a role in another account. This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty(dynamic = true)

String getStsRoleExternalId();
@Schema(
title = "AWS STS Session name. This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty(dynamic = true)
String getStsRoleSessionName();

@Schema(
title = "AWS STS Session duration.",
description = "The duration of the role session (default: 15 minutes, i.e., PT15M). This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty
Duration getStsRoleSessionDuration();

@Schema(
title = "The AWS STS endpoint with which the SDKClient should communicate."
)
@PluginProperty(dynamic = true)
String getStsEndpointOverride();

@Schema(
title = "AWS region with which the SDK should communicate."
)
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/io/kestra/plugin/aws/s3/Downloads.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public List.Output run(RunContext runContext) throws Exception {
.expectedBucketOwner(this.expectedBucketOwner)
.regexp(this.regexp)
.filter(this.filter)
.stsRoleArn(this.stsRoleArn)
.stsRoleSessionName(this.stsRoleSessionName)
.stsRoleExternalId(this.stsRoleExternalId)
.stsRoleSessionDuration(this.stsRoleSessionDuration)
.stsEndpointOverride(this.stsEndpointOverride)
.build();
List.Output run = task.run(runContext);

Expand Down
12 changes: 11 additions & 1 deletion src/main/java/io/kestra/plugin/aws/s3/S3Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static Pair<GetObjectResponse, URI> download(RunContext runContext, S3Asy
.build()
);

GetObjectResponse response =download.completionFuture().get().response();
GetObjectResponse response = download.completionFuture().get().response();

runContext.metric(Counter.of("file.size", response.contentLength()));

Expand Down Expand Up @@ -68,6 +68,11 @@ static void performAction(
.secretKeyId(abstractConnection.getSecretKeyId())
.key(object.getKey())
.bucket(abstractS3Object.getBucket())
.stsRoleArn(abstractConnection.getStsRoleArn())
.stsRoleExternalId(abstractConnection.getStsRoleExternalId())
.stsRoleSessionName(abstractConnection.getStsRoleSessionName())
.stsRoleSessionDuration(abstractConnection.getStsRoleSessionDuration())
.stsEndpointOverride(abstractConnection.getStsEndpointOverride())
.build();
delete.run(runContext);
}
Expand All @@ -80,6 +85,11 @@ static void performAction(
.endpointOverride(abstractS3.getEndpointOverride())
.accessKeyId(abstractConnection.getAccessKeyId())
.secretKeyId(abstractConnection.getSecretKeyId())
.stsRoleArn(abstractConnection.getStsRoleArn())
.stsRoleExternalId(abstractConnection.getStsRoleExternalId())
.stsRoleSessionName(abstractConnection.getStsRoleSessionName())
.stsRoleSessionDuration(abstractConnection.getStsRoleSessionDuration())
.stsEndpointOverride(abstractConnection.getStsEndpointOverride())
.from(Copy.CopyObjectFrom.builder()
.bucket(abstractS3Object.getBucket())
.key(object.getKey())
Expand Down
19 changes: 18 additions & 1 deletion src/main/java/io/kestra/plugin/aws/s3/Trigger.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import io.kestra.plugin.aws.AbstractConnectionInterface;
import io.kestra.plugin.aws.s3.models.S3Object;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;

import java.time.Duration;
Expand Down Expand Up @@ -145,6 +149,14 @@ public class Trigger extends AbstractTrigger implements PollingTriggerInterface,

private Copy.CopyObject moveTo;

// Configuration for AWS STS AssumeRole
protected String stsRoleArn;
protected String stsRoleExternalId;
protected String stsRoleSessionName;
protected String stsEndpointOverride;
@Builder.Default
protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION;

@Override
public Optional<Execution> evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception {
RunContext runContext = conditionContext.getRunContext();
Expand All @@ -166,6 +178,11 @@ public Optional<Execution> evaluate(ConditionContext conditionContext, TriggerCo
.expectedBucketOwner(this.expectedBucketOwner)
.regexp(this.regexp)
.filter(this.filter)
.stsRoleArn(this.stsRoleArn)
.stsRoleSessionName(this.stsRoleSessionName)
.stsRoleExternalId(this.stsRoleExternalId)
.stsRoleSessionDuration(this.stsRoleSessionDuration)
.stsEndpointOverride(this.stsEndpointOverride)
.build();
List.Output run = task.run(runContext);

Expand Down
14 changes: 14 additions & 0 deletions src/main/java/io/kestra/plugin/aws/sqs/Trigger.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.kestra.core.models.triggers.TriggerContext;
import io.kestra.core.models.triggers.TriggerOutput;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.aws.AbstractConnectionInterface;
import io.kestra.plugin.aws.sqs.model.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
Expand Down Expand Up @@ -75,6 +76,14 @@ public class Trigger extends AbstractTrigger implements PollingTriggerInterface,
@Schema(title = "The serializer/deserializer to use.")
private SerdeType serdeType = SerdeType.STRING;

// Configuration for AWS STS AssumeRole
protected String stsRoleArn;
protected String stsRoleExternalId;
protected String stsRoleSessionName;
protected String stsEndpointOverride;
@Builder.Default
protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION;

@Override
public Optional<Execution> evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception {
RunContext runContext = conditionContext.getRunContext();
Expand All @@ -90,6 +99,11 @@ public Optional<Execution> evaluate(ConditionContext conditionContext, TriggerCo
.maxRecords(this.maxRecords)
.maxDuration(this.maxDuration)
.serdeType(this.serdeType)
.stsRoleArn(this.stsRoleArn)
.stsRoleSessionName(this.stsRoleSessionName)
.stsRoleExternalId(this.stsRoleExternalId)
.stsRoleSessionDuration(this.stsRoleSessionDuration)
.stsEndpointOverride(this.stsEndpointOverride)
.build();

Consume.Output run = task.run(runContext);
Expand Down

0 comments on commit 8bf8905

Please sign in to comment.