From 678a340607d13013ea66a4009a10b8ea8227d6f8 Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Thu, 2 Jan 2025 15:22:24 -0800 Subject: [PATCH] [robust heartbeats] - convert heartbeats from threads to a process --- metaflow/metadata_provider/heartbeat.py | 91 ++++++++++++++----------- 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/metaflow/metadata_provider/heartbeat.py b/metaflow/metadata_provider/heartbeat.py index 7320c5e7bde..65e413fcfd5 100644 --- a/metaflow/metadata_provider/heartbeat.py +++ b/metaflow/metadata_provider/heartbeat.py @@ -3,14 +3,46 @@ from threading import Thread import requests +import sys +import json -from metaflow.exception import MetaflowException +from metaflow.sidecar import MessageTypes, Message from metaflow.metaflow_config import SERVICE_HEADERS -from metaflow.sidecar import Message, MessageTypes +from metaflow.exception import MetaflowException +from multiprocessing import Process HB_URL_KEY = "hb_url" +def call_heartbeat(hb_url, headers): + response = requests.post(url=hb_url, data="{}", headers=headers) + if response.status_code == 200: + return json.loads(response.json()).get("wait_time_in_seconds") + else: + raise HeartBeatException( + "HeartBeat request (%s) failed" + " (code %s): %s" % (hb_url, response.status_code, response.text) + ) + + +# todo make this function more robust. +# todo add some logging for this to be captured somewhere +def ping_heartbeat(hb_url, headers, default_frequency_secs): + retry_counter = 0 + while True: + try: + frequency_secs = call_heartbeat(hb_url, headers) + + if frequency_secs is None or frequency_secs <= 0: + frequency_secs = default_frequency_secs + + time.sleep(frequency_secs) + retry_counter = 0 + except HeartBeatException as e: + retry_counter = retry_counter + 1 + time.sleep(4**retry_counter) + + class HeartBeatException(MetaflowException): headline = "Metaflow heart beat error" @@ -21,58 +53,35 @@ def __init__(self, msg): class MetadataHeartBeat(object): def __init__(self): self.headers = SERVICE_HEADERS - self.req_thread = Thread(target=self._ping) - self.req_thread.daemon = True self.default_frequency_secs = 10 self.hb_url = None + self.hb_process = None def process_message(self, msg): # type: (Message) -> None if msg.msg_type == MessageTypes.SHUTDOWN: self._shutdown() - if not self.req_thread.is_alive(): + + if not self.hb_process: # set post url self.hb_url = msg.payload[HB_URL_KEY] - # start thread - self.req_thread.start() + self.hb_process = Process( + target=ping_heartbeat, + args=(self.hb_url, self.headers, self.default_frequency_secs), + daemon=True, + ) + self.hb_process.start() @classmethod def get_worker(cls): return cls - def _ping(self): - retry_counter = 0 - while True: - try: - frequency_secs = self._heartbeat() - - if frequency_secs is None or frequency_secs <= 0: - frequency_secs = self.default_frequency_secs - - time.sleep(frequency_secs) - retry_counter = 0 - except HeartBeatException as e: - retry_counter = retry_counter + 1 - time.sleep(4**retry_counter) - - def _heartbeat(self): - if self.hb_url is not None: - response = requests.post( - url=self.hb_url, data="{}", headers=self.headers.copy() - ) - # Unfortunately, response.json() returns a string that we need - # to cast to json; however when the request encounters an error - # the return type is a json blob :/ - if response.status_code == 200: - return json.loads(response.json()).get("wait_time_in_seconds") - else: - raise HeartBeatException( - "HeartBeat request (%s) failed" - " (code %s): %s" - % (self.hb_url, response.status_code, response.text) - ) - return None - def _shutdown(self): # attempts sending one last heartbeat - self._heartbeat() + if self.hb_process is not None and self.hb_url is not None: + try: + call_heartbeat(self.hb_url, self.headers) + except HeartBeatException as e: + pass + self.hb_process.terminate() + self.hb_process.join()