diff --git a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchTaskRunner.java b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchTaskRunner.java index 5c8b96c..19522d1 100644 --- a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchTaskRunner.java +++ b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchTaskRunner.java @@ -227,8 +227,9 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List commands = ListUtils.emptyOnNull(filesToUpload).stream() + .map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)); + if (outputDirectoryEnabled) { + commands = Stream.concat(commands, Stream.of("mkdir " + batchOutputDirectory)); + } containers.add( withResources( TaskContainerProperties.builder() @@ -352,11 +358,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)), - Stream.of("mkdir " + batchOutputDirectory) - ).toList() + commands.toList() )) .name(inputFilesContainerName), baseSideContainerMemory, @@ -367,6 +369,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build()) .toArray(KeyValuePair[]::new) ) - .essential(!hasS3Bucket), + .essential(!needsOutputFilesContainer), Integer.parseInt(resources.getRequest().getMemory()) - sideContainersMemoryAllocations, Float.parseFloat(resources.getRequest().getCpu()) - sideContainersCpuAllocations ); - if (hasS3Bucket) { + if (needsOutputFilesContainer) { mainContainerBuilder.dependsOn(TaskContainerDependency.builder().containerName(inputFilesContainerName).condition("SUCCESS").build()); mainContainerBuilder.mountPoints(volumeMount); } containers.add(mainContainerBuilder.build()); - if (hasS3Bucket) { + if (needsOutputFilesContainer) { + Stream commands = filesToDownload.stream() + .map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)); + if (outputDirectoryEnabled) { + commands = Stream.concat(commands, Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive")); + } containers.add( withResources( TaskContainerProperties.builder() @@ -404,11 +412,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)), - Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive") - ).toList() + commands.toList() )) .dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build()) .name(outputFilesContainerName), @@ -517,7 +521,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List transferManager.downloadFile( DownloadFileRequest.builder() @@ -530,15 +534,17 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List builder - .prefix(batchOutputDirectory.toString().substring(1)) - ) - .build()) - .completionFuture() - .get(); + if (outputDirectoryEnabled) { + transferManager.downloadDirectory(DownloadDirectoryRequest.builder() + .bucket(renderedBucket) + .destination(taskCommands.getOutputDirectory()) + .listObjectsV2RequestTransformer(builder -> builder + .prefix(batchOutputDirectory.toString().substring(1)) + ) + .build()) + .completionFuture() + .get(); + } } } } finally { @@ -560,9 +566,12 @@ public Map runnerAdditionalVars(RunContext runContext, TaskComma additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory); if (this.bucket != null) { + additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory); + } + + if (taskCommands.outputDirectoryEnabled()) { Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create()); additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory); - additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory); } return additionalVars;