diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 5e08ba64..78623fdd 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -371,6 +371,7 @@ class Config(GKEJob.Config): accelerator: AcceleratorConfig = AcceleratorConfig() reservation: Optional[str] = None + enable_ondemand: Optional[bool] = None enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None @@ -387,6 +388,13 @@ def define_flags(cls, fv: flags.FlagValues): "not all TPU types support this flag.", **common_kwargs, ) + flags.DEFINE_boolean( + "enable_ondemand", + None, + "Allow the job to run using on-demand quota when no reservation is provided. " + "Without this flag a job can only land on spot quota.", + **common_kwargs, + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -508,8 +516,9 @@ def _build_pod(self) -> Nested[Any]: logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation) selector.update({"cloud.google.com/reservation-name": cfg.reservation}) else: - logging.info("Found tier=%s in env. Using spot quota", tier) - selector.update({"cloud.google.com/gke-spot": "true"}) + if not cfg.enable_ondemand: + logging.info("Found tier=%s in env. Using spot quota", tier) + selector.update({"cloud.google.com/gke-spot": "true"}) tolerations.append( { "key": "cloud.google.com/gke-spot",