Skip to content

Commit

Permalink
added separate option in decorator for TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
kulikovv committed Jan 15, 2025
1 parent 7d53b75 commit 814693f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
8 changes: 8 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def create_jobset(
cpu=None,
gpu=None,
gpu_vendor=None,
tpu=None,
tpu_vendor=None,
disk=None,
memory=None,
use_tmpfs=None,
Expand Down Expand Up @@ -210,6 +212,8 @@ def create_jobset(
disk=disk,
gpu=gpu,
gpu_vendor=gpu_vendor,
tpu=tpu,
tpu_vendor=tpu_vendor,
timeout_in_seconds=run_time_limit,
# Retries are handled by Metaflow runtime
retries=0,
Expand Down Expand Up @@ -472,6 +476,8 @@ def create_job_object(
cpu=None,
gpu=None,
gpu_vendor=None,
tpu=None,
tpu_vendor=None,
disk=None,
memory=None,
use_tmpfs=None,
Expand Down Expand Up @@ -515,6 +521,8 @@ def create_job_object(
disk=disk,
gpu=gpu,
gpu_vendor=gpu_vendor,
tpu=tpu,
tpu_vendor=tpu_vendor,
timeout_in_seconds=run_time_limit,
# Retries are handled by Metaflow runtime
retries=0,
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def kubernetes():
@click.option("--memory", help="Memory requirement for Kubernetes pod.")
@click.option("--gpu", help="GPU requirement for Kubernetes pod.")
@click.option("--gpu-vendor", help="GPU vendor requirement for Kubernetes pod.")
@click.option("--tpu", help="TPU requirement for Kubernetes pod.")
@click.option("--tpu-vendor", help="TPU vendor requirement for Kubernetes pod.")
@click.option("--run-id", help="Passed to the top-level 'step'.")
@click.option("--task-id", help="Passed to the top-level 'step'.")
@click.option("--input-paths", help="Passed to the top-level 'step'.")
Expand Down Expand Up @@ -163,6 +165,8 @@ def step(
memory=None,
gpu=None,
gpu_vendor=None,
tpu=None,
tpu_vendor=None,
use_tmpfs=None,
tmpfs_tempdir=None,
tmpfs_size=None,
Expand Down Expand Up @@ -305,6 +309,8 @@ def _sync_metadata():
memory=memory,
gpu=gpu,
gpu_vendor=gpu_vendor,
tpu=tpu,
tpu_vendor=tpu_vendor,
use_tmpfs=use_tmpfs,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_jobsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,15 @@ def dump(self):
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
},
**{
"%s.com/tpu".lower()
% self._kwargs["tpu_vendor"]: str(
self._kwargs["tpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["tpu"] is not None
},
},
),
volume_mounts=(
Expand Down

0 comments on commit 814693f

Please sign in to comment.