From aeeb792cd24eed35864955dbbe603216dc4311cf Mon Sep 17 00:00:00 2001 From: jessicasyu <15913767+jessicasyu@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:50:24 -0400 Subject: [PATCH] Add docstrings and unit tests for check batch job task --- .../batch/check_batch_job.py | 40 ++++- .../batch/test_check_batch_job.py | 146 ++++++++++++++++++ 2 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 tests/container_collection/batch/test_check_batch_job.py diff --git a/src/container_collection/batch/check_batch_job.py b/src/container_collection/batch/check_batch_job.py index 1a349ba..6dc256e 100644 --- a/src/container_collection/batch/check_batch_job.py +++ b/src/container_collection/batch/check_batch_job.py @@ -2,16 +2,41 @@ from typing import Union import boto3 -import prefect -from prefect.server.schemas.states import Failed, State +from prefect.context import TaskRunContext +from prefect.states import Failed, State RETRIES_EXCEEDED_EXIT_CODE = 80 -def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]: - task_run = prefect.context.get_run_context().task_run # type: ignore +def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State, bool]: + """ + Check for exit code of an AWS Batch job. - if task_run.run_count > max_retries: + If this task is running within a Prefect flow, it will use the task run + context to get the current run count. While the run count is below the + maximum number of retries, the task will continue to attempt to get the exit + code, and can be called with a retry delay to periodically check the status + of jobs. + + If this task is not running within a Prefect flow, the ``max_retries`` + parameters is ignored. Jobs that are still running will throw an exception. + + Parameters + ---------- + job_arn + Job ARN. + max_retries + Maximum number of retries. + + Returns + ------- + : + Exit code if the job is complete, otherwise throws an exception. + """ + + context = TaskRunContext.get() + + if context is not None and context.task_run.run_count > max_retries: return RETRIES_EXCEEDED_EXIT_CODE client = boto3.client("batch") @@ -30,8 +55,11 @@ def check_batch_job(job_arn: str, max_retries: int) -> Union[int, State]: response = client.describe_jobs(jobs=[job_arn])["jobs"] status = response[0]["status"] - if status == "RUNNING": + # For jobs that are running, throw the appropriate exception. + if context is not None and status == "RUNNING": return Failed() + if status == "RUNNING": + raise RuntimeError("Job is in RUNNING state and does not have exit code.") exitcode = response[0]["attempts"][0]["container"]["exitCode"] return exitcode diff --git a/tests/container_collection/batch/test_check_batch_job.py b/tests/container_collection/batch/test_check_batch_job.py new file mode 100644 index 0000000..f352395 --- /dev/null +++ b/tests/container_collection/batch/test_check_batch_job.py @@ -0,0 +1,146 @@ +import os +import sys +import unittest +from typing import Optional +from unittest import mock + +import boto3 +from prefect import flow +from prefect.exceptions import FailedRun +from prefect.testing.utilities import prefect_test_harness + +from container_collection.batch import check_batch_job as check_batch_job_task +from container_collection.batch.check_batch_job import RETRIES_EXCEEDED_EXIT_CODE, check_batch_job + +SUCCEEDED_EXIT_CODE = 0 +FAILED_EXIT_CODE = 1 + + +def make_describe_jobs_response(status: Optional[str], exit_code: Optional[int]): + if status is None: + return {"jobs": []} + if status in ("SUCCEEDED", "FAILED"): + return {"jobs": [{"status": status, "attempts": [{"container": {"exitCode": exit_code}}]}]} + return {"jobs": [{"status": status}]} + + +def make_boto_mock(statuses: list[Optional[str]], exit_code: Optional[int] = None): + batch_mock = mock.MagicMock() + boto3_mock = mock.MagicMock(spec=boto3) + boto3_mock.client.return_value = batch_mock + batch_mock.describe_jobs.side_effect = [ + make_describe_jobs_response(status, exit_code) for status in statuses + ] + return boto3_mock + + +@flow +def run_task_under_flow(max_retries: int): + return check_batch_job_task("job-arn", max_retries) + + +@mock.patch.dict( + os.environ, + {"PREFECT_LOGGING_LEVEL": "CRITICAL"}, +) +class TestCheckBatchJob(unittest.TestCase): + @classmethod + def setUpClass(cls): + with prefect_test_harness(): + yield + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["RUNNING"]), + ) + def test_check_batch_job_as_method_running_throws_exception(self): + with self.assertRaises(RuntimeError): + check_batch_job("job-arn", 0) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock([None, "PENDING", "RUNNING"]), + ) + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], "sleep", lambda _: None + ) + def test_check_batch_job_as_method_running_with_waits_throws_exception(self): + with self.assertRaises(RuntimeError): + check_batch_job("job-arn", 0) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["SUCCEEDED"], SUCCEEDED_EXIT_CODE), + ) + def test_check_batch_job_as_method_succeeded(self): + exit_code = check_batch_job("job-arn", 0) + self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["FAILED"], FAILED_EXIT_CODE), + ) + def test_check_batch_job_as_method_failed(self): + exit_code = check_batch_job("job-arn", 0) + self.assertEqual(FAILED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["RUNNING"]), + ) + def test_check_batch_job_as_task_running_below_max_retries_throws_failed_run(self): + max_retries = 1 + with self.assertRaises(FailedRun): + run_task_under_flow(max_retries) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock([None, "PENDING", "RUNNING"]), + ) + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], "sleep", lambda _: None + ) + def test_check_batch_job_as_task_running_below_max_retries_with_waits_throws_failed_run(self): + max_retries = 1 + with self.assertRaises(FailedRun): + run_task_under_flow(max_retries) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock([None]), + ) + def test_check_batch_job_as_task_max_retries_exceeded(self): + max_retries = 0 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(RETRIES_EXCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["SUCCEEDED"], SUCCEEDED_EXIT_CODE), + ) + def test_check_batch_job_as_task_succeeded(self): + max_retries = 1 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.batch.check_batch_job"], + "boto3", + make_boto_mock(["FAILED"], FAILED_EXIT_CODE), + ) + def test_check_batch_job_as_task_failed(self): + max_retries = 1 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(FAILED_EXIT_CODE, exit_code) + + +if __name__ == "__main__": + unittest.main()