From 29b39cfdecc3792331ebc3f7447fe189cf51ac2f Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 6 Feb 2025 12:28:19 +0000 Subject: [PATCH] feat(gcp_sqs_config): adds support for min and max replicas --- .../cloud_management/resource_allocation/k8s/keda.py | 7 ++++++- .../resource_allocation/k8s/keda_deprecated.py | 7 ++++++- .../configurations/deprecated/execute_on_gcp_with_sqs.py | 6 +++++- .../configurations/execute_on_gcp_with_sqs.py | 5 ++++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/keda.py b/zetta_utils/cloud_management/resource_allocation/k8s/keda.py index 3b054e818..bbabf278b 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/keda.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/keda.py @@ -207,17 +207,22 @@ def scaled_job_ctx_mngr( cluster_info: ClusterInfo, job_spec: k8s_client.V1JobSpec, secrets: list[k8s_client.V1Secret], - max_replicas: int, + replicas: int, sqs_trigger_name: str, queue: SQSQueue, + max_replicas: int = 0, namespace: str | None = "default", ): + if max_replicas == 0: + max_replicas = replicas + replicas = 0 configuration, _ = get_cluster_data(cluster_info) with secrets_ctx_mngr(run_id, secrets, cluster_info, namespace=namespace): manifest = _get_scaled_job_manifest( f"{run_id}-{group_name}", [_get_sqs_trigger(sqs_trigger_name, queue)], job_spec=job_spec, + min_replicas=replicas, max_replicas=max_replicas, ) so_name = manifest["metadata"]["name"] diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/keda_deprecated.py b/zetta_utils/cloud_management/resource_allocation/k8s/keda_deprecated.py index 2418eda71..9ae342b5c 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/keda_deprecated.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/keda_deprecated.py @@ -207,10 +207,14 @@ def scaled_job_ctx_mngr( cluster_info: ClusterInfo, job_spec: k8s_client.V1JobSpec, secrets: list[k8s_client.V1Secret], - max_replicas: int, + replicas: int, queue: SQSQueue, + max_replicas: int = 0, namespace: str | None = "default", ): + if max_replicas == 0: + max_replicas = replicas + replicas = 0 configuration, _ = get_cluster_data(cluster_info) with secrets_ctx_mngr(run_id, secrets, cluster_info, namespace=namespace): with sqs_trigger_ctx_mngr(run_id, cluster_info, namespace) as trigger_name: @@ -218,6 +222,7 @@ def scaled_job_ctx_mngr( run_id, [_get_sqs_trigger(trigger_name, queue)], job_spec=job_spec, + min_replicas=replicas, max_replicas=max_replicas, ) so_name = manifest["metadata"]["name"] diff --git a/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py index 1acb290e3..7676cedcc 100644 --- a/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py @@ -66,6 +66,7 @@ def get_gcp_with_sqs_config( semaphores_spec: dict[SemaphoreType, int] | None = None, provisioning_model: Literal["standard", "spot"] = "spot", idle_worker_timeout: int = 300, + max_worker_replicas: int = 0, ) -> tuple[PushMessageQueue[Task], PullMessageQueue[OutcomeReport], list[AbstractContextManager]]: work_queue_name = f"run-{execution_id}-work" outcome_queue_name = f"run-{execution_id}-outcome" @@ -114,7 +115,8 @@ def get_gcp_with_sqs_config( cluster_info=worker_cluster, job_spec=job_spec, secrets=secrets, - max_replicas=worker_replicas, + replicas=worker_replicas, + max_replicas=max_worker_replicas, queue=task_queue, ) ctx_managers.append(scaled_job_ctx_mngr) @@ -171,6 +173,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals provisioning_model: Literal["standard", "spot"] = "spot", sqs_based_scaling: bool = True, idle_worker_timeout: int = 300, + max_worker_replicas: int = 0, ): if debug and not local_test: raise ValueError("`debug` can only be set to `True` when `local_test` is also `True`.") @@ -240,6 +243,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals provisioning_model=provisioning_model, sqs_based_scaling=sqs_based_scaling, idle_worker_timeout=idle_worker_timeout, + max_worker_replicas=max_worker_replicas, ) with ExitStack() as stack: diff --git a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py index 903ad8678..a1a36574a 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py @@ -61,6 +61,7 @@ def _ensure_required_env_vars(): class WorkerGroup: replicas: int resource_limits: dict[str, int | float | str] + max_replicas: int = 0 queue_tags: list[str] | None = None num_procs: int = 1 sqs_based_scaling: bool = True @@ -75,6 +76,7 @@ class WorkerGroup: class WorkerGroupDict(TypedDict, total=False): replicas: int resource_limits: dict[str, int | float | str] + max_replicas: NotRequired[int] queue_tags: NotRequired[list[str]] num_procs: NotRequired[int] sqs_based_scaling: NotRequired[bool] @@ -134,7 +136,8 @@ def _get_group_taskqueue_and_contexts( job_spec=job_spec, secrets=[], sqs_trigger_name=sqs_trigger_name, - max_replicas=group.replicas, + replicas=group.replicas, + max_replicas=group.max_replicas, queue=task_queue, ) ctx_managers.append(scaled_job_ctx_mngr)