diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index a590b40527b..6b924b15f39 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -352,6 +352,7 @@ KUBERNETES_PERSISTENT_VOLUME_CLAIMS = from_conf( "KUBERNETES_PERSISTENT_VOLUME_CLAIMS", "" ) +KUBERNETES_EPHEMERAL_VOLUME_CLAIMS = from_conf("KUBERNETES_EPHEMERAL_VOLUME_CLAIMS", "") KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "") # Default labels for kubernetes pods KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "") diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 304fa9f3bd9..6724a56fa5b 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -667,6 +667,7 @@ def _visit(node, workflow, exit_node=None): "use_tmpfs", "tmpfs_size", "persistent_volume_claims", + "ephemeral_volume_claims", "image_pull_policy", ]: if kube_deco[attr]: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index ea0d6c6798e..0e2e049c0ad 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -58,6 +58,7 @@ parse_kube_keyvalue_list, validate_kube_labels, ) +from metaflow.plugins.kubernetes.kube_utils import VOLUME_CLAIM_TEMPLATE_DEFAULTS from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.user_configs.config_options import ConfigInput @@ -1951,6 +1952,7 @@ def _container_templates(self): tmpfs_path=tmpfs_path, timeout_in_seconds=run_time_limit, persistent_volume_claims=resources["persistent_volume_claims"], + ephemeral_volume_claims=resources["ephemeral_volume_claims"], shared_memory=shared_memory, port=port, qos=resources["qos"], @@ -2084,6 +2086,7 @@ def _container_templates(self): ) .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory) .pvc_volumes(resources.get("persistent_volume_claims")) + .ephemeral_volume_claims(resources.get("ephemeral_volume_claims")) # Set node selectors .node_selectors(resources.get("node_selector")) # Set tolerations @@ -2214,6 +2217,20 @@ def _container_templates(self): if resources.get("persistent_volume_claims") is not None else [] + ) + # Support ephemeral volume claims. + + ( + [ + kubernetes_sdk.V1VolumeMount( + name=name, mount_path=values["path"] + ) + for name, values in resources.get( + "ephemeral_volume_claims" + ).items() + ] + if resources.get("ephemeral_volume_claims") + is not None + else [] ), ).to_dict() ) @@ -3500,6 +3517,36 @@ def pvc_volumes(self, pvcs=None): ) return self + def ephemeral_volume_claims(self, claims=None): + """ + Create and attach Ephemeral Volume Claims as volumes. + + Parameters: + ----------- + claims: Optional[Dict] + a dictionary of ephemeral volumes name's to the paths they should be mounted to. e.g. + {"claim-1": {"path": "/mnt/path1", "spec": {"storageClassName": "my-claim"}, "metadata": {"labels": ["abc123"]}}} + """ + if claims is None: + return self + if "volumes" not in self.payload: + self.payload["volumes"] = [] + for name, values in claims.items(): + self.payload["volumes"].append( + { + "name": name, + "ephemeral": { + "volumeClaimTemplate": { + "spec": { + **VOLUME_CLAIM_TEMPLATE_DEFAULTS, + **values.get("spec", {}), + }, + } + }, + } + ) + return self + def node_selectors(self, node_selectors): if "nodeSelector" not in self.payload: self.payload["nodeSelector"] = {} diff --git a/metaflow/plugins/kubernetes/constants.py b/metaflow/plugins/kubernetes/constants.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/kubernetes/kube_utils.py b/metaflow/plugins/kubernetes/kube_utils.py index e19207df2ae..0c554b195c5 100644 --- a/metaflow/plugins/kubernetes/kube_utils.py +++ b/metaflow/plugins/kubernetes/kube_utils.py @@ -2,6 +2,12 @@ from metaflow.util import get_username, get_latest_run_id +VOLUME_CLAIM_TEMPLATE_DEFAULTS = { + "accessModes": ["ReadWriteOnce"], + "resources": {"requests": {"storage": "1Gi"}}, +} + + def parse_cli_options(flow_name, run_id, user, my_runs, echo): if user and my_runs: raise CommandException("--user and --my-runs are mutually exclusive.") diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 195dbf041ce..a43fa41f037 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -191,6 +191,7 @@ def create_jobset( run_time_limit=None, env=None, persistent_volume_claims=None, + ephemeral_volume_claims=None, tolerations=None, labels=None, shared_memory=None, @@ -226,6 +227,7 @@ def create_jobset( tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, persistent_volume_claims=persistent_volume_claims, + ephemeral_volume_claims=ephemeral_volume_claims, shared_memory=shared_memory, port=port, num_parallel=num_parallel, @@ -485,6 +487,7 @@ def create_job_object( run_time_limit=None, env=None, persistent_volume_claims=None, + ephemeral_volume_claims=None, tolerations=None, labels=None, shared_memory=None, @@ -529,6 +532,7 @@ def create_job_object( tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, persistent_volume_claims=persistent_volume_claims, + ephemeral_volume_claims=ephemeral_volume_claims, shared_memory=shared_memory, port=port, qos=qos, diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index def709cadea..51d5888cffc 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -109,6 +109,9 @@ def kubernetes(): @click.option( "--persistent-volume-claims", type=JSONTypeClass(), default=None, multiple=False ) +@click.option( + "--ephemeral-volume-claims", type=JSONTypeClass(), default=None, multiple=False +) @click.option( "--tolerations", default=None, @@ -156,6 +159,7 @@ def step( tmpfs_path=None, run_time_limit=None, persistent_volume_claims=None, + ephemeral_volume_claims=None, tolerations=None, shared_memory=None, port=None, @@ -297,6 +301,7 @@ def _sync_metadata(): run_time_limit=run_time_limit, env=env, persistent_volume_claims=persistent_volume_claims, + ephemeral_volume_claims=ephemeral_volume_claims, tolerations=tolerations, shared_memory=shared_memory, port=port, diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 53f08daf051..54970bbdc27 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -22,6 +22,7 @@ KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, KUBERNETES_PERSISTENT_VOLUME_CLAIMS, + KUBERNETES_EPHEMERAL_VOLUME_CLAIMS, KUBERNETES_PORT, KUBERNETES_SERVICE_ACCOUNT, KUBERNETES_SHARED_MEMORY, @@ -102,6 +103,9 @@ class KubernetesDecorator(StepDecorator): persistent_volume_claims : Dict[str, str], optional, default None A map (dictionary) of persistent volumes to be mounted to the pod for this step. The map is from persistent volumes to the path to which the volume is to be mounted, e.g., `{'pvc-name': '/path/to/mount/on'}`. + ephemeral_volume_claims: Dict[str, Any], optional, default None + A map (dictionary) of ephemeral volumes to be mounted to the pod for this step. The map is a name + to a dictionary containing the key 'path' (required), 'metadata' (optional), and 'spec' (optional). shared_memory: int, optional Shared memory size (in MiB) required for this step port: int, optional @@ -136,6 +140,7 @@ class KubernetesDecorator(StepDecorator): "tmpfs_size": None, "tmpfs_path": "/metaflow_temp", "persistent_volume_claims": None, # e.g., {"pvc-name": "/mnt/vol", "another-pvc": "/mnt/vol2"} + "ephemeral_volume_claims": None, # e.g., {"ephemeral-name": {"path": "/mnt/vol", "spec": {"storageClassName": "my_storage_class"}}} "shared_memory": None, "port": None, "compute_pool": None, @@ -171,6 +176,13 @@ def init(self): self.attributes["persistent_volume_claims"] = json.loads( KUBERNETES_PERSISTENT_VOLUME_CLAIMS ) + if ( + not self.attributes["ephemeral_volume_claims"] + and KUBERNETES_EPHEMERAL_VOLUME_CLAIMS + ): + self.attributes["ephemeral_volume_claims"] = json.loads( + KUBERNETES_EPHEMERAL_VOLUME_CLAIMS + ) if not self.attributes["image_pull_policy"] and KUBERNETES_IMAGE_PULL_POLICY: self.attributes["image_pull_policy"] = KUBERNETES_IMAGE_PULL_POLICY @@ -426,7 +438,11 @@ def runtime_step_cli( "=".join([key, str(val)]) if val else key for key, val in v.items() ] - elif k in ["tolerations", "persistent_volume_claims"]: + elif k in [ + "tolerations", + "persistent_volume_claims", + "ephemeral_volume_claims", + ]: cli_args.command_options[k] = json.dumps(v) else: cli_args.command_options[k] = v diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index 1728cdfd674..c1b39d870e7 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -15,7 +15,7 @@ KubernetesJobSet, ) # We need this import for Kubernetes Client. -from .kube_utils import qos_requests_and_limits +from .kube_utils import qos_requests_and_limits, VOLUME_CLAIM_TEMPLATE_DEFAULTS class KubernetesJobException(MetaflowException): @@ -205,6 +205,18 @@ def create_job_spec(self): ] if self._kwargs["persistent_volume_claims"] is not None else [] + ) + + ( + [ + client.V1VolumeMount( + mount_path=vals["path"], name=name + ) + for name, vals in self._kwargs[ + "ephemeral_volume_claims" + ].items() + ] + if self._kwargs["ephemeral_volume_claims"] is not None + else [] ), ) ], @@ -266,6 +278,26 @@ def create_job_spec(self): ] if self._kwargs["persistent_volume_claims"] is not None else [] + ) + + ( + [ + client.V1Volume( + name=name, + ephemeral=client.V1EphemeralVolumeSource( + volume_claim_template=client.V1PersistentVolumeClaimTemplate( + spec={ + **VOLUME_CLAIM_TEMPLATE_DEFAULTS, + **vals.get("spec", {}), + }, + ) + ), + ) + for name, vals in self._kwargs[ + "ephemeral_volume_claims" + ].items() + ] + if self._kwargs["ephemeral_volume_claims"] is not None + else [] ), ), ), diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index e7236aea746..5519d65ea7d 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -9,7 +9,7 @@ from metaflow.tracing import inject_tracing_vars from metaflow.metaflow_config import KUBERNETES_SECRETS -from .kube_utils import qos_requests_and_limits +from .kube_utils import qos_requests_and_limits, VOLUME_CLAIM_TEMPLATE_DEFAULTS class KubernetesJobsetException(MetaflowException): @@ -707,6 +707,19 @@ def dump(self): if self._kwargs["persistent_volume_claims"] is not None else [] + ) + + ( + [ + client.V1VolumeMount( + mount_path=vals["path"], name=name + ) + for name, vals in self._kwargs[ + "ephemeral_volume_claims" + ].items() + ] + if self._kwargs["ephemeral_volume_claims"] + is not None + else [] ), ) ], @@ -772,6 +785,27 @@ def dump(self): if self._kwargs["persistent_volume_claims"] is not None else [] + ) + + ( + [ + client.V1Volume( + name=name, + ephemeral=client.V1EphemeralVolumeSource( + volume_claim_template=client.V1PersistentVolumeClaimTemplate( + spec={ + **VOLUME_CLAIM_TEMPLATE_DEFAULTS, + **vals.get("spec", {}), + }, + ) + ), + ) + for name, vals in self._kwargs[ + "ephemeral_volume_claims" + ].items() + ] + if self._kwargs["ephemeral_volume_claims"] + is not None + else [] ), ), ),