Skip to content

Commit

Permalink
feat: adds run context manager with heartbeat support
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed Jan 19, 2024
1 parent 046b1ae commit a74730f
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 0 deletions.
2 changes: 2 additions & 0 deletions zetta_utils/run/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import RUN_DB, RunInfo, register, run_ctx_manager
from .resource import Resource, register_resource, ResourceTypes, ResourceKeys, RESOURCE_DB
83 changes: 83 additions & 0 deletions zetta_utils/run/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import time
from contextlib import contextmanager
from enum import Enum
from typing import Optional

import attrs

from zetta_utils import log
from zetta_utils.cloud_management.resource_allocation.k8s import ClusterInfo
from zetta_utils.common import RepeatTimer
from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.mazepa import id_generation
from zetta_utils.parsing import json

logger = log.get_logger("zetta_utils")

DEFAULT_PROJECT = "zetta-research"
RUN_DB_NAME = "run-info"
RUN_INFO_PATH = "gs://zetta_utils_runs"
RUN_DB = build_db_layer(DatastoreBackend(namespace=RUN_DB_NAME, project=DEFAULT_PROJECT))
RUN_ID = None


class RunInfo(Enum):
ZETTA_USER = "zetta_user"
HEARTBEAT = "heartbeat"
CLUSTERS = "clusters"
STATE = "state"


class RunState(Enum):
RUNNING = "running"
TIMEOUT = "timeout"
SUCCEEDED = "succeeded"
FAILED = "failed"


def register(clusters: list[ClusterInfo]) -> None: # pragma: no cover
"""
Register run info to database, for the garbage collector.
"""
info: DBRowDataT = {
RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"],
RunInfo.HEARTBEAT.value: time.time(),
RunInfo.CLUSTERS.value: json.dumps([attrs.asdict(cluster) for cluster in clusters]),
RunInfo.STATE.value: RunState.RUNNING.value,
}
_update_run_info(info)


def _update_run_info(info: DBRowDataT) -> None: # pragma: no cover
row_key = f"run-{RUN_ID}"
col_keys = tuple(info.keys())
RUN_DB[(row_key, col_keys)] = info


@contextmanager
def run_ctx_manager(run_id: Optional[str] = None, heartbeat_interval: int = 5):
def _send_heartbeat():
info: DBRowDataT = {RunInfo.HEARTBEAT.value: time.time()}
_update_run_info(info)

heartbeat = None
if run_id is None:
run_id = id_generation.get_unique_id(slug_len=4, add_uuid=False, max_len=50)

global RUN_ID # pylint: disable=global-statement
RUN_ID = run_id
try:
if heartbeat_interval > 0:
heartbeat = RepeatTimer(heartbeat_interval, _send_heartbeat)
heartbeat.start()
yield
except Exception as e:
info: DBRowDataT = {RunInfo.STATE.value: RunState.FAILED.value}
_update_run_info(info)
raise e from None
finally:
RUN_ID = None
if heartbeat:
heartbeat.cancel()
176 changes: 176 additions & 0 deletions zetta_utils/run/gc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
Garbage collection for run resources.
"""

import json
import logging
import os
import time
from typing import Any, Dict, List, Mapping

import taskqueue
from boto3.exceptions import Boto3Error
from google.api_core.exceptions import GoogleAPICallError
from kubernetes.client.exceptions import ApiException as K8sApiException

from kubernetes import client as k8s_client # type: ignore
from zetta_utils.cloud_management.resource_allocation.k8s import (
ClusterInfo,
get_cluster_data,
)
from zetta_utils.log import get_logger
from zetta_utils.message_queues.sqs import utils as sqs_utils
from zetta_utils.run import (
RESOURCE_DB,
RUN_DB,
Resource,
ResourceKeys,
ResourceTypes,
RunInfo,
)

logger = get_logger("zetta_utils")


def _delete_db_entry(client: Any, entry_id: str, columns: List[str]):
parent_key = client.key("Row", entry_id)
for column in columns:
col_key = client.key("Column", column, parent=parent_key)
client.delete(col_key)


def _delete_run_entry(run_id: str): # pragma: no cover
client = RUN_DB.backend.client # type: ignore
columns = [key.value for key in list(RunInfo)]
_delete_db_entry(client, run_id, columns)


def _delete_resource_entry(resource_id: str): # pragma: no cover
client = RESOURCE_DB.backend.client # type: ignore
columns = [key.value for key in list(ResourceKeys)]
_delete_db_entry(client, resource_id, columns)


def _get_stale_run_ids() -> list[str]: # pragma: no cover
client = RUN_DB.backend.client # type: ignore

lookback = int(os.environ["EXECUTION_HEARTBEAT_LOOKBACK"])
time_diff = time.time() - lookback

filters = [("heartbeat", "<", time_diff)]
query = client.query(kind="Column", filters=filters)
query.keys_only()

entities = list(query.fetch())
return [entity.key.parent.id_or_name for entity in entities]


def _read_clusters(run_id_key: str) -> list[ClusterInfo]: # pragma: no cover
col_keys = ("clusters",)
try:
clusters_str = RUN_DB[(run_id_key, col_keys)][col_keys[0]]
except KeyError:
return []
clusters: list[Mapping] = json.loads(clusters_str)
return [ClusterInfo(**cluster) for cluster in clusters]


def _read_run_resources(run_id: str) -> Dict[str, Resource]:
client = RESOURCE_DB.backend.client # type: ignore

query = client.query(kind="Column")
query = query.add_filter("run_id", "=", run_id)
query.keys_only()

entities = list(query.fetch())
resouce_ids = [entity.key.parent.id_or_name for entity in entities]

col_keys = ("type", "name")
resources = RESOURCE_DB[(resouce_ids, col_keys)]
resources = [Resource(run_id=run_id, **res) for res in resources]
return dict(zip(resouce_ids, resources))


def _delete_k8s_resources(run_id: str, resources: Dict[str, Resource]) -> bool: # pragma: no cover
success = True
logger.info(f"Deleting k8s resources from run {run_id}")
clusters = _read_clusters(run_id)
for cluster in clusters:
try:
configuration, _ = get_cluster_data(cluster)
except GoogleAPICallError as exc:
# cluster does not exit, discard resource entries
if exc.code == 404:
for resource_id in resources.keys():
_delete_resource_entry(resource_id)
continue

k8s_client.Configuration.set_default(configuration)

k8s_apps_v1_api = k8s_client.AppsV1Api()
k8s_core_v1_api = k8s_client.CoreV1Api()
for resource_id, resource in resources.items():
try:
if resource.type == ResourceTypes.K8S_DEPLOYMENT.value:
logger.info(f"Deleting k8s deployment `{resource.name}`")
k8s_apps_v1_api.delete_namespaced_deployment(
name=resource.name, namespace="default"
)
elif resource.type == ResourceTypes.K8S_SECRET.value:
logger.info(f"Deleting k8s secret `{resource.name}`")
k8s_core_v1_api.delete_namespaced_secret(
name=resource.name, namespace="default"
)
except K8sApiException as exc:
if exc.status == 404:
success = True
logger.info(f"Resource does not exist: `{resource.name}`: {exc}")
_delete_resource_entry(resource_id)
else:
success = False
logger.warning(f"Failed to delete k8s resource `{resource.name}`: {exc}")
raise K8sApiException() from exc
return success


def _delete_sqs_queues(resources: Dict[str, Resource]) -> bool: # pragma: no cover
success = True
for resource_id, resource in resources.items():
if resource.type != ResourceTypes.SQS_QUEUE.value:
continue
region_name = resource.region
if resource.region == "" or resource.region is None:
region_name = taskqueue.secrets.AWS_DEFAULT_REGION
sqs_client = sqs_utils.get_sqs_client(region_name=region_name)
try:
logger.info(f"Deleting SQS queue `{resource.name}`")
queue_url = sqs_client.get_queue_url(QueueName=resource.name)["QueueUrl"]
sqs_client.delete_queue(QueueUrl=queue_url)
except sqs_client.exceptions.QueueDoesNotExist as exc:
logger.info(f"Queue does not exist: `{resource.name}`: {exc}")
_delete_resource_entry(resource_id)
except Boto3Error as exc:
success = False
logger.warning(f"Failed to delete queue `{resource.name}`: {exc}")
return success


def cleanup_run(run_id: str):
success = True
resources = _read_run_resources(run_id)
success &= _delete_k8s_resources(run_id, resources)
success &= _delete_sqs_queues(resources)

if success is True:
_delete_run_entry(run_id)
logger.info(f"`{run_id}` run cleanup complete.")
else:
logger.info(f"`{run_id}` run cleanup incomplete.")


if __name__ == "__main__": # pragma: no cover
run_ids = _get_stale_run_ids()
logger.setLevel(logging.INFO)
for _id in run_ids:
logger.info(f"Cleaning up run `{_id}`")
cleanup_run(_id)
47 changes: 47 additions & 0 deletions zetta_utils/run/resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import uuid
from enum import Enum

import attrs

from zetta_utils.layer.db_layer import build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend

DEFAULT_PROJECT = "zetta-research"
RESOURCE_DB_NAME = "run-resource"

RESOURCE_DB = build_db_layer(
DatastoreBackend(
namespace=RESOURCE_DB_NAME,
project=DEFAULT_PROJECT, # all resources must be tracked in the default project
)
)


class ResourceTypes(Enum):
K8S_CONFIGMAP = "k8s_configmap"
K8S_DEPLOYMENT = "k8s_deployment"
K8S_JOB = "k8s_job"
K8S_SECRET = "k8s_secret"
K8S_SERVICE = "k8s_service"
SQS_QUEUE = "sqs_queue"


class ResourceKeys(Enum):
RUN_ID = "run_id"
TYPE = "type"
NAME = "name"


@attrs.frozen
class Resource:
run_id: str
type: str
name: str
region: str = ""


def register_resource(resource: Resource) -> None:
_resource = attrs.asdict(resource)
row_key = str(uuid.uuid4())
col_keys = tuple(_resource.keys())
RESOURCE_DB[(row_key, col_keys)] = _resource

0 comments on commit a74730f

Please sign in to comment.