From f27c6df130c04a0a631c6653715ca4c804c45c76 Mon Sep 17 00:00:00 2001 From: Florian Hussonnois Date: Tue, 5 Mar 2024 13:31:38 +0100 Subject: [PATCH] chore: refactor to use AWS SDK httpClientBuilder --- .../kestra/plugin/aws/AbstractConnection.java | 83 +++++++++++++++---- .../io/kestra/plugin/aws/athena/Query.java | 18 +--- .../plugin/aws/dynamodb/AbstractDynamoDb.java | 18 +--- .../kestra/plugin/aws/ecr/GetAuthToken.java | 21 ++--- .../plugin/aws/eventbridge/PutEvents.java | 16 +--- .../kestra/plugin/aws/kinesis/PutRecords.java | 16 +--- .../io/kestra/plugin/aws/lambda/Invoke.java | 18 +--- .../io/kestra/plugin/aws/s3/AbstractS3.java | 51 ++++-------- .../io/kestra/plugin/aws/sns/AbstractSns.java | 24 ++---- .../io/kestra/plugin/aws/sqs/AbstractSqs.java | 22 +---- 10 files changed, 114 insertions(+), 173 deletions(-) diff --git a/src/main/java/io/kestra/plugin/aws/AbstractConnection.java b/src/main/java/io/kestra/plugin/aws/AbstractConnection.java index 403f04dd..5e462be4 100644 --- a/src/main/java/io/kestra/plugin/aws/AbstractConnection.java +++ b/src/main/java/io/kestra/plugin/aws/AbstractConnection.java @@ -3,6 +3,7 @@ import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.tasks.Task; import io.kestra.core.runners.RunContext; +import io.kestra.core.utils.Rethrow; import jakarta.annotation.Nullable; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -17,6 +18,11 @@ import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.AwsClient; +import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder; +import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; +import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder; +import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sts.StsClient; import software.amazon.awssdk.services.sts.StsClientBuilder; @@ -25,6 +31,7 @@ import java.net.URI; import java.time.Duration; +import java.util.Optional; @SuperBuilder @ToString @@ -50,9 +57,13 @@ public abstract class AbstractConnection extends Task implements AbstractConnect @Builder.Default protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION; - protected AwsCredentialsProvider credentials(final RunContext runContext) throws IllegalVariableEvaluationException { - - final AwsClientConfig awsClientConfig = awsClientConfig(runContext); + /** + * Factory method for constructing a new {@link AwsCredentialsProvider} for the given AWS Client config. + * + * @param awsClientConfig The AwsClientConfig. + * @return a new {@link AwsCredentialsProvider} instance. + */ + protected static AwsCredentialsProvider credentialsProvider(final AwsClientConfig awsClientConfig) { // StsAssumeRoleCredentialsProvider if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn)) { @@ -121,25 +132,69 @@ private static StsClient stsClient(final AwsClientConfig awsClientConfig) { return builder.build(); } - private AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException { + /** + * Configures and returns the given {@link AwsSyncClientBuilder}. + */ + protected & AwsSyncClientBuilder> B configureSyncClient( + final AwsClientConfig clientConfig, final B builder) throws IllegalVariableEvaluationException { + + builder + // Use the httpClientBuilder to delegate the lifecycle management of the HTTP client to the AWS SDK + .httpClientBuilder(serviceDefaults -> ApacheHttpClient.builder().build()) + .credentialsProvider(credentialsProvider(clientConfig)); + + return configureClient(clientConfig, builder); + } + /** + * Configures and returns the given {@link AwsAsyncClientBuilder}. + */ + protected & AwsAsyncClientBuilder> B configureAsyncClient( + final AwsClientConfig clientConfig, final B builder) { + + builder.credentialsProvider(credentialsProvider(clientConfig)); + return configureClient(clientConfig, builder); + } + + /** + * Configures and returns the given {@link AwsClientBuilder}. + */ + protected > B configureClient( + final AwsClientConfig clientConfig, final B builder) { + + builder.credentialsProvider(credentialsProvider(clientConfig)); + + if (clientConfig.region() != null) { + builder.region(Region.of(clientConfig.region())); + } + if (clientConfig.endpointOverride() != null) { + builder.endpointOverride(URI.create(clientConfig.endpointOverride())); + } + return builder; + } + + protected AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException { return new AwsClientConfig( - runContext.render(this.accessKeyId), - runContext.render(this.secretKeyId), - runContext.render(this.sessionToken), - runContext.render(this.stsRoleArn), - runContext.render(this.stsRoleExternalId), - runContext.render(this.stsRoleSessionName), - runContext.render(this.stsEndpointOverride), + renderStringConfig(runContext, this.accessKeyId), + renderStringConfig(runContext, this.secretKeyId), + renderStringConfig(runContext, this.sessionToken), + renderStringConfig(runContext, this.stsRoleArn), + renderStringConfig(runContext, this.stsRoleExternalId), + renderStringConfig(runContext, this.stsRoleSessionName), + renderStringConfig(runContext, this.stsEndpointOverride), stsRoleSessionDuration, - runContext.render(this.region), - runContext.render(this.endpointOverride) + renderStringConfig(runContext, this.region), + renderStringConfig(runContext, this.endpointOverride) ); } + private String renderStringConfig(final RunContext runContext, final String config) throws IllegalVariableEvaluationException { + return Optional.ofNullable(config).map(Rethrow.throwFunction(runContext::render)).orElse(null); + } + /** * Common AWS Client configuration properties. */ - private record AwsClientConfig( + public record AwsClientConfig( @Nullable String accessKeyId, @Nullable String secretKeyId, @Nullable String sessionToken, diff --git a/src/main/java/io/kestra/plugin/aws/athena/Query.java b/src/main/java/io/kestra/plugin/aws/athena/Query.java index 2736284d..b4a47a02 100644 --- a/src/main/java/io/kestra/plugin/aws/athena/Query.java +++ b/src/main/java/io/kestra/plugin/aws/athena/Query.java @@ -19,8 +19,6 @@ import lombok.ToString; import lombok.experimental.SuperBuilder; import org.apache.commons.lang3.tuple.Pair; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.athena.model.*; @@ -208,19 +206,9 @@ else if (fetchType == FetchType.STORE) { } } - private AthenaClient client(RunContext runContext) throws IllegalVariableEvaluationException { - var builder = AthenaClient.builder() - .httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + private AthenaClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, AthenaClient.builder()).build(); } public QueryExecutionStatistics waitForQueryToComplete(AthenaClient client, String queryExecutionId) throws InterruptedException { diff --git a/src/main/java/io/kestra/plugin/aws/dynamodb/AbstractDynamoDb.java b/src/main/java/io/kestra/plugin/aws/dynamodb/AbstractDynamoDb.java index 27e5fd04..06e5cc34 100644 --- a/src/main/java/io/kestra/plugin/aws/dynamodb/AbstractDynamoDb.java +++ b/src/main/java/io/kestra/plugin/aws/dynamodb/AbstractDynamoDb.java @@ -15,8 +15,6 @@ import lombok.ToString; import lombok.experimental.SuperBuilder; import org.apache.commons.lang3.tuple.Pair; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -44,19 +42,9 @@ public abstract class AbstractDynamoDb extends AbstractConnection { @NotNull protected String tableName; - protected DynamoDbClient client(RunContext runContext) throws IllegalVariableEvaluationException { - var builder = DynamoDbClient.builder() - .httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + protected DynamoDbClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, DynamoDbClient.builder()).build(); } protected Map objectMapFrom(Map fields) { diff --git a/src/main/java/io/kestra/plugin/aws/ecr/GetAuthToken.java b/src/main/java/io/kestra/plugin/aws/ecr/GetAuthToken.java index 6be9f0a3..700c9e81 100644 --- a/src/main/java/io/kestra/plugin/aws/ecr/GetAuthToken.java +++ b/src/main/java/io/kestra/plugin/aws/ecr/GetAuthToken.java @@ -1,5 +1,6 @@ package io.kestra.plugin.aws.ecr; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.tasks.Output; @@ -10,12 +11,9 @@ import io.swagger.v3.oas.annotations.media.Schema; import lombok.*; import lombok.experimental.SuperBuilder; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.ecr.EcrClient; -import software.amazon.awssdk.services.ecr.EcrClientBuilder; import software.amazon.awssdk.services.ecr.model.AuthorizationData; -import java.net.URI; import java.util.Base64; import java.util.List; @@ -43,17 +41,7 @@ public class GetAuthToken extends AbstractConnection implements RunnableTask authorizationData = client.getAuthorizationToken().authorizationData(); String encodedToken = authorizationData.get(0).authorizationToken(); @@ -70,6 +58,11 @@ public TokenOutput run(RunContext runContext) throws Exception { } } + private EcrClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, EcrClient.builder()).build(); + } + @Builder @Getter public static class TokenOutput implements Output { diff --git a/src/main/java/io/kestra/plugin/aws/eventbridge/PutEvents.java b/src/main/java/io/kestra/plugin/aws/eventbridge/PutEvents.java index d0652a23..09293520 100644 --- a/src/main/java/io/kestra/plugin/aws/eventbridge/PutEvents.java +++ b/src/main/java/io/kestra/plugin/aws/eventbridge/PutEvents.java @@ -21,7 +21,6 @@ import lombok.experimental.SuperBuilder; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.eventbridge.EventBridgeClient; import software.amazon.awssdk.services.eventbridge.model.PutEventsRequest; import software.amazon.awssdk.services.eventbridge.model.PutEventsRequestEntry; @@ -164,18 +163,9 @@ private PutEventsResponse putEvents(RunContext runContext, List entryList } } - protected EventBridgeClient client(RunContext runContext) throws IllegalVariableEvaluationException { - var builder = EventBridgeClient.builder() - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + private EventBridgeClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, EventBridgeClient.builder()).build(); } @SuppressWarnings("unchecked") diff --git a/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java b/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java index 2fd6ab69..dfd835df 100644 --- a/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java +++ b/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java @@ -30,7 +30,6 @@ import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; import jakarta.validation.constraints.NotNull; -import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry; import java.io.*; import java.net.URI; @@ -206,18 +205,9 @@ private File writeOutputFile(RunContext runContext, PutRecordsResponse putRecord return tempFile; } - protected KinesisClient client(RunContext runContext) throws IllegalVariableEvaluationException { - KinesisClientBuilder builder = KinesisClient.builder() - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + protected KinesisClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, KinesisClient.builder()).build(); } @Builder diff --git a/src/main/java/io/kestra/plugin/aws/lambda/Invoke.java b/src/main/java/io/kestra/plugin/aws/lambda/Invoke.java index 96c98ffc..1f4ceacb 100644 --- a/src/main/java/io/kestra/plugin/aws/lambda/Invoke.java +++ b/src/main/java/io/kestra/plugin/aws/lambda/Invoke.java @@ -35,10 +35,7 @@ import lombok.experimental.SuperBuilder; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaClient; -import software.amazon.awssdk.services.lambda.LambdaClientBuilder; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; import software.amazon.awssdk.services.lambda.model.LambdaException; @@ -132,18 +129,9 @@ public Output run(RunContext runContext) throws Exception { } @VisibleForTesting - LambdaClient client(RunContext runContext) throws IllegalVariableEvaluationException { - LambdaClientBuilder builder = LambdaClient.builder().httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + LambdaClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, LambdaClient.builder()).build(); } @VisibleForTesting diff --git a/src/main/java/io/kestra/plugin/aws/s3/AbstractS3.java b/src/main/java/io/kestra/plugin/aws/s3/AbstractS3.java index c7027b77..0813c854 100644 --- a/src/main/java/io/kestra/plugin/aws/s3/AbstractS3.java +++ b/src/main/java/io/kestra/plugin/aws/s3/AbstractS3.java @@ -8,9 +8,10 @@ import lombok.NoArgsConstructor; import lombok.ToString; import lombok.experimental.SuperBuilder; -import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.*; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3CrtAsyncClientBuilder; import java.net.URI; @@ -21,49 +22,25 @@ @NoArgsConstructor public abstract class AbstractS3 extends AbstractConnection { - protected S3Client client(RunContext runContext) throws IllegalVariableEvaluationException { - S3ClientBuilder s3ClientBuilder = S3Client.builder() - .httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - s3ClientBuilder.region(Region.of(runContext.render(this.region))); - } - - if (this.endpointOverride != null) { - s3ClientBuilder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return s3ClientBuilder.build(); + protected S3Client client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, S3Client.builder()).build(); } - protected S3AsyncClient asyncClient(RunContext runContext) throws IllegalVariableEvaluationException { - + protected S3AsyncClient asyncClient(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); if (this.getCompatibilityMode()) { - S3AsyncClientBuilder s3ClientBuilder = S3AsyncClient.builder() - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - s3ClientBuilder.region(Region.of(runContext.render(this.region))); - } - - if (this.endpointOverride != null) { - s3ClientBuilder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - return s3ClientBuilder.build(); - + return configureAsyncClient(clientConfig, S3AsyncClient.builder()).build(); } else { S3CrtAsyncClientBuilder s3ClientBuilder = S3AsyncClient.crtBuilder() - .credentialsProvider(this.credentials(runContext)); + .credentialsProvider(credentialsProvider(clientConfig)); - if (this.region != null) { - s3ClientBuilder.region(Region.of(runContext.render(this.region))); + if (clientConfig.region() != null) { + s3ClientBuilder.region(Region.of(clientConfig.region())); } - - if (this.endpointOverride != null) { - s3ClientBuilder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); + if (clientConfig.endpointOverride() != null) { + s3ClientBuilder.endpointOverride(URI.create(clientConfig.endpointOverride())); } - return s3ClientBuilder.build(); } diff --git a/src/main/java/io/kestra/plugin/aws/sns/AbstractSns.java b/src/main/java/io/kestra/plugin/aws/sns/AbstractSns.java index 653c0a08..2ae1325a 100644 --- a/src/main/java/io/kestra/plugin/aws/sns/AbstractSns.java +++ b/src/main/java/io/kestra/plugin/aws/sns/AbstractSns.java @@ -5,41 +5,27 @@ import io.kestra.core.runners.RunContext; import io.kestra.plugin.aws.AbstractConnection; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.ToString; import lombok.experimental.SuperBuilder; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sns.SnsClient; -import java.net.URI; -import jakarta.validation.constraints.NotNull; - @SuperBuilder @ToString @EqualsAndHashCode @Getter @NoArgsConstructor -abstract class AbstractSns extends AbstractConnection { +abstract class AbstractSns extends AbstractConnection { @Schema(title = "The SNS topic ARN. The topic must already exist.") @PluginProperty(dynamic = true) @NotNull private String topicArn; - protected SnsClient client(RunContext runContext) throws IllegalVariableEvaluationException { - var builder = SnsClient.builder() - .httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + protected SnsClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, SnsClient.builder()).build(); } } diff --git a/src/main/java/io/kestra/plugin/aws/sqs/AbstractSqs.java b/src/main/java/io/kestra/plugin/aws/sqs/AbstractSqs.java index 24c744d7..19cf3092 100644 --- a/src/main/java/io/kestra/plugin/aws/sqs/AbstractSqs.java +++ b/src/main/java/io/kestra/plugin/aws/sqs/AbstractSqs.java @@ -8,32 +8,18 @@ import lombok.NoArgsConstructor; import lombok.ToString; import lombok.experimental.SuperBuilder; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sqs.SqsClient; -import java.net.URI; - @SuperBuilder @ToString @EqualsAndHashCode @Getter @NoArgsConstructor -abstract class AbstractSqs extends AbstractConnection implements SqsConnectionInterface { +abstract class AbstractSqs extends AbstractConnection implements SqsConnectionInterface { private String queueUrl; - protected SqsClient client(RunContext runContext) throws IllegalVariableEvaluationException { - var builder = SqsClient.builder() - .httpClient(ApacheHttpClient.create()) - .credentialsProvider(this.credentials(runContext)); - - if (this.region != null) { - builder.region(Region.of(runContext.render(this.region))); - } - if (this.endpointOverride != null) { - builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); - } - - return builder.build(); + protected SqsClient client(final RunContext runContext) throws IllegalVariableEvaluationException { + final AwsClientConfig clientConfig = awsClientConfig(runContext); + return configureSyncClient(clientConfig, SqsClient.builder()).build(); } }