Skip to content

Commit

Permalink
Add unit tests for container collection fargate tasks (#65)
Browse files Browse the repository at this point in the history
* Add docstrings and unit tests for make fargate task task

* Add docstrings and unit tests for register fargate task task

* Add docstrings and unit tests for submit fargate task task

* Add kwargs to submit batch job task

* Add docstrings and unit tests for check fargate task task

* Add docstrings and unit tests for terminate fargate task task

* Bump mypy (1.3.0 -> 1.10.0)
  • Loading branch information
jessicasyu committed Aug 29, 2024
1 parent a1e5777 commit 96af342
Show file tree
Hide file tree
Showing 14 changed files with 687 additions and 50 deletions.
63 changes: 32 additions & 31 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ tabulate = "^0.9.0"
[tool.poetry.group.dev.dependencies]
black = "^24.3.0"
isort = "^5.12.0"
mypy = "^1.3.0"
mypy = "^1.10.0"
pylint = "^2.16.2"
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
Expand Down Expand Up @@ -63,6 +63,7 @@ module = [
"deepdiff.*",
"docker.*",
"pandas.*",
"ruamel.*",
"tabulate.*",
]
ignore_missing_imports = true
Expand Down
2 changes: 0 additions & 2 deletions src/container_collection/batch/register_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def register_batch_job(job_definition: dict) -> str:
"""

client = boto3.client("batch")

response = client.describe_job_definitions(
jobDefinitionName=job_definition["jobDefinitionName"],
)
Expand All @@ -35,5 +34,4 @@ def register_batch_job(job_definition: dict) -> str:
return existing_definition["jobDefinitionArn"]

response = client.register_job_definition(**job_definition)

return response["jobDefinitionArn"]
20 changes: 16 additions & 4 deletions src/container_collection/batch/submit_batch_job.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Any

import boto3


def submit_batch_job(
name: str, job_definition_arn: str, user: str, queue: str, size: int
name: str,
job_definition_arn: str,
user: str,
queue: str,
size: int,
**kwargs: Any,
) -> list[str]:
"""
Submit job to AWS Batch.
Submit AWS Batch job.
Parameters
----------
Expand All @@ -19,23 +26,28 @@ def submit_batch_job(
Job queue.
size
Number of jobs in array.
**kwargs
Additional parameters for job submission. The keyword arguments are
passed to `boto3` Batch client method `submit_job`.
Returns
-------
:
List of job ARNs.
"""

job_submission = {
default_job_submission = {
"jobName": f"{user}_{name}",
"jobQueue": queue,
"jobDefinition": job_definition_arn,
}

if size > 1:
job_submission["arrayProperties"] = {"size": size} # type: ignore
default_job_submission["arrayProperties"] = {"size": size} # type: ignore

client = boto3.client("batch")
job_submission = default_job_submission | kwargs
response = client.submit_job(**job_submission)

if size > 1:
Expand Down
41 changes: 36 additions & 5 deletions src/container_collection/fargate/check_fargate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,44 @@
from typing import Union

import boto3
import prefect
from prefect.server.schemas.states import Failed, State
from prefect.context import TaskRunContext
from prefect.states import Failed, State

RETRIES_EXCEEDED_EXIT_CODE = 80
"""Exit code used when task run retries exceed the maximum retries."""


def check_fargate_task(cluster: str, task_arn: str, max_retries: int) -> Union[int, State]:
task_run = prefect.context.get_run_context().task_run # type: ignore
"""
Check for exit code of an AWS Fargate task.
if task_run.run_count > max_retries:
If this task is running within a Prefect flow, it will use the task run
context to get the current run count. While the run count is below the
maximum number of retries, the task will continue to attempt to get the exit
code, and can be called with a retry delay to periodically check the status
of jobs.
If this task is not running within a Prefect flow, the ``max_retries``
parameters is ignored. Tasks that are still running will throw an exception.
Parameters
----------
cluster
ECS cluster name.
task_arn : str
Task ARN.
max_retries
Maximum number of retries.
Returns
-------
:
Exit code if the job is complete, otherwise throws an exception.
"""

context = TaskRunContext.get()

if context is not None and context.task_run.run_count > max_retries:
return RETRIES_EXCEEDED_EXIT_CODE

client = boto3.client("ecs")
Expand All @@ -30,8 +58,11 @@ def check_fargate_task(cluster: str, task_arn: str, max_retries: int) -> Union[i
response = client.describe_tasks(cluster=cluster, tasks=[task_arn])["tasks"]
status = response[0]["lastStatus"]

if status == "RUNNING":
# For tasks that are running, throw the appropriate exception.
if context is not None and status == "RUNNING":
return Failed()
if status == "RUNNING":
raise RuntimeError("Task is in RUNNING state and does not have exit code.")

exitcode = response[0]["containers"][0]["exitCode"]
return exitcode
40 changes: 36 additions & 4 deletions src/container_collection/fargate/make_fargate_task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,49 @@
def make_fargate_task(
name: str,
image: str,
account: str,
region: str,
user: str,
vcpus: int,
memory: int,
task_role_arn: str,
execution_role_arn: str,
) -> dict:
"""
Create Fargate task definition.
Docker images on the Docker Hub registry are available by default, and can
be specified using ``image:tag``. Otherwise, use ``repository/image:tag``.
Parameters
----------
name
Task definition name.
image
Docker image.
region
Logging region.
user
User name prefix for task name.
vcpus
Hard limit of CPU units to present to the task.
memory
Hard limit of memory to present to the task.
task_role_arn
ARN for IAM role for the task container.
execution_role_arn : str
ARN for IAM role for the container agent.
Returns
-------
:
Task definition.
"""

return {
"containerDefinitions": [
{
"name": f"{user}_{name}",
"image": f"{account}.dkr.ecr.{region}.amazonaws.com/{user}/{image}",
"image": image,
"essential": True,
"portMappings": [],
"environment": [],
Expand All @@ -31,8 +63,8 @@ def make_fargate_task(
"family": f"{user}_{name}",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"taskRoleArn": f"arn:aws:iam::{account}:role/BatchJobRole",
"executionRoleArn": f"arn:aws:iam::{account}:role/ecsTaskExecutionRole",
"taskRoleArn": task_role_arn,
"executionRoleArn": execution_role_arn,
"cpu": str(vcpus),
"memory": str(memory),
}
18 changes: 18 additions & 0 deletions src/container_collection/fargate/register_fargate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@


def register_fargate_task(task_definition: dict) -> str:
"""
Register task definition to ECS Fargate.
If a definition for the given task definition name already exists, and the
contents of the definition are not changed, then the method will return the
existing task definition ARN rather than creating a new revision.
Parameters
----------
task_definition
Fargate task definition.
Returns
-------
:
Task definition ARN.
"""

client = boto3.client("ecs")
response = client.list_task_definitions(familyPrefix=task_definition["family"])

Expand Down
44 changes: 42 additions & 2 deletions src/container_collection/fargate/submit_fargate_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import boto3


Expand All @@ -9,8 +11,38 @@ def submit_fargate_task(
security_groups: list[str],
subnets: list[str],
command: list[str],
**kwargs: Any,
) -> list[str]:
task_submission = {
"""
Submit task to AWS Fargate.
Parameters
----------
name
Task name.
task_definition_arn
Task definition ARN.
user
User name prefix for task name.
cluster
ECS cluster name.
security_groups
List of security groups.
subnets
List of subnets.
command
Command list passed to container.
**kwargs
Additional parameters for task submission. The keyword arguments are
passed to `boto3` ECS client method `run_task`.
Returns
-------
:
Task ARN.
"""

default_task_submission = {
"taskDefinition": task_definition_arn,
"capacityProviderStrategy": [
{"capacityProvider": "FARGATE", "weight": 1},
Expand All @@ -26,10 +58,18 @@ def submit_fargate_task(
"securityGroups": security_groups,
}
},
"overrides": {"containerOverrides": [{"name": f"{user}_{name}", "command": command}]},
"overrides": {
"containerOverrides": [
{
"name": f"{user}_{name}",
"command": command,
}
]
},
}

client = boto3.client("ecs")
task_submission = default_task_submission | kwargs
response = client.run_task(**task_submission)

return response["tasks"][0]["taskArn"]
Loading

0 comments on commit 96af342

Please sign in to comment.