Skip to content

Commit

Permalink
Update deployer API to support more orchestrators
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Sep 3, 2024
1 parent dda21b9 commit f238868
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
5 changes: 4 additions & 1 deletion metaflow/plugins/argo/argo_workflows_deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def trigger(instance: DeployedFlow, **kwargs):
)

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(tfp_runner_attribute, command_obj)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.TIMEOUT
)

if command_obj.process.returncode == 0:
triggered_run = TriggeredRun(deployer=instance.deployer, content=content)
Expand Down Expand Up @@ -257,6 +259,7 @@ class ArgoWorkflowsDeployer(DeployerImpl):
"""

TYPE: ClassVar[Optional[str]] = "argo-workflows"
TIMEOUT: ClassVar[Optional[int]] = 10

def __init__(self, deployer_kwargs, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def trigger(instance: DeployedFlow, **kwargs):
)

command_obj = instance.deployer.spm.get(pid)
content = handle_timeout(tfp_runner_attribute, command_obj)
content = handle_timeout(
tfp_runner_attribute, command_obj, instance.deployer.TIMEOUT
)

if command_obj.process.returncode == 0:
triggered_run = TriggeredRun(deployer=instance.deployer, content=content)
Expand All @@ -217,6 +219,7 @@ class StepFunctionsDeployer(DeployerImpl):
"""

TYPE: ClassVar[Optional[str]] = "step-functions"
TIMEOUT: ClassVar[Optional[int]] = 10

def __init__(self, deployer_kwargs, **kwargs):
"""
Expand Down
15 changes: 11 additions & 4 deletions metaflow/runner/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metaflow.runner.utils import read_from_file_when_ready


def handle_timeout(tfp_runner_attribute, command_obj: CommandManager):
def handle_timeout(tfp_runner_attribute, command_obj: CommandManager, timeout):
"""
Handle the timeout for a running subprocess command that reads a file
and raises an error with appropriate logs if a TimeoutError occurs.
Expand All @@ -35,7 +35,7 @@ def handle_timeout(tfp_runner_attribute, command_obj: CommandManager):
stdout and stderr logs.
"""
try:
content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=10)
content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=timeout)
return content
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
Expand Down Expand Up @@ -252,7 +252,10 @@ def _enrich_object(self, env):
class DeployerImpl(object):
"""
Base class for deployer implementations. Each implementation should define a TYPE
class variable that matches the name of the CLI group.
class variable that matches the name of the CLI group. The deployer implementation
should also define a `TIMEOUT` class variable that specifies the timeout for reading
the contents of the temporary file.
Parameters
----------
Expand All @@ -274,6 +277,7 @@ class variable that matches the name of the CLI group.
"""

TYPE: ClassVar[Optional[str]] = None
TIMEOUT: ClassVar[Optional[int]] = None

def __init__(
self,
Expand Down Expand Up @@ -349,11 +353,14 @@ def create(self, **kwargs) -> DeployedFlow:
)

command_obj = self.spm.get(pid)
content = handle_timeout(tfp_runner_attribute, command_obj)
content = handle_timeout(
tfp_runner_attribute, command_obj, timeout=self.TIMEOUT
)
content = json.loads(content)
self.name = content.get("name")
self.flow_name = content.get("flow_name")
self.metadata = content.get("metadata")
self.additional_metadata = content.get("additional_metadata", {})

if command_obj.process.returncode == 0:
deployed_flow = DeployedFlow(deployer=self)
Expand Down

0 comments on commit f238868

Please sign in to comment.