diff --git a/zetta_utils/run/__init__.py b/zetta_utils/run/__init__.py new file mode 100644 index 000000000..2f828e95f --- /dev/null +++ b/zetta_utils/run/__init__.py @@ -0,0 +1,2 @@ +from .base import RUN_DB, RunInfo, register, run_ctx_manager +from .resource import Resource, register_resource, ResourceTypes, ResourceKeys, RESOURCE_DB diff --git a/zetta_utils/run/base.py b/zetta_utils/run/base.py new file mode 100644 index 000000000..a4712f5a7 --- /dev/null +++ b/zetta_utils/run/base.py @@ -0,0 +1,83 @@ +import os +import time +from contextlib import contextmanager +from enum import Enum +from typing import Optional + +import attrs + +from zetta_utils import log +from zetta_utils.cloud_management.resource_allocation.k8s import ClusterInfo +from zetta_utils.common import RepeatTimer +from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer +from zetta_utils.layer.db_layer.datastore import DatastoreBackend +from zetta_utils.mazepa import id_generation +from zetta_utils.parsing import json + +logger = log.get_logger("zetta_utils") + +DEFAULT_PROJECT = "zetta-research" +RUN_DB_NAME = "run-info" +RUN_INFO_PATH = "gs://zetta_utils_runs" +RUN_DB = build_db_layer(DatastoreBackend(namespace=RUN_DB_NAME, project=DEFAULT_PROJECT)) +RUN_ID = None + + +class RunInfo(Enum): + ZETTA_USER = "zetta_user" + HEARTBEAT = "heartbeat" + CLUSTERS = "clusters" + STATE = "state" + + +class RunState(Enum): + RUNNING = "running" + TIMEOUT = "timeout" + SUCCEEDED = "succeeded" + FAILED = "failed" + + +def register(clusters: list[ClusterInfo]) -> None: # pragma: no cover + """ + Register run info to database, for the garbage collector. + """ + info: DBRowDataT = { + RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"], + RunInfo.HEARTBEAT.value: time.time(), + RunInfo.CLUSTERS.value: json.dumps([attrs.asdict(cluster) for cluster in clusters]), + RunInfo.STATE.value: RunState.RUNNING.value, + } + _update_run_info(info) + + +def _update_run_info(info: DBRowDataT) -> None: # pragma: no cover + row_key = f"run-{RUN_ID}" + col_keys = tuple(info.keys()) + RUN_DB[(row_key, col_keys)] = info + + +@contextmanager +def run_ctx_manager(run_id: Optional[str] = None, heartbeat_interval: int = 5): + def _send_heartbeat(): + info: DBRowDataT = {RunInfo.HEARTBEAT.value: time.time()} + _update_run_info(info) + + heartbeat = None + if run_id is None: + run_id = id_generation.get_unique_id(slug_len=4, add_uuid=False, max_len=50) + + global RUN_ID # pylint: disable=global-statement + RUN_ID = run_id + try: + if heartbeat_interval > 0: + heartbeat = RepeatTimer(heartbeat_interval, _send_heartbeat) + heartbeat.start() + yield + except Exception as e: + info: DBRowDataT = {RunInfo.STATE.value: RunState.FAILED.value} + _update_run_info(info) + raise e from None + finally: + RUN_ID = None + if heartbeat: + heartbeat.cancel() diff --git a/zetta_utils/run/gc.py b/zetta_utils/run/gc.py new file mode 100644 index 000000000..0d59f7235 --- /dev/null +++ b/zetta_utils/run/gc.py @@ -0,0 +1,176 @@ +""" +Garbage collection for run resources. +""" + +import json +import logging +import os +import time +from typing import Any, Dict, List, Mapping + +import taskqueue +from boto3.exceptions import Boto3Error +from google.api_core.exceptions import GoogleAPICallError +from kubernetes.client.exceptions import ApiException as K8sApiException + +from kubernetes import client as k8s_client # type: ignore +from zetta_utils.cloud_management.resource_allocation.k8s import ( + ClusterInfo, + get_cluster_data, +) +from zetta_utils.log import get_logger +from zetta_utils.message_queues.sqs import utils as sqs_utils +from zetta_utils.run import ( + RESOURCE_DB, + RUN_DB, + Resource, + ResourceKeys, + ResourceTypes, + RunInfo, +) + +logger = get_logger("zetta_utils") + + +def _delete_db_entry(client: Any, entry_id: str, columns: List[str]): + parent_key = client.key("Row", entry_id) + for column in columns: + col_key = client.key("Column", column, parent=parent_key) + client.delete(col_key) + + +def _delete_run_entry(run_id: str): # pragma: no cover + client = RUN_DB.backend.client # type: ignore + columns = [key.value for key in list(RunInfo)] + _delete_db_entry(client, run_id, columns) + + +def _delete_resource_entry(resource_id: str): # pragma: no cover + client = RESOURCE_DB.backend.client # type: ignore + columns = [key.value for key in list(ResourceKeys)] + _delete_db_entry(client, resource_id, columns) + + +def _get_stale_run_ids() -> list[str]: # pragma: no cover + client = RUN_DB.backend.client # type: ignore + + lookback = int(os.environ["EXECUTION_HEARTBEAT_LOOKBACK"]) + time_diff = time.time() - lookback + + filters = [("heartbeat", "<", time_diff)] + query = client.query(kind="Column", filters=filters) + query.keys_only() + + entities = list(query.fetch()) + return [entity.key.parent.id_or_name for entity in entities] + + +def _read_clusters(run_id_key: str) -> list[ClusterInfo]: # pragma: no cover + col_keys = ("clusters",) + try: + clusters_str = RUN_DB[(run_id_key, col_keys)][col_keys[0]] + except KeyError: + return [] + clusters: list[Mapping] = json.loads(clusters_str) + return [ClusterInfo(**cluster) for cluster in clusters] + + +def _read_run_resources(run_id: str) -> Dict[str, Resource]: + client = RESOURCE_DB.backend.client # type: ignore + + query = client.query(kind="Column") + query = query.add_filter("run_id", "=", run_id) + query.keys_only() + + entities = list(query.fetch()) + resouce_ids = [entity.key.parent.id_or_name for entity in entities] + + col_keys = ("type", "name") + resources = RESOURCE_DB[(resouce_ids, col_keys)] + resources = [Resource(run_id=run_id, **res) for res in resources] + return dict(zip(resouce_ids, resources)) + + +def _delete_k8s_resources(run_id: str, resources: Dict[str, Resource]) -> bool: # pragma: no cover + success = True + logger.info(f"Deleting k8s resources from run {run_id}") + clusters = _read_clusters(run_id) + for cluster in clusters: + try: + configuration, _ = get_cluster_data(cluster) + except GoogleAPICallError as exc: + # cluster does not exit, discard resource entries + if exc.code == 404: + for resource_id in resources.keys(): + _delete_resource_entry(resource_id) + continue + + k8s_client.Configuration.set_default(configuration) + + k8s_apps_v1_api = k8s_client.AppsV1Api() + k8s_core_v1_api = k8s_client.CoreV1Api() + for resource_id, resource in resources.items(): + try: + if resource.type == ResourceTypes.K8S_DEPLOYMENT.value: + logger.info(f"Deleting k8s deployment `{resource.name}`") + k8s_apps_v1_api.delete_namespaced_deployment( + name=resource.name, namespace="default" + ) + elif resource.type == ResourceTypes.K8S_SECRET.value: + logger.info(f"Deleting k8s secret `{resource.name}`") + k8s_core_v1_api.delete_namespaced_secret( + name=resource.name, namespace="default" + ) + except K8sApiException as exc: + if exc.status == 404: + success = True + logger.info(f"Resource does not exist: `{resource.name}`: {exc}") + _delete_resource_entry(resource_id) + else: + success = False + logger.warning(f"Failed to delete k8s resource `{resource.name}`: {exc}") + raise K8sApiException() from exc + return success + + +def _delete_sqs_queues(resources: Dict[str, Resource]) -> bool: # pragma: no cover + success = True + for resource_id, resource in resources.items(): + if resource.type != ResourceTypes.SQS_QUEUE.value: + continue + region_name = resource.region + if resource.region == "" or resource.region is None: + region_name = taskqueue.secrets.AWS_DEFAULT_REGION + sqs_client = sqs_utils.get_sqs_client(region_name=region_name) + try: + logger.info(f"Deleting SQS queue `{resource.name}`") + queue_url = sqs_client.get_queue_url(QueueName=resource.name)["QueueUrl"] + sqs_client.delete_queue(QueueUrl=queue_url) + except sqs_client.exceptions.QueueDoesNotExist as exc: + logger.info(f"Queue does not exist: `{resource.name}`: {exc}") + _delete_resource_entry(resource_id) + except Boto3Error as exc: + success = False + logger.warning(f"Failed to delete queue `{resource.name}`: {exc}") + return success + + +def cleanup_run(run_id: str): + success = True + resources = _read_run_resources(run_id) + success &= _delete_k8s_resources(run_id, resources) + success &= _delete_sqs_queues(resources) + + if success is True: + _delete_run_entry(run_id) + logger.info(f"`{run_id}` run cleanup complete.") + else: + logger.info(f"`{run_id}` run cleanup incomplete.") + + +if __name__ == "__main__": # pragma: no cover + run_ids = _get_stale_run_ids() + logger.setLevel(logging.INFO) + for _id in run_ids: + logger.info(f"Cleaning up run `{_id}`") + cleanup_run(_id) diff --git a/zetta_utils/run/resource.py b/zetta_utils/run/resource.py new file mode 100644 index 000000000..594b28358 --- /dev/null +++ b/zetta_utils/run/resource.py @@ -0,0 +1,47 @@ +import uuid +from enum import Enum + +import attrs + +from zetta_utils.layer.db_layer import build_db_layer +from zetta_utils.layer.db_layer.datastore import DatastoreBackend + +DEFAULT_PROJECT = "zetta-research" +RESOURCE_DB_NAME = "run-resource" + +RESOURCE_DB = build_db_layer( + DatastoreBackend( + namespace=RESOURCE_DB_NAME, + project=DEFAULT_PROJECT, # all resources must be tracked in the default project + ) +) + + +class ResourceTypes(Enum): + K8S_CONFIGMAP = "k8s_configmap" + K8S_DEPLOYMENT = "k8s_deployment" + K8S_JOB = "k8s_job" + K8S_SECRET = "k8s_secret" + K8S_SERVICE = "k8s_service" + SQS_QUEUE = "sqs_queue" + + +class ResourceKeys(Enum): + RUN_ID = "run_id" + TYPE = "type" + NAME = "name" + + +@attrs.frozen +class Resource: + run_id: str + type: str + name: str + region: str = "" + + +def register_resource(resource: Resource) -> None: + _resource = attrs.asdict(resource) + row_key = str(uuid.uuid4()) + col_keys = tuple(_resource.keys()) + RESOURCE_DB[(row_key, col_keys)] = _resource