Skip to content

Commit

Permalink
Add unit tests for container collection batch tasks (#64)
Browse files Browse the repository at this point in the history
* Add entrypoint to existing tests

* Add docstrings and unit tests for make batch job task

* Add moto dependency

* Add docstrings and unit tests for register batch job task

* Add docstrings and unit tests for submit batch job task

* Add docstrings and unit tests for check batch job task

* Add docstrings and unit tests for terminate batch job task

* Add docstrings to constants

* Add docstrings and unit tests for get batch logs task
  • Loading branch information
jessicasyu committed Aug 28, 2024
1 parent 81acde5 commit a1e5777
Show file tree
Hide file tree
Showing 19 changed files with 916 additions and 10 deletions.
95 changes: 94 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ furo = "^2023.5.20"
myst-parser = "^2.0.0"
sphinx-copybutton = "^0.5.2"
tox = "^4.5.1"
moto = {extras = ["ec2", "batch", "logs"], version = "^5.0.11"}

[tool.isort]
profile = "black"
Expand Down
41 changes: 35 additions & 6 deletions src/container_collection/batch/check_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,42 @@
from typing import Union

import boto3
import prefect
from prefect.server.schemas.states import Failed, State
from prefect.context import TaskRunContext
from prefect.states import Failed, State

RETRIES_EXCEEDED_EXIT_CODE = 80
"""Exit code used when task run retries exceed the maximum retries."""


def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]:
task_run = prefect.context.get_run_context().task_run # type: ignore
def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State, bool]:
"""
Check for exit code of an AWS Batch job.
if task_run.run_count > max_retries:
If this task is running within a Prefect flow, it will use the task run
context to get the current run count. While the run count is below the
maximum number of retries, the task will continue to attempt to get the exit
code, and can be called with a retry delay to periodically check the status
of jobs.
If this task is not running within a Prefect flow, the ``max_retries``
parameters is ignored. Jobs that are still running will throw an exception.
Parameters
----------
job_arn
Job ARN.
max_retries
Maximum number of retries.
Returns
-------
:
Exit code if the job is complete, otherwise throws an exception.
"""

context = TaskRunContext.get()

if context is not None and context.task_run.run_count > max_retries:
return RETRIES_EXCEEDED_EXIT_CODE

client = boto3.client("batch")
Expand All @@ -30,8 +56,11 @@ def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]:
response = client.describe_jobs(jobs=[job_arn])["jobs"]
status = response[0]["status"]

if status == "RUNNING":
# For jobs that are running, throw the appropriate exception.
if context is not None and status == "RUNNING":
return Failed()
if status == "RUNNING":
raise RuntimeError("Job is in RUNNING state and does not have exit code.")

exitcode = response[0]["attempts"][0]["container"]["exitCode"]
return exitcode
26 changes: 24 additions & 2 deletions src/container_collection/batch/get_batch_logs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
import boto3

LOG_GROUP_NAME = "/aws/batch/job"
"""AWS Batch log group name."""


def get_batch_logs(job_arn: str, log_filter: str) -> str:
"""
Get logs for AWS Batch job.
Parameters
----------
job_arn
Job ARN.
log_filter
Filter for log events.
Returns
-------
:
All filtered log events.
"""

client = boto3.client("batch")
response = client.describe_jobs(jobs=[job_arn])["jobs"][0]
log_stream = response["container"]["logStreamName"]
Expand All @@ -10,19 +29,22 @@ def get_batch_logs(job_arn: str, log_filter: str) -> str:
log_events: list[str] = []

response = client.filter_log_events(
logGroupName="/aws/batch/job",
logGroupName=LOG_GROUP_NAME,
logStreamNames=[log_stream],
filterPattern=log_filter,
)

while "nextToken" in response:
if response["events"]:
log_events = log_events + [event["message"] for event in response["events"]]

while "nextToken" in response:
response = client.filter_log_events(
logGroupName="/aws/batch/job",
logStreamNames=[log_stream],
filterPattern=log_filter,
nextToken=response["nextToken"],
)

log_events = log_events + [event["message"] for event in response["events"]]

return "\n".join(log_events)
38 changes: 38 additions & 0 deletions src/container_collection/batch/make_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@ def make_batch_job(
environment: Optional[list[dict[str, str]]] = None,
job_role_arn: Optional[str] = None,
) -> dict:
"""
Create batch job definition.
Docker images on the Docker Hub registry are available by default, and can
be specified using ``image:tag``. Otherwise, use ``repository/image:tag``.
Environment variables are passed as key-value pairs using the following
structure:
.. code-block:: python
[
{ "name" : "envName1", "value" : "envValue1" },
{ "name" : "envName2", "value" : "envValue2" },
...
]
Parameters
----------
name
Job definition name.
image
Docker image.
vcpus
Number of vCPUs to reserve for the container.
memory
Memory limit available to the container
environment
List of environment variables as key-value pairs.
job_role_arn
ARN for IAM role for the job container.
Returns
-------
:
Job definition.
"""

container_properties = {
"image": image,
"vcpus": vcpus,
Expand Down
19 changes: 19 additions & 0 deletions src/container_collection/batch/register_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@


def register_batch_job(job_definition: dict) -> str:
"""
Register job definition to AWS Batch.
If a definition for the given job definition name already exists, and the
contents of the definition are not changed, then the method will return the
existing job definition ARN rather than creating a new revision.
Parameters
----------
job_definition
Batch job definition.
Returns
-------
:
Job definition ARN.
"""

client = boto3.client("batch")

response = client.describe_job_definitions(
Expand All @@ -17,4 +35,5 @@ def register_batch_job(job_definition: dict) -> str:
return existing_definition["jobDefinitionArn"]

response = client.register_job_definition(**job_definition)

return response["jobDefinitionArn"]
22 changes: 22 additions & 0 deletions src/container_collection/batch/submit_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
def submit_batch_job(
name: str, job_definition_arn: str, user: str, queue: str, size: int
) -> list[str]:
"""
Submit job to AWS Batch.
Parameters
----------
name
Job name.
job_definition_arn
Job definition ARN.
user
User name prefix for job name.
queue
Job queue.
size
Number of jobs in array.
Returns
-------
:
List of job ARNs.
"""

job_submission = {
"jobName": f"{user}_{name}",
"jobQueue": queue,
Expand Down
16 changes: 15 additions & 1 deletion src/container_collection/batch/terminate_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@

import boto3

TERMINATION_REASON = "Termination requested by workflow."
"""Reason sent for terminating jobs from a workflow."""


def terminate_batch_job(job_arn: str) -> None:
"""
Terminate job on AWS Batch.
Task will sleep for 1 minute after sending the termination request.
Parameters
----------
job_arn
Job ARN.
"""

client = boto3.client("batch")
client.terminate_job(jobId=job_arn, reason="Prefect workflow termination")
client.terminate_job(jobId=job_arn, reason=TERMINATION_REASON)
sleep(60)
Loading

0 comments on commit a1e5777

Please sign in to comment.