Skip to content

Commit

Permalink
chore: refactor to use AWS SDK httpClientBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
fhussonnois committed Mar 6, 2024
1 parent 8bf8905 commit f27c6df
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 173 deletions.
83 changes: 69 additions & 14 deletions src/main/java/io/kestra/plugin/aws/AbstractConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
import io.kestra.core.utils.Rethrow;
import jakarta.annotation.Nullable;
import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -17,6 +18,11 @@
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.awscore.AwsClient;
import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
Expand All @@ -25,6 +31,7 @@

import java.net.URI;
import java.time.Duration;
import java.util.Optional;

@SuperBuilder
@ToString
Expand All @@ -50,9 +57,13 @@ public abstract class AbstractConnection extends Task implements AbstractConnect
@Builder.Default
protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION;

protected AwsCredentialsProvider credentials(final RunContext runContext) throws IllegalVariableEvaluationException {

final AwsClientConfig awsClientConfig = awsClientConfig(runContext);
/**
* Factory method for constructing a new {@link AwsCredentialsProvider} for the given AWS Client config.
*
* @param awsClientConfig The AwsClientConfig.
* @return a new {@link AwsCredentialsProvider} instance.
*/
protected static AwsCredentialsProvider credentialsProvider(final AwsClientConfig awsClientConfig) {

// StsAssumeRoleCredentialsProvider
if (StringUtils.isNotEmpty(awsClientConfig.stsRoleArn)) {
Expand Down Expand Up @@ -121,25 +132,69 @@ private static StsClient stsClient(final AwsClientConfig awsClientConfig) {
return builder.build();
}

private AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException {
/**
* Configures and returns the given {@link AwsSyncClientBuilder}.
*/
protected <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsSyncClientBuilder<B, C>> B configureSyncClient(
final AwsClientConfig clientConfig, final B builder) throws IllegalVariableEvaluationException {

builder
// Use the httpClientBuilder to delegate the lifecycle management of the HTTP client to the AWS SDK
.httpClientBuilder(serviceDefaults -> ApacheHttpClient.builder().build())
.credentialsProvider(credentialsProvider(clientConfig));

return configureClient(clientConfig, builder);
}
/**
* Configures and returns the given {@link AwsAsyncClientBuilder}.
*/
protected <C extends AwsClient, B extends AwsClientBuilder<B, C> & AwsAsyncClientBuilder<B, C>> B configureAsyncClient(
final AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(credentialsProvider(clientConfig));
return configureClient(clientConfig, builder);
}

/**
* Configures and returns the given {@link AwsClientBuilder}.
*/
protected <C extends AwsClient, B extends AwsClientBuilder<B, C> > B configureClient(
final AwsClientConfig clientConfig, final B builder) {

builder.credentialsProvider(credentialsProvider(clientConfig));

if (clientConfig.region() != null) {
builder.region(Region.of(clientConfig.region()));
}
if (clientConfig.endpointOverride() != null) {
builder.endpointOverride(URI.create(clientConfig.endpointOverride()));
}
return builder;
}

protected AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException {
return new AwsClientConfig(
runContext.render(this.accessKeyId),
runContext.render(this.secretKeyId),
runContext.render(this.sessionToken),
runContext.render(this.stsRoleArn),
runContext.render(this.stsRoleExternalId),
runContext.render(this.stsRoleSessionName),
runContext.render(this.stsEndpointOverride),
renderStringConfig(runContext, this.accessKeyId),
renderStringConfig(runContext, this.secretKeyId),
renderStringConfig(runContext, this.sessionToken),
renderStringConfig(runContext, this.stsRoleArn),
renderStringConfig(runContext, this.stsRoleExternalId),
renderStringConfig(runContext, this.stsRoleSessionName),
renderStringConfig(runContext, this.stsEndpointOverride),
stsRoleSessionDuration,
runContext.render(this.region),
runContext.render(this.endpointOverride)
renderStringConfig(runContext, this.region),
renderStringConfig(runContext, this.endpointOverride)
);
}

private String renderStringConfig(final RunContext runContext, final String config) throws IllegalVariableEvaluationException {
return Optional.ofNullable(config).map(Rethrow.throwFunction(runContext::render)).orElse(null);
}

/**
* Common AWS Client configuration properties.
*/
private record AwsClientConfig(
public record AwsClientConfig(
@Nullable String accessKeyId,
@Nullable String secretKeyId,
@Nullable String sessionToken,
Expand Down
18 changes: 3 additions & 15 deletions src/main/java/io/kestra/plugin/aws/athena/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.commons.lang3.tuple.Pair;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.athena.AthenaClient;
import software.amazon.awssdk.services.athena.model.*;

Expand Down Expand Up @@ -208,19 +206,9 @@ else if (fetchType == FetchType.STORE) {
}
}

private AthenaClient client(RunContext runContext) throws IllegalVariableEvaluationException {
var builder = AthenaClient.builder()
.httpClient(ApacheHttpClient.create())
.credentialsProvider(this.credentials(runContext));

if (this.region != null) {
builder.region(Region.of(runContext.render(this.region)));
}
if (this.endpointOverride != null) {
builder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

return builder.build();
private AthenaClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, AthenaClient.builder()).build();
}

public QueryExecutionStatistics waitForQueryToComplete(AthenaClient client, String queryExecutionId) throws InterruptedException {
Expand Down
18 changes: 3 additions & 15 deletions src/main/java/io/kestra/plugin/aws/dynamodb/AbstractDynamoDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.commons.lang3.tuple.Pair;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;

Expand Down Expand Up @@ -44,19 +42,9 @@ public abstract class AbstractDynamoDb extends AbstractConnection {
@NotNull
protected String tableName;

protected DynamoDbClient client(RunContext runContext) throws IllegalVariableEvaluationException {
var builder = DynamoDbClient.builder()
.httpClient(ApacheHttpClient.create())
.credentialsProvider(this.credentials(runContext));

if (this.region != null) {
builder.region(Region.of(runContext.render(this.region)));
}
if (this.endpointOverride != null) {
builder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

return builder.build();
protected DynamoDbClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, DynamoDbClient.builder()).build();
}

protected Map<String, Object> objectMapFrom(Map<String, AttributeValue> fields) {
Expand Down
21 changes: 7 additions & 14 deletions src/main/java/io/kestra/plugin/aws/ecr/GetAuthToken.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.kestra.plugin.aws.ecr;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.tasks.Output;
Expand All @@ -10,12 +11,9 @@
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.ecr.EcrClient;
import software.amazon.awssdk.services.ecr.EcrClientBuilder;
import software.amazon.awssdk.services.ecr.model.AuthorizationData;

import java.net.URI;
import java.util.Base64;
import java.util.List;

Expand Down Expand Up @@ -43,17 +41,7 @@ public class GetAuthToken extends AbstractConnection implements RunnableTask<Get

@Override
public TokenOutput run(RunContext runContext) throws Exception {
EcrClientBuilder ecrClientBuilder = EcrClient.builder().credentialsProvider(this.credentials(runContext));

if (this.region != null) {
ecrClientBuilder.region(Region.of(runContext.render(this.region)));
}

if (this.endpointOverride != null) {
ecrClientBuilder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

try (EcrClient client = ecrClientBuilder.build()) {
try (EcrClient client = client(runContext)) {
List<AuthorizationData> authorizationData = client.getAuthorizationToken().authorizationData();

String encodedToken = authorizationData.get(0).authorizationToken();
Expand All @@ -70,6 +58,11 @@ public TokenOutput run(RunContext runContext) throws Exception {
}
}

private EcrClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, EcrClient.builder()).build();
}

@Builder
@Getter
public static class TokenOutput implements Output {
Expand Down
16 changes: 3 additions & 13 deletions src/main/java/io/kestra/plugin/aws/eventbridge/PutEvents.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import lombok.experimental.SuperBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.eventbridge.EventBridgeClient;
import software.amazon.awssdk.services.eventbridge.model.PutEventsRequest;
import software.amazon.awssdk.services.eventbridge.model.PutEventsRequestEntry;
Expand Down Expand Up @@ -164,18 +163,9 @@ private PutEventsResponse putEvents(RunContext runContext, List<Entry> entryList
}
}

protected EventBridgeClient client(RunContext runContext) throws IllegalVariableEvaluationException {
var builder = EventBridgeClient.builder()
.credentialsProvider(this.credentials(runContext));

if (this.region != null) {
builder.region(Region.of(runContext.render(this.region)));
}
if (this.endpointOverride != null) {
builder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

return builder.build();
private EventBridgeClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, EventBridgeClient.builder()).build();
}

@SuppressWarnings("unchecked")
Expand Down
16 changes: 3 additions & 13 deletions src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;

import jakarta.validation.constraints.NotNull;
import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry;

import java.io.*;
import java.net.URI;
Expand Down Expand Up @@ -206,18 +205,9 @@ private File writeOutputFile(RunContext runContext, PutRecordsResponse putRecord
return tempFile;
}

protected KinesisClient client(RunContext runContext) throws IllegalVariableEvaluationException {
KinesisClientBuilder builder = KinesisClient.builder()
.credentialsProvider(this.credentials(runContext));

if (this.region != null) {
builder.region(Region.of(runContext.render(this.region)));
}
if (this.endpointOverride != null) {
builder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

return builder.build();
protected KinesisClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, KinesisClient.builder()).build();
}

@Builder
Expand Down
18 changes: 3 additions & 15 deletions src/main/java/io/kestra/plugin/aws/lambda/Invoke.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaClient;
import software.amazon.awssdk.services.lambda.LambdaClientBuilder;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.lambda.model.LambdaException;
Expand Down Expand Up @@ -132,18 +129,9 @@ public Output run(RunContext runContext) throws Exception {
}

@VisibleForTesting
LambdaClient client(RunContext runContext) throws IllegalVariableEvaluationException {
LambdaClientBuilder builder = LambdaClient.builder().httpClient(ApacheHttpClient.create())
.credentialsProvider(this.credentials(runContext));

if (this.region != null) {
builder.region(Region.of(runContext.render(this.region)));
}
if (this.endpointOverride != null) {
builder.endpointOverride(URI.create(runContext.render(this.endpointOverride)));
}

return builder.build();
LambdaClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
return configureSyncClient(clientConfig, LambdaClient.builder()).build();
}

@VisibleForTesting
Expand Down
Loading

0 comments on commit f27c6df

Please sign in to comment.