diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 016b115d8fe..1b63259fae7 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -244,6 +244,8 @@ # in all Metaflow deployments. Hopefully, some day we can flip the # default to True. BATCH_EMIT_TAGS = from_conf("BATCH_EMIT_TAGS", False) +# Default tags to add to AWS Batch jobs. These are in addition to the defaults set when BATCH_EMIT_TAGS is true. +BATCH_DEFAULT_TAGS = from_conf("BATCH_DEFAULT_TAGS", {}) ### # AWS Step Functions configuration diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index ebeede29fa4..5cac570f92b 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -7,31 +7,30 @@ import time from metaflow import util -from metaflow.plugins.datatools.s3.s3tail import S3Tail -from metaflow.plugins.aws.aws_utils import sanitize_batch_tag from metaflow.exception import MetaflowException from metaflow.metaflow_config import ( - SERVICE_INTERNAL_URL, - DATATOOLS_S3ROOT, - DATASTORE_SYSROOT_S3, - DEFAULT_METADATA, - SERVICE_HEADERS, + AWS_SECRETS_MANAGER_DEFAULT_REGION, + BATCH_DEFAULT_TAGS, BATCH_EMIT_TAGS, CARD_S3ROOT, - S3_ENDPOINT_URL, + DATASTORE_SYSROOT_S3, + DATATOOLS_S3ROOT, + DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, - AWS_SECRETS_MANAGER_DEFAULT_REGION, + S3_ENDPOINT_URL, S3_SERVER_SIDE_ENCRYPTION, + SERVICE_HEADERS, + SERVICE_INTERNAL_URL, ) - from metaflow.metaflow_config_funcs import config_values - from metaflow.mflog import ( - export_mflog_env_vars, + BASH_SAVE_LOGS, bash_capture_logs, + export_mflog_env_vars, tail_logs, - BASH_SAVE_LOGS, ) +from metaflow.plugins.aws.aws_utils import sanitize_batch_tag +from metaflow.plugins.datatools.s3.s3tail import S3Tail from .batch_client import BatchClient @@ -63,7 +62,7 @@ def _command(self, environment, code_package_url, step_name, step_cmds, task_spe datastore_type="s3", stdout_path=STDOUT_PATH, stderr_path=STDERR_PATH, - **task_spec + **task_spec, ) init_cmds = environment.get_package_commands(code_package_url, "s3") init_expr = " && ".join(init_cmds) @@ -186,6 +185,7 @@ def create_job( attrs={}, host_volumes=None, use_tmpfs=None, + tags=None, tmpfs_tempdir=None, tmpfs_size=None, tmpfs_path=None, @@ -317,6 +317,20 @@ def create_job( if key in attrs: k, v = sanitize_batch_tag(key, attrs.get(key)) job.tag(k, v) + + if not isinstance(BATCH_DEFAULT_TAGS, dict): + raise BatchException( + "The BATCH_DEFAULT_TAGS config option must be a dictionary of key-value tags." + ) + for name, value in BATCH_DEFAULT_TAGS.items(): + job.tag(name, value) + + # add custom tags last to allow override of defaults + if tags is not None: + if not isinstance(tags, dict): + raise BatchException("tags must be a dictionary of key-value tags.") + for name, value in tags.items(): + job.tag(name, value) return job def launch_job( @@ -342,6 +356,7 @@ def launch_job( efa=None, host_volumes=None, use_tmpfs=None, + tags=None, tmpfs_tempdir=None, tmpfs_size=None, tmpfs_path=None, @@ -380,6 +395,7 @@ def launch_job( attrs=attrs, host_volumes=host_volumes, use_tmpfs=use_tmpfs, + tags=tags, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 6d8fd186af1..0ac6ac56716 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -1,12 +1,11 @@ -from metaflow._vendor import click import os import sys import time import traceback -from metaflow import util -from metaflow import R -from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY +from metaflow import R, util +from metaflow._vendor import click +from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException from metaflow.metadata.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE @@ -146,6 +145,7 @@ def kill(ctx, run_id, user, my_runs): help="Activate designated number of elastic fabric adapter devices. " "EFA driver must be installed and instance type compatible with EFA", ) +@click.option("--aws-tags", multiple=True, default=None, help="AWS tags.") @click.option("--use-tmpfs", is_flag=True, help="tmpfs requirement for AWS Batch.") @click.option("--tmpfs-tempdir", is_flag=True, help="tmpfs requirement for AWS Batch.") @click.option("--tmpfs-size", help="tmpfs requirement for AWS Batch.") @@ -179,13 +179,14 @@ def step( swappiness=None, inferentia=None, efa=None, + aws_tags=None, use_tmpfs=None, tmpfs_tempdir=None, tmpfs_size=None, tmpfs_path=None, host_volumes=None, num_parallel=None, - **kwargs + **kwargs, ): def echo(msg, stream="stderr", batch_id=None, **kwargs): msg = util.to_unicode(msg) @@ -311,12 +312,13 @@ def _sync_metadata(): attrs=attrs, host_volumes=host_volumes, use_tmpfs=use_tmpfs, + tags=aws_tags, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, num_parallel=num_parallel, ) - except Exception as e: + except Exception: traceback.print_exc() _sync_metadata() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index c62c70c5385..6465d4aa753 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -1,34 +1,32 @@ import os -import sys import platform -import requests +import sys import time -from metaflow import util -from metaflow import R, current +import requests +from metaflow import R, current from metaflow.decorators import StepDecorator -from metaflow.plugins.resources_decorator import ResourcesDecorator -from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task from metaflow.metadata import MetaDatum from metaflow.metadata.util import sync_local_metadata_to_datastore from metaflow.metaflow_config import ( - ECS_S3_ACCESS_IAM_ROLE, - BATCH_JOB_QUEUE, BATCH_CONTAINER_IMAGE, BATCH_CONTAINER_REGISTRY, - ECS_FARGATE_EXECUTION_ROLE, + BATCH_JOB_QUEUE, DATASTORE_LOCAL_DIR, + ECS_FARGATE_EXECUTION_ROLE, + ECS_S3_ACCESS_IAM_ROLE, ) +from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task from metaflow.sidecar import Sidecar from metaflow.unbounded_foreach import UBF_CONTROL -from .batch import BatchException from ..aws_utils import ( compute_resource_attributes, get_docker_registry, get_ec2_instance_metadata, ) +from .batch import BatchException class BatchDecorator(StepDecorator): @@ -73,6 +71,9 @@ class BatchDecorator(StepDecorator): aggressively. Accepted values are whole numbers between 0 and 100. use_tmpfs: bool, default: False This enables an explicit tmpfs mount for this step. + tags: map, optional + Sets arbitrary AWS tags on the AWS Batch compute environment. + Set as string key-value pairs. tmpfs_tempdir: bool, default: True sets METAFLOW_TEMPDIR to tmpfs_path if set for this step. tmpfs_size: int, optional @@ -103,6 +104,7 @@ class BatchDecorator(StepDecorator): "efa": None, "host_volumes": None, "use_tmpfs": False, + "tags": None, "tmpfs_tempdir": True, "tmpfs_size": None, "tmpfs_path": "/metaflow_temp", @@ -346,7 +348,7 @@ def _wait_for_mapper_tasks(self, flow, step_name): len(flow._control_mapper_tasks), ) ) - except Exception as e: + except Exception: pass raise Exception( "Batch secondary workers did not finish in %s seconds" % TIMEOUT