Skip to content

Commit

Permalink
Add docstrings and unit tests for check batch job task
Browse files Browse the repository at this point in the history
  • Loading branch information
jessicasyu committed Aug 27, 2024
1 parent 5862bc2 commit aeeb792
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 6 deletions.
40 changes: 34 additions & 6 deletions src/container_collection/batch/check_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,41 @@
from typing import Union

import boto3
import prefect
from prefect.server.schemas.states import Failed, State
from prefect.context import TaskRunContext
from prefect.states import Failed, State

RETRIES_EXCEEDED_EXIT_CODE = 80


def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]:
task_run = prefect.context.get_run_context().task_run # type: ignore
def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State, bool]:
"""
Check for exit code of an AWS Batch job.
if task_run.run_count > max_retries:
If this task is running within a Prefect flow, it will use the task run
context to get the current run count. While the run count is below the
maximum number of retries, the task will continue to attempt to get the exit
code, and can be called with a retry delay to periodically check the status
of jobs.
If this task is not running within a Prefect flow, the ``max_retries``
parameters is ignored. Jobs that are still running will throw an exception.
Parameters
----------
job_arn
Job ARN.
max_retries
Maximum number of retries.
Returns
-------
:
Exit code if the job is complete, otherwise throws an exception.
"""

context = TaskRunContext.get()

if context is not None and context.task_run.run_count > max_retries:
return RETRIES_EXCEEDED_EXIT_CODE

client = boto3.client("batch")
Expand All @@ -30,8 +55,11 @@ def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]:
response = client.describe_jobs(jobs=[job_arn])["jobs"]
status = response[0]["status"]

if status == "RUNNING":
# For jobs that are running, throw the appropriate exception.
if context is not None and status == "RUNNING":
return Failed()
if status == "RUNNING":
raise RuntimeError("Job is in RUNNING state and does not have exit code.")

exitcode = response[0]["attempts"][0]["container"]["exitCode"]
return exitcode
146 changes: 146 additions & 0 deletions tests/container_collection/batch/test_check_batch_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
import sys
import unittest
from typing import Optional
from unittest import mock

import boto3
from prefect import flow
from prefect.exceptions import FailedRun
from prefect.testing.utilities import prefect_test_harness

from container_collection.batch import check_batch_job as check_batch_job_task
from container_collection.batch.check_batch_job import RETRIES_EXCEEDED_EXIT_CODE, check_batch_job

SUCCEEDED_EXIT_CODE = 0
FAILED_EXIT_CODE = 1


def make_describe_jobs_response(status: Optional[str], exit_code: Optional[int]):
if status is None:
return {"jobs": []}
if status in ("SUCCEEDED", "FAILED"):
return {"jobs": [{"status": status, "attempts": [{"container": {"exitCode": exit_code}}]}]}
return {"jobs": [{"status": status}]}


def make_boto_mock(statuses: list[Optional[str]], exit_code: Optional[int] = None):
batch_mock = mock.MagicMock()
boto3_mock = mock.MagicMock(spec=boto3)
boto3_mock.client.return_value = batch_mock
batch_mock.describe_jobs.side_effect = [
make_describe_jobs_response(status, exit_code) for status in statuses
]
return boto3_mock


@flow
def run_task_under_flow(max_retries: int):
return check_batch_job_task("job-arn", max_retries)


@mock.patch.dict(
os.environ,
{"PREFECT_LOGGING_LEVEL": "CRITICAL"},
)
class TestCheckBatchJob(unittest.TestCase):
@classmethod
def setUpClass(cls):
with prefect_test_harness():
yield

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["RUNNING"]),
)
def test_check_batch_job_as_method_running_throws_exception(self):
with self.assertRaises(RuntimeError):
check_batch_job("job-arn", 0)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock([None, "PENDING", "RUNNING"]),
)
@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"], "sleep", lambda _: None
)
def test_check_batch_job_as_method_running_with_waits_throws_exception(self):
with self.assertRaises(RuntimeError):
check_batch_job("job-arn", 0)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["SUCCEEDED"], SUCCEEDED_EXIT_CODE),
)
def test_check_batch_job_as_method_succeeded(self):
exit_code = check_batch_job("job-arn", 0)
self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["FAILED"], FAILED_EXIT_CODE),
)
def test_check_batch_job_as_method_failed(self):
exit_code = check_batch_job("job-arn", 0)
self.assertEqual(FAILED_EXIT_CODE, exit_code)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["RUNNING"]),
)
def test_check_batch_job_as_task_running_below_max_retries_throws_failed_run(self):
max_retries = 1
with self.assertRaises(FailedRun):
run_task_under_flow(max_retries)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock([None, "PENDING", "RUNNING"]),
)
@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"], "sleep", lambda _: None
)
def test_check_batch_job_as_task_running_below_max_retries_with_waits_throws_failed_run(self):
max_retries = 1
with self.assertRaises(FailedRun):
run_task_under_flow(max_retries)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock([None]),
)
def test_check_batch_job_as_task_max_retries_exceeded(self):
max_retries = 0
exit_code = run_task_under_flow(max_retries)
self.assertEqual(RETRIES_EXCEEDED_EXIT_CODE, exit_code)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["SUCCEEDED"], SUCCEEDED_EXIT_CODE),
)
def test_check_batch_job_as_task_succeeded(self):
max_retries = 1
exit_code = run_task_under_flow(max_retries)
self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code)

@mock.patch.object(
sys.modules["container_collection.batch.check_batch_job"],
"boto3",
make_boto_mock(["FAILED"], FAILED_EXIT_CODE),
)
def test_check_batch_job_as_task_failed(self):
max_retries = 1
exit_code = run_task_under_flow(max_retries)
self.assertEqual(FAILED_EXIT_CODE, exit_code)


if __name__ == "__main__":
unittest.main()

0 comments on commit aeeb792

Please sign in to comment.