Skip to content

Commit

Permalink
fix(batch-script-runner): post-review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-mulier-p committed Mar 22, 2024
1 parent df2f557 commit 8e92d74
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 223 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dependencies {
// AWS libs: versions are managed by the Micronaut BOM
api platform("io.micronaut.platform:micronaut-platform:$micronautVersion")
api 'software.amazon.awssdk:cloudwatchlogs'
api 'software.amazon.awssdk:batch:2.25.14'
api 'software.amazon.awssdk:batch:2.25.14' // we can remove this after micronaut bump as long as it contains the RegisterJobDefinitionRequest.ecsProperties exists
api 'software.amazon.awssdk:s3'
api 'software.amazon.awssdk:s3-transfer-manager'
api 'software.amazon.awssdk.crt:aws-crt:0.29.10' //used by s3-transfer-manager
Expand Down
162 changes: 13 additions & 149 deletions src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,9 @@
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.runners.RunContext;
import io.kestra.core.utils.Rethrow;
import io.swagger.v3.oas.annotations.media.Schema;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.awscore.AwsClient;
import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.net.URI;

import java.time.Duration;
import java.util.Optional;

public interface AbstractConnectionInterface {

Expand Down Expand Up @@ -48,20 +33,19 @@ public interface AbstractConnectionInterface {
String getSessionToken();

@Schema(
title = "AWS STS Role.",
title = "AWS STS Role.",
description = "The Amazon Resource Name (ARN) of the role to assume. If set the task will use the `StsAssumeRoleCredentialsProvider`. Otherwise, the `StaticCredentialsProvider` will be used with the provided Access Key Id and Secret Key."
)
@PluginProperty(dynamic = true)

String getStsRoleArn();

@Schema(
title = "AWS STS External Id.",
title = "AWS STS External Id.",
description = " A unique identifier that might be required when you assume a role in another account. This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty(dynamic = true)

String getStsRoleExternalId();

@Schema(
title = "AWS STS Session name. This property is only used when an `stsRoleArn` is defined."
)
Expand Down Expand Up @@ -99,138 +83,18 @@ default Boolean getCompatibilityMode() {
return false;
}

static String renderStringConfig(final RunContext runContext, final String config) throws IllegalVariableEvaluationException {
return Optional.ofNullable(config).map(Rethrow.throwFunction(runContext::render)).orElse(null);
}

default AbstractConnection.AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException {
return new AbstractConnection.AwsClientConfig(
renderStringConfig(runContext, this.getAccessKeyId()),
renderStringConfig(runContext, this.getSecretKeyId()),
renderStringConfig(runContext, this.getSessionToken()),
renderStringConfig(runContext, this.getStsRoleArn()),
renderStringConfig(runContext, this.getStsRoleExternalId()),
renderStringConfig(runContext, this.getStsRoleSessionName()),
renderStringConfig(runContext, this.getStsEndpointOverride()),
runContext.render(this.getAccessKeyId()),
runContext.render(this.getSecretKeyId()),
runContext.render(this.getSessionToken()),
runContext.render(this.getStsRoleArn()),
runContext.render(this.getStsRoleExternalId()),
runContext.render(this.getStsRoleSessionName()),
runContext.render(this.getStsEndpointOverride()),
getStsRoleSessionDuration(),
renderStringConfig(runContext, this.getRegion()),
renderStringConfig(runContext, this.getEndpointOverride())
runContext.render(this.getRegion()),
runContext.render(this.getEndpointOverride())
);
}

/**
* Factory method for constructing a new {@link AwsCredentialsProvider} for the given AWS Client config.
*
* @param awsClientConfig The AwsClientConfig.
* @return a new {@link AwsCredentialsProvider} instance.
*/
static AwsCredentialsProvider credentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {

// StsAssumeRoleCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn())) {
return stsAssumeRoleCredentialsProvider(awsClientConfig);
}

// StaticCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.accessKeyId()) &&
StringUtils.isNotEmpty(awsClientConfig.secretKeyId())) {
return staticCredentialsProvider(awsClientConfig);
}

// Otherwise, use DefaultCredentialsProvider
return DefaultCredentialsProvider.builder().build();
}

static StaticCredentialsProvider staticCredentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {
final AwsCredentials credentials;
if (StringUtils.isNotEmpty(awsClientConfig.sessionToken())) {
credentials = AwsSessionCredentials.create(
awsClientConfig.accessKeyId(),
awsClientConfig.secretKeyId(),
awsClientConfig.sessionToken()
);
} else {
credentials = AwsBasicCredentials.create(
awsClientConfig.accessKeyId(),
awsClientConfig.secretKeyId()
);
}
return StaticCredentialsProvider.create(credentials);
}

static StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {

String roleSessionName = awsClientConfig.stsRoleSessionName();
roleSessionName = roleSessionName != null ? roleSessionName : "kestra-plugin-s3-" + System.currentTimeMillis();

final AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder()
.roleArn(awsClientConfig.stsRoleArn())
.roleSessionName(roleSessionName)
.durationSeconds((int) awsClientConfig.stsRoleSessionDuration().toSeconds())
.externalId(awsClientConfig.stsRoleExternalId())
.build();

return StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient(awsClientConfig))
.refreshRequest(assumeRoleRequest)
.build();
}

static StsClient stsClient(final AbstractConnection.AwsClientConfig awsClientConfig) {
StsClientBuilder builder = StsClient.builder();

final String stsEndpointOverride = awsClientConfig.stsEndpointOverride();
if (stsEndpointOverride != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.endpointOverride(URI.create(stsEndpointOverride)));
}

final String regionString = awsClientConfig.region();
if (regionString != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.region(Region.of(regionString)));
}
return builder.build();
}

/**
* Configures and returns the given {@link AwsSyncClientBuilder}.
*/
static <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsSyncClientBuilder<B, C>> B configureSyncClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) throws IllegalVariableEvaluationException {

builder
// Use the httpClientBuilder to delegate the lifecycle management of the HTTP client to the AWS SDK
.httpClientBuilder(serviceDefaults -> ApacheHttpClient.builder().build())
.credentialsProvider(AbstractConnectionInterface.credentialsProvider(clientConfig));

return configureClient(clientConfig, builder);
}

/**
* Configures and returns the given {@link AwsAsyncClientBuilder}.
*/
static <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsAsyncClientBuilder<B, C>> B configureAsyncClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(AbstractConnectionInterface.credentialsProvider(clientConfig));
return configureClient(clientConfig, builder);
}

/**
* Configures and returns the given {@link AwsClientBuilder}.
*/
static <C extends AwsClient, B extends AwsClientBuilder<B, C>> B configureClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(AbstractConnectionInterface.credentialsProvider(clientConfig));

if (clientConfig.region() != null) {
builder.region(Region.of(clientConfig.region()));
}
if (clientConfig.endpointOverride() != null) {
builder.endpointOverride(URI.create(clientConfig.endpointOverride()));
}
return builder;
}
}
135 changes: 135 additions & 0 deletions src/main/java/io/kestra/plugin/aws/ConnectionUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package io.kestra.plugin.aws;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.awscore.AwsClient;
import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.net.URI;

public class ConnectionUtils {
/**
* Factory method for constructing a new {@link AwsCredentialsProvider} for the given AWS Client config.
*
* @param awsClientConfig The AwsClientConfig.
* @return a new {@link AwsCredentialsProvider} instance.
*/
public static AwsCredentialsProvider credentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {

// StsAssumeRoleCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn())) {
return stsAssumeRoleCredentialsProvider(awsClientConfig);
}

// StaticCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.accessKeyId()) &&
StringUtils.isNotEmpty(awsClientConfig.secretKeyId())) {
return staticCredentialsProvider(awsClientConfig);
}

// Otherwise, use DefaultCredentialsProvider
return DefaultCredentialsProvider.builder().build();
}

public static StaticCredentialsProvider staticCredentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {
final AwsCredentials credentials;
if (StringUtils.isNotEmpty(awsClientConfig.sessionToken())) {
credentials = AwsSessionCredentials.create(
awsClientConfig.accessKeyId(),
awsClientConfig.secretKeyId(),
awsClientConfig.sessionToken()
);
} else {
credentials = AwsBasicCredentials.create(
awsClientConfig.accessKeyId(),
awsClientConfig.secretKeyId()
);
}
return StaticCredentialsProvider.create(credentials);
}

public static StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider(final AbstractConnection.AwsClientConfig awsClientConfig) {

String roleSessionName = awsClientConfig.stsRoleSessionName();
roleSessionName = roleSessionName != null ? roleSessionName : "kestra-plugin-s3-" + System.currentTimeMillis();

final AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder()
.roleArn(awsClientConfig.stsRoleArn())
.roleSessionName(roleSessionName)
.durationSeconds((int) awsClientConfig.stsRoleSessionDuration().toSeconds())
.externalId(awsClientConfig.stsRoleExternalId())
.build();

return StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient(awsClientConfig))
.refreshRequest(assumeRoleRequest)
.build();
}

public static StsClient stsClient(final AbstractConnection.AwsClientConfig awsClientConfig) {
StsClientBuilder builder = StsClient.builder();

final String stsEndpointOverride = awsClientConfig.stsEndpointOverride();
if (stsEndpointOverride != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.endpointOverride(URI.create(stsEndpointOverride)));
}

final String regionString = awsClientConfig.region();
if (regionString != null) {
builder.applyMutation(stsClientBuilder ->
stsClientBuilder.region(Region.of(regionString)));
}
return builder.build();
}

/**
* Configures and returns the given {@link AwsSyncClientBuilder}.
*/
public static <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsSyncClientBuilder<B, C>> B configureSyncClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) throws IllegalVariableEvaluationException {

builder
// Use the httpClientBuilder to delegate the lifecycle management of the HTTP client to the AWS SDK
.httpClientBuilder(serviceDefaults -> ApacheHttpClient.builder().build())
.credentialsProvider(ConnectionUtils.credentialsProvider(clientConfig));

return configureClient(clientConfig, builder);
}

/**
* Configures and returns the given {@link AwsAsyncClientBuilder}.
*/
public static <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsAsyncClientBuilder<B, C>> B configureAsyncClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(ConnectionUtils.credentialsProvider(clientConfig));
return configureClient(clientConfig, builder);
}

/**
* Configures and returns the given {@link AwsClientBuilder}.
*/
public static <C extends AwsClient, B extends AwsClientBuilder<B, C>> B configureClient(
final AbstractConnection.AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(ConnectionUtils.credentialsProvider(clientConfig));

if (clientConfig.region() != null) {
builder.region(Region.of(clientConfig.region()));
}
if (clientConfig.endpointOverride() != null) {
builder.endpointOverride(URI.create(clientConfig.endpointOverride()));
}
return builder;
}
}
12 changes: 4 additions & 8 deletions src/main/java/io/kestra/plugin/aws/athena/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
import io.kestra.plugin.aws.AbstractConnection;
import io.kestra.plugin.aws.AbstractConnectionInterface;
import io.kestra.plugin.aws.ConnectionUtils;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.apache.commons.lang3.tuple.Pair;
import software.amazon.awssdk.services.athena.AthenaClient;
Expand All @@ -36,7 +33,6 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import jakarta.validation.constraints.NotNull;

import static io.kestra.core.utils.Rethrow.throwConsumer;

Expand Down Expand Up @@ -209,7 +205,7 @@ else if (fetchType == FetchType.STORE) {

private AthenaClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
AwsClientConfig clientConfig = awsClientConfig(runContext);
return AbstractConnectionInterface.configureSyncClient(clientConfig, AthenaClient.builder()).build();
return ConnectionUtils.configureSyncClient(clientConfig, AthenaClient.builder()).build();
}

public QueryExecutionStatistics waitForQueryToComplete(AthenaClient client, String queryExecutionId) throws InterruptedException {
Expand Down
Loading

0 comments on commit 8e92d74

Please sign in to comment.