From 96af342bd4736752a1aeb238f93a378440977af2 Mon Sep 17 00:00:00 2001 From: "Jessica S. Yu" <15913767+jessicasyu@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:29:01 -0400 Subject: [PATCH] Add unit tests for container collection fargate tasks (#65) * Add docstrings and unit tests for make fargate task task * Add docstrings and unit tests for register fargate task task * Add docstrings and unit tests for submit fargate task task * Add kwargs to submit batch job task * Add docstrings and unit tests for check fargate task task * Add docstrings and unit tests for terminate fargate task task * Bump mypy (1.3.0 -> 1.10.0) --- poetry.lock | 63 ++++---- pyproject.toml | 3 +- .../batch/register_batch_job.py | 2 - .../batch/submit_batch_job.py | 20 ++- .../fargate/check_fargate_task.py | 41 ++++- .../fargate/make_fargate_task.py | 40 ++++- .../fargate/register_fargate_task.py | 18 +++ .../fargate/submit_fargate_task.py | 44 ++++- .../fargate/terminate_fargate_task.py | 19 ++- .../fargate/test_check_fargate_task.py | 151 ++++++++++++++++++ .../fargate/test_make_fargate_task.py | 55 +++++++ .../fargate/test_register_fargate_task.py | 131 +++++++++++++++ .../fargate/test_submit_fargate_task.py | 114 +++++++++++++ .../fargate/test_terminate_fargate_task.py | 36 +++++ 14 files changed, 687 insertions(+), 50 deletions(-) create mode 100644 tests/container_collection/fargate/test_check_fargate_task.py create mode 100644 tests/container_collection/fargate/test_make_fargate_task.py create mode 100644 tests/container_collection/fargate/test_register_fargate_task.py create mode 100644 tests/container_collection/fargate/test_submit_fargate_task.py create mode 100644 tests/container_collection/fargate/test_terminate_fargate_task.py diff --git a/poetry.lock b/poetry.lock index 5d22506..164b5af 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1608,48 +1608,49 @@ xray = ["aws-xray-sdk (>=0.93,!=0.96)", "setuptools"] [[package]] name = "mypy" -version = "1.3.0" +version = "1.11.2" description = "Optional static typing for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mypy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c1eb485cea53f4f5284e5baf92902cd0088b24984f4209e25981cc359d64448d"}, - {file = "mypy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c99c3ecf223cf2952638da9cd82793d8f3c0c5fa8b6ae2b2d9ed1e1ff51ba85"}, - {file = "mypy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550a8b3a19bb6589679a7c3c31f64312e7ff482a816c96e0cecec9ad3a7564dd"}, - {file = "mypy-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cbc07246253b9e3d7d74c9ff948cd0fd7a71afcc2b77c7f0a59c26e9395cb152"}, - {file = "mypy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:a22435632710a4fcf8acf86cbd0d69f68ac389a3892cb23fbad176d1cddaf228"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6e33bb8b2613614a33dff70565f4c803f889ebd2f859466e42b46e1df76018dd"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7d23370d2a6b7a71dc65d1266f9a34e4cde9e8e21511322415db4b26f46f6b8c"}, - {file = "mypy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:658fe7b674769a0770d4b26cb4d6f005e88a442fe82446f020be8e5f5efb2fae"}, - {file = "mypy-1.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e42d29e324cdda61daaec2336c42512e59c7c375340bd202efa1fe0f7b8f8ca"}, - {file = "mypy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:d0b6c62206e04061e27009481cb0ec966f7d6172b5b936f3ead3d74f29fe3dcf"}, - {file = "mypy-1.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:76ec771e2342f1b558c36d49900dfe81d140361dd0d2df6cd71b3db1be155409"}, - {file = "mypy-1.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc95f8386314272bbc817026f8ce8f4f0d2ef7ae44f947c4664efac9adec929"}, - {file = "mypy-1.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:faff86aa10c1aa4a10e1a301de160f3d8fc8703b88c7e98de46b531ff1276a9a"}, - {file = "mypy-1.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8c5979d0deb27e0f4479bee18ea0f83732a893e81b78e62e2dda3e7e518c92ee"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c5d2cc54175bab47011b09688b418db71403aefad07cbcd62d44010543fc143f"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:87df44954c31d86df96c8bd6e80dfcd773473e877ac6176a8e29898bfb3501cb"}, - {file = "mypy-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473117e310febe632ddf10e745a355714e771ffe534f06db40702775056614c4"}, - {file = "mypy-1.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:74bc9b6e0e79808bf8678d7678b2ae3736ea72d56eede3820bd3849823e7f305"}, - {file = "mypy-1.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:44797d031a41516fcf5cbfa652265bb994e53e51994c1bd649ffcd0c3a7eccbf"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ddae0f39ca146972ff6bb4399f3b2943884a774b8771ea0a8f50e971f5ea5ba8"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c4c42c60a8103ead4c1c060ac3cdd3ff01e18fddce6f1016e08939647a0e703"}, - {file = "mypy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e86c2c6852f62f8f2b24cb7a613ebe8e0c7dc1402c61d36a609174f63e0ff017"}, - {file = "mypy-1.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f9dca1e257d4cc129517779226753dbefb4f2266c4eaad610fc15c6a7e14283e"}, - {file = "mypy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:95d8d31a7713510685b05fbb18d6ac287a56c8f6554d88c19e73f724a445448a"}, - {file = "mypy-1.3.0-py3-none-any.whl", hash = "sha256:a8763e72d5d9574d45ce5881962bc8e9046bf7b375b0abf031f3e6811732a897"}, - {file = "mypy-1.3.0.tar.gz", hash = "sha256:e1f4d16e296f5135624b34e8fb741eb0eadedca90862405b1f1fde2040b9bd11"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] install-types = ["pip"] -python2 = ["typed-ast (>=1.4.0,<2)"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -3711,4 +3712,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8539c9c4bc7ae283e800ad68a6674f40af5b4175a4521b582b664cb2135b5bd2" +content-hash = "80b06c5cc42cde0fb898aec773b0b2988451dd60b87040d7bdad23268ead2700" diff --git a/pyproject.toml b/pyproject.toml index 34404ca..29d6195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ tabulate = "^0.9.0" [tool.poetry.group.dev.dependencies] black = "^24.3.0" isort = "^5.12.0" -mypy = "^1.3.0" +mypy = "^1.10.0" pylint = "^2.16.2" pytest = "^7.3.0" pytest-cov = "^4.0.0" @@ -63,6 +63,7 @@ module = [ "deepdiff.*", "docker.*", "pandas.*", + "ruamel.*", "tabulate.*", ] ignore_missing_imports = true diff --git a/src/container_collection/batch/register_batch_job.py b/src/container_collection/batch/register_batch_job.py index 3e885c0..615e1de 100644 --- a/src/container_collection/batch/register_batch_job.py +++ b/src/container_collection/batch/register_batch_job.py @@ -22,7 +22,6 @@ def register_batch_job(job_definition: dict) -> str: """ client = boto3.client("batch") - response = client.describe_job_definitions( jobDefinitionName=job_definition["jobDefinitionName"], ) @@ -35,5 +34,4 @@ def register_batch_job(job_definition: dict) -> str: return existing_definition["jobDefinitionArn"] response = client.register_job_definition(**job_definition) - return response["jobDefinitionArn"] diff --git a/src/container_collection/batch/submit_batch_job.py b/src/container_collection/batch/submit_batch_job.py index af81b1d..0c9066e 100644 --- a/src/container_collection/batch/submit_batch_job.py +++ b/src/container_collection/batch/submit_batch_job.py @@ -1,11 +1,18 @@ +from typing import Any + import boto3 def submit_batch_job( - name: str, job_definition_arn: str, user: str, queue: str, size: int + name: str, + job_definition_arn: str, + user: str, + queue: str, + size: int, + **kwargs: Any, ) -> list[str]: """ - Submit job to AWS Batch. + Submit AWS Batch job. Parameters ---------- @@ -19,6 +26,10 @@ def submit_batch_job( Job queue. size Number of jobs in array. + **kwargs + Additional parameters for job submission. The keyword arguments are + passed to `boto3` Batch client method `submit_job`. + Returns ------- @@ -26,16 +37,17 @@ def submit_batch_job( List of job ARNs. """ - job_submission = { + default_job_submission = { "jobName": f"{user}_{name}", "jobQueue": queue, "jobDefinition": job_definition_arn, } if size > 1: - job_submission["arrayProperties"] = {"size": size} # type: ignore + default_job_submission["arrayProperties"] = {"size": size} # type: ignore client = boto3.client("batch") + job_submission = default_job_submission | kwargs response = client.submit_job(**job_submission) if size > 1: diff --git a/src/container_collection/fargate/check_fargate_task.py b/src/container_collection/fargate/check_fargate_task.py index e52428b..ec2fc5d 100644 --- a/src/container_collection/fargate/check_fargate_task.py +++ b/src/container_collection/fargate/check_fargate_task.py @@ -2,16 +2,44 @@ from typing import Union import boto3 -import prefect -from prefect.server.schemas.states import Failed, State +from prefect.context import TaskRunContext +from prefect.states import Failed, State RETRIES_EXCEEDED_EXIT_CODE = 80 +"""Exit code used when task run retries exceed the maximum retries.""" def check_fargate_task(cluster: str, task_arn: str, max_retries: int) -> Union[int, State]: - task_run = prefect.context.get_run_context().task_run # type: ignore + """ + Check for exit code of an AWS Fargate task. - if task_run.run_count > max_retries: + If this task is running within a Prefect flow, it will use the task run + context to get the current run count. While the run count is below the + maximum number of retries, the task will continue to attempt to get the exit + code, and can be called with a retry delay to periodically check the status + of jobs. + + If this task is not running within a Prefect flow, the ``max_retries`` + parameters is ignored. Tasks that are still running will throw an exception. + + Parameters + ---------- + cluster + ECS cluster name. + task_arn : str + Task ARN. + max_retries + Maximum number of retries. + + Returns + ------- + : + Exit code if the job is complete, otherwise throws an exception. + """ + + context = TaskRunContext.get() + + if context is not None and context.task_run.run_count > max_retries: return RETRIES_EXCEEDED_EXIT_CODE client = boto3.client("ecs") @@ -30,8 +58,11 @@ def check_fargate_task(cluster: str, task_arn: str, max_retries: int) -> Union[i response = client.describe_tasks(cluster=cluster, tasks=[task_arn])["tasks"] status = response[0]["lastStatus"] - if status == "RUNNING": + # For tasks that are running, throw the appropriate exception. + if context is not None and status == "RUNNING": return Failed() + if status == "RUNNING": + raise RuntimeError("Task is in RUNNING state and does not have exit code.") exitcode = response[0]["containers"][0]["exitCode"] return exitcode diff --git a/src/container_collection/fargate/make_fargate_task.py b/src/container_collection/fargate/make_fargate_task.py index dbc61fa..f842e0c 100644 --- a/src/container_collection/fargate/make_fargate_task.py +++ b/src/container_collection/fargate/make_fargate_task.py @@ -1,17 +1,49 @@ def make_fargate_task( name: str, image: str, - account: str, region: str, user: str, vcpus: int, memory: int, + task_role_arn: str, + execution_role_arn: str, ) -> dict: + """ + Create Fargate task definition. + + Docker images on the Docker Hub registry are available by default, and can + be specified using ``image:tag``. Otherwise, use ``repository/image:tag``. + + Parameters + ---------- + name + Task definition name. + image + Docker image. + region + Logging region. + user + User name prefix for task name. + vcpus + Hard limit of CPU units to present to the task. + memory + Hard limit of memory to present to the task. + task_role_arn + ARN for IAM role for the task container. + execution_role_arn : str + ARN for IAM role for the container agent. + + Returns + ------- + : + Task definition. + """ + return { "containerDefinitions": [ { "name": f"{user}_{name}", - "image": f"{account}.dkr.ecr.{region}.amazonaws.com/{user}/{image}", + "image": image, "essential": True, "portMappings": [], "environment": [], @@ -31,8 +63,8 @@ def make_fargate_task( "family": f"{user}_{name}", "networkMode": "awsvpc", "requiresCompatibilities": ["FARGATE"], - "taskRoleArn": f"arn:aws:iam::{account}:role/BatchJobRole", - "executionRoleArn": f"arn:aws:iam::{account}:role/ecsTaskExecutionRole", + "taskRoleArn": task_role_arn, + "executionRoleArn": execution_role_arn, "cpu": str(vcpus), "memory": str(memory), } diff --git a/src/container_collection/fargate/register_fargate_task.py b/src/container_collection/fargate/register_fargate_task.py index 59aa72d..7ed2bf3 100644 --- a/src/container_collection/fargate/register_fargate_task.py +++ b/src/container_collection/fargate/register_fargate_task.py @@ -3,6 +3,24 @@ def register_fargate_task(task_definition: dict) -> str: + """ + Register task definition to ECS Fargate. + + If a definition for the given task definition name already exists, and the + contents of the definition are not changed, then the method will return the + existing task definition ARN rather than creating a new revision. + + Parameters + ---------- + task_definition + Fargate task definition. + + Returns + ------- + : + Task definition ARN. + """ + client = boto3.client("ecs") response = client.list_task_definitions(familyPrefix=task_definition["family"]) diff --git a/src/container_collection/fargate/submit_fargate_task.py b/src/container_collection/fargate/submit_fargate_task.py index 70edbfa..c0ffe2d 100644 --- a/src/container_collection/fargate/submit_fargate_task.py +++ b/src/container_collection/fargate/submit_fargate_task.py @@ -1,3 +1,5 @@ +from typing import Any + import boto3 @@ -9,8 +11,38 @@ def submit_fargate_task( security_groups: list[str], subnets: list[str], command: list[str], + **kwargs: Any, ) -> list[str]: - task_submission = { + """ + Submit task to AWS Fargate. + + Parameters + ---------- + name + Task name. + task_definition_arn + Task definition ARN. + user + User name prefix for task name. + cluster + ECS cluster name. + security_groups + List of security groups. + subnets + List of subnets. + command + Command list passed to container. + **kwargs + Additional parameters for task submission. The keyword arguments are + passed to `boto3` ECS client method `run_task`. + + Returns + ------- + : + Task ARN. + """ + + default_task_submission = { "taskDefinition": task_definition_arn, "capacityProviderStrategy": [ {"capacityProvider": "FARGATE", "weight": 1}, @@ -26,10 +58,18 @@ def submit_fargate_task( "securityGroups": security_groups, } }, - "overrides": {"containerOverrides": [{"name": f"{user}_{name}", "command": command}]}, + "overrides": { + "containerOverrides": [ + { + "name": f"{user}_{name}", + "command": command, + } + ] + }, } client = boto3.client("ecs") + task_submission = default_task_submission | kwargs response = client.run_task(**task_submission) return response["tasks"][0]["taskArn"] diff --git a/src/container_collection/fargate/terminate_fargate_task.py b/src/container_collection/fargate/terminate_fargate_task.py index 15e87cc..35c5888 100644 --- a/src/container_collection/fargate/terminate_fargate_task.py +++ b/src/container_collection/fargate/terminate_fargate_task.py @@ -1,6 +1,23 @@ +from time import sleep + import boto3 +TERMINATION_REASON = "Termination requested by workflow." +"""Reason sent for terminating jobs from a workflow.""" + def terminate_fargate_task(cluster: str, task_arn: str) -> None: + """ + Terminate task on AWS Fargate. + + Parameters + ---------- + cluster + ECS cluster name. + task_arn + Task ARN. + """ + client = boto3.client("ecs") - client.stop_task(cluster=cluster, task=task_arn, reason="Prefect workflow termination") + client.stop_task(cluster=cluster, task=task_arn, reason=TERMINATION_REASON) + sleep(60) diff --git a/tests/container_collection/fargate/test_check_fargate_task.py b/tests/container_collection/fargate/test_check_fargate_task.py new file mode 100644 index 0000000..1e25f8b --- /dev/null +++ b/tests/container_collection/fargate/test_check_fargate_task.py @@ -0,0 +1,151 @@ +import os +import sys +import unittest +from typing import Optional +from unittest import mock + +import boto3 +from prefect import flow +from prefect.exceptions import FailedRun +from prefect.testing.utilities import prefect_test_harness + +from container_collection.fargate import check_fargate_task as check_fargate_task_task +from container_collection.fargate.check_fargate_task import ( + RETRIES_EXCEEDED_EXIT_CODE, + check_fargate_task, +) + +SUCCEEDED_EXIT_CODE = 0 +FAILED_EXIT_CODE = 1 + + +def make_describe_tasks_response(status: Optional[str], exit_code: Optional[int]): + if status is None: + return {"tasks": []} + if status == "STOPPED": + return {"tasks": [{"lastStatus": status, "containers": [{"exitCode": exit_code}]}]} + return {"tasks": [{"lastStatus": status}]} + + +def make_boto_mock(statuses: list[Optional[str]], exit_code: Optional[int] = None): + ecs_mock = mock.MagicMock() + boto3_mock = mock.MagicMock(spec=boto3) + boto3_mock.client.return_value = ecs_mock + ecs_mock.describe_tasks.side_effect = [ + make_describe_tasks_response(status, exit_code) for status in statuses + ] + return boto3_mock + + +@flow +def run_task_under_flow(max_retries: int): + return check_fargate_task_task("cluster", "task-arn", max_retries) + + +@mock.patch.dict( + os.environ, + {"PREFECT_LOGGING_LEVEL": "CRITICAL"}, +) +class TestCheckFargateTask(unittest.TestCase): + @classmethod + def setUpClass(cls): + with prefect_test_harness(): + yield + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["RUNNING"]), + ) + def test_check_fargate_task_as_method_running_throws_exception(self): + with self.assertRaises(RuntimeError): + check_fargate_task("cluster", "task-arn", 0) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock([None, "PENDING", "RUNNING"]), + ) + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], "sleep", lambda _: None + ) + def test_check_fargate_task_as_method_running_with_waits_throws_exception(self): + with self.assertRaises(RuntimeError): + check_fargate_task("cluster", "task-arn", 0) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["STOPPED"], SUCCEEDED_EXIT_CODE), + ) + def test_check_fargate_task_as_method_succeeded(self): + exit_code = check_fargate_task("cluster", "task-arn", 0) + self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["STOPPED"], FAILED_EXIT_CODE), + ) + def test_check_fargate_task_as_method_failed(self): + exit_code = check_fargate_task("cluster", "task-arn", 0) + self.assertEqual(FAILED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["RUNNING"]), + ) + def test_check_fargate_task_as_task_running_below_max_retries_throws_failed_run(self): + max_retries = 1 + with self.assertRaises(FailedRun): + run_task_under_flow(max_retries) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock([None, "PENDING", "RUNNING"]), + ) + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], "sleep", lambda _: None + ) + def test_check_fargate_task_as_task_running_below_max_retries_with_waits_throws_failed_run( + self, + ): + max_retries = 1 + with self.assertRaises(FailedRun): + run_task_under_flow(max_retries) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock([None]), + ) + def test_check_fargate_task_as_task_max_retries_exceeded(self): + max_retries = 0 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(RETRIES_EXCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["STOPPED"], SUCCEEDED_EXIT_CODE), + ) + def test_check_fargate_task_as_task_succeeded(self): + max_retries = 1 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(SUCCEEDED_EXIT_CODE, exit_code) + + @mock.patch.object( + sys.modules["container_collection.fargate.check_fargate_task"], + "boto3", + make_boto_mock(["STOPPED"], FAILED_EXIT_CODE), + ) + def test_check_fargate_task_as_task_failed(self): + max_retries = 1 + exit_code = run_task_under_flow(max_retries) + self.assertEqual(FAILED_EXIT_CODE, exit_code) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/container_collection/fargate/test_make_fargate_task.py b/tests/container_collection/fargate/test_make_fargate_task.py new file mode 100644 index 0000000..251f851 --- /dev/null +++ b/tests/container_collection/fargate/test_make_fargate_task.py @@ -0,0 +1,55 @@ +import unittest + +from container_collection.fargate.make_fargate_task import make_fargate_task + + +class TestMakeFargateTask(unittest.TestCase): + def test_make_fargate_task(self): + name = "job_name" + image = "image_name" + region = "region_name" + user = "user_name" + vcpus = 10 + memory = 20 + task_role_arn = "task_role_arn" + execution_role_arn = "execution_role_arn" + + expected_task = { + "containerDefinitions": [ + { + "name": f"{user}_{name}", + "image": image, + "essential": True, + "portMappings": [], + "environment": [], + "mountPoints": [], + "volumesFrom": [], + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": f"/ecs/{user}_{name}", + "awslogs-region": region, + "awslogs-stream-prefix": "ecs", + "awslogs-create-group": "true", + }, + }, + } + ], + "family": f"{user}_{name}", + "networkMode": "awsvpc", + "requiresCompatibilities": ["FARGATE"], + "taskRoleArn": task_role_arn, + "executionRoleArn": execution_role_arn, + "cpu": str(vcpus), + "memory": str(memory), + } + + task = make_fargate_task( + name, image, region, user, vcpus, memory, task_role_arn, execution_role_arn + ) + + self.assertDictEqual(expected_task, task) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/container_collection/fargate/test_register_fargate_task.py b/tests/container_collection/fargate/test_register_fargate_task.py new file mode 100644 index 0000000..366c3ef --- /dev/null +++ b/tests/container_collection/fargate/test_register_fargate_task.py @@ -0,0 +1,131 @@ +import os +import unittest +from unittest import mock + +import boto3 +from moto import mock_aws + +from container_collection.fargate.register_fargate_task import register_fargate_task + +ACCOUNT = "123123123123" +REGION = "default-region" + + +@mock.patch.dict( + os.environ, + { + "MOTO_ALLOW_NONEXISTENT_REGION": "True", + "MOTO_ACCOUNT_ID": ACCOUNT, + "AWS_DEFAULT_REGION": REGION, + }, +) +class TestRegisterFargateTask(unittest.TestCase): + def setUp(self) -> None: + self.name = "task-definition-name" + self.image = "jobimage:latest" + self.vcpus = "1" + self.memory = "256" + + @mock_aws + def test_register_fargate_task_new_definition(self): + task_definition = { + "family": self.name, + "containerDefinitions": [ + { + "name": self.name, + "image": self.image, + } + ], + "cpu": self.vcpus, + "memory": self.memory, + } + + expected_task_arn = f"arn:aws:ecs:{REGION}:{ACCOUNT}:task-definition/{self.name}:1" + + task_arn = register_fargate_task(task_definition) + + ecs_client = boto3.client("ecs") + response = ecs_client.describe_task_definition(taskDefinition=self.name) + registered_task = response["taskDefinition"] + + self.assertEqual(expected_task_arn, task_arn) + self.assertEqual(self.name, registered_task["family"]) + self.assertEqual(self.image, registered_task["containerDefinitions"][0]["image"]) + self.assertEqual(self.vcpus, registered_task["cpu"]) + self.assertEqual(self.memory, registered_task["memory"]) + + @mock_aws + def test_register_fargate_task_updated_definition(self): + vcpus_modified = str(int(self.vcpus) + 1) + + first_task_definition = { + "family": self.name, + "containerDefinitions": [ + { + "name": self.name, + "image": self.image, + } + ], + "cpu": self.vcpus, + "memory": self.memory, + } + + second_task_definition = { + "family": self.name, + "containerDefinitions": [ + { + "name": self.name, + "image": self.image, + } + ], + "cpu": vcpus_modified, + "memory": self.memory, + } + + expected_task_arn = f"arn:aws:ecs:{REGION}:{ACCOUNT}:task-definition/{self.name}:2" + + ecs_client = boto3.client("ecs") + ecs_client.register_task_definition(**first_task_definition) + task_arn = register_fargate_task(second_task_definition) + + response = ecs_client.describe_task_definition(taskDefinition=self.name) + registered_task = response["taskDefinition"] + + self.assertEqual(expected_task_arn, task_arn) + self.assertEqual(self.name, registered_task["family"]) + self.assertEqual(self.image, registered_task["containerDefinitions"][0]["image"]) + self.assertEqual(vcpus_modified, registered_task["cpu"]) + self.assertEqual(self.memory, registered_task["memory"]) + + @mock_aws + def test_register_fargate_task_existing_definition(self): + task_definition = { + "family": self.name, + "containerDefinitions": [ + { + "name": self.name, + "image": self.image, + } + ], + "cpu": self.vcpus, + "memory": self.memory, + } + + expected_task_arn = f"arn:aws:ecs:{REGION}:{ACCOUNT}:task-definition/{self.name}:1" + + ecs_client = boto3.client("ecs") + ecs_client.register_task_definition(**task_definition) + task_arn = register_fargate_task(task_definition) + + response = ecs_client.describe_task_definition(taskDefinition=self.name) + registered_task = response["taskDefinition"] + + self.assertEqual(expected_task_arn, task_arn) + self.assertEqual(self.name, registered_task["family"]) + self.assertEqual(self.image, registered_task["containerDefinitions"][0]["image"]) + self.assertEqual(self.vcpus, registered_task["cpu"]) + self.assertEqual(self.memory, registered_task["memory"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/container_collection/fargate/test_submit_fargate_task.py b/tests/container_collection/fargate/test_submit_fargate_task.py new file mode 100644 index 0000000..6651a6f --- /dev/null +++ b/tests/container_collection/fargate/test_submit_fargate_task.py @@ -0,0 +1,114 @@ +import os +import unittest +from unittest import mock + +import boto3 +from moto import mock_aws + +from container_collection.fargate.submit_fargate_task import submit_fargate_task + +ACCOUNT = "123123123123" +REGION = "default-region" +CLUSTER = "cluster-name" + + +@mock.patch.dict( + os.environ, + { + "MOTO_ALLOW_NONEXISTENT_REGION": "True", + "MOTO_ACCOUNT_ID": ACCOUNT, + "AWS_DEFAULT_REGION": REGION, + }, +) +class TestSubmitFargateTask(unittest.TestCase): + def setUp(self) -> None: + # Note that the memory specified in the container definition is not + # actually required for Fargate, but is used in the moto library to + # determine resource requirements. + self.task_definition = { + "family": "task-definition-name", + "containerDefinitions": [ + { + "name": "task-definition-name", + "image": "jobimage:latest", + "memory": 256, + } + ], + "requiresCompatibilities": ["FARGATE"], + "cpu": "1", + "memory": "256", + } + + arn_prefix = f"arn:aws:ecs:{REGION}:{ACCOUNT}" + self.task_definition_arn = f"{arn_prefix}:task-definition/task-definition-name:1" + self.cluster_arn = f"{arn_prefix}:cluster/{CLUSTER}" + + @mock_aws + def test_submit_fargate_task(self): + name = "task-name" + user = "user" + command = ["command", "string"] + cluster = "cluster-name" + + initialize_infrastructure() + ec2_client = boto3.client("ec2") + ecs_client = boto3.client("ecs") + ecs_client.register_task_definition(**self.task_definition) + + security_group_id = ec2_client.describe_security_groups()["SecurityGroups"][0]["GroupId"] + subnet_id = ec2_client.describe_subnets()["Subnets"][0]["SubnetId"] + + task_arn = submit_fargate_task( + name, + self.task_definition_arn, + user, + cluster, + [security_group_id], + [subnet_id], + command, + launchType="FARGATE", + capacityProviderStrategy=[], + ) + + response = ecs_client.describe_tasks(cluster=cluster, tasks=[task_arn]) + task = response["tasks"][0] + self.assertEqual(1, len(response["tasks"])) + self.assertEqual(f"{user}_{name}", task["overrides"]["containerOverrides"][0]["name"]) + self.assertEqual(command, task["overrides"]["containerOverrides"][0]["command"]) + self.assertEqual(task_arn, task["taskArn"]) + self.assertEqual(self.cluster_arn, task["clusterArn"]) + self.assertEqual(self.task_definition_arn, task["taskDefinitionArn"]) + + +def initialize_infrastructure() -> None: + # Create clients. + ec2_client = boto3.client("ec2") + ecs_client = boto3.client("ecs") + + # Create VPC. + vpc = ec2_client.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + + # Create subnet. + ec2_client.create_subnet(AvailabilityZone="us-east-1a", CidrBlock="10.0.0.0/18", VpcId=vpc_id) + + # Create security group. + ec2_client.create_security_group( + Description="test security group description", + GroupName="security-group-name", + VpcId=vpc_id, + ) + + # Create ECS Fargate cluster. + ecs_client.create_cluster( + clusterName=CLUSTER, + capacityProviders=["FARGATE", "FARGATE_SPOT"], + defaultCapacityProviderStrategy=[ + {"capacityProvider": "FARGATE_SPOT", "weight": 5, "base": 0}, + {"capacityProvider": "FARGATE", "weight": 5, "base": 0}, + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/container_collection/fargate/test_terminate_fargate_task.py b/tests/container_collection/fargate/test_terminate_fargate_task.py new file mode 100644 index 0000000..75137b8 --- /dev/null +++ b/tests/container_collection/fargate/test_terminate_fargate_task.py @@ -0,0 +1,36 @@ +import sys +import unittest +from unittest import mock + +import boto3 + +from container_collection.fargate.terminate_fargate_task import ( + TERMINATION_REASON, + terminate_fargate_task, +) + + +class TestTerminateFargateTask(unittest.TestCase): + @mock.patch.object( + sys.modules["container_collection.fargate.terminate_fargate_task"], + "boto3", + return_value=mock.MagicMock(spec=boto3), + ) + @mock.patch.object( + sys.modules["container_collection.fargate.terminate_fargate_task"], "sleep", lambda _: None + ) + def test_terminate_fargate_task(self, boto3_mock): + ecs_mock = mock.MagicMock() + boto3_mock.client.return_value = ecs_mock + + task_arn = "task-arn" + cluster = "cluster" + terminate_fargate_task(cluster, task_arn) + + ecs_mock.stop_task.assert_called_with( + cluster=cluster, task=task_arn, reason=TERMINATION_REASON + ) + + +if __name__ == "__main__": + unittest.main()