diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 1e3c55a3ba8..66303916b34 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -313,6 +313,10 @@ ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "") ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "") +## Kueue Support +KUEUE_ENABLED = from_conf("KUEUE_ENABLED", False) +KUEUE_LOCALQUEUE_NAME = from_conf("KUEUE_LOCALQUEUE_NAME", "") + ## # Argo Events Configuration ## diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a19033f97ba..f29b6c0b53b 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1368,15 +1368,34 @@ def _container_templates(self): tmpfs_size = resources["tmpfs_size"] tmpfs_path = resources["tmpfs_path"] tmpfs_tempdir = resources["tmpfs_tempdir"] - # Set shared_memory to 0 if it isn't specified. This results - # in Kubernetes using it's default value when the pod is created. - shared_memory = resources.get("shared_memory", 0) tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs) if tmpfs_enabled and tmpfs_tempdir: env["METAFLOW_TEMPDIR"] = tmpfs_path + # Set shared_memory to 0 if it isn't specified. This results + # in Kubernetes using it's default value when the pod is created. + shared_memory = resources.get("shared_memory", 0) + + kueue_enabled = resources["kueue_enabled"] + kueue_localqueue_name = resources["kueue_localqueue_name"] + kueue_annotations = {} + kueue_labels = {} + if kueue_enabled: + kueue_annotations["kueue.x-k8s.io/retriable-in-group"] = "false" + kueue_annotations["kueue.x-k8s.io/pod-group-total-count"] = str( + 1 + ) # For now, might change with @parallel support + kueue_labels["kueue.x-k8s.io/queue-name"] = kueue_localqueue_name + kueue_labels["kueue.x-k8s.io/managed"] = "true" + kueue_labels["kueue.x-k8s.io/pod-group-name"] = ( + "{{workflow.name}}-" + node.name + ) + if node.is_inside_foreach: + kueue_labels["kueue.x-k8s.io/pod-group-name"] = \ + kueue_labels["kueue.x-k8s.io/pod-group-name"] + \ + "-{{inputs.parameters.split-index}}" # Create a ContainerTemplate for this node. Ideally, we would have # liked to inline this ContainerTemplate and avoid scanning the workflow # twice, but due to issues with variable substitution, we will have to @@ -1399,13 +1418,16 @@ def _container_templates(self): minutes_between_retries=minutes_between_retries, ) .metadata( - ObjectMeta().annotation("metaflow/step_name", node.name) + ObjectMeta() + .annotation("metaflow/step_name", node.name) # Unfortunately, we can't set the task_id since it is generated # inside the pod. However, it can be inferred from the annotation # set by argo-workflows - `workflows.argoproj.io/outputs` - refer # the field 'task-id' in 'parameters' # .annotation("metaflow/task_id", ...) .annotation("metaflow/attempt", retry_count) + .annotations(kueue_annotations) + .labels(kueue_labels) ) # Set emptyDir volume for state management .empty_dir_volume("out") diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index a69138da407..a2eb538b086 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -175,6 +175,8 @@ def create_job( tolerations=None, labels=None, shared_memory=None, + kueue_enabled=None, + kueue_localqueue_name=None, ): if env is None: env = {} @@ -183,6 +185,8 @@ def create_job( KubernetesClient() .job( generate_name="t-{uid}-".format(uid=str(uuid4())[:8]), + run_id=run_id, + task_id=task_id, namespace=namespace, service_account=service_account, secrets=secrets, @@ -215,6 +219,8 @@ def create_job( tmpfs_path=tmpfs_path, persistent_volume_claims=persistent_volume_claims, shared_memory=shared_memory, + kueue_enabled=kueue_enabled, + kueue_localqueue_name=kueue_localqueue_name, ) .environment_variable("METAFLOW_CODE_SHA", code_package_sha) .environment_variable("METAFLOW_CODE_URL", code_package_url) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 0ccd20148d9..c4b27965f38 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -1,4 +1,5 @@ import os + import sys import time import traceback @@ -108,6 +109,15 @@ def kubernetes(): multiple=False, ) @click.option("--shared-memory", default=None, help="Size of shared memory in MiB") +@click.option( + "--kueue-enabled", + is_flag=True, + default=None, + help="Whether to use Kueue for scheduling Kubernetes jobs/pods", +) +@click.option( + "--kueue-localqueue-name", help="Name of the LocalQueue configured with kueue" +) @click.pass_context def step( ctx, @@ -134,6 +144,8 @@ def step( persistent_volume_claims=None, tolerations=None, shared_memory=None, + kueue_enabled=None, + kueue_localqueue_name=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -248,6 +260,8 @@ def _sync_metadata(): persistent_volume_claims=persistent_volume_claims, tolerations=tolerations, shared_memory=shared_memory, + kueue_enabled=kueue_enabled, + kueue_localqueue_name=kueue_localqueue_name, ) except Exception as e: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 8bc5420424c..f0a9a891a0d 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -21,6 +21,8 @@ KUBERNETES_TOLERATIONS, KUBERNETES_SERVICE_ACCOUNT, KUBERNETES_SHARED_MEMORY, + KUEUE_ENABLED, + KUEUE_LOCALQUEUE_NAME, ) from metaflow.plugins.resources_decorator import ResourcesDecorator from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task @@ -90,6 +92,10 @@ class KubernetesDecorator(StepDecorator): volumes to the path to which the volume is to be mounted, e.g., `{'pvc-name': '/path/to/mount/on'}`. shared_memory: int, optional Shared memory size (in MiB) required for this step + kueue_enabled: bool, optional + Whether Kubernetes job/Argo workflow pod should submitted using Kueue + kueue_localqueue_name: str, optional + The name of the localqueue object configured in Kueue to use for submitting jobs/pods """ name = "kubernetes" @@ -113,6 +119,8 @@ class KubernetesDecorator(StepDecorator): "tmpfs_path": "/metaflow_temp", "persistent_volume_claims": None, # e.g., {"pvc-name": "/mnt/vol", "another-pvc": "/mnt/vol2"} "shared_memory": None, + "kueue_enabled": None, + "kueue_localqueue_name": None, } package_url = None package_sha = None @@ -201,6 +209,15 @@ def __init__(self, attributes=None, statically_defined=False): if not self.attributes["shared_memory"]: self.attributes["shared_memory"] = KUBERNETES_SHARED_MEMORY + # Process config options related to KUEUE + if self.attributes["kueue_enabled"] is None: + self.attributes["kueue_enabled"] = KUEUE_ENABLED + if ( + "kueue_localqueue_name" not in self.attributes + or self.attributes["kueue_localqueue_name"] is None + ): + self.attributes["kueue_localqueue_name"] = KUEUE_LOCALQUEUE_NAME + # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # Executing Kubernetes jobs requires a non-local datastore. diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index 90634215b49..21a87071924 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -83,14 +83,31 @@ def create(self): else None ) + annotations = self._kwargs.get("annotations", {}) + labels = self._kwargs.get("labels", {}) + + kueue_enabled = bool(self._kwargs["kueue_enabled"]) + localqueue_name = self._kwargs["kueue_localqueue_name"] + if kueue_enabled: + labels["kueue.x-k8s.io/queue-name"] = localqueue_name + labels["kueue.x-k8s.io/pod-group-name"] = ( + self._kwargs["run_id"] + + "-" + + self._kwargs["step_name"] + + "-" + + self._kwargs["task_id"] + ) + annotations["kueue.x-k8s.io/retriable-in-group"] = "false" + annotations["kueue.x-k8s.io/pod-group-total-count"] = str(1) + self._job = client.V1Job( api_version="batch/v1", kind="Job", metadata=client.V1ObjectMeta( # Annotations are for humans - annotations=self._kwargs.get("annotations", {}), + annotations=annotations, # While labels are for Kubernetes - labels=self._kwargs.get("labels", {}), + labels=labels, generate_name=self._kwargs["generate_name"], namespace=self._kwargs["namespace"], # Defaults to `default` ),