Skip to content

Commit

Permalink
fix(runner): output dir property as a guard
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-mulier-p committed Apr 9, 2024
1 parent 740f4e7 commit de71d29
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions src/main/java/io/kestra/plugin/aws/runner/AwsBatchTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
throw new IllegalArgumentException("You must provide a S3Bucket to use `inputFiles` or `namespaceFiles`");
}
boolean hasFilesToDownload = !ListUtils.isEmpty(filesToDownload);
if (hasFilesToDownload && !hasS3Bucket) {
throw new IllegalArgumentException("You must provide a S3Bucket to use `outputFiles`");
boolean outputDirectoryEnabled = taskCommands.outputDirectoryEnabled();
if ((hasFilesToDownload || outputDirectoryEnabled) && !hasS3Bucket) {
throw new IllegalArgumentException("You must provide a S3Bucket to use `outputFiles` or `{{ outputDir }}`");
}

Logger logger = runContext.logger();
Expand Down Expand Up @@ -342,7 +343,12 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
.sourceVolume(kestraVolume)
.build();

if (hasS3Bucket) {
if (hasFilesToUpload || outputDirectoryEnabled) {
Stream<String> commands = ListUtils.emptyOnNull(filesToUpload).stream()
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath));
if (outputDirectoryEnabled) {
commands = Stream.concat(commands, Stream.of("mkdir " + batchOutputDirectory));
}
containers.add(
withResources(
TaskContainerProperties.builder()
Expand All @@ -352,11 +358,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
.command(ScriptService.scriptCommands(
List.of("/bin/sh", "-c"),
null,
Stream.concat(
ListUtils.emptyOnNull(filesToUpload).stream()
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)),
Stream.of("mkdir " + batchOutputDirectory)
).toList()
commands.toList()
))
.name(inputFilesContainerName),
baseSideContainerMemory,
Expand All @@ -367,6 +369,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
int sideContainersMemoryAllocations = hasS3Bucket ? baseSideContainerMemory * 2 : 0;
float sideContainersCpuAllocations = hasS3Bucket ? baseSideContainerCpu * 2 : 0;

boolean needsOutputFilesContainer = hasFilesToDownload || outputDirectoryEnabled;
TaskContainerProperties.Builder mainContainerBuilder = withResources(
TaskContainerProperties.builder()
.image(taskCommands.getContainerImage())
Expand All @@ -383,19 +386,24 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
.map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build())
.toArray(KeyValuePair[]::new)
)
.essential(!hasS3Bucket),
.essential(!needsOutputFilesContainer),
Integer.parseInt(resources.getRequest().getMemory()) - sideContainersMemoryAllocations,
Float.parseFloat(resources.getRequest().getCpu()) - sideContainersCpuAllocations
);

if (hasS3Bucket) {
if (needsOutputFilesContainer) {
mainContainerBuilder.dependsOn(TaskContainerDependency.builder().containerName(inputFilesContainerName).condition("SUCCESS").build());
mainContainerBuilder.mountPoints(volumeMount);
}

containers.add(mainContainerBuilder.build());

if (hasS3Bucket) {
if (needsOutputFilesContainer) {
Stream<String> commands = filesToDownload.stream()
.map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath));
if (outputDirectoryEnabled) {
commands = Stream.concat(commands, Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive"));
}
containers.add(
withResources(
TaskContainerProperties.builder()
Expand All @@ -404,11 +412,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
.command(ScriptService.scriptCommands(
List.of("/bin/sh", "-c"),
null,
Stream.concat(
filesToDownload.stream()
.map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)),
Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive")
).toList()
commands.toList()
))
.dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build())
.name(outputFilesContainerName),
Expand Down Expand Up @@ -517,7 +521,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
throw new TaskException(exitCode, logConsumer.getStdOutCount(), logConsumer.getStdErrCount());
}

if (hasS3Bucket) {
if (hasFilesToDownload || outputDirectoryEnabled) {
try (S3TransferManager transferManager = transferManager(runContext)) {
filesToDownload.stream().map(relativePath -> transferManager.downloadFile(
DownloadFileRequest.builder()
Expand All @@ -530,15 +534,17 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
)).map(FileDownload::completionFuture)
.forEach(throwConsumer(CompletableFuture::get));

transferManager.downloadDirectory(DownloadDirectoryRequest.builder()
.bucket(renderedBucket)
.destination(taskCommands.getOutputDirectory())
.listObjectsV2RequestTransformer(builder -> builder
.prefix(batchOutputDirectory.toString().substring(1))
)
.build())
.completionFuture()
.get();
if (outputDirectoryEnabled) {
transferManager.downloadDirectory(DownloadDirectoryRequest.builder()
.bucket(renderedBucket)
.destination(taskCommands.getOutputDirectory())
.listObjectsV2RequestTransformer(builder -> builder
.prefix(batchOutputDirectory.toString().substring(1))
)
.build())
.completionFuture()
.get();
}
}
}
} finally {
Expand All @@ -560,9 +566,12 @@ public Map<String, Object> runnerAdditionalVars(RunContext runContext, TaskComma
additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory);

if (this.bucket != null) {
additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory);
}

if (taskCommands.outputDirectoryEnabled()) {
Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create());
additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory);
additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory);
}

return additionalVars;
Expand Down

0 comments on commit de71d29

Please sign in to comment.