-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adds run context manager with heartbeat support
- Loading branch information
Showing
4 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import RUN_DB, RunInfo, register, run_ctx_manager | ||
from .resource import Resource, register_resource, ResourceTypes, ResourceKeys, RESOURCE_DB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
import time | ||
from contextlib import contextmanager | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
import attrs | ||
|
||
from zetta_utils import log | ||
from zetta_utils.cloud_management.resource_allocation.k8s import ClusterInfo | ||
from zetta_utils.common import RepeatTimer | ||
from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer | ||
from zetta_utils.layer.db_layer.datastore import DatastoreBackend | ||
from zetta_utils.mazepa import id_generation | ||
from zetta_utils.parsing import json | ||
|
||
logger = log.get_logger("zetta_utils") | ||
|
||
DEFAULT_PROJECT = "zetta-research" | ||
RUN_DB_NAME = "run-info" | ||
RUN_INFO_PATH = "gs://zetta_utils_runs" | ||
RUN_DB = build_db_layer(DatastoreBackend(namespace=RUN_DB_NAME, project=DEFAULT_PROJECT)) | ||
RUN_ID = None | ||
|
||
|
||
class RunInfo(Enum): | ||
ZETTA_USER = "zetta_user" | ||
HEARTBEAT = "heartbeat" | ||
CLUSTERS = "clusters" | ||
STATE = "state" | ||
|
||
|
||
class RunState(Enum): | ||
RUNNING = "running" | ||
TIMEOUT = "timeout" | ||
SUCCEEDED = "succeeded" | ||
FAILED = "failed" | ||
|
||
|
||
def register(clusters: list[ClusterInfo]) -> None: # pragma: no cover | ||
""" | ||
Register run info to database, for the garbage collector. | ||
""" | ||
info: DBRowDataT = { | ||
RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"], | ||
RunInfo.HEARTBEAT.value: time.time(), | ||
RunInfo.CLUSTERS.value: json.dumps([attrs.asdict(cluster) for cluster in clusters]), | ||
RunInfo.STATE.value: RunState.RUNNING.value, | ||
} | ||
_update_run_info(info) | ||
|
||
|
||
def _update_run_info(info: DBRowDataT) -> None: # pragma: no cover | ||
row_key = f"run-{RUN_ID}" | ||
col_keys = tuple(info.keys()) | ||
RUN_DB[(row_key, col_keys)] = info | ||
|
||
|
||
@contextmanager | ||
def run_ctx_manager(run_id: Optional[str] = None, heartbeat_interval: int = 5): | ||
def _send_heartbeat(): | ||
info: DBRowDataT = {RunInfo.HEARTBEAT.value: time.time()} | ||
_update_run_info(info) | ||
|
||
heartbeat = None | ||
if run_id is None: | ||
run_id = id_generation.get_unique_id(slug_len=4, add_uuid=False, max_len=50) | ||
|
||
global RUN_ID # pylint: disable=global-statement | ||
RUN_ID = run_id | ||
try: | ||
if heartbeat_interval > 0: | ||
heartbeat = RepeatTimer(heartbeat_interval, _send_heartbeat) | ||
heartbeat.start() | ||
yield | ||
except Exception as e: | ||
info: DBRowDataT = {RunInfo.STATE.value: RunState.FAILED.value} | ||
_update_run_info(info) | ||
raise e from None | ||
finally: | ||
RUN_ID = None | ||
if heartbeat: | ||
heartbeat.cancel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
""" | ||
Garbage collection for run resources. | ||
""" | ||
|
||
import json | ||
import logging | ||
import os | ||
import time | ||
from typing import Any, Dict, List, Mapping | ||
|
||
import taskqueue | ||
from boto3.exceptions import Boto3Error | ||
from google.api_core.exceptions import GoogleAPICallError | ||
from kubernetes.client.exceptions import ApiException as K8sApiException | ||
|
||
from kubernetes import client as k8s_client # type: ignore | ||
from zetta_utils.cloud_management.resource_allocation.k8s import ( | ||
ClusterInfo, | ||
get_cluster_data, | ||
) | ||
from zetta_utils.log import get_logger | ||
from zetta_utils.message_queues.sqs import utils as sqs_utils | ||
from zetta_utils.run import ( | ||
RESOURCE_DB, | ||
RUN_DB, | ||
Resource, | ||
ResourceKeys, | ||
ResourceTypes, | ||
RunInfo, | ||
) | ||
|
||
logger = get_logger("zetta_utils") | ||
|
||
|
||
def _delete_db_entry(client: Any, entry_id: str, columns: List[str]): | ||
parent_key = client.key("Row", entry_id) | ||
for column in columns: | ||
col_key = client.key("Column", column, parent=parent_key) | ||
client.delete(col_key) | ||
|
||
|
||
def _delete_run_entry(run_id: str): # pragma: no cover | ||
client = RUN_DB.backend.client # type: ignore | ||
columns = [key.value for key in list(RunInfo)] | ||
_delete_db_entry(client, run_id, columns) | ||
|
||
|
||
def _delete_resource_entry(resource_id: str): # pragma: no cover | ||
client = RESOURCE_DB.backend.client # type: ignore | ||
columns = [key.value for key in list(ResourceKeys)] | ||
_delete_db_entry(client, resource_id, columns) | ||
|
||
|
||
def _get_stale_run_ids() -> list[str]: # pragma: no cover | ||
client = RUN_DB.backend.client # type: ignore | ||
|
||
lookback = int(os.environ["EXECUTION_HEARTBEAT_LOOKBACK"]) | ||
time_diff = time.time() - lookback | ||
|
||
filters = [("heartbeat", "<", time_diff)] | ||
query = client.query(kind="Column", filters=filters) | ||
query.keys_only() | ||
|
||
entities = list(query.fetch()) | ||
return [entity.key.parent.id_or_name for entity in entities] | ||
|
||
|
||
def _read_clusters(run_id_key: str) -> list[ClusterInfo]: # pragma: no cover | ||
col_keys = ("clusters",) | ||
try: | ||
clusters_str = RUN_DB[(run_id_key, col_keys)][col_keys[0]] | ||
except KeyError: | ||
return [] | ||
clusters: list[Mapping] = json.loads(clusters_str) | ||
return [ClusterInfo(**cluster) for cluster in clusters] | ||
|
||
|
||
def _read_run_resources(run_id: str) -> Dict[str, Resource]: | ||
client = RESOURCE_DB.backend.client # type: ignore | ||
|
||
query = client.query(kind="Column") | ||
query = query.add_filter("run_id", "=", run_id) | ||
query.keys_only() | ||
|
||
entities = list(query.fetch()) | ||
resouce_ids = [entity.key.parent.id_or_name for entity in entities] | ||
|
||
col_keys = ("type", "name") | ||
resources = RESOURCE_DB[(resouce_ids, col_keys)] | ||
resources = [Resource(run_id=run_id, **res) for res in resources] | ||
return dict(zip(resouce_ids, resources)) | ||
|
||
|
||
def _delete_k8s_resources(run_id: str, resources: Dict[str, Resource]) -> bool: # pragma: no cover | ||
success = True | ||
logger.info(f"Deleting k8s resources from run {run_id}") | ||
clusters = _read_clusters(run_id) | ||
for cluster in clusters: | ||
try: | ||
configuration, _ = get_cluster_data(cluster) | ||
except GoogleAPICallError as exc: | ||
# cluster does not exit, discard resource entries | ||
if exc.code == 404: | ||
for resource_id in resources.keys(): | ||
_delete_resource_entry(resource_id) | ||
continue | ||
|
||
k8s_client.Configuration.set_default(configuration) | ||
|
||
k8s_apps_v1_api = k8s_client.AppsV1Api() | ||
k8s_core_v1_api = k8s_client.CoreV1Api() | ||
for resource_id, resource in resources.items(): | ||
try: | ||
if resource.type == ResourceTypes.K8S_DEPLOYMENT.value: | ||
logger.info(f"Deleting k8s deployment `{resource.name}`") | ||
k8s_apps_v1_api.delete_namespaced_deployment( | ||
name=resource.name, namespace="default" | ||
) | ||
elif resource.type == ResourceTypes.K8S_SECRET.value: | ||
logger.info(f"Deleting k8s secret `{resource.name}`") | ||
k8s_core_v1_api.delete_namespaced_secret( | ||
name=resource.name, namespace="default" | ||
) | ||
except K8sApiException as exc: | ||
if exc.status == 404: | ||
success = True | ||
logger.info(f"Resource does not exist: `{resource.name}`: {exc}") | ||
_delete_resource_entry(resource_id) | ||
else: | ||
success = False | ||
logger.warning(f"Failed to delete k8s resource `{resource.name}`: {exc}") | ||
raise K8sApiException() from exc | ||
return success | ||
|
||
|
||
def _delete_sqs_queues(resources: Dict[str, Resource]) -> bool: # pragma: no cover | ||
success = True | ||
for resource_id, resource in resources.items(): | ||
if resource.type != ResourceTypes.SQS_QUEUE.value: | ||
continue | ||
region_name = resource.region | ||
if resource.region == "" or resource.region is None: | ||
region_name = taskqueue.secrets.AWS_DEFAULT_REGION | ||
sqs_client = sqs_utils.get_sqs_client(region_name=region_name) | ||
try: | ||
logger.info(f"Deleting SQS queue `{resource.name}`") | ||
queue_url = sqs_client.get_queue_url(QueueName=resource.name)["QueueUrl"] | ||
sqs_client.delete_queue(QueueUrl=queue_url) | ||
except sqs_client.exceptions.QueueDoesNotExist as exc: | ||
logger.info(f"Queue does not exist: `{resource.name}`: {exc}") | ||
_delete_resource_entry(resource_id) | ||
except Boto3Error as exc: | ||
success = False | ||
logger.warning(f"Failed to delete queue `{resource.name}`: {exc}") | ||
return success | ||
|
||
|
||
def cleanup_run(run_id: str): | ||
success = True | ||
resources = _read_run_resources(run_id) | ||
success &= _delete_k8s_resources(run_id, resources) | ||
success &= _delete_sqs_queues(resources) | ||
|
||
if success is True: | ||
_delete_run_entry(run_id) | ||
logger.info(f"`{run_id}` run cleanup complete.") | ||
else: | ||
logger.info(f"`{run_id}` run cleanup incomplete.") | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
run_ids = _get_stale_run_ids() | ||
logger.setLevel(logging.INFO) | ||
for _id in run_ids: | ||
logger.info(f"Cleaning up run `{_id}`") | ||
cleanup_run(_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import uuid | ||
from enum import Enum | ||
|
||
import attrs | ||
|
||
from zetta_utils.layer.db_layer import build_db_layer | ||
from zetta_utils.layer.db_layer.datastore import DatastoreBackend | ||
|
||
DEFAULT_PROJECT = "zetta-research" | ||
RESOURCE_DB_NAME = "run-resource" | ||
|
||
RESOURCE_DB = build_db_layer( | ||
DatastoreBackend( | ||
namespace=RESOURCE_DB_NAME, | ||
project=DEFAULT_PROJECT, # all resources must be tracked in the default project | ||
) | ||
) | ||
|
||
|
||
class ResourceTypes(Enum): | ||
K8S_CONFIGMAP = "k8s_configmap" | ||
K8S_DEPLOYMENT = "k8s_deployment" | ||
K8S_JOB = "k8s_job" | ||
K8S_SECRET = "k8s_secret" | ||
K8S_SERVICE = "k8s_service" | ||
SQS_QUEUE = "sqs_queue" | ||
|
||
|
||
class ResourceKeys(Enum): | ||
RUN_ID = "run_id" | ||
TYPE = "type" | ||
NAME = "name" | ||
|
||
|
||
@attrs.frozen | ||
class Resource: | ||
run_id: str | ||
type: str | ||
name: str | ||
region: str = "" | ||
|
||
|
||
def register_resource(resource: Resource) -> None: | ||
_resource = attrs.asdict(resource) | ||
row_key = str(uuid.uuid4()) | ||
col_keys = tuple(_resource.keys()) | ||
RESOURCE_DB[(row_key, col_keys)] = _resource |