Skip to content

Commit

Permalink
fix(runner): use new ScriptRunner API
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-mulier-p committed Apr 5, 2024
1 parent e88941e commit 40e3585
Showing 1 changed file with 26 additions and 35 deletions.
61 changes: 26 additions & 35 deletions src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.kestra.plugin.aws.runner;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.collect.Iterables;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Plugin;
Expand Down Expand Up @@ -155,14 +154,6 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab
@Builder.Default
private Duration waitUntilCompletion = Duration.ofHours(1);

@JsonIgnore
@Getter(AccessLevel.NONE)
private Path batchWorkingDirectory;

@JsonIgnore
@Getter(AccessLevel.NONE)
private Path batchOutputDirectory;

@Override
public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
boolean hasS3Bucket = this.bucket != null;
Expand Down Expand Up @@ -228,6 +219,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
.type(JobDefinitionType.CONTAINER)
.tags(ScriptService.labels(runContext, "kestra-", true, true))
.platformCapabilities(platformCapability);
Path batchWorkingDirectory = (Path) this.additionalVars(runContext, scriptCommands).get(ScriptService.VAR_WORKING_DIR);

if (hasFilesToUpload) {
try (S3TransferManager transferManager = transferManager(runContext)) {
Expand All @@ -238,7 +230,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
.builder()
.bucket(renderedBucket)
// Use path to eventually deduplicate leading '/'
.key((this.batchWorkingDirectory + Path.of("/" + relativePath).toString()).substring(1))
.key((batchWorkingDirectory + Path.of("/" + relativePath).toString()).substring(1))
.build()
)
.source(runContext.resolve(Path.of(relativePath)))
Expand Down Expand Up @@ -280,12 +272,14 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li

int baseSideContainerMemory = 128;
float baseSideContainerCpu = 0.1f;
Map<String, Object> additionalVars = additionalVars(scriptCommands);
Map<String, Object> additionalVars = this.additionalVars(runContext, scriptCommands);
Object s3WorkingDir = additionalVars.get(ScriptService.VAR_BUCKET_PATH);
Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR);
MountPoint volumeMount = MountPoint.builder()
.containerPath(this.batchWorkingDirectory.toString())
.containerPath(batchWorkingDirectory.toString())
.sourceVolume(kestraVolume)
.build();

if (hasS3Bucket) {
containers.add(
withResources(
Expand All @@ -298,8 +292,8 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
null,
Stream.concat(
ListUtils.emptyOnNull(filesToUpload).stream()
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + this.batchWorkingDirectory + Path.of("/" + relativePath)),
Stream.of("mkdir " + this.batchOutputDirectory)
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)),
Stream.of("mkdir " + batchOutputDirectory)
).toList()
))
.name(inputFilesContainerName),
Expand All @@ -323,7 +317,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
.build()
)
.environment(
this.env(scriptCommands).entrySet().stream()
this.env(runContext, scriptCommands).entrySet().stream()
.map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build())
.toArray(KeyValuePair[]::new)
)
Expand All @@ -350,8 +344,8 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
null,
Stream.concat(
filesToDownload.stream()
.map(relativePath -> "aws s3 cp " + this.batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)),
Stream.of("aws s3 cp " + this.batchOutputDirectory + "/ " + s3WorkingDir + "/" + this.batchWorkingDirectory.relativize(this.batchOutputDirectory) + "/ --recursive")
.map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)),
Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive")
).toList()
))
.dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build())
Expand Down Expand Up @@ -467,7 +461,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
DownloadFileRequest.builder()
.getObjectRequest(GetObjectRequest.builder()
.bucket(renderedBucket)
.key((this.batchWorkingDirectory + "/" + relativePath).substring(1))
.key((batchWorkingDirectory + "/" + relativePath).substring(1))
.build())
.destination(scriptCommands.getWorkingDirectory().resolve(Path.of(relativePath.startsWith("/") ? relativePath.substring(1) : relativePath)))
.build()
Expand All @@ -478,7 +472,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
.bucket(renderedBucket)
.destination(scriptCommands.getOutputDirectory())
.listObjectsV2RequestTransformer(builder -> builder
.prefix(this.batchOutputDirectory.toString().substring(1))
.prefix(batchOutputDirectory.toString().substring(1))
)
.build())
.completionFuture()
Expand All @@ -490,29 +484,26 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
// Manual close after cleanup to make sure we get all remaining logs
cloudWatchLogsAsyncClient.close();
if (hasS3Bucket) {
cleanupS3Resources(runContext);
cleanupS3Resources(runContext, batchWorkingDirectory);
}
}

return new RunnerResult(0, logConsumer);
}

@Override
public Map<String, Object> additionalVars(ScriptCommands scriptCommands) {
if (this.additionalVars == null) {
this.additionalVars = new HashMap<>();
this.additionalVars.putAll(scriptCommands.getAdditionalVars());

if (bucket != null) {
this.batchWorkingDirectory = Path.of("/" + IdUtils.create());
this.additionalVars.put(ScriptService.VAR_WORKING_DIR, this.batchWorkingDirectory);
this.batchOutputDirectory = this.batchWorkingDirectory.resolve(IdUtils.create());
this.additionalVars.put(ScriptService.VAR_OUTPUT_DIR, this.batchOutputDirectory);
this.additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + bucket + this.batchWorkingDirectory);
}
public Map<String, Object> runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException {
Map<String, Object> additionalVars = new HashMap<>();
Path batchWorkingDirectory = Path.of("/" + IdUtils.create());
additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory);

if (this.bucket != null) {
Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create());
additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory);
additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory);
}

return super.additionalVars(scriptCommands);
return additionalVars;
}

@Nullable
Expand Down Expand Up @@ -543,12 +534,12 @@ private TaskContainerProperties.Builder withResources(TaskContainerProperties.Bu
);
}

private void cleanupS3Resources(RunContext runContext) throws IllegalVariableEvaluationException {
private void cleanupS3Resources(RunContext runContext, Path batchWorkingDirectory) throws IllegalVariableEvaluationException {
String renderedBucket = runContext.render(bucket);
try (S3AsyncClient s3AsyncClient = asyncClient(runContext)) {
ListObjectsV2Request listRequest = ListObjectsV2Request.builder()
.bucket(renderedBucket)
.prefix(this.batchWorkingDirectory.toString())
.prefix(batchWorkingDirectory.toString())
.build();

ListObjectsV2Response listResponse = s3AsyncClient.listObjectsV2(listRequest).get();
Expand Down

0 comments on commit 40e3585

Please sign in to comment.