From f238868a8d9b5f11b4068fba2445a75a546fae70 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 3 Sep 2024 01:32:08 -0700 Subject: [PATCH] Update deployer API to support more orchestrators --- metaflow/plugins/argo/argo_workflows_deployer.py | 5 ++++- .../aws/step_functions/step_functions_deployer.py | 5 ++++- metaflow/runner/deployer.py | 15 +++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows_deployer.py b/metaflow/plugins/argo/argo_workflows_deployer.py index 57f95dda34d..bae6f1dfcba 100644 --- a/metaflow/plugins/argo/argo_workflows_deployer.py +++ b/metaflow/plugins/argo/argo_workflows_deployer.py @@ -226,7 +226,9 @@ def trigger(instance: DeployedFlow, **kwargs): ) command_obj = instance.deployer.spm.get(pid) - content = handle_timeout(tfp_runner_attribute, command_obj) + content = handle_timeout( + tfp_runner_attribute, command_obj, instance.deployer.TIMEOUT + ) if command_obj.process.returncode == 0: triggered_run = TriggeredRun(deployer=instance.deployer, content=content) @@ -257,6 +259,7 @@ class ArgoWorkflowsDeployer(DeployerImpl): """ TYPE: ClassVar[Optional[str]] = "argo-workflows" + TIMEOUT: ClassVar[Optional[int]] = 10 def __init__(self, deployer_kwargs, **kwargs): """ diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer.py b/metaflow/plugins/aws/step_functions/step_functions_deployer.py index 1bb5f2cf048..2226c75837b 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_deployer.py +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer.py @@ -193,7 +193,9 @@ def trigger(instance: DeployedFlow, **kwargs): ) command_obj = instance.deployer.spm.get(pid) - content = handle_timeout(tfp_runner_attribute, command_obj) + content = handle_timeout( + tfp_runner_attribute, command_obj, instance.deployer.TIMEOUT + ) if command_obj.process.returncode == 0: triggered_run = TriggeredRun(deployer=instance.deployer, content=content) @@ -217,6 +219,7 @@ class StepFunctionsDeployer(DeployerImpl): """ TYPE: ClassVar[Optional[str]] = "step-functions" + TIMEOUT: ClassVar[Optional[int]] = 10 def __init__(self, deployer_kwargs, **kwargs): """ diff --git a/metaflow/runner/deployer.py b/metaflow/runner/deployer.py index 680e8b0a5cd..3a58c5d3018 100644 --- a/metaflow/runner/deployer.py +++ b/metaflow/runner/deployer.py @@ -11,7 +11,7 @@ from metaflow.runner.utils import read_from_file_when_ready -def handle_timeout(tfp_runner_attribute, command_obj: CommandManager): +def handle_timeout(tfp_runner_attribute, command_obj: CommandManager, timeout): """ Handle the timeout for a running subprocess command that reads a file and raises an error with appropriate logs if a TimeoutError occurs. @@ -35,7 +35,7 @@ def handle_timeout(tfp_runner_attribute, command_obj: CommandManager): stdout and stderr logs. """ try: - content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=10) + content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=timeout) return content except TimeoutError as e: stdout_log = open(command_obj.log_files["stdout"]).read() @@ -252,7 +252,10 @@ def _enrich_object(self, env): class DeployerImpl(object): """ Base class for deployer implementations. Each implementation should define a TYPE - class variable that matches the name of the CLI group. + class variable that matches the name of the CLI group. The deployer implementation + should also define a `TIMEOUT` class variable that specifies the timeout for reading + the contents of the temporary file. + Parameters ---------- @@ -274,6 +277,7 @@ class variable that matches the name of the CLI group. """ TYPE: ClassVar[Optional[str]] = None + TIMEOUT: ClassVar[Optional[int]] = None def __init__( self, @@ -349,11 +353,14 @@ def create(self, **kwargs) -> DeployedFlow: ) command_obj = self.spm.get(pid) - content = handle_timeout(tfp_runner_attribute, command_obj) + content = handle_timeout( + tfp_runner_attribute, command_obj, timeout=self.TIMEOUT + ) content = json.loads(content) self.name = content.get("name") self.flow_name = content.get("flow_name") self.metadata = content.get("metadata") + self.additional_metadata = content.get("additional_metadata", {}) if command_obj.process.returncode == 0: deployed_flow = DeployedFlow(deployer=self)