Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move Query and abstract connection to dynamic properties #566

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/main/java/io/kestra/plugin/aws/AbstractConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
public abstract class AbstractConnection extends Task implements AbstractConnectionInterface {

protected Property<String> region;
protected String endpointOverride;
protected Boolean compatibilityMode;
protected Property<String> endpointOverride;
protected Property<Boolean> compatibilityMode;

// Configuration for StaticCredentialsProvider
protected String accessKeyId;
protected String secretKeyId;
protected String sessionToken;
protected Property<String> accessKeyId;
protected Property<String> secretKeyId;
protected Property<String> sessionToken;

// Configuration for AWS STS AssumeRole
protected String stsRoleArn;
protected String stsRoleExternalId;
protected String stsRoleSessionName;
protected String stsEndpointOverride;
protected Property<String> stsRoleArn;
protected Property<String> stsRoleExternalId;
protected Property<String> stsRoleSessionName;
protected Property<String> stsEndpointOverride;
@Builder.Default
protected Duration stsRoleSessionDuration = AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION;
protected Property<Duration> stsRoleSessionDuration = Property.of(AbstractConnectionInterface.AWS_MIN_STS_ROLE_SESSION_DURATION);

/**
* Common AWS Client configuration properties.
Expand Down
52 changes: 21 additions & 31 deletions src/main/java/io/kestra/plugin/aws/AbstractConnectionInterface.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,48 @@ public interface AbstractConnectionInterface {
title = "Access Key Id in order to connect to AWS.",
description = "If no credentials are defined, we will use the [default credentials provider chain](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials-chain.html) to fetch credentials."
)
@PluginProperty(dynamic = true)
String getAccessKeyId();
Property<String> getAccessKeyId();

@Schema(
title = "Secret Key Id in order to connect to AWS.",
description = "If no credentials are defined, we will use the [default credentials provider chain](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials-chain.html) to fetch credentials."
)
@PluginProperty(dynamic = true)
String getSecretKeyId();
Property<String> getSecretKeyId();

@Schema(
title = "AWS session token, retrieved from an AWS token service, used for authenticating that this user has received temporary permissions to access a given resource.",
description = "If no credentials are defined, we will use the [default credentials provider chain](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials-chain.html) to fetch credentials."
)
@PluginProperty(dynamic = true)
String getSessionToken();
Property<String> getSessionToken();

@Schema(
title = "AWS STS Role.",
description = "The Amazon Resource Name (ARN) of the role to assume. If set the task will use the `StsAssumeRoleCredentialsProvider`. If no credentials are defined, we will use the [default credentials provider chain](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials-chain.html) to fetch credentials."
)
@PluginProperty(dynamic = true)
String getStsRoleArn();
Property<String> getStsRoleArn();

@Schema(
title = "AWS STS External Id.",
description = " A unique identifier that might be required when you assume a role in another account. This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty(dynamic = true)
String getStsRoleExternalId();
Property<String> getStsRoleExternalId();

@Schema(
title = "AWS STS Session name.",
description = "This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty(dynamic = true)
String getStsRoleSessionName();
Property<String> getStsRoleSessionName();

@Schema(
title = "AWS STS Session duration.",
description = "The duration of the role session (default: 15 minutes, i.e., PT15M). This property is only used when an `stsRoleArn` is defined."
)
@PluginProperty
Duration getStsRoleSessionDuration();
Property<Duration> getStsRoleSessionDuration();

@Schema(
title = "The AWS STS endpoint with which the SDKClient should communicate."
)
@PluginProperty(dynamic = true)
String getStsEndpointOverride();
Property<String> getStsEndpointOverride();

@Schema(
title = "AWS region with which the SDK should communicate."
Expand All @@ -76,26 +68,24 @@ public interface AbstractConnectionInterface {
title = "The endpoint with which the SDK should communicate.",
description = "This property allows you to use a different S3 compatible storage backend."
)
@PluginProperty(dynamic = true)
String getEndpointOverride();
Property<String> getEndpointOverride();

@PluginProperty(dynamic = true)
default Boolean getCompatibilityMode() {
return false;
default Property<Boolean> getCompatibilityMode() {
return Property.of(false);
}

default AbstractConnection.AwsClientConfig awsClientConfig(final RunContext runContext) throws IllegalVariableEvaluationException {
return new AbstractConnection.AwsClientConfig(
runContext.render(this.getAccessKeyId()),
runContext.render(this.getSecretKeyId()),
runContext.render(this.getSessionToken()),
runContext.render(this.getStsRoleArn()),
runContext.render(this.getStsRoleExternalId()),
runContext.render(this.getStsRoleSessionName()),
runContext.render(this.getStsEndpointOverride()),
getStsRoleSessionDuration(),
this.getRegion() == null ? null : this.getRegion().as(runContext, String.class),
runContext.render(this.getEndpointOverride())
runContext.render(this.getAccessKeyId()).as(String.class).orElse(null),
runContext.render(this.getSecretKeyId()).as(String.class).orElse(null),
runContext.render(this.getSessionToken()).as(String.class).orElse(null),
runContext.render(this.getStsRoleArn()).as(String.class).orElse(null),
runContext.render(this.getStsRoleExternalId()).as(String.class).orElse(null),
runContext.render(this.getStsRoleSessionName()).as(String.class).orElse(null),
runContext.render(this.getStsEndpointOverride()).as(String.class).orElse(null),
runContext.render(this.getStsRoleSessionDuration()).as(Duration.class).orElse(null),
runContext.render(this.getRegion()).as(String.class).orElse(null),
runContext.render(this.getEndpointOverride()).as(String.class).orElse(null)
);
}
}
32 changes: 14 additions & 18 deletions src/main/java/io/kestra/plugin/aws/athena/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.Output;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.models.tasks.common.FetchType;
Expand Down Expand Up @@ -49,7 +50,7 @@
title = "Query an Athena table.",
description = """
The query will wait for completion, except if fetchMode is set to `NONE`, and will output converted rows.
Row conversion is based on the types listed [here](https://docs.aws.amazon.com/athena/latest/ug/data-types.html).
Row conversion is based on the types listed [here](https://docs.aws.amazon.com/athena/latest/ug/data-types.html).
Complex data types like array, map and struct will be converted to a string."""
)
@Plugin(
Expand Down Expand Up @@ -78,26 +79,22 @@
)
public class Query extends AbstractConnection implements RunnableTask<Query.QueryOutput> {
@Schema(title = "Athena catalog.")
@PluginProperty(dynamic = true)
private String catalog;
private Property<String> catalog;

@Schema(title = "Athena database.")
@NotNull
@PluginProperty(dynamic = true)
private String database;
private Property<String> database;

@Schema(
title = "Athena output location.",
description = "The query results will be stored in this output location. Must be an existing S3 bucket."
)
@NotNull
@PluginProperty(dynamic = true)
private String outputLocation;
private Property<String> outputLocation;

@Schema(title = "Athena SQL query.")
@NotNull
@PluginProperty(dynamic = true)
private String query;
private Property<String> query;

@Schema(
title = "The way you want to store the data.",
Expand All @@ -107,15 +104,13 @@ public class Query extends AbstractConnection implements RunnableTask<Query.Quer
+ "NONE does nothing — in this case, the task submits the query without waiting for its completion."
)
@NotNull
@PluginProperty
@Builder.Default
private FetchType fetchType = FetchType.STORE;
private Property<FetchType> fetchType = Property.of(FetchType.STORE);

@Schema(title = "Whether to skip the first row which is usually the header.")
@NotNull
@PluginProperty
@Builder.Default
private boolean skipHeader = true;
private Property<Boolean> skipHeader = Property.of(true);


private static DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
Expand All @@ -125,23 +120,24 @@ public class Query extends AbstractConnection implements RunnableTask<Query.Quer
public QueryOutput run(RunContext runContext) throws Exception {
// The QueryExecutionContext allows us to set the database.
var queryExecutionContext = QueryExecutionContext.builder()
.catalog(catalog != null ? runContext.render(catalog) : null)
.database(runContext.render(database))
.catalog(catalog != null ? runContext.render(catalog).as(String.class).orElseThrow() : null)
.database(runContext.render(database).as(String.class).orElseThrow())
.build();

// The result configuration specifies where the results of the query should go.
var resultConfiguration = ResultConfiguration.builder()
.outputLocation(runContext.render(outputLocation))
.outputLocation(runContext.render(outputLocation).as(String.class).orElseThrow())
.build();

var startQueryExecutionRequest = StartQueryExecutionRequest.builder()
.queryString(runContext.render(query))
.queryString(runContext.render(query).as(String.class).orElseThrow())
.queryExecutionContext(queryExecutionContext)
.resultConfiguration(resultConfiguration)
.build();

try (var client = client(runContext)) {
var startQueryExecution = client.startQueryExecution(startQueryExecutionRequest);
var fetchType = runContext.render(this.fetchType).as(FetchType.class).orElseThrow();
runContext.logger().info("Query created with Athena execution identifier {}", startQueryExecution.queryExecutionId());
if (fetchType == FetchType.NONE) {
return QueryOutput.builder().queryExecutionId(startQueryExecution.queryExecutionId()).build();
Expand Down Expand Up @@ -179,7 +175,7 @@ public QueryOutput run(RunContext runContext) throws Exception {
.build();
var getQueryResultsResults = client.getQueryResults(getQueryResult);
List<Row> results = getQueryResultsResults.resultSet().rows();
if (skipHeader && results != null && !results.isEmpty()) {
if (runContext.render(skipHeader).as(Boolean.class).orElseThrow() && results != null && !results.isEmpty()) {
// we skip the first row, this is usually needed as by default Athena returns the header as the first row
results = results.subList(1, results.size());
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/io/kestra/plugin/aws/auth/EksToken.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ public Output run(RunContext runContext) throws Exception {
if(this.getRegion() == null) {
throw new RuntimeException("Region is required");
}
final Region awsRegion = Region.of(this.getRegion().as(runContext, String.class));
final Region awsRegion = Region.of(runContext.render(this.getRegion()).as(String.class).orElseThrow());

SdkHttpFullRequest requestToSign = SdkHttpFullRequest
.builder()
.method(SdkHttpMethod.GET)
.uri(getStsRegionalEndpointUri(runContext, awsRegion))
.appendHeader("x-k8s-aws-id", this.clusterName.as(runContext, String.class))
.appendHeader("x-k8s-aws-id", runContext.render(this.clusterName).as(String.class).orElseThrow())
.appendRawQueryParameter("Action", "GetCallerIdentity")
.appendRawQueryParameter("Version", "2011-06-15")
.build();

ZonedDateTime expirationDate = ZonedDateTime.now().plusSeconds(expirationDuration.as(runContext, Long.class));
ZonedDateTime expirationDate = ZonedDateTime.now().plusSeconds(runContext.render(expirationDuration).as(Long.class).orElseThrow());
Aws4PresignerParams presignerParams = Aws4PresignerParams.builder()
.awsCredentials(ConnectionUtils.credentialsProvider(this.awsClientConfig(runContext)).resolveCredentials())
.signingRegion(awsRegion)
Expand Down
19 changes: 10 additions & 9 deletions src/main/java/io/kestra/plugin/aws/cli/AwsCLI.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import lombok.*;
import lombok.experimental.SuperBuilder;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -148,19 +149,19 @@ public ScriptOutput run(RunContext runContext) throws Exception {

// hack for missing env vars supports: https://github.com/aws/aws-cli/issues/5639
if (this.stsRoleArn != null) {
allCommands.add("aws configure set role_arn " + runContext.render(this.stsRoleArn));
allCommands.add("aws configure set role_arn " + runContext.render(this.stsRoleArn).as(String.class).orElseThrow());
}

if (this.stsRoleSessionName != null) {
allCommands.add("aws configure set role_session_name " + runContext.render(this.stsRoleSessionName));
allCommands.add("aws configure set role_session_name " + runContext.render(this.stsRoleSessionName).as(String.class).orElseThrow());
}

if (this.stsRoleExternalId != null) {
allCommands.add("aws configure set external_id " + runContext.render(this.stsRoleExternalId));
allCommands.add("aws configure set external_id " + runContext.render(this.stsRoleExternalId).as(String.class).orElseThrow());
}

if (this.stsRoleSessionDuration != null) {
allCommands.add("aws configure set duration_seconds " + stsRoleSessionDuration.getSeconds());
allCommands.add("aws configure set duration_seconds " + runContext.render(stsRoleSessionDuration).as(Duration.class).orElseThrow().getSeconds());
}

if (this.stsCredentialSource != null) {
Expand Down Expand Up @@ -206,23 +207,23 @@ private Map<String, String> getEnv(RunContext runContext) throws IllegalVariable
Map<String, String> envs = new HashMap<>();

if (this.accessKeyId != null) {
envs.put("AWS_ACCESS_KEY_ID", runContext.render(this.accessKeyId));
envs.put("AWS_ACCESS_KEY_ID", runContext.render(this.accessKeyId).as(String.class).orElseThrow());
}

if (this.secretKeyId != null) {
envs.put("AWS_SECRET_ACCESS_KEY", runContext.render(this.secretKeyId));
envs.put("AWS_SECRET_ACCESS_KEY", runContext.render(this.secretKeyId).as(String.class).orElseThrow());
}

if (this.region != null) {
envs.put("AWS_DEFAULT_REGION", this.region.as(runContext, String.class));
envs.put("AWS_DEFAULT_REGION", runContext.render(this.region).as(String.class).orElseThrow());
}

if (this.sessionToken != null) {
envs.put("AWS_SESSION_TOKEN", runContext.render(this.sessionToken));
envs.put("AWS_SESSION_TOKEN", runContext.render(this.sessionToken).as(String.class).orElseThrow());
}

if (this.endpointOverride != null) {
envs.put("AWS_ENDPOINT_URL", runContext.render(this.endpointOverride));
envs.put("AWS_ENDPOINT_URL", runContext.render(this.endpointOverride).as(String.class).orElseThrow());
}

envs.put("AWS_DEFAULT_OUTPUT", this.outputFormat.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.common.FetchOutput;
import io.kestra.core.models.tasks.common.FetchType;
import io.kestra.core.runners.RunContext;
Expand Down Expand Up @@ -40,9 +41,8 @@
@NoArgsConstructor
public abstract class AbstractDynamoDb extends AbstractConnection {
@Schema(title = "The DynamoDB table name.")
@PluginProperty(dynamic = true)
@NotNull
protected String tableName;
protected Property<String> tableName;

protected DynamoDbClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
Expand Down Expand Up @@ -109,7 +109,7 @@ protected AttributeValue objectFrom(Object value) {
return AttributeValue.fromS(value.toString());
}

protected FetchOutput fetchOutputs(List<Map<String, AttributeValue>> items, FetchType fetchType, RunContext runContext) throws IOException {
protected FetchOutput fetchOutputs(List<Map<String, AttributeValue>> items, FetchType fetchType, RunContext runContext) throws IOException, IllegalVariableEvaluationException {
var outputBuilder = FetchOutput.builder();
switch (fetchType) {
case FETCH:
Expand Down Expand Up @@ -139,7 +139,7 @@ protected FetchOutput fetchOutputs(List<Map<String, AttributeValue>> items, Fetc

runContext.metric(Counter.of(
"records", output.getSize(),
"tableName", getTableName()
"tableName", runContext.render(getTableName()).as(String.class).orElseThrow()
));

return output;
Expand Down
11 changes: 6 additions & 5 deletions src/main/java/io/kestra/plugin/aws/dynamodb/DeleteItem.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.models.tasks.VoidOutput;
import io.kestra.core.runners.RunContext;
Expand Down Expand Up @@ -41,7 +42,7 @@
secretKeyId: "<secret-key>"
region: "eu-central-1"
tableName: "persons"
key:
key:
id: "1"
"""
)
Expand All @@ -52,16 +53,16 @@ public class DeleteItem extends AbstractDynamoDb implements RunnableTask<VoidOut
title = "The DynamoDB item key.",
description = "The DynamoDB item identifier."
)
@PluginProperty
private Map<String, Object> key;
private Property<Map<String, Object>> key;

@Override
public VoidOutput run(RunContext runContext) throws Exception {
try (var dynamoDb = client(runContext)) {
Map<String, AttributeValue> key = valueMapFrom(getKey());
var renderedKey = runContext.render(this.key).asMap(String.class, Object.class);
Map<String, AttributeValue> key = valueMapFrom(renderedKey);

var deleteRequest = DeleteItemRequest.builder()
.tableName(runContext.render(this.getTableName()))
.tableName(runContext.render(this.getTableName()).as(String.class).orElseThrow())
.key(key)
.build();

Expand Down
Loading
Loading