From 40e3585c3f33bbd6361a8374bc339d981f6373ea Mon Sep 17 00:00:00 2001 From: "brian.mulier" Date: Thu, 4 Apr 2024 15:56:46 +0200 Subject: [PATCH] fix(runner): use new ScriptRunner API --- .../aws/runner/AwsBatchScriptRunner.java | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java index af338a1..12ccafc 100644 --- a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java +++ b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java @@ -1,6 +1,5 @@ package io.kestra.plugin.aws.runner; -import com.fasterxml.jackson.annotation.JsonIgnore; import com.google.common.collect.Iterables; import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.Plugin; @@ -155,14 +154,6 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab @Builder.Default private Duration waitUntilCompletion = Duration.ofHours(1); - @JsonIgnore - @Getter(AccessLevel.NONE) - private Path batchWorkingDirectory; - - @JsonIgnore - @Getter(AccessLevel.NONE) - private Path batchOutputDirectory; - @Override public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List filesToUpload, List filesToDownload) throws Exception { boolean hasS3Bucket = this.bucket != null; @@ -228,6 +219,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li .type(JobDefinitionType.CONTAINER) .tags(ScriptService.labels(runContext, "kestra-", true, true)) .platformCapabilities(platformCapability); + Path batchWorkingDirectory = (Path) this.additionalVars(runContext, scriptCommands).get(ScriptService.VAR_WORKING_DIR); if (hasFilesToUpload) { try (S3TransferManager transferManager = transferManager(runContext)) { @@ -238,7 +230,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li .builder() .bucket(renderedBucket) // Use path to eventually deduplicate leading '/' - .key((this.batchWorkingDirectory + Path.of("/" + relativePath).toString()).substring(1)) + .key((batchWorkingDirectory + Path.of("/" + relativePath).toString()).substring(1)) .build() ) .source(runContext.resolve(Path.of(relativePath))) @@ -280,12 +272,14 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li int baseSideContainerMemory = 128; float baseSideContainerCpu = 0.1f; - Map additionalVars = additionalVars(scriptCommands); + Map additionalVars = this.additionalVars(runContext, scriptCommands); Object s3WorkingDir = additionalVars.get(ScriptService.VAR_BUCKET_PATH); + Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR); MountPoint volumeMount = MountPoint.builder() - .containerPath(this.batchWorkingDirectory.toString()) + .containerPath(batchWorkingDirectory.toString()) .sourceVolume(kestraVolume) .build(); + if (hasS3Bucket) { containers.add( withResources( @@ -298,8 +292,8 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li null, Stream.concat( ListUtils.emptyOnNull(filesToUpload).stream() - .map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + this.batchWorkingDirectory + Path.of("/" + relativePath)), - Stream.of("mkdir " + this.batchOutputDirectory) + .map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)), + Stream.of("mkdir " + batchOutputDirectory) ).toList() )) .name(inputFilesContainerName), @@ -323,7 +317,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li .build() ) .environment( - this.env(scriptCommands).entrySet().stream() + this.env(runContext, scriptCommands).entrySet().stream() .map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build()) .toArray(KeyValuePair[]::new) ) @@ -350,8 +344,8 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li null, Stream.concat( filesToDownload.stream() - .map(relativePath -> "aws s3 cp " + this.batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)), - Stream.of("aws s3 cp " + this.batchOutputDirectory + "/ " + s3WorkingDir + "/" + this.batchWorkingDirectory.relativize(this.batchOutputDirectory) + "/ --recursive") + .map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)), + Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive") ).toList() )) .dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build()) @@ -467,7 +461,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li DownloadFileRequest.builder() .getObjectRequest(GetObjectRequest.builder() .bucket(renderedBucket) - .key((this.batchWorkingDirectory + "/" + relativePath).substring(1)) + .key((batchWorkingDirectory + "/" + relativePath).substring(1)) .build()) .destination(scriptCommands.getWorkingDirectory().resolve(Path.of(relativePath.startsWith("/") ? relativePath.substring(1) : relativePath))) .build() @@ -478,7 +472,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li .bucket(renderedBucket) .destination(scriptCommands.getOutputDirectory()) .listObjectsV2RequestTransformer(builder -> builder - .prefix(this.batchOutputDirectory.toString().substring(1)) + .prefix(batchOutputDirectory.toString().substring(1)) ) .build()) .completionFuture() @@ -490,7 +484,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li // Manual close after cleanup to make sure we get all remaining logs cloudWatchLogsAsyncClient.close(); if (hasS3Bucket) { - cleanupS3Resources(runContext); + cleanupS3Resources(runContext, batchWorkingDirectory); } } @@ -498,21 +492,18 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li } @Override - public Map additionalVars(ScriptCommands scriptCommands) { - if (this.additionalVars == null) { - this.additionalVars = new HashMap<>(); - this.additionalVars.putAll(scriptCommands.getAdditionalVars()); - - if (bucket != null) { - this.batchWorkingDirectory = Path.of("/" + IdUtils.create()); - this.additionalVars.put(ScriptService.VAR_WORKING_DIR, this.batchWorkingDirectory); - this.batchOutputDirectory = this.batchWorkingDirectory.resolve(IdUtils.create()); - this.additionalVars.put(ScriptService.VAR_OUTPUT_DIR, this.batchOutputDirectory); - this.additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + bucket + this.batchWorkingDirectory); - } + public Map runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException { + Map additionalVars = new HashMap<>(); + Path batchWorkingDirectory = Path.of("/" + IdUtils.create()); + additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory); + + if (this.bucket != null) { + Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create()); + additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory); + additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory); } - return super.additionalVars(scriptCommands); + return additionalVars; } @Nullable @@ -543,12 +534,12 @@ private TaskContainerProperties.Builder withResources(TaskContainerProperties.Bu ); } - private void cleanupS3Resources(RunContext runContext) throws IllegalVariableEvaluationException { + private void cleanupS3Resources(RunContext runContext, Path batchWorkingDirectory) throws IllegalVariableEvaluationException { String renderedBucket = runContext.render(bucket); try (S3AsyncClient s3AsyncClient = asyncClient(runContext)) { ListObjectsV2Request listRequest = ListObjectsV2Request.builder() .bucket(renderedBucket) - .prefix(this.batchWorkingDirectory.toString()) + .prefix(batchWorkingDirectory.toString()) .build(); ListObjectsV2Response listResponse = s3AsyncClient.listObjectsV2(listRequest).get();