Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add capability to add custom tags to batch
Browse files Browse the repository at this point in the history
Fixes #1627
Limess committed Nov 8, 2023
1 parent 6badc1d commit d1199b4
Showing 4 changed files with 53 additions and 31 deletions.
2 changes: 2 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
@@ -244,6 +244,8 @@
# in all Metaflow deployments. Hopefully, some day we can flip the
# default to True.
BATCH_EMIT_TAGS = from_conf("BATCH_EMIT_TAGS", False)
# Default tags to add to AWS Batch jobs. These are in addition to the defaults set when BATCH_EMIT_TAGS is true.
BATCH_DEFAULT_TAGS = from_conf("BATCH_DEFAULT_TAGS", {})

###
# AWS Step Functions configuration
44 changes: 30 additions & 14 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
@@ -7,31 +7,30 @@
import time

from metaflow import util
from metaflow.plugins.datatools.s3.s3tail import S3Tail
from metaflow.plugins.aws.aws_utils import sanitize_batch_tag
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import (
SERVICE_INTERNAL_URL,
DATATOOLS_S3ROOT,
DATASTORE_SYSROOT_S3,
DEFAULT_METADATA,
SERVICE_HEADERS,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
BATCH_DEFAULT_TAGS,
BATCH_EMIT_TAGS,
CARD_S3ROOT,
S3_ENDPOINT_URL,
DATASTORE_SYSROOT_S3,
DATATOOLS_S3ROOT,
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
S3_ENDPOINT_URL,
S3_SERVER_SIDE_ENCRYPTION,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
)

from metaflow.metaflow_config_funcs import config_values

from metaflow.mflog import (
export_mflog_env_vars,
BASH_SAVE_LOGS,
bash_capture_logs,
export_mflog_env_vars,
tail_logs,
BASH_SAVE_LOGS,
)
from metaflow.plugins.aws.aws_utils import sanitize_batch_tag
from metaflow.plugins.datatools.s3.s3tail import S3Tail

from .batch_client import BatchClient

@@ -63,7 +62,7 @@ def _command(self, environment, code_package_url, step_name, step_cmds, task_spe
datastore_type="s3",
stdout_path=STDOUT_PATH,
stderr_path=STDERR_PATH,
**task_spec
**task_spec,
)
init_cmds = environment.get_package_commands(code_package_url, "s3")
init_expr = " && ".join(init_cmds)
@@ -186,6 +185,7 @@ def create_job(
attrs={},
host_volumes=None,
use_tmpfs=None,
tags=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
@@ -317,6 +317,20 @@ def create_job(
if key in attrs:
k, v = sanitize_batch_tag(key, attrs.get(key))
job.tag(k, v)

if not isinstance(BATCH_DEFAULT_TAGS, dict):
raise BatchException(
"The BATCH_DEFAULT_TAGS config option must be a dictionary of key-value tags."
)
for name, value in BATCH_DEFAULT_TAGS.items():
job.tag(name, value)

# add custom tags last to allow override of defaults
if tags is not None:
if not isinstance(tags, dict):
raise BatchException("tags must be a dictionary of key-value tags.")
for name, value in tags.items():
job.tag(name, value)
return job

def launch_job(
@@ -342,6 +356,7 @@ def launch_job(
efa=None,
host_volumes=None,
use_tmpfs=None,
tags=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
@@ -380,6 +395,7 @@ def launch_job(
attrs=attrs,
host_volumes=host_volumes,
use_tmpfs=use_tmpfs,
tags=tags,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
14 changes: 8 additions & 6 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from metaflow._vendor import click
import os
import sys
import time
import traceback

from metaflow import util
from metaflow import R
from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY
from metaflow import R, util
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE
@@ -146,6 +145,7 @@ def kill(ctx, run_id, user, my_runs):
help="Activate designated number of elastic fabric adapter devices. "
"EFA driver must be installed and instance type compatible with EFA",
)
@click.option("--aws-tags", multiple=True, default=None, help="AWS tags.")
@click.option("--use-tmpfs", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-tempdir", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-size", help="tmpfs requirement for AWS Batch.")
@@ -179,13 +179,14 @@ def step(
swappiness=None,
inferentia=None,
efa=None,
aws_tags=None,
use_tmpfs=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
host_volumes=None,
num_parallel=None,
**kwargs
**kwargs,
):
def echo(msg, stream="stderr", batch_id=None, **kwargs):
msg = util.to_unicode(msg)
@@ -311,12 +312,13 @@ def _sync_metadata():
attrs=attrs,
host_volumes=host_volumes,
use_tmpfs=use_tmpfs,
tags=aws_tags,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
num_parallel=num_parallel,
)
except Exception as e:
except Exception:
traceback.print_exc()
_sync_metadata()
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
24 changes: 13 additions & 11 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
import os
import sys
import platform
import requests
import sys
import time

from metaflow import util
from metaflow import R, current
import requests

from metaflow import R, current
from metaflow.decorators import StepDecorator
from metaflow.plugins.resources_decorator import ResourcesDecorator
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.metadata import MetaDatum
from metaflow.metadata.util import sync_local_metadata_to_datastore
from metaflow.metaflow_config import (
ECS_S3_ACCESS_IAM_ROLE,
BATCH_JOB_QUEUE,
BATCH_CONTAINER_IMAGE,
BATCH_CONTAINER_REGISTRY,
ECS_FARGATE_EXECUTION_ROLE,
BATCH_JOB_QUEUE,
DATASTORE_LOCAL_DIR,
ECS_FARGATE_EXECUTION_ROLE,
ECS_S3_ACCESS_IAM_ROLE,
)
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.sidecar import Sidecar
from metaflow.unbounded_foreach import UBF_CONTROL

from .batch import BatchException
from ..aws_utils import (
compute_resource_attributes,
get_docker_registry,
get_ec2_instance_metadata,
)
from .batch import BatchException


class BatchDecorator(StepDecorator):
@@ -73,6 +71,9 @@ class BatchDecorator(StepDecorator):
aggressively. Accepted values are whole numbers between 0 and 100.
use_tmpfs: bool, default: False
This enables an explicit tmpfs mount for this step.
tags: map, optional
Sets arbitrary AWS tags on the AWS Batch compute environment.
Set as string key-value pairs.
tmpfs_tempdir: bool, default: True
sets METAFLOW_TEMPDIR to tmpfs_path if set for this step.
tmpfs_size: int, optional
@@ -103,6 +104,7 @@ class BatchDecorator(StepDecorator):
"efa": None,
"host_volumes": None,
"use_tmpfs": False,
"tags": None,
"tmpfs_tempdir": True,
"tmpfs_size": None,
"tmpfs_path": "/metaflow_temp",
@@ -346,7 +348,7 @@ def _wait_for_mapper_tasks(self, flow, step_name):
len(flow._control_mapper_tasks),
)
)
except Exception as e:
except Exception:
pass
raise Exception(
"Batch secondary workers did not finish in %s seconds" % TIMEOUT

0 comments on commit d1199b4

Please sign in to comment.