diff --git a/build.gradle b/build.gradle index b4ad426..122f956 100644 --- a/build.gradle +++ b/build.gradle @@ -56,7 +56,7 @@ dependencies { // AWS libs: versions are managed by the Micronaut BOM api platform("io.micronaut.platform:micronaut-platform:$micronautVersion") api 'software.amazon.awssdk:cloudwatchlogs' - api 'software.amazon.awssdk:batch' + api 'software.amazon.awssdk:batch:2.25.14' api 'software.amazon.awssdk:s3' api 'software.amazon.awssdk:s3-transfer-manager' api 'software.amazon.awssdk.crt:aws-crt:0.29.10' //used by s3-transfer-manager diff --git a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java index fe8dd9e..db17e38 100644 --- a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java +++ b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java @@ -1,6 +1,7 @@ package io.kestra.plugin.aws.runner; import com.google.common.annotations.Beta; +import com.google.common.collect.Iterables; import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.PluginProperty; import io.kestra.core.models.script.*; @@ -39,6 +40,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import static io.kestra.core.utils.Rethrow.throwConsumer; @@ -50,8 +53,9 @@ @NoArgsConstructor @Beta @Schema(title = "AWS Batch script runner", description = """ - Run a script in a container on an AWS Batch Compute Environment. + Run a script in a container on an AWS Batch Compute Environment (Only Fargate or EC2 are supported; For EKS, Kubernetes Script Runner must be used). To use `inputFiles`, `outputFiles` and `namespaceFiles` properties, you must provide a `s3Bucket`. + Doing so will upload the files to the bucket before running the script and download them after the script execution. This runner will wait for the task to succeed or fail up to a max `waitUntilCompletion` duration. It will return with an exit code according to the following mapping: - SUCCEEDED: 0 @@ -72,6 +76,8 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab JobStatus.SUBMITTED, 6, JobStatus.UNKNOWN_TO_SDK_VERSION, -1 ); + public static final String S3_WORKING_DIR_KEY = "s3WorkingDir"; + public static final String WORKING_DIR_KEY = "workingDir"; @NotNull @PluginProperty(dynamic = true) @@ -139,6 +145,7 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab description = "If using a Fargate compute environments, resources requests must match this table: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html" ) @PluginProperty + @NotNull @Builder.Default private Resources resources = Resources.builder() .request( @@ -186,52 +193,45 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List new IllegalArgumentException("Compute environment not found: " + computeEnvironmentArn)); - if (this.resources != null && this.resources.getRequest() != null) { - containerPropsBuilder - .resourceRequirements( - ResourceRequirement.builder() - .type(ResourceType.MEMORY) - .value(this.resources.getRequest().getMemory()) - .build(), - ResourceRequirement.builder() - .type(ResourceType.VCPU) - .value(this.resources.getRequest().getCpu()) - .build() - ); + String kestraVolume = "kestra"; + if (computeEnvironmentDetail.containerOrchestrationType() != OrchestrationType.ECS) { + throw new IllegalArgumentException("Only ECS compute environments are supported"); } - if (this.executionRoleArn != null) { - containerPropsBuilder.executionRoleArn(runContext.render(this.executionRoleArn)); - } + PlatformCapability platformCapability = switch (computeEnvironmentDetail.computeResources().type()) { + case FARGATE: + case FARGATE_SPOT: + yield PlatformCapability.FARGATE; + case EC2: + case SPOT: + yield PlatformCapability.EC2; + default: + yield null; + }; - if (commands.getEnv() != null) { - containerPropsBuilder - .environment( - commands.getEnv().entrySet().stream() - .map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build()) - .toArray(KeyValuePair[]::new) - ); - } + RegisterJobDefinitionRequest.Builder jobDefBuilder = RegisterJobDefinitionRequest.builder() + .jobDefinitionName(IdUtils.create()) + .type(JobDefinitionType.CONTAINER) + .platformCapabilities(platformCapability); String renderedBucket = runContext.render(s3Bucket); - String workingDirKey = IdUtils.create(); - String outputDirKey = IdUtils.create(); + String workingDirName = IdUtils.create(); Map additionalVars = Optional.ofNullable(renderedBucket) .map(bucket -> Map.of( - "workingDir", "s3://" + bucket + "/" + workingDirKey, - "outputDir", "s3://" + bucket + "/" + outputDirKey + S3_WORKING_DIR_KEY, "s3://" + bucket + "/" + workingDirName, + WORKING_DIR_KEY, "/" + workingDirName, + "outputDir", "/" + workingDirName )) .orElse(Collections.emptyMap()); - List command = ScriptService.uploadInputFiles(runContext, runContext.render(commands.getCommands(), additionalVars)); if (filesToUpload != null && !filesToUpload.isEmpty()) { try (S3TransferManager transferManager = transferManager(runContext)) { @@ -241,7 +241,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List new IllegalArgumentException("Compute environment not found: " + computeEnvironmentArn)); - - PlatformCapability platformCapability = switch (computeEnvironmentDetail.computeResources().type()) { - case FARGATE: - case FARGATE_SPOT: - yield PlatformCapability.FARGATE; - case EC2: - case SPOT: - yield PlatformCapability.EC2; - default: - yield null; - }; + ); if (platformCapability == PlatformCapability.FARGATE) { - containerPropsBuilder.networkConfiguration( - NetworkConfiguration.builder() - .assignPublicIp(AssignPublicIp.ENABLED) + taskPropertiesBuilder + .networkConfiguration( + NetworkConfiguration.builder() + .assignPublicIp(AssignPublicIp.ENABLED) + .build() + ); + } + + if (this.executionRoleArn != null) { + taskPropertiesBuilder.executionRoleArn(runContext.render(this.executionRoleArn)); + } + + if (this.jobRoleArn != null) { + taskPropertiesBuilder.taskRoleArn(runContext.render(this.jobRoleArn)); + } + + List containers = new ArrayList<>(); + String inputFilesContainerName = "inputFiles"; + String mainContainerName = "main"; + String outputFilesContainerName = "outputFiles"; + + int baseSideContainerMemory = 128; + float baseSideContainerCpu = 0.1f; + Object s3WorkingDir = additionalVars.get(S3_WORKING_DIR_KEY); + if (hasFilesToUpload) { + containers.add( + withResources( + TaskContainerProperties.builder() + .image("ghcr.io/kestra-io/awsbatch:latest") + .mountPoints( + MountPoint.builder() + .containerPath("/" + workingDirName) + .sourceVolume(kestraVolume) + .build() + ) + .essential(false) + .command(ScriptService.scriptCommands( + List.of("/bin/sh", "-c"), + null, + filesToUpload.stream() + .map(relativePath -> "aws s3 cp " + s3WorkingDir + "/" + relativePath + " /" + workingDirName + "/" + relativePath) + .toList() + )) + .name(inputFilesContainerName), + baseSideContainerMemory, + baseSideContainerCpu).build() + ); + } + + int sideContainersMemoryAllocations = (hasFilesToUpload ? baseSideContainerMemory : 0) + (hasFilesToDownload ? baseSideContainerMemory : 0); + float sideContainersCpuAllocations = (hasFilesToUpload ? baseSideContainerCpu : 0) + (hasFilesToDownload ? baseSideContainerCpu : 0); + + List command = ScriptService.uploadInputFiles(runContext, runContext.render(commands.getCommands(), additionalVars)); + TaskContainerProperties.Builder mainContainerBuilder = withResources( + TaskContainerProperties.builder() + .image(commands.getContainerImage()) + .command(command) + .name(mainContainerName) + .logConfiguration( + LogConfiguration.builder() + .logDriver(LogDriver.AWSLOGS) + .options(Map.of("awslogs-stream-prefix", jobName)) + .build() + ) + .dependsOn(TaskContainerDependency.builder().containerName(inputFilesContainerName).condition("COMPLETE").build()) + .essential(!hasFilesToDownload), + Integer.parseInt(resources.getRequest().getMemory()) - sideContainersMemoryAllocations, + Float.parseFloat(resources.getRequest().getCpu()) - sideContainersCpuAllocations + ); + + if (commands.getEnv() != null) { + mainContainerBuilder + .environment( + commands.getEnv().entrySet().stream() + .map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build()) + .toArray(KeyValuePair[]::new) + ); + } + + if (hasFilesToUpload || hasFilesToDownload) { + mainContainerBuilder.mountPoints( + MountPoint.builder() + .containerPath("/" + workingDirName) + .sourceVolume(kestraVolume) .build() ); } + containers.add(mainContainerBuilder.build()); + + if (hasFilesToDownload) { + containers.add( + withResources( + TaskContainerProperties.builder() + .image("ghcr.io/kestra-io/awsbatch:latest") + .mountPoints( + MountPoint.builder() + .containerPath("/" + workingDirName) + .sourceVolume(kestraVolume) + .build() + ) + .command(ScriptService.scriptCommands( + List.of("/bin/sh", "-c"), + null, + filesToDownload.stream() + .map(relativePath -> "aws s3 cp /" + workingDirName + "/" + relativePath + " " + s3WorkingDir + "/" + relativePath) + .toList() + )) + .dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("COMPLETE").build()) + .name(outputFilesContainerName), + baseSideContainerMemory, + baseSideContainerCpu).build() + ); + + taskPropertiesBuilder.containers(containers); + + jobDefBuilder.ecsProperties( + EcsProperties.builder() + .taskProperties(taskPropertiesBuilder.build()) + .build() + ); + } logger.debug("Registering job definition"); - RegisterJobDefinitionResponse registerJobDefinitionResponse = client.registerJobDefinition( - RegisterJobDefinitionRequest.builder() - .jobDefinitionName(IdUtils.create()) - .type(JobDefinitionType.CONTAINER) - .platformCapabilities(platformCapability) - .containerProperties(containerPropsBuilder.build()) - .build() - ); + RegisterJobDefinitionResponse registerJobDefinitionResponse = client.registerJobDefinition(jobDefBuilder.build()); String jobDefArn = registerJobDefinitionResponse.jobDefinitionArn(); logger.debug("Job definition successfully registered: {}", jobDefArn); @@ -383,7 +476,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List filesToUpload, List filesToDownload, String workingDirKey, String outputDirKey) throws IllegalVariableEvaluationException { - String renderedBucket = runContext.render(s3Bucket); - try(S3AsyncClient s3AsyncClient = asyncClient(runContext)) { - List> deletions = new ArrayList<>(); - if (filesToUpload != null && !filesToUpload.isEmpty()) { - deletions.add(s3AsyncClient.deleteObjects( - DeleteObjectsRequest.builder() - .bucket(renderedBucket) - .delete(Delete.builder() - .objects( - filesToUpload.stream() - .map(relativePath -> ObjectIdentifier.builder() - .key(workingDirKey + "/" + relativePath) - .build() - ) - .toList() - ).build() - ).build() - )); - } + private TaskContainerProperties.Builder withResources(TaskContainerProperties.Builder builder, Integer memory, Float cpu) { + return builder + .resourceRequirements( + ResourceRequirement.builder() + .type(ResourceType.MEMORY) + .value(memory.toString()) + .build(), + ResourceRequirement.builder() + .type(ResourceType.VCPU) + .value(cpu.toString()) + .build() + ); + } - if (filesToDownload != null && !filesToDownload.isEmpty()) { - deletions.add(s3AsyncClient.deleteObjects( + private void cleanupS3Resources(RunContext runContext, List filesToUpload, List filesToDownload, String workingDirName) throws IllegalVariableEvaluationException { + String renderedBucket = runContext.render(s3Bucket); + try (S3AsyncClient s3AsyncClient = asyncClient(runContext)) { + StreamSupport.stream(Iterables.partition( + Stream.concat( + Optional.ofNullable(filesToUpload).stream().flatMap(Collection::stream), + Optional.ofNullable(filesToDownload).stream().flatMap(Collection::stream) + ).map(file -> ObjectIdentifier.builder().key("/" + workingDirName + "/" + file).build()) + .toList(), + 1000 + ).spliterator(), false).parallel() + .map(objects -> s3AsyncClient.deleteObjects( DeleteObjectsRequest.builder() .bucket(renderedBucket) .delete(Delete.builder() - .objects( - filesToDownload.stream() - .map(relativePath -> ObjectIdentifier.builder() - .key(outputDirKey + "/" + relativePath) - .build() - ) - .toList() - ).build() - ).build() - )); - } - - deletions.stream().parallel().forEach(throwConsumer(CompletableFuture::get)); + .objects(objects) + .build()) + .build() + )).forEach(throwConsumer(CompletableFuture::get)); } catch (Exception e) { runContext.logger().warn("Error while cleaning up S3: {}", e.getMessage()); } diff --git a/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java b/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java index 388c6d0..d778586 100644 --- a/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java +++ b/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java @@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.is; @MicronautTest -@Disabled("Need AWS credentials") +@Disabled("Too costly to run on CI") public class AwsBatchScriptRunnerTest extends AbstractScriptRunnerTest { @Inject private RunContextFactory runContextFactory; @@ -53,9 +53,8 @@ protected void inputAndOutputFiles() throws Exception { Map logsWithIsStdErr = new HashMap<>(); CommandsWrapper commandsWrapper = new CommandsWrapper(runContext) .withCommands(ScriptService.scriptCommands(List.of("/bin/sh", "-c"), null, List.of( - "aws s3 cp {{workingDir}}/hello.txt hello.txt", - "cat hello.txt", - "aws s3 cp hello.txt {{outputDir}}/output.txt" + "cat {{workingDir}}/hello.txt", + "cat {{workingDir}}/hello.txt > {{workingDir}}/output.txt" ))) .withContainerImage("ghcr.io/kestra-io/awsbatch:latest") .withLogConsumer(new AbstractLogConsumer() { @@ -71,11 +70,7 @@ public void accept(String log, Boolean isStdErr) { assertThat(run.getExitCode(), is(0)); // Verify logs - Map.Entry helloWorldEntry = logsWithIsStdErr.entrySet().stream() - .filter(e -> e.getKey().contains("Hello World")) - .findFirst() - .orElseThrow(); - assertThat(helloWorldEntry.getValue(), is(false)); + assertThat(logsWithIsStdErr.get("[JOB LOG] Hello World"), is(false)); // Verify outputFiles File outputFile = runContext.resolve(Path.of("output.txt")).toFile(); @@ -94,6 +89,7 @@ protected ScriptRunner scriptRunner() { .executionRoleArn("arn:aws:iam::634784741179:role/AWS-Batch-Role-For-Fargate") .jobRoleArn("arn:aws:iam::634784741179:role/S3-Within-AWS-Batch") .waitUntilCompletion(Duration.ofMinutes(30)) + .jobQueueArn("arn:aws:batch:eu-west-3:634784741179:job-queue/queue") .build(); } }