From 171013f095f4aca2d0a9edd7a5afc3631deada53 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 21 Oct 2024 15:35:17 -0700 Subject: [PATCH 01/28] Add static and runtime dag info, API to fetch ancestor tasks --- metaflow/client/core.py | 80 ++++++++++++++++++++++++++ metaflow/metadata_provider/metadata.py | 15 +++++ metaflow/task.py | 65 ++++++++++++++++++++- 3 files changed, 157 insertions(+), 3 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 87b6a88c37c..e1f3a0cd69d 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,6 +1123,86 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" + def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + """ + Returns a dictionary with iterators over the immediate ancestors of this task. + + Returns + ------- + Dict[str, Iterator[Task]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestors steps and the values are iterators over the + tasks of the corresponding steps. + """ + + def _prev_task(flow_id, run_id, previous_step): + # Find any previous task for current step + + step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No previous task found for step {previous_step}") + + flow_id, run_id, step_name, task_id = self.path_components + previous_steps = self.metadata_dict.get("previous_steps", None) + print( + f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" + ) + print(f"previous_steps: {previous_steps}") + + if not previous_steps or len(previous_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + ancestor_iters = {} + if len(previous_steps) > 1: + # This is a static join, so there is no change in foreach stack length + prev_foreach_stack_len = cur_foreach_stack_len + else: + prev_task = _prev_task(flow_id, run_id, previous_steps[0]) + prev_foreach_stack_len = len( + prev_task.metadata_dict.get("foreach-stack", []) + ) + + print( + f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" + ) + if prev_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif prev_foreach_stack_len > cur_foreach_stack_len: + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in previous steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for prev_step in previous_steps: + # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " + # f"{field_value}") + ancestor_iters[prev_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + ) + + return ancestor_iters + + # def closest_siblings(self) -> Iterator["Task"]: + # """ + # Returns an iterator over the closest siblings of this task. + # + # Returns + # ------- + # Iterator[Task] + # Iterator over the closest siblings of this task + # """ + # flow_id, run_id, step_name, task_id = self.path_components + # print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}") + @property def metadata(self) -> List[Metadata]: """ diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 11c3873a85e..a2d607eff7a 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -672,6 +672,21 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) + @classmethod + def _filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + raise NotImplementedError() + + @classmethod + def filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + task_ids = cls._filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + return task_ids + @staticmethod def _apply_filter(elts, filters): if filters is None: diff --git a/metaflow/task.py b/metaflow/task.py index 6b73302652b..2ccaeab9c42 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -4,6 +4,8 @@ import sys import os import time +import json +import hashlib import traceback from types import MethodType, FunctionType @@ -37,6 +39,23 @@ class MetaflowTask(object): MetaflowTask prepares a Flow instance for execution of a single step. """ + @staticmethod + def _dynamic_runtime_metadata(foreach_stack): + foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack] + foreach_indices_truncated = foreach_indices[:-1] + foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack] + return foreach_indices, foreach_indices_truncated, foreach_step_names + + def _static_runtime_metadata(self, graph_info, step_name): + if step_name == "start": + return [] + + return [ + node_name + for node_name, attributes in graph_info["steps"].items() + if step_name in attributes["next"] + ] + def __init__( self, flow, @@ -493,6 +512,33 @@ def run_step( ) ) + # Add runtime dag info + foreach_indices, foreach_indices_truncated, foreach_step_names = ( + self._dynamic_runtime_metadata(foreach_stack) + ) + metadata.extend( + [ + MetaDatum( + field="foreach-indices", + value=foreach_indices, + type="foreach-indices", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-indices-truncated", + value=foreach_indices_truncated, + type="foreach-indices-truncated", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-step-names", + value=foreach_step_names, + type="foreach-step-names", + tags=metadata_tags, + ), + ] + ) + self.metadata.register_metadata( run_id, step_name, @@ -559,12 +605,17 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. if join_type: + if join_type == "foreach": + # We only want to persist one of the input paths + self.flow._input_paths = str(input_paths[0]) + # Join step: # Ensure that we have the right number of inputs. The @@ -616,7 +667,9 @@ def run_step( "graph_info": self.flow._graph_info, } ) - + previous_steps = self._static_runtime_metadata( + self.flow._graph_info, step_name + ) for deco in decorators: deco.task_pre_step( step_name, @@ -727,8 +780,14 @@ def run_step( field="attempt_ok", value=attempt_ok, type="internal_attempt_status", - tags=["attempt_id:{0}".format(retry_count)], - ) + tags=metadata_tags, + ), + MetaDatum( + field="previous_steps", + value=previous_steps, + type="previous_steps", + tags=metadata_tags, + ), ], ) From a842dc1b011dc5017365023cf7e3e1789f308e6c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 02:53:31 -0700 Subject: [PATCH 02/28] Add API to get immediate successors --- metaflow/client/core.py | 82 +++++++++++++++++++++++++++++++++-------- metaflow/task.py | 15 +++++--- 2 files changed, 76 insertions(+), 21 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index e1f3a0cd69d..fa9c5a028be 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,21 +1123,21 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + def immediate_ancestors(self) -> Dict[str, List[str]]: """ - Returns a dictionary with iterators over the immediate ancestors of this task. + Returns a dictionary of immediate ancestors task ids of this task for each + previous step. Returns ------- - Dict[str, Iterator[Task]] + Dict[str, List[str]] Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are iterators over the - tasks of the corresponding steps. + names of the ancestors steps and the values are the corresponding + task ids of the ancestors. """ def _prev_task(flow_id, run_id, previous_step): # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) task = next(iter(step.tasks()), None) if task: @@ -1146,10 +1146,6 @@ def _prev_task(flow_id, run_id, previous_step): flow_id, run_id, step_name, task_id = self.path_components previous_steps = self.metadata_dict.get("previous_steps", None) - print( - f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" - ) - print(f"previous_steps: {previous_steps}") if not previous_steps or len(previous_steps) == 0: return @@ -1165,9 +1161,6 @@ def _prev_task(flow_id, run_id, previous_step): prev_task.metadata_dict.get("foreach-stack", []) ) - print( - f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" - ) if prev_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) @@ -1181,16 +1174,73 @@ def _prev_task(flow_id, run_id, previous_step): field_value = self.metadata_dict.get("foreach-indices-truncated") for prev_step in previous_steps: - # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " - # f"{field_value}") ancestor_iters[prev_step] = ( self._metaflow.metadata.filter_tasks_by_metadata( flow_id, run_id, step_name, prev_step, field_name, field_value ) ) - return ancestor_iters + def immediate_successors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate successors task ids of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate successors of this task. The keys are the + names of the successors steps and the values are the corresponding + task ids of the successors. + """ + + def _successor_task(flow_id, run_id, successor_step): + # Find any previous task for current step + step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No successor task found for step {successor_step}") + + flow_id, run_id, step_name, task_id = self.path_components + successor_steps = self.metadata_dict.get("successor_steps", None) + + if not successor_steps or len(successor_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + successor_iters = {} + if len(successor_steps) > 1: + # This is a static split, so there is no change in foreach stack length + successor_foreach_stack_len = cur_foreach_stack_len + else: + successor_task = _successor_task(flow_id, run_id, successor_steps[0]) + successor_foreach_stack_len = len( + successor_task.metadata_dict.get("foreach-stack", []) + ) + + if successor_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif successor_foreach_stack_len > cur_foreach_stack_len: + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of tasks in successor steps + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in successor steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for successor_step in successor_steps: + successor_iters[successor_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, successor_step, field_name, field_value + ) + ) + return successor_iters + # def closest_siblings(self) -> Iterator["Task"]: # """ # Returns an iterator over the closest siblings of this task. diff --git a/metaflow/task.py b/metaflow/task.py index 2ccaeab9c42..2e3c2975169 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -47,14 +47,13 @@ def _dynamic_runtime_metadata(foreach_stack): return foreach_indices, foreach_indices_truncated, foreach_step_names def _static_runtime_metadata(self, graph_info, step_name): - if step_name == "start": - return [] - - return [ + prev_steps = [ node_name for node_name, attributes in graph_info["steps"].items() if step_name in attributes["next"] ] + succesor_steps = graph_info["steps"][step_name]["next"] + return prev_steps, succesor_steps def __init__( self, @@ -667,7 +666,7 @@ def run_step( "graph_info": self.flow._graph_info, } ) - previous_steps = self._static_runtime_metadata( + previous_steps, successor_steps = self._static_runtime_metadata( self.flow._graph_info, step_name ) for deco in decorators: @@ -788,6 +787,12 @@ def run_step( type="previous_steps", tags=metadata_tags, ), + MetaDatum( + field="successor_steps", + value=successor_steps, + type="successor_steps", + tags=metadata_tags, + ), ], ) From 3f9fdfd7d3771ecc397a3fcb985ef17985c17439 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 11:13:24 -0700 Subject: [PATCH 03/28] Add API for getting closest siblings --- metaflow/client/core.py | 42 ++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index fa9c5a028be..014adad1ed5 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1241,17 +1241,37 @@ def _successor_task(flow_id, run_id, successor_step): ) return successor_iters - # def closest_siblings(self) -> Iterator["Task"]: - # """ - # Returns an iterator over the closest siblings of this task. - # - # Returns - # ------- - # Iterator[Task] - # Iterator over the closest siblings of this task - # """ - # flow_id, run_id, step_name, task_id = self.path_components - # print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}") + def closest_siblings(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of closest siblings of this task for each step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of closest siblings of this task. The keys are the + names of the current step and the values are the corresponding + task ids of the siblings. + """ + flow_id, run_id, step_name, task_id = self.path_components + + foreach_stack = self.metadata_dict.get("foreach-stack", []) + foreach_step_names = self.metadata_dict.get("foreach-step-names", []) + if len(foreach_stack) == 0: + raise MetaflowInternalError("Task is not part of any foreach split") + elif step_name != foreach_step_names[-1]: + raise MetaflowInternalError( + f"Step {step_name} does not have any direct siblings since it is not part " + f"of a new foreach split." + ) + + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices-truncated") + # We find all tasks of the same step that have the same foreach-indices-truncated value + return { + step_name: self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, step_name, field_name, field_value + ) + } @property def metadata(self) -> List[Metadata]: From 4d0298dfe4d71147afbfdbf1528c901ddd32ff5b Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 11:34:18 -0700 Subject: [PATCH 04/28] Update metadata API params --- metaflow/metadata_provider/metadata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index a2d607eff7a..9e11515abde 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -674,16 +674,16 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def _filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, prev_step, field_name, field_value + cls, flow_id, run_id, step_name, query_step, field_name, field_value ): raise NotImplementedError() @classmethod def filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, prev_step, field_name, field_value + cls, flow_id, run_id, step_name, query_step, field_name, field_value ): task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, step_name, prev_step, field_name, field_value + flow_id, run_id, step_name, query_step, field_name, field_value ) return task_ids From a42f31f05e22c60ab7b94c1495a87072904a16dd Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 15:13:53 -0700 Subject: [PATCH 05/28] Refactor ancestor and successor client code --- metaflow/client/core.py | 191 ++++++++++++------------- metaflow/metadata_provider/metadata.py | 7 +- metaflow/task.py | 4 - 3 files changed, 99 insertions(+), 103 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 014adad1ed5..68b5a18e3aa 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,64 +1123,108 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def immediate_ancestors(self) -> Dict[str, List[str]]: + def _get_task_for_queried_step(self, flow_id, run_id, query_step): """ - Returns a dictionary of immediate ancestors task ids of this task for each - previous step. - - Returns - ------- - Dict[str, List[str]] - Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are the corresponding - task ids of the ancestors. + Returns a Task object corresponding to the queried step. + If the queried step has several tasks, the first task is returned. """ + # Find any previous task for current step + step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No task found for the queried step {query_step}") - def _prev_task(flow_id, run_id, previous_step): - # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No previous task found for step {previous_step}") - - flow_id, run_id, step_name, task_id = self.path_components - previous_steps = self.metadata_dict.get("previous_steps", None) - - if not previous_steps or len(previous_steps) == 0: - return - - cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) - ancestor_iters = {} - if len(previous_steps) > 1: + def _get_filter_query_value( + self, flow_id, run_id, cur_foreach_stack_len, query_steps, query_type + ): + """ + For a given query type, returns the field name and value to be used for filtering tasks + based on the task's metadata. + """ + if len(query_steps) > 1: # This is a static join, so there is no change in foreach stack length - prev_foreach_stack_len = cur_foreach_stack_len + query_foreach_stack_len = cur_foreach_stack_len else: - prev_task = _prev_task(flow_id, run_id, previous_steps[0]) - prev_foreach_stack_len = len( - prev_task.metadata_dict.get("foreach-stack", []) + query_task = self._get_task_for_queried_step( + flow_id, run_id, query_steps[0] + ) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-stack", []) ) - if prev_foreach_stack_len == cur_foreach_stack_len: + # print(f"query_foreach_stack_len: {query_foreach_stack_len} cur_foreach_stack_len: {cur_foreach_stack_len}") + if query_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) - elif prev_foreach_stack_len > cur_foreach_stack_len: - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices") + elif query_type == "ancestor": + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join + # We will compare the foreach-indices-truncated value of ancestor task with the + # foreach-indices value of current task + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach split + # We will compare the foreach-indices value of ancestor task with the + # foreach-indices value of current task + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") else: - # We will compare the foreach-stack-truncated value of current task with the - # foreach-stack value of tasks in previous steps - field_name = "foreach-indices" - field_value = self.metadata_dict.get("foreach-indices-truncated") + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach split + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of successor tasks + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach join + # We will compare the foreach-indices-truncated value of current task with the + # foreach-indices value of successor tasks + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + return field_name, field_value + + def _get_related_tasks( + self, steps_key: str, relation_type: str + ) -> Dict[str, List[str]]: + flow_id, run_id, _, _ = self.path_components + query_steps = self.metadata_dict.get(steps_key) + + if not query_steps: + return {} + + field_name, field_value = self._get_filter_query_value( + flow_id, + run_id, + len(self.metadata_dict.get("foreach-stack", [])), + query_steps, + relation_type, + ) - for prev_step in previous_steps: - ancestor_iters[prev_step] = ( - self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, prev_step, field_name, field_value - ) + return { + query_step: self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, query_step, field_name, field_value ) - return ancestor_iters + for query_step in query_steps + } + + @property + def immediate_ancestors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate ancestors task ids of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestors steps and the values are the corresponding + task ids of the ancestors. + """ + return self._get_related_tasks("previous_steps", "ancestor") + @property def immediate_successors(self) -> Dict[str, List[str]]: """ Returns a dictionary of immediate successors task ids of this task for each @@ -1193,55 +1237,10 @@ def immediate_successors(self) -> Dict[str, List[str]]: names of the successors steps and the values are the corresponding task ids of the successors. """ + return self._get_related_tasks("successor_steps", "successor") - def _successor_task(flow_id, run_id, successor_step): - # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No successor task found for step {successor_step}") - - flow_id, run_id, step_name, task_id = self.path_components - successor_steps = self.metadata_dict.get("successor_steps", None) - - if not successor_steps or len(successor_steps) == 0: - return - - cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) - successor_iters = {} - if len(successor_steps) > 1: - # This is a static split, so there is no change in foreach stack length - successor_foreach_stack_len = cur_foreach_stack_len - else: - successor_task = _successor_task(flow_id, run_id, successor_steps[0]) - successor_foreach_stack_len = len( - successor_task.metadata_dict.get("foreach-stack", []) - ) - - if successor_foreach_stack_len == cur_foreach_stack_len: - field_name = "foreach-indices" - field_value = self.metadata_dict.get(field_name) - elif successor_foreach_stack_len > cur_foreach_stack_len: - # We will compare the foreach-indices value of current task with the - # foreach-indices-truncated value of tasks in successor steps - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices") - else: - # We will compare the foreach-stack-truncated value of current task with the - # foreach-stack value of tasks in successor steps - field_name = "foreach-indices" - field_value = self.metadata_dict.get("foreach-indices-truncated") - - for successor_step in successor_steps: - successor_iters[successor_step] = ( - self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, successor_step, field_name, field_value - ) - ) - return successor_iters - - def closest_siblings(self) -> Dict[str, List[str]]: + @property + def immediate_siblings(self) -> Dict[str, List[str]]: """ Returns a dictionary of closest siblings of this task for each step. @@ -1252,13 +1251,13 @@ def closest_siblings(self) -> Dict[str, List[str]]: names of the current step and the values are the corresponding task ids of the siblings. """ - flow_id, run_id, step_name, task_id = self.path_components + flow_id, run_id, step_name, _ = self.path_components foreach_stack = self.metadata_dict.get("foreach-stack", []) foreach_step_names = self.metadata_dict.get("foreach-step-names", []) if len(foreach_stack) == 0: raise MetaflowInternalError("Task is not part of any foreach split") - elif step_name != foreach_step_names[-1]: + if step_name != foreach_step_names[-1]: raise MetaflowInternalError( f"Step {step_name} does not have any direct siblings since it is not part " f"of a new foreach split." @@ -1269,7 +1268,7 @@ def closest_siblings(self) -> Dict[str, List[str]]: # We find all tasks of the same step that have the same foreach-indices-truncated value return { step_name: self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, step_name, field_name, field_value + flow_id, run_id, step_name, field_name, field_value ) } diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 9e11515abde..ac713505099 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -674,16 +674,17 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def _filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, query_step, field_name, field_value + cls, flow_id, run_id, query_step, field_name, field_value ): raise NotImplementedError() @classmethod def filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, query_step, field_name, field_value + cls, flow_id, run_id, query_step, field_name, field_value ): + # TODO: Do we need to do anything wrt to task attempt? task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, step_name, query_step, field_name, field_value + flow_id, run_id, query_step, field_name, field_value ) return task_ids diff --git a/metaflow/task.py b/metaflow/task.py index 2e3c2975169..39d8c3f705f 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -611,10 +611,6 @@ def run_step( # user artifacts in the user's step code. if join_type: - if join_type == "foreach": - # We only want to persist one of the input paths - self.flow._input_paths = str(input_paths[0]) - # Join step: # Ensure that we have the right number of inputs. The From 493c3faed8059b47ca45d43c046819757da2d57c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 15:36:34 -0700 Subject: [PATCH 06/28] Remove unneccessary prints --- metaflow/client/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 68b5a18e3aa..0b33ad7108f 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1153,7 +1153,6 @@ def _get_filter_query_value( query_task.metadata_dict.get("foreach-stack", []) ) - # print(f"query_foreach_stack_len: {query_foreach_stack_len} cur_foreach_stack_len: {cur_foreach_stack_len}") if query_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) From 26a0a6aca1411bebc14d04f2aced7fe82f7e6060 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 3 Jan 2025 16:40:45 -0800 Subject: [PATCH 07/28] Support querying ancestors and successors in local metadata provider --- metaflow/metadata_provider/metadata.py | 29 +++--- metaflow/plugins/metadata_providers/local.py | 97 ++++++++++++++++++++ 2 files changed, 115 insertions(+), 11 deletions(-) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index ac713505099..7409fa371c0 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -672,21 +672,28 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) - @classmethod - def _filter_tasks_by_metadata( - cls, flow_id, run_id, query_step, field_name, field_value - ): - raise NotImplementedError() - @classmethod def filter_tasks_by_metadata( cls, flow_id, run_id, query_step, field_name, field_value ): - # TODO: Do we need to do anything wrt to task attempt? - task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, query_step, field_name, field_value - ) - return task_ids + """ + Filter tasks by metadata field and value, and returns the list of task_ids + that satisfy the query. + + Parameters + ---------- + flow_id : str + Flow id, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + query_step: str + Step name to query tasks from + field_name: str + Metadata field name to query + field_value: str + Metadata field value to query + """ + raise NotImplementedError() @staticmethod def _apply_filter(elts, filters): diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index ea7754cac5f..344f5d912c0 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -202,6 +202,103 @@ def _optimistically_mutate(): "Tagging failed due to too many conflicting updates from other processes" ) + @classmethod + def filter_tasks_by_metadata(cls, flow_id: str, run_id: str, query_step: str, + field_name: str, field_value: str) -> list: + """ + Filter tasks by metadata field and value, returning task IDs that match criteria. + + Parameters + ---------- + flow_id : str + Identifier for the flow + run_id : str + Identifier for the run + query_step : str + Name of the step to query tasks from + field_name : str + Name of metadata field to query + field_value : str + Value to match in metadata field + + Returns + ------- + list + List of task IDs that match the query criteria + + Raises + ------ + JSONDecodeError + If metadata file is corrupted or empty + FileNotFoundError + If metadata file is not found + """ + + def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: + """Find the most recent metadata file for the given field prefix.""" + # The metadata is saved as files with the format: sysmeta__.json + # and the artifact files are saved as: _artifact__.json + # We loop over all the JSON files in the directory and find the latest one + # that matches the field prefix. + json_files = glob.glob(os.path.join(path, "*.json")) + matching_files = [] + + for file_path in json_files: + filename = os.path.basename(file_path) + name, timestamp = filename.rsplit('_', 1) + timestamp = timestamp.split('.')[0] + + if name == field_prefix: + matching_files.append((file_path, int(timestamp))) + + if not matching_files: + return None + + return max(matching_files, key=lambda x: x[1]) + + def _read_metadata_value(file_path: str) -> dict: + """Read and parse metadata from JSON file.""" + try: + with open(file_path, 'r') as f: + return json.load(f) + except json.JSONDecodeError: + raise json.JSONDecodeError("Failed to decode metadata JSON file - may be corrupted or empty") + except Exception as e: + raise Exception(f"Error reading metadata file: {str(e)}") + + try: + # Get all tasks for the given step + tasks = LocalMetadataProvider.get_object( + "step", "task", {}, None, flow_id, run_id, query_step + ) + + filtered_task_ids = [] + field_name_prefix = f"sysmeta_{field_name}" + + # Filter tasks based on metadata + for task in tasks: + task_id = task.get("task_id") + if not task_id: + continue + + meta_path = LocalMetadataProvider._get_metadir( + flow_id, run_id, query_step, task_id + ) + + latest_file = _get_latest_metadata_file(meta_path, field_name_prefix) + if not latest_file: + continue + + # Read metadata and check value + metadata = _read_metadata_value(latest_file[0]) + if metadata.get("value") == field_value: + filtered_task_ids.append(task_id) + + return filtered_task_ids + + except Exception as e: + raise Exception(f"Failed to filter tasks: {str(e)}") + @classmethod def _get_object_internal( cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args From 5f092e71c21379c2ac3fd92067c14de00dd3f343 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 3 Jan 2025 17:10:52 -0800 Subject: [PATCH 08/28] Refactor and simplify client code --- metaflow/client/core.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 0b33ad7108f..51f06ad3897 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -379,7 +379,7 @@ def __iter__(self) -> Iterator["MetaflowObject"]: _CLASSES[self._CHILD_CLASS]._NAME, query_filter, self._attempt, - *self.path_components + *self.path_components, ) unfiltered_children = unfiltered_children if unfiltered_children else [] children = filter( @@ -1128,7 +1128,7 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step): Returns a Task object corresponding to the queried step. If the queried step has several tasks, the first task is returned. """ - # Find any previous task for current step + # Find any task corresponding to the queried step step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False) task = next(iter(step.tasks()), None) if task: @@ -1184,28 +1184,30 @@ def _get_filter_query_value( field_value = self.metadata_dict.get("foreach-indices-truncated") return field_name, field_value - def _get_related_tasks( - self, steps_key: str, relation_type: str - ) -> Dict[str, List[str]]: + def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: flow_id, run_id, _, _ = self.path_components - query_steps = self.metadata_dict.get(steps_key) + steps = ( + self.metadata_dict.get("previous_steps") + if relation_type == "ancestor" + else self.metadata_dict.get("successor_steps") + ) - if not query_steps: + if not steps: return {} field_name, field_value = self._get_filter_query_value( flow_id, run_id, len(self.metadata_dict.get("foreach-stack", [])), - query_steps, + steps, relation_type, ) return { - query_step: self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, query_step, field_name, field_value + step: self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step, field_name, field_value ) - for query_step in query_steps + for step in steps } @property @@ -1221,7 +1223,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]: names of the ancestors steps and the values are the corresponding task ids of the ancestors. """ - return self._get_related_tasks("previous_steps", "ancestor") + return self._get_related_tasks("ancestor") @property def immediate_successors(self) -> Dict[str, List[str]]: @@ -1236,7 +1238,7 @@ def immediate_successors(self) -> Dict[str, List[str]]: names of the successors steps and the values are the corresponding task ids of the successors. """ - return self._get_related_tasks("successor_steps", "successor") + return self._get_related_tasks("successor") @property def immediate_siblings(self) -> Dict[str, List[str]]: From 890ff920d43c3cd9c9c27c8b0aba299963b54e71 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 3 Jan 2025 17:44:09 -0800 Subject: [PATCH 09/28] Make query logic more descriptive --- metaflow/client/core.py | 56 ++++++++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 51f06ad3897..e9cb7cefb74 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1135,36 +1135,70 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step): return task raise MetaflowNotFound(f"No task found for the queried step {query_step}") - def _get_filter_query_value( - self, flow_id, run_id, cur_foreach_stack_len, query_steps, query_type + def _get_metadata_query_vals( + self, + flow_id: str, + run_id: str, + cur_foreach_stack_len: int, + steps: List[str], + query_type: str, ): """ - For a given query type, returns the field name and value to be used for filtering tasks - based on the task's metadata. - """ - if len(query_steps) > 1: - # This is a static join, so there is no change in foreach stack length + Returns the field name and field value to be used for querying metadata of successor or ancestor tasks. + + Parameters + ---------- + flow_id : str + Flow ID of the task + run_id : str + Run ID of the task + cur_foreach_stack_len : int + Length of the foreach stack of the current task + steps : List[str] + List of step names whose tasks will be returned. For static joins, and static splits, we can have + ancestors and successors across multiple steps. + query_type : str + Type of query. Can be 'ancestor' or 'successor'. + """ + # For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated + # which help us in querying ancestor and successor tasks. + # `foreach-indices`: contains the indices of the foreach stack at the time of task execution. + # `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but + # truncated by 1 + # For example, a task thats nested 3 levels deep in a foreach stack may have the following values: + # foreach-indices = [0, 1, 2] + # foreach-indices-truncated = [0, 1] + + if len(steps) > 1: + # This is a static join or a static split. There will be no change in foreach stack length query_foreach_stack_len = cur_foreach_stack_len else: + # For linear steps, or foreach splits and joins, ancestor and successor tasks will all belong to + # the same step. query_task = self._get_task_for_queried_step( - flow_id, run_id, query_steps[0] + flow_id, run_id, steps[0] ) query_foreach_stack_len = len( query_task.metadata_dict.get("foreach-stack", []) ) if query_foreach_stack_len == cur_foreach_stack_len: + # The successor or ancestor tasks belong to the same foreach stack level field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) elif query_type == "ancestor": if query_foreach_stack_len > cur_foreach_stack_len: # This is a foreach join + # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # Ancestor Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] # We will compare the foreach-indices-truncated value of ancestor task with the # foreach-indices value of current task field_name = "foreach-indices-truncated" field_value = self.metadata_dict.get("foreach-indices") else: # This is a foreach split + # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] # We will compare the foreach-indices value of ancestor task with the # foreach-indices value of current task field_name = "foreach-indices" @@ -1172,12 +1206,16 @@ def _get_filter_query_value( else: if query_foreach_stack_len > cur_foreach_stack_len: # This is a foreach split + # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # Successor Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] # We will compare the foreach-indices value of current task with the # foreach-indices-truncated value of successor tasks field_name = "foreach-indices-truncated" field_value = self.metadata_dict.get("foreach-indices") else: # This is a foreach join + # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # Successor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] # We will compare the foreach-indices-truncated value of current task with the # foreach-indices value of successor tasks field_name = "foreach-indices" @@ -1195,7 +1233,7 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: if not steps: return {} - field_name, field_value = self._get_filter_query_value( + field_name, field_value = self._get_metadata_query_vals( flow_id, run_id, len(self.metadata_dict.get("foreach-stack", [])), From 7f91bf813908c4172a1c809a95dec2f6e4e43d31 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 6 Jan 2025 18:09:21 -0800 Subject: [PATCH 10/28] Add core tests for ancestor task API --- metaflow/client/core.py | 39 +++++----- .../tests/card_default_editable_with_id.py | 1 - test/core/tests/catch_retry.py | 2 - test/core/tests/client_ancestors.py | 72 +++++++++++++++++++ test/core/tests/tag_catch.py | 2 - 5 files changed, 94 insertions(+), 22 deletions(-) create mode 100644 test/core/tests/client_ancestors.py diff --git a/metaflow/client/core.py b/metaflow/client/core.py index e9cb7cefb74..d11a056cf4e 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1175,9 +1175,7 @@ def _get_metadata_query_vals( else: # For linear steps, or foreach splits and joins, ancestor and successor tasks will all belong to # the same step. - query_task = self._get_task_for_queried_step( - flow_id, run_id, steps[0] - ) + query_task = self._get_task_for_queried_step(flow_id, run_id, steps[0]) query_foreach_stack_len = len( query_task.metadata_dict.get("foreach-stack", []) ) @@ -1242,53 +1240,57 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: ) return { - step: self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step, field_name, field_value - ) + step: [ + f"{flow_id}/{run_id}/{step}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step, field_name, field_value + ) + ] for step in steps } @property def immediate_ancestors(self) -> Dict[str, List[str]]: """ - Returns a dictionary of immediate ancestors task ids of this task for each + Returns a dictionary of immediate ancestors task pathspecs of this task for each previous step. Returns ------- Dict[str, List[str]] Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are the corresponding - task ids of the ancestors. + names of the ancestor steps and the values are the corresponding + task pathspecs of the ancestors. """ return self._get_related_tasks("ancestor") @property def immediate_successors(self) -> Dict[str, List[str]]: """ - Returns a dictionary of immediate successors task ids of this task for each + Returns a dictionary of immediate successors task pathspecs of this task for each previous step. Returns ------- Dict[str, List[str]] Dictionary of immediate successors of this task. The keys are the - names of the successors steps and the values are the corresponding - task ids of the successors. + names of the successor steps and the values are the corresponding + task pathspecs of the successors. """ return self._get_related_tasks("successor") @property def immediate_siblings(self) -> Dict[str, List[str]]: """ - Returns a dictionary of closest siblings of this task for each step. + Returns a dictionary of closest sibling task pathspecs of this task for each + sibling step. Returns ------- Dict[str, List[str]] Dictionary of closest siblings of this task. The keys are the names of the current step and the values are the corresponding - task ids of the siblings. + task pathspecs of the siblings. """ flow_id, run_id, step_name, _ = self.path_components @@ -1306,9 +1308,12 @@ def immediate_siblings(self) -> Dict[str, List[str]]: field_value = self.metadata_dict.get("foreach-indices-truncated") # We find all tasks of the same step that have the same foreach-indices-truncated value return { - step_name: self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, field_name, field_value - ) + step_name: [ + f"{flow_id}/{run_id}/{step_name}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, field_name, field_value + ) + ] } @property diff --git a/test/core/tests/card_default_editable_with_id.py b/test/core/tests/card_default_editable_with_id.py index c46c52248ab..483e156ce36 100644 --- a/test/core/tests/card_default_editable_with_id.py +++ b/test/core/tests/card_default_editable_with_id.py @@ -47,7 +47,6 @@ def check_results(self, flow, checker): cli_check_dict = checker.artifact_dict(step.name, "random_number") for task_pathspec in cli_check_dict: - task_id = task_pathspec.split("/")[-1] cards_info = checker.list_cards(step.name, task_id) number = cli_check_dict[task_pathspec]["random_number"] diff --git a/test/core/tests/catch_retry.py b/test/core/tests/catch_retry.py index 6d9231da460..a13bdf33750 100644 --- a/test/core/tests/catch_retry.py +++ b/test/core/tests/catch_retry.py @@ -60,7 +60,6 @@ def step_all(self): raise TestRetry() def check_results(self, flow, checker): - checker.assert_log( "start", "stdout", "stdout testing logs 3\n", exact_match=False ) @@ -69,7 +68,6 @@ def check_results(self, flow, checker): ) for step in flow: - if step.name == "start": checker.assert_artifact("start", "test_attempt", 3) try: diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py new file mode 100644 index 00000000000..fd0a08af0cc --- /dev/null +++ b/test/core/tests/client_ancestors.py @@ -0,0 +1,72 @@ +from metaflow_test import MetaflowTest, ExpectationFailed, steps + + +class ImmediateAncestorTest(MetaflowTest): + """ + Test that immediate_ancestor API returns correct parent tasks + by comparing with parent task ids stored during execution. + """ + + PRIORITY = 1 + + @steps(0, ["start"]) + def step_start(self): + from metaflow import current + + self.step_name = current.step_name + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + self.parent_pathspecs = set() + + @steps(1, ["join"]) + def step_join(self): + from metaflow import current + + self.step_name = current.step_name + + # Store the parent task ids + # Store the task pathspec for all the parent tasks + self.parent_pathspecs = set(inp.task_pathspec for inp in inputs) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print(f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}") + + @steps(2, ["all"]) + def step_all(self): + from metaflow import current + + self.step_name = current.step_name + # Store the parent task ids + # Task only has one parent, so we store the parent task id + self.parent_pathspecs = set([self.task_pathspec]) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print(f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}") + + + def check_results(self, flow, checker): + from itertools import chain + run = checker.get_run() + + if run is None: + print("Run is None") + # very basic sanity check for CLI checker + for step in flow: + checker.assert_artifact(step.name, "step_name", step.name) + return + + # For each step in the flow + for step in run: + # For each task in the step + for task in step: + ancestors = task.immediate_ancestors + ancestor_pathspecs = set(chain.from_iterable(ancestors.values())) + + # Compare with stored parent_task_pathspecs + task_pathspec = task.data.task_pathspec + assert ( + ancestor_pathspecs == task.data.parent_pathspecs + ), f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, got {ancestor_pathspecs}" diff --git a/test/core/tests/tag_catch.py b/test/core/tests/tag_catch.py index a8efb919560..5cdffe3fcd6 100644 --- a/test/core/tests/tag_catch.py +++ b/test/core/tests/tag_catch.py @@ -62,7 +62,6 @@ def step_all(self): os.kill(os.getpid(), signal.SIGKILL) def check_results(self, flow, checker): - checker.assert_log( "start", "stdout", "stdout testing logs 3\n", exact_match=False ) @@ -71,7 +70,6 @@ def check_results(self, flow, checker): ) for step in flow: - if step.name == "start": checker.assert_artifact("start", "test_attempt", 3) try: From 0cfea97f2b25c27468904488b12e28045f334626 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 7 Jan 2025 01:17:23 -0800 Subject: [PATCH 11/28] Add core test for immediate successor API --- test/core/tests/client_ancestors.py | 5 +- test/core/tests/client_successors.py | 97 ++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 test/core/tests/client_successors.py diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py index fd0a08af0cc..b46f81952fc 100644 --- a/test/core/tests/client_ancestors.py +++ b/test/core/tests/client_ancestors.py @@ -3,7 +3,7 @@ class ImmediateAncestorTest(MetaflowTest): """ - Test that immediate_ancestor API returns correct parent tasks + Test that immediate_ancestors API returns correct parent tasks by comparing with parent task ids stored during execution. """ @@ -69,4 +69,5 @@ def check_results(self, flow, checker): task_pathspec = task.data.task_pathspec assert ( ancestor_pathspecs == task.data.parent_pathspecs - ), f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, got {ancestor_pathspecs}" + ), (f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, " + f"got {ancestor_pathspecs}") diff --git a/test/core/tests/client_successors.py b/test/core/tests/client_successors.py new file mode 100644 index 00000000000..cf308fd010c --- /dev/null +++ b/test/core/tests/client_successors.py @@ -0,0 +1,97 @@ +from metaflow_test import MetaflowTest, ExpectationFailed, steps + + +class ImmediateSuccessorTest(MetaflowTest): + """ + Test that immediate_successors API returns correct successor tasks + by comparing with parent task ids stored during execution. + """ + + PRIORITY = 1 + + @steps(0, ["start"]) + def step_start(self): + from metaflow import current + + self.step_name = current.step_name + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + self.parent_pathspecs = set() + + @steps(1, ["join"]) + def step_join(self): + from metaflow import current + + self.step_name = current.step_name + + # Store the parent task ids + # Store the task pathspec for all the parent tasks + self.parent_pathspecs = set(inp.task_pathspec for inp in inputs) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + @steps(2, ["all"]) + def step_all(self): + from metaflow import current + + self.step_name = current.step_name + # Store the parent task ids + # Task only has one parent, so we store the parent task id + self.parent_pathspecs = set([self.task_pathspec]) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + def check_results(self, flow, checker): + from metaflow import Task + from itertools import chain + + run = checker.get_run() + + if run is None: + print("Run is None") + # very basic sanity check for CLI checker + for step in flow: + checker.assert_artifact(step.name, "step_name", step.name) + return + + # For each step in the flow + for step in run: + # For each task in the step + for task in step: + cur_task_pathspec = task.data.task_pathspec + successors = task.immediate_successors + actual_successors_pathspecs_set = set(chain.from_iterable(successors.values())) + expected_successor_pathspecs_set = set() + for successor_step_name, successor_pathspecs in successors.items(): + # Assert that the current task is in the parent_pathspecs of the successor tasks + for successor_pathspec in successor_pathspecs: + successor_task = Task( + successor_pathspec, _namespace_check=False + ) + print(f"Successor task: {successor_task}") + assert ( + task.data.task_pathspec + in successor_task.data.parent_pathspecs + ), (f"Task {task.data.task_pathspec} is not in the parent_pathspecs of the successor task " + f"{successor_task.data.task_pathspec}") + + successor_step = run[successor_step_name] + for successor_task in successor_step: + if cur_task_pathspec in successor_task.data.parent_pathspecs: + expected_successor_pathspecs_set.add(successor_task.data.task_pathspec) + + # Assert that None of the tasks in the successor steps have the current task in their + # parent_pathspecs + assert ( + actual_successors_pathspecs_set == expected_successor_pathspecs_set + ), (f"Expected successor pathspecs: {expected_successor_pathspecs_set}, got " + f"{actual_successors_pathspecs_set}") From 6bd07784bdf9d8fd97d10578ad2eddc4f94669c8 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 7 Jan 2025 11:46:19 -0800 Subject: [PATCH 12/28] Add endpoint in OSS metadata service --- metaflow/metadata_provider/metadata.py | 10 ++++- metaflow/plugins/metadata_providers/local.py | 13 +++++-- .../plugins/metadata_providers/service.py | 38 +++++++++++++++++++ 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 7409fa371c0..5caaf9a9ec5 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -5,6 +5,7 @@ from collections import namedtuple from itertools import chain +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.tagging_util import validate_tag from metaflow.util import get_username, resolve_identity_as_tuple, is_stringish @@ -674,8 +675,8 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def filter_tasks_by_metadata( - cls, flow_id, run_id, query_step, field_name, field_value - ): + cls, flow_id: str, run_id: str, query_step: str, field_name: str, field_value: str + ) -> List[str]: """ Filter tasks by metadata field and value, and returns the list of task_ids that satisfy the query. @@ -692,6 +693,11 @@ def filter_tasks_by_metadata( Metadata field name to query field_value: str Metadata field value to query + + Returns + ------- + List[str] + List of task_ids that satisfy the query """ raise NotImplementedError() diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 344f5d912c0..58177e045d9 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -6,6 +6,7 @@ import tempfile import time from collections import namedtuple +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.metadata_provider.metadata import ObjectOrder @@ -203,8 +204,14 @@ def _optimistically_mutate(): ) @classmethod - def filter_tasks_by_metadata(cls, flow_id: str, run_id: str, query_step: str, - field_name: str, field_value: str) -> list: + def filter_tasks_by_metadata( + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str + ) -> List[str]: """ Filter tasks by metadata field and value, returning task IDs that match criteria. @@ -223,7 +230,7 @@ def filter_tasks_by_metadata(cls, flow_id: str, run_id: str, query_step: str, Returns ------- - list + List[str] List of task IDs that match the query criteria Raises diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 2e69026deb1..3320f4e020d 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -4,6 +4,7 @@ import requests import time +from typing import List from metaflow.exception import ( MetaflowException, MetaflowTaggingError, @@ -304,6 +305,43 @@ def _new_task( self._register_system_metadata(run_id, step_name, task["task_id"], attempt) return task["task_id"], did_create + @classmethod + def filter_tasks_by_metadata( + cls, flow_id: str, run_id: str, query_step: str, field_name: str, field_value: str + ) -> List[str]: + """ + Filter tasks by metadata field and value, and returns the list of task_ids + that satisfy the query. + + Parameters + ---------- + flow_id : str + Flow id, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + query_step: str + Step name to query tasks from + field_name: str + Metadata field name to query + field_value: str + Metadata field value to query + + Returns + ------- + List[str] + List of task_ids that satisfy the query + """ + query_params = { + "field_name": field_name, + "field_value": field_value, + "query_step": query_step, + } + url = ServiceMetadataProvider._obj_path( + flow_id, run_id, query_step + ) + url = f"{url}/tasks?{urlencode(query_params)}" + return cls._request(cls._monitor, url, "GET") + @staticmethod def _obj_path( flow_name, From b92581b7450931d5bfcc639346203ebeb61272ef Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 10 Jan 2025 12:25:35 -0800 Subject: [PATCH 13/28] Add logs to tests --- test/core/tests/client_ancestors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py index b46f81952fc..5548392e774 100644 --- a/test/core/tests/client_ancestors.py +++ b/test/core/tests/client_ancestors.py @@ -63,6 +63,7 @@ def check_results(self, flow, checker): # For each task in the step for task in step: ancestors = task.immediate_ancestors + print(f"Task is {task.data.task_pathspec} and ancestors are {ancestors}") ancestor_pathspecs = set(chain.from_iterable(ancestors.values())) # Compare with stored parent_task_pathspecs From 6002395a78283a06cb52a3300b6a7d411de165fb Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sat, 11 Jan 2025 19:11:55 -0800 Subject: [PATCH 14/28] Log for each stack to metadata, update query logic --- metaflow/client/core.py | 8 ++++++-- metaflow/metaflow_config.py | 3 +-- metaflow/plugins/metadata_providers/local.py | 1 - metaflow/task.py | 1 + test/core/contexts.json | 2 +- test/core/tests/client_ancestors.py | 1 - 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index d11a056cf4e..fdbc4ad5ed6 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1,5 +1,6 @@ from __future__ import print_function +import time import json import os import tarfile @@ -1177,7 +1178,7 @@ def _get_metadata_query_vals( # the same step. query_task = self._get_task_for_queried_step(flow_id, run_id, steps[0]) query_foreach_stack_len = len( - query_task.metadata_dict.get("foreach-stack", []) + query_task.metadata_dict.get("foreach-indices", []) ) if query_foreach_stack_len == cur_foreach_stack_len: @@ -1221,6 +1222,7 @@ def _get_metadata_query_vals( return field_name, field_value def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: + start_time = time.time() flow_id, run_id, _, _ = self.path_components steps = ( self.metadata_dict.get("previous_steps") @@ -1234,11 +1236,13 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: field_name, field_value = self._get_metadata_query_vals( flow_id, run_id, - len(self.metadata_dict.get("foreach-stack", [])), + len(self.metadata_dict.get("foreach-indices", [])), steps, relation_type, ) + cur_time = time.time() + return { step: [ f"{flow_id}/{run_id}/{step}/{task_id}" diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 8dd0b76e7cd..415a934cbe4 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -248,8 +248,7 @@ # Default container registry DEFAULT_CONTAINER_REGISTRY = from_conf("DEFAULT_CONTAINER_REGISTRY") # Controls whether to include foreach stack information in metadata. -# TODO(Darin, 05/01/24): Remove this flag once we are confident with this feature. -INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", False) +INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", True) # Maximum length of the foreach value string to be stored in each ForeachFrame. MAXIMUM_FOREACH_VALUE_CHARS = from_conf("MAXIMUM_FOREACH_VALUE_CHARS", 30) # The default runtime limit (In seconds) of jobs launched by any compute provider. Default of 5 days. diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 58177e045d9..9cf4d5a5099 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -291,7 +291,6 @@ def _read_metadata_value(file_path: str) -> dict: meta_path = LocalMetadataProvider._get_metadir( flow_id, run_id, query_step, task_id ) - latest_file = _get_latest_metadata_file(meta_path, field_name_prefix) if not latest_file: continue diff --git a/metaflow/task.py b/metaflow/task.py index 39d8c3f705f..81d96eb7809 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -501,6 +501,7 @@ def run_step( current_foreach_path_length += len(foreach_step) foreach_stack_formatted.append(foreach_step) + if foreach_stack_formatted: metadata.append( MetaDatum( diff --git a/test/core/contexts.json b/test/core/contexts.json index 3c40ce3e607..9d1e264ca47 100644 --- a/test/core/contexts.json +++ b/test/core/contexts.json @@ -24,7 +24,7 @@ "--tag=\u523a\u8eab means sashimi", "--tag=multiple tags should be ok" ], - "checks": [ "python3-cli", "python3-metadata"], + "checks": ["python3-cli", "python3-metadata"], "disabled_tests": [ "LargeArtifactTest", "S3FailureTest", diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py index 5548392e774..b46f81952fc 100644 --- a/test/core/tests/client_ancestors.py +++ b/test/core/tests/client_ancestors.py @@ -63,7 +63,6 @@ def check_results(self, flow, checker): # For each task in the step for task in step: ancestors = task.immediate_ancestors - print(f"Task is {task.data.task_pathspec} and ancestors are {ancestors}") ancestor_pathspecs = set(chain.from_iterable(ancestors.values())) # Compare with stored parent_task_pathspecs From d505d8789906740099294416b337cae0ce9d2800 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sat, 11 Jan 2025 19:18:49 -0800 Subject: [PATCH 15/28] Add more comments to code --- metaflow/task.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metaflow/task.py b/metaflow/task.py index 81d96eb7809..6785c2a424a 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -501,7 +501,6 @@ def run_step( current_foreach_path_length += len(foreach_step) foreach_stack_formatted.append(foreach_step) - if foreach_stack_formatted: metadata.append( MetaDatum( @@ -512,7 +511,10 @@ def run_step( ) ) - # Add runtime dag info + # Add runtime dag info - for a nested foreach this may look like: + # foreach_indices: [0, 1] + # foreach_indices_truncated: [0] + # foreach_step_names: ['step1', 'step2'] foreach_indices, foreach_indices_truncated, foreach_step_names = ( self._dynamic_runtime_metadata(foreach_stack) ) From f71e26b5e1f31626ce4b492d5349d942c1a7a4fc Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sat, 11 Jan 2025 19:20:45 -0800 Subject: [PATCH 16/28] Run black formatting --- metaflow/metadata_provider/metadata.py | 7 ++++++- metaflow/plugins/metadata_providers/local.py | 12 ++++++----- .../plugins/metadata_providers/service.py | 11 ++++++---- test/core/tests/client_ancestors.py | 18 ++++++++++------- test/core/tests/client_successors.py | 20 +++++++++++++------ 5 files changed, 45 insertions(+), 23 deletions(-) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 5caaf9a9ec5..5f66433bf3d 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -675,7 +675,12 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def filter_tasks_by_metadata( - cls, flow_id: str, run_id: str, query_step: str, field_name: str, field_value: str + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str, ) -> List[str]: """ Filter tasks by metadata field and value, and returns the list of task_ids diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 9cf4d5a5099..88d0810ef5f 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -210,7 +210,7 @@ def filter_tasks_by_metadata( run_id: str, query_step: str, field_name: str, - field_value: str + field_value: str, ) -> List[str]: """ Filter tasks by metadata field and value, returning task IDs that match criteria. @@ -252,8 +252,8 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: for file_path in json_files: filename = os.path.basename(file_path) - name, timestamp = filename.rsplit('_', 1) - timestamp = timestamp.split('.')[0] + name, timestamp = filename.rsplit("_", 1) + timestamp = timestamp.split(".")[0] if name == field_prefix: matching_files.append((file_path, int(timestamp))) @@ -266,10 +266,12 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: def _read_metadata_value(file_path: str) -> dict: """Read and parse metadata from JSON file.""" try: - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return json.load(f) except json.JSONDecodeError: - raise json.JSONDecodeError("Failed to decode metadata JSON file - may be corrupted or empty") + raise json.JSONDecodeError( + "Failed to decode metadata JSON file - may be corrupted or empty" + ) except Exception as e: raise Exception(f"Error reading metadata file: {str(e)}") diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 3320f4e020d..042839bf5eb 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -307,7 +307,12 @@ def _new_task( @classmethod def filter_tasks_by_metadata( - cls, flow_id: str, run_id: str, query_step: str, field_name: str, field_value: str + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str, ) -> List[str]: """ Filter tasks by metadata field and value, and returns the list of task_ids @@ -336,9 +341,7 @@ def filter_tasks_by_metadata( "field_value": field_value, "query_step": query_step, } - url = ServiceMetadataProvider._obj_path( - flow_id, run_id, query_step - ) + url = ServiceMetadataProvider._obj_path(flow_id, run_id, query_step) url = f"{url}/tasks?{urlencode(query_params)}" return cls._request(cls._monitor, url, "GET") diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py index b46f81952fc..0f875185c72 100644 --- a/test/core/tests/client_ancestors.py +++ b/test/core/tests/client_ancestors.py @@ -30,7 +30,9 @@ def step_join(self): # Set the current task id self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" - print(f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}") + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) @steps(2, ["all"]) def step_all(self): @@ -44,11 +46,13 @@ def step_all(self): # Set the current task id self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" - print(f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}") - + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) def check_results(self, flow, checker): from itertools import chain + run = checker.get_run() if run is None: @@ -67,7 +71,7 @@ def check_results(self, flow, checker): # Compare with stored parent_task_pathspecs task_pathspec = task.data.task_pathspec - assert ( - ancestor_pathspecs == task.data.parent_pathspecs - ), (f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, " - f"got {ancestor_pathspecs}") + assert ancestor_pathspecs == task.data.parent_pathspecs, ( + f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, " + f"got {ancestor_pathspecs}" + ) diff --git a/test/core/tests/client_successors.py b/test/core/tests/client_successors.py index cf308fd010c..c349192c2ee 100644 --- a/test/core/tests/client_successors.py +++ b/test/core/tests/client_successors.py @@ -69,7 +69,9 @@ def check_results(self, flow, checker): for task in step: cur_task_pathspec = task.data.task_pathspec successors = task.immediate_successors - actual_successors_pathspecs_set = set(chain.from_iterable(successors.values())) + actual_successors_pathspecs_set = set( + chain.from_iterable(successors.values()) + ) expected_successor_pathspecs_set = set() for successor_step_name, successor_pathspecs in successors.items(): # Assert that the current task is in the parent_pathspecs of the successor tasks @@ -81,17 +83,23 @@ def check_results(self, flow, checker): assert ( task.data.task_pathspec in successor_task.data.parent_pathspecs - ), (f"Task {task.data.task_pathspec} is not in the parent_pathspecs of the successor task " - f"{successor_task.data.task_pathspec}") + ), ( + f"Task {task.data.task_pathspec} is not in the parent_pathspecs of the successor task " + f"{successor_task.data.task_pathspec}" + ) successor_step = run[successor_step_name] for successor_task in successor_step: if cur_task_pathspec in successor_task.data.parent_pathspecs: - expected_successor_pathspecs_set.add(successor_task.data.task_pathspec) + expected_successor_pathspecs_set.add( + successor_task.data.task_pathspec + ) # Assert that None of the tasks in the successor steps have the current task in their # parent_pathspecs assert ( actual_successors_pathspecs_set == expected_successor_pathspecs_set - ), (f"Expected successor pathspecs: {expected_successor_pathspecs_set}, got " - f"{actual_successors_pathspecs_set}") + ), ( + f"Expected successor pathspecs: {expected_successor_pathspecs_set}, got " + f"{actual_successors_pathspecs_set}" + ) From 536278d4887d6509d386b91d716e82d90723cedc Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 10:25:59 -0800 Subject: [PATCH 17/28] Set monitor to None in filter tasks API --- metaflow/plugins/metadata_providers/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 042839bf5eb..02ef8626b1c 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -343,7 +343,7 @@ def filter_tasks_by_metadata( } url = ServiceMetadataProvider._obj_path(flow_id, run_id, query_step) url = f"{url}/tasks?{urlencode(query_params)}" - return cls._request(cls._monitor, url, "GET") + return cls._request(None, url, "GET") @staticmethod def _obj_path( From 25599994d325b72bbae3658fad0b1e8a45797bfc Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 10:28:07 -0800 Subject: [PATCH 18/28] import urlencode --- metaflow/plugins/metadata_providers/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 02ef8626b1c..2acf95a5955 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -18,7 +18,7 @@ from metaflow.metadata_provider import MetadataProvider from metaflow.metadata_provider.heartbeat import HB_URL_KEY from metaflow.sidecar import Message, MessageTypes, Sidecar - +from urllib.parse import urlencode from metaflow.util import version_parse From c172ef225dc9bc52de262c4fb4ffa8b152ce97d1 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 12:29:53 -0800 Subject: [PATCH 19/28] Address comments --- metaflow/client/core.py | 47 +++++++++----------- metaflow/plugins/metadata_providers/local.py | 4 +- metaflow/task.py | 8 ++-- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index fdbc4ad5ed6..aba68a8f0df 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1119,6 +1119,8 @@ class Task(MetaflowObject): def __init__(self, *args, **kwargs): super(Task, self).__init__(*args, **kwargs) + # We want to cache metadata dictionary since it's used in many places + self._metadata_dict = None def _iter_filter(self, x): # exclude private data artifacts @@ -1129,12 +1131,7 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step): Returns a Task object corresponding to the queried step. If the queried step has several tasks, the first task is returned. """ - # Find any task corresponding to the queried step - step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No task found for the queried step {query_step}") + return Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False).task def _get_metadata_query_vals( self, @@ -1142,7 +1139,7 @@ def _get_metadata_query_vals( run_id: str, cur_foreach_stack_len: int, steps: List[str], - query_type: str, + is_ancestor: bool, ): """ Returns the field name and field value to be used for querying metadata of successor or ancestor tasks. @@ -1158,15 +1155,15 @@ def _get_metadata_query_vals( steps : List[str] List of step names whose tasks will be returned. For static joins, and static splits, we can have ancestors and successors across multiple steps. - query_type : str - Type of query. Can be 'ancestor' or 'successor'. + is_ancestor : bool + If we are querying for ancestor tasks, set this to True. """ # For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated # which help us in querying ancestor and successor tasks. # `foreach-indices`: contains the indices of the foreach stack at the time of task execution. # `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but # truncated by 1 - # For example, a task thats nested 3 levels deep in a foreach stack may have the following values: + # For example, a task that's nested 3 levels deep in a foreach stack may have the following values: # foreach-indices = [0, 1, 2] # foreach-indices-truncated = [0, 1] @@ -1185,7 +1182,7 @@ def _get_metadata_query_vals( # The successor or ancestor tasks belong to the same foreach stack level field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) - elif query_type == "ancestor": + elif is_ancestor: if query_foreach_stack_len > cur_foreach_stack_len: # This is a foreach join # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] @@ -1199,7 +1196,7 @@ def _get_metadata_query_vals( # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] # Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] # We will compare the foreach-indices value of ancestor task with the - # foreach-indices value of current task + # foreach-indices-truncated value of current task field_name = "foreach-indices" field_value = self.metadata_dict.get("foreach-indices-truncated") else: @@ -1221,13 +1218,12 @@ def _get_metadata_query_vals( field_value = self.metadata_dict.get("foreach-indices-truncated") return field_name, field_value - def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: - start_time = time.time() + def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]: flow_id, run_id, _, _ = self.path_components steps = ( - self.metadata_dict.get("previous_steps") - if relation_type == "ancestor" - else self.metadata_dict.get("successor_steps") + self.metadata_dict.get("previous-steps") + if is_ancestor + else self.metadata_dict.get("successor-steps") ) if not steps: @@ -1238,11 +1234,9 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: run_id, len(self.metadata_dict.get("foreach-indices", [])), steps, - relation_type, + is_ancestor=is_ancestor, ) - cur_time = time.time() - return { step: [ f"{flow_id}/{run_id}/{step}/{task_id}" @@ -1266,7 +1260,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]: names of the ancestor steps and the values are the corresponding task pathspecs of the ancestors. """ - return self._get_related_tasks("ancestor") + return self._get_related_tasks(is_ancestor=True) @property def immediate_successors(self) -> Dict[str, List[str]]: @@ -1281,7 +1275,7 @@ def immediate_successors(self) -> Dict[str, List[str]]: names of the successor steps and the values are the corresponding task pathspecs of the successors. """ - return self._get_related_tasks("successor") + return self._get_related_tasks(is_ancestor=False) @property def immediate_siblings(self) -> Dict[str, List[str]]: @@ -1408,9 +1402,12 @@ def metadata_dict(self) -> Dict[str, str]: Dictionary mapping metadata name with value """ # use the newest version of each key, hence sorting - return { - m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at) - } + if self._metadata_dict is None: + self._metadata_dict = { + m.name: m.value + for m in sorted(self.metadata, key=lambda m: m.created_at) + } + return self._metadata_dict @property def index(self) -> Optional[int]: diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 88d0810ef5f..995e589daf6 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -247,7 +247,7 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: # and the artifact files are saved as: _artifact__.json # We loop over all the JSON files in the directory and find the latest one # that matches the field prefix. - json_files = glob.glob(os.path.join(path, "*.json")) + json_files = glob.glob(os.path.join(path, f"{field_prefix}*.json")) matching_files = [] for file_path in json_files: @@ -287,8 +287,6 @@ def _read_metadata_value(file_path: str) -> dict: # Filter tasks based on metadata for task in tasks: task_id = task.get("task_id") - if not task_id: - continue meta_path = LocalMetadataProvider._get_metadir( flow_id, run_id, query_step, task_id diff --git a/metaflow/task.py b/metaflow/task.py index 6785c2a424a..3e9b553d03b 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -781,15 +781,15 @@ def run_step( tags=metadata_tags, ), MetaDatum( - field="previous_steps", + field="previous-steps", value=previous_steps, - type="previous_steps", + type="previous-steps", tags=metadata_tags, ), MetaDatum( - field="successor_steps", + field="successor-steps", value=successor_steps, - type="successor_steps", + type="successor-steps", tags=metadata_tags, ), ], From c825d66c916934db8b3e09bbda6758702aa35585 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 15:31:20 -0800 Subject: [PATCH 20/28] Update logic for siblings, make it work for static splits as well --- metaflow/client/core.py | 65 ++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index aba68a8f0df..e5273db9e33 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1278,41 +1278,58 @@ def immediate_successors(self) -> Dict[str, List[str]]: return self._get_related_tasks(is_ancestor=False) @property - def immediate_siblings(self) -> Dict[str, List[str]]: + def siblings(self) -> Dict[str, List[str]]: """ - Returns a dictionary of closest sibling task pathspecs of this task for each - sibling step. + Returns a dictionary of sibling task pathspecs of this task. Siblings of a task have + the same common parent task. Returns ------- Dict[str, List[str]] - Dictionary of closest siblings of this task. The keys are the + Dictionary of siblings task pathspecs of this task. The keys are the names of the current step and the values are the corresponding task pathspecs of the siblings. """ flow_id, run_id, step_name, _ = self.path_components - foreach_stack = self.metadata_dict.get("foreach-stack", []) - foreach_step_names = self.metadata_dict.get("foreach-step-names", []) - if len(foreach_stack) == 0: - raise MetaflowInternalError("Task is not part of any foreach split") - if step_name != foreach_step_names[-1]: - raise MetaflowInternalError( - f"Step {step_name} does not have any direct siblings since it is not part " - f"of a new foreach split." - ) + ancestor_steps = self.metadata_dict.get("previous-steps") + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-indices", [])) + if len(ancestor_steps) > 1 or step_name in ("start", "end"): + # This is a static join, or a start/end step. The current task will have no siblings. + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices-truncated") - # We find all tasks of the same step that have the same foreach-indices-truncated value - return { - step_name: [ - f"{flow_id}/{run_id}/{step_name}/{task_id}" - for task_id in self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, field_name, field_value - ) - ] - } + # This can be a linear step, a foreach split, a foreach join, or a static split. + query_task = self._get_task_for_queried_step(flow_id, run_id, ancestor_steps[0]) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-indices", []) + ) + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join, there will be no siblings + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } + elif query_foreach_stack_len < cur_foreach_stack_len: + # This is a foreach split, there will be multiple siblings + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices-truncated") + # We find all tasks of the same step that have the same foreach-indices-truncated value + return { + step_name: [ + f"{flow_id}/{run_id}/{step_name}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, field_name, field_value + ) + ] + } + + # Logic for static splits, and linear steps + # To find siblings, we first find the single ancestor task of the current task. + # And then we find all the successor tasks of this ancestor task. + ancestor_task_pathspecs = self.immediate_ancestors.get(ancestor_steps[0]) + ancestor_task = Task(ancestor_task_pathspecs[0], _namespace_check=False) + return ancestor_task.immediate_successors @property def metadata(self) -> List[Metadata]: From 81aca618cbc6336584f5cea985962059dc705f31 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 13 Jan 2025 14:33:08 -0800 Subject: [PATCH 21/28] Initial commit for spin steps --- metaflow/cli.py | 50 ++++- metaflow/cli_components/run_cmds.py | 97 +++++++-- metaflow/cli_components/step_cmd.py | 75 +++++++ metaflow/metaflow_config.py | 8 + metaflow/plugins/pypi/conda_decorator.py | 1 + metaflow/runtime.py | 251 ++++++++++++++++++++++- metaflow/util.py | 44 ++++ 7 files changed, 495 insertions(+), 31 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 3a8dc4ecaa9..cdce10deca2 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -134,6 +134,8 @@ def config_merge_cb(ctx, param, value): "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-internal": "metaflow.cli_components.step_cmd.spin_internal", }, ) def cli(ctx): @@ -384,7 +386,6 @@ def start( # second one processed will return the actual options. The order of processing # depends on what (and in what order) the user specifies on the command line. config_options = config_file or config_value - if ( hasattr(ctx, "saved_args") and ctx.saved_args @@ -462,14 +463,10 @@ def start( ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor @@ -485,6 +482,47 @@ def start( ctx.obj.config_options = config_options + # Override values for spin + if hasattr(ctx, "saved_args") and ctx.saved_args and ctx.saved_args[0] == "spin": + # For spin, we will only use the local metadata provider, datastore, environment + # and null event logger and monitor + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( + ctx.obj.echo + ) + ctx.obj.datastore_impl.datastore_root = datastore_root + + FlowDataStore.default_storage_impl = ctx.obj.datastore_impl + ctx.obj.flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, + ctx.obj.metadata, + ctx.obj.event_logger, + ctx.obj.monitor, + ) + echo( + "Using local metadata provider, datastore, environment, and null event logger and monitor for spin." + ) + print(f"Using metadata provider: {ctx.obj.metadata}") + echo(f"Using Datastore root: {datastore_root}") + echo(f"Using Flow Datastore: {ctx.obj.flow_datastore}") + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) + decorators._init(ctx.obj.flow) # It is important to initialize flow decorators early as some of the @@ -528,7 +566,7 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index bf77d16ad1f..fc873452d47 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -9,11 +9,11 @@ from ..graph import FlowGraph from ..metaflow_current import current from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, get_latest_task_pathspec def before_run(obj, tags, decospecs): @@ -70,6 +70,28 @@ def write_file(file_path, content): f.write(str(content)) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def common_run_options(func): @click.option( "--tag", @@ -110,20 +132,6 @@ def common_run_options(func): "option multiple times to attach multiple decorators " "in steps.", ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -167,6 +175,7 @@ def wrapper(*args, **kwargs): @click.argument("step-to-rerun", required=False) @click.command(help="Resume execution of a previous run of this flow.") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -285,6 +294,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli_entrypoint("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -360,3 +370,58 @@ def run( f, ) runtime.execute() + + +@click.command(help="Spins up a step locally") +@click.argument( + "step-name", + required=True, + type=str, +) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task ID to use when spinning up the step. The spinned up step will use the artifacts" + "corresponding to this task ID. If not provided, an arbitrary task ID from the latest run will be used.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + step_name, + task_pathspec=None, + run_id_file=None, + runner_attribute_file=None, + **kwargs +): + before_run(obj, [], []) + if task_pathspec is None: + task_pathspec = get_latest_task_pathspec(obj.flow.name, step_name) + + obj.echo( + f"Spinning up step *{step_name}* locally with task pathspec *{task_pathspec}*" + ) + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name) + + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + task_pathspec, + ) + + # write_latest_run_id(obj, runtime.run_id) + # write_file(run_id_file, runtime.run_id) + + spin_runtime.execute() + pass diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 4b40c9e5e54..376937a4298 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -174,3 +174,78 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.pass_context +def spin_internal( + ctx, + step_name, + run_id=None, + task_id=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + namespace=None, +): + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + print("I am here 1") + print("I am here 2") + # echo("Spinning a task, *%s*" % step_name, fg="magenta", bold=False) + + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, # local datastore + ctx.obj.metadata, # local metadata provider + ctx.obj.environment, # local environment + ctx.obj.echo, + ctx.obj.event_logger, # null logger + ctx.obj.monitor, # null monitor + None, # no unbounded foreach context + ) + # echo("Task is: ", task) + # pass diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 415a934cbe4..56be74c9592 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,14 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + + ### # User configuration ### diff --git a/metaflow/plugins/pypi/conda_decorator.py b/metaflow/plugins/pypi/conda_decorator.py index b1b7ee833d9..f43a68425dc 100644 --- a/metaflow/plugins/pypi/conda_decorator.py +++ b/metaflow/plugins/pypi/conda_decorator.py @@ -287,6 +287,7 @@ def task_pre_step( def runtime_step_cli( self, cli_args, retry_count, max_user_code_retries, ubf_context ): + print("Let's go - I am here") if self.disabled: return # Ensure local installation of Metaflow is visible to user code diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 7e9269841fb..c446266be32 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -73,6 +73,220 @@ # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + task_pathspec, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + self._args = args or [] + self._kwargs = kwargs or {} + + self._step_func = step_func + self._task_pathspec = task_pathspec + self._prev_task = Task(self._task_pathspec, _namespace_check=False) + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._max_log_size = max_log_size + self._run_queue = [] + self._poll = procpoll.make_poll() + self._workers = {} # fd -> subprocess mapping + self._finished = {} + + # Create a new run_id for the spin task + self._run_id = self._metadata.new_run_id() + print( + f"New run_id for spin task: {self._run_id} and step func: {self._step_func.name}" + ) + + for deco in self._step_func.decorators: + deco.runtime_init(flow, graph, package, self._run_id) + + print(f"Input paths: {self.input_paths}") + + @property + def split_index(self): + if self._split_index: + return self._split_index + foreach_indices = self._prev_task.metadata_dict.get("foreach-indices", []) + self._split_index = foreach_indices[-1] if foreach_indices else None + return self._split_index + + @property + def input_paths(self): + def _format_input_paths(task_id): + _, run_id, step_name, task_id = task_id.split("/") + return f"{run_id}/{step_name}/{task_id}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._task_pathspec.split("/") + + step = Step(f"{flow_name}/{run_id}/_parameters", _namespace_check=False) + task = next(iter(step.tasks()), None) + if not task: + raise MetaflowException( + f"Task not found for {step} in the metadata store" + ) + self._input_paths = [f"{run_id}/_parameters/{task.id}"] + else: + ancestors = self._prev_task.immediate_ancestors + self._input_paths = [ + _format_input_paths(ancestor) + for i, ancestor in enumerate(chain.from_iterable(ancestors.values())) + ] + return self._input_paths + + @property + def whitelist_decorators(self): + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + self._flow_datastore, + self._flow, + step, + self._run_id, + self._metadata, + self._environment, + self._entrypoint, + self._event_logger, + self._monitor, + input_paths=self.input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + # Configurations are passed through a file to avoid overloading the + # command-line. We only need to create this file once and it can be reused + # for any task launch + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + + task = self._new_task(self._step_func.name, {}) + _ds = self._flow_datastore.get_task_datastore( + self._run_id, self._step_func.name, task.task_id, attempt=0, mode="w" + ) + + for deco in self.whitelist_decorators: + deco.runtime_task_created( + _ds, + task.task_id, + self.split_index, + self.input_paths, + is_cloned=False, + ubf_context=None, + ) + + # Start a new worker to spin a step + worker = Worker(task, self._max_log_size, self._config_file_name, spin=True) + for fd in worker.fds(): + self._workers[fd] = worker + self._poll.add(fd) + + finished_tasks = list(self._poll_workers()) + try: + pass + except KeyboardInterrupt as ex: + self._logger("Workflow interrupted.", system_msg=True, bad=True) + self._killall() + exception = ex + raise + + except Exception as ex: + self._logger("Workflow failed.", system_msg=True, bad=True) + self._killall() + exception = ex + raise + finally: + # on finish clean tasks + for step in self._flow: + for deco in step.decorators: + deco.runtime_finished(exception) + + def _launch_spin(self): + args = CLIArgs(self.task, self.spin) + env = dict(os.environ) + + for deco in self.task.decos: + deco.runtime_step_cli( + args, + self.task.retries, + self.task.user_code_retries, + self.task.ubf_context, + ) + + # Add user configurations using a file to avoid using up too much space on the + # command line + if self._config_file_name: + args.top_level_options["local-config-file"] = self._config_file_name + + print(f"Args Entrypoint updated is {args.entrypoint}") + env.update(args.get_env()) + env["PYTHONUNBUFFERED"] = "x" + cmdline = args.get_args() + print(f"Command line is: {cmdline}") + + process = subprocess.Popen( + cmdline, + env=env, + bufsize=1, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + # Read and print subprocess output + stdout, stderr = process.communicate() + print(f"stdout: {stdout.decode()}") + print(f"stderr: {stderr.decode()}") + + class NativeRuntime(object): def __init__( self, @@ -1508,8 +1722,9 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__(self, task, spin=False): self.task = task + self.spin = spin self.entrypoint = list(task.entrypoint) self.top_level_options = { "quiet": True, @@ -1542,18 +1757,36 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, + "ubf-context": self.task.ubf_context, + } + self.env = {} + + def spin_args(self): + self.commands = ["spin-internal"] + self.command_args = [self.task.step] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": self.task.input_paths, + "split-index": self.task.split_index, } self.env = {} diff --git a/metaflow/util.py b/metaflow/util.py index cd3447d0e48..e9355e31ae1 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,6 +9,7 @@ from itertools import takewhile import re +from typing import Callable from metaflow.exception import MetaflowUnknownUser, MetaflowInternalError try: @@ -193,6 +194,49 @@ def get_latest_run_id(echo, flow_name): return None +def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: + """ + Returns a task pathspec from the latest run of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + + Returns + ------- + str + The task pathspec of the first task of the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + run = Flow(flow_name, _namespace_check=False).latest_run + + if run is None: + raise MetaflowNotFound(f"No run found for the flow {flow_name}") + + try: + step = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False) + except Exception: + raise MetaflowNotFound( + f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" + ) + + task = next(iter(step.tasks()), None) + if task: + return f"{flow_name}/{run.id}/{step_name}/{task.id}" + raise MetaflowNotFound(f"No task found for the queried step {query_step}") + + def write_latest_run_id(obj, run_id): from metaflow.plugins.datastores.local_storage import LocalStorage From 5cb5292c1992db180183d2b4d1c5f70ee84b2c74 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 22:27:55 -0800 Subject: [PATCH 22/28] Make it work with subprocess --- metaflow/runtime.py | 124 +++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 66 deletions(-) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index c446266be32..6f32aba499c 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -19,12 +19,15 @@ from functools import partial from concurrent import futures + from metaflow.datastore.exceptions import DataException +from itertools import chain from contextlib import contextmanager from . import get_namespace from .metadata_provider import MetaDatum from .metaflow_config import MAX_ATTEMPTS, UI_URL +from .metaflow_config import SPIN_ALLOWED_DECORATORS from .exception import ( MetaflowException, MetaflowInternalError, @@ -102,8 +105,6 @@ def __init__( self._entrypoint = entrypoint self._event_logger = event_logger self._monitor = monitor - self._args = args or [] - self._kwargs = kwargs or {} self._step_func = step_func self._task_pathspec = task_pathspec @@ -113,10 +114,6 @@ def __init__( self._whitelist_decorators = None self._config_file_name = None self._max_log_size = max_log_size - self._run_queue = [] - self._poll = procpoll.make_poll() - self._workers = {} # fd -> subprocess mapping - self._finished = {} # Create a new run_id for the spin task self._run_id = self._metadata.new_run_id() @@ -124,8 +121,18 @@ def __init__( f"New run_id for spin task: {self._run_id} and step func: {self._step_func.name}" ) + print( + f"Decorators for {self._step_func.name}: {list(self._step_func.decorators)}" + ) + for deco in self._step_func.decorators: + print( + f"Running runtime_init for {deco.__class__.__name__} at {self._step_func.name}" + ) + print("-" * 100) deco.runtime_init(flow, graph, package, self._run_id) + if hasattr(deco, "_metaflow_home"): + print(f"Metaflow home is {deco._metaflow_home}") print(f"Input paths: {self.input_paths}") @@ -208,83 +215,68 @@ def execute(self): else: self._config_file_name = None - task = self._new_task(self._step_func.name, {}) + self.task = self._new_task(self._step_func.name, {}) _ds = self._flow_datastore.get_task_datastore( - self._run_id, self._step_func.name, task.task_id, attempt=0, mode="w" + self._run_id, self._step_func.name, self.task.task_id, attempt=0, mode="w" ) for deco in self.whitelist_decorators: deco.runtime_task_created( _ds, - task.task_id, + self.task.task_id, self.split_index, self.input_paths, is_cloned=False, ubf_context=None, ) - # Start a new worker to spin a step - worker = Worker(task, self._max_log_size, self._config_file_name, spin=True) - for fd in worker.fds(): - self._workers[fd] = worker - self._poll.add(fd) + self.launch_spin() - finished_tasks = list(self._poll_workers()) - try: - pass - except KeyboardInterrupt as ex: - self._logger("Workflow interrupted.", system_msg=True, bad=True) - self._killall() - exception = ex - raise - - except Exception as ex: - self._logger("Workflow failed.", system_msg=True, bad=True) - self._killall() - exception = ex - raise - finally: - # on finish clean tasks - for step in self._flow: - for deco in step.decorators: - deco.runtime_finished(exception) - - def _launch_spin(self): - args = CLIArgs(self.task, self.spin) - env = dict(os.environ) + # Start a new worker to spin a step + # on finish clean tasks + exception = None + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) - for deco in self.task.decos: - deco.runtime_step_cli( - args, - self.task.retries, - self.task.user_code_retries, - self.task.ubf_context, - ) + def launch_spin(self): + args = CLIArgs(self.task, spin=True) + env = dict(os.environ) - # Add user configurations using a file to avoid using up too much space on the - # command line - if self._config_file_name: - args.top_level_options["local-config-file"] = self._config_file_name - - print(f"Args Entrypoint updated is {args.entrypoint}") - env.update(args.get_env()) - env["PYTHONUNBUFFERED"] = "x" - cmdline = args.get_args() - print(f"Command line is: {cmdline}") - - process = subprocess.Popen( - cmdline, - env=env, - bufsize=1, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, + for deco in self.task.decos: + deco.runtime_step_cli( + args, + self.task.retries, + self.task.user_code_retries, + self.task.ubf_context, ) - # Read and print subprocess output - stdout, stderr = process.communicate() - print(f"stdout: {stdout.decode()}") - print(f"stderr: {stderr.decode()}") + # Add user configurations using a file to avoid using up too much space on the + # command line + if self._config_file_name: + args.top_level_options["local-config-file"] = self._config_file_name + + print(f"Args Entrypoint updated is {args.entrypoint}") + env.update(args.get_env()) + env["PYTHONUNBUFFERED"] = "x" + cmdline = args.get_args() + print(f"Command line is: {cmdline}") + + process = subprocess.Popen( + cmdline, + env=env, + bufsize=1, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + # Read and print subprocess output + stdout, stderr = process.communicate() + print("STDOUT:\n") + print(f"{stdout.decode()}") + print("-" * 100) + print("STDERR:\n") + print(f"stderr: {stderr.decode()}") class NativeRuntime(object): From 23d7f1296cedd1e3ab39ef0d3f3938f264b8196c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 23:33:51 -0800 Subject: [PATCH 23/28] Make it work with subprocess and truncated buffers --- metaflow/cli_components/step_cmd.py | 5 ++ metaflow/runtime.py | 129 ++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 38 deletions(-) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 376937a4298..035922452ff 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -229,6 +229,8 @@ def spin_internal( max_user_code_retries=None, namespace=None, ): + import sys + if ctx.obj.is_quiet: echo = echo_dev_null else: @@ -248,4 +250,7 @@ def spin_internal( None, # no unbounded foreach context ) # echo("Task is: ", task) + print("Task is: ", task) + print("I am here 3") + print("sys.executable: ", sys.executable) # pass diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 6f32aba499c..f04703e64d7 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -114,6 +114,7 @@ def __init__( self._whitelist_decorators = None self._config_file_name = None self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" # Create a new run_id for the spin task self._run_id = self._metadata.new_run_id() @@ -203,10 +204,8 @@ def _new_task(self, step, input_paths=None, **kwargs): ) def execute(self): + exception = None with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: - # Configurations are passed through a file to avoid overloading the - # command-line. We only need to create this file once and it can be reused - # for any task launch config_value = dump_config_values(self._flow) if config_value: json.dump(config_value, config_file) @@ -215,30 +214,36 @@ def execute(self): else: self._config_file_name = None - self.task = self._new_task(self._step_func.name, {}) - _ds = self._flow_datastore.get_task_datastore( - self._run_id, self._step_func.name, self.task.task_id, attempt=0, mode="w" - ) - - for deco in self.whitelist_decorators: - deco.runtime_task_created( - _ds, + self.task = self._new_task(self._step_func.name, {}) + _ds = self._flow_datastore.get_task_datastore( + self._run_id, + self._step_func.name, self.task.task_id, - self.split_index, - self.input_paths, - is_cloned=False, - ubf_context=None, + attempt=0, + mode="w", ) - self.launch_spin() + for deco in self.whitelist_decorators: + deco.runtime_task_created( + _ds, + self.task.task_id, + self.split_index, + self.input_paths, + is_cloned=False, + ubf_context=None, + ) - # Start a new worker to spin a step - # on finish clean tasks - exception = None - for deco in self.whitelist_decorators: - deco.runtime_finished(exception) + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) - def launch_spin(self): + def _launch_and_monitor_task(self): args = CLIArgs(self.task, spin=True) env = dict(os.environ) @@ -255,28 +260,76 @@ def launch_spin(self): if self._config_file_name: args.top_level_options["local-config-file"] = self._config_file_name - print(f"Args Entrypoint updated is {args.entrypoint}") env.update(args.get_env()) env["PYTHONUNBUFFERED"] = "x" + + stdout_buffer = TruncatedBuffer("stdout", self._max_log_size) + stderr_buffer = TruncatedBuffer("stderr", self._max_log_size) + cmdline = args.get_args() - print(f"Command line is: {cmdline}") + self._logger(f"Launching command: {' '.join(cmdline)}", system_msg=True) - process = subprocess.Popen( - cmdline, - env=env, - bufsize=1, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, + try: + process = subprocess.Popen( + cmdline, + env=env, + bufsize=1, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + except Exception as e: + raise TaskFailed(self.task, f"Failed to launch task: {str(e)}") + + while True: + stdout_line = process.stdout.readline() + if stdout_line: + self._process_output(stdout_line, stdout_buffer) + + stderr_line = process.stderr.readline() + if stderr_line: + self._process_output(stderr_line, stderr_buffer, is_stderr=True) + + if process.poll() is not None: + break + + # Process any remaining output + for line in process.stdout: + self._process_output(line, stdout_buffer) + for line in process.stderr: + self._process_output(line, stderr_buffer, is_stderr=True) + + returncode = process.wait() + + self.task.save_metadata( + "runtime", + { + "return_code": returncode, + "success": returncode == 0, + }, + ) + + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + self.task.save_logs( + { + "stdout": stdout_buffer.get_buffer(), + "stderr": stderr_buffer.get_buffer(), + } ) - # Read and print subprocess output - stdout, stderr = process.communicate() - print("STDOUT:\n") - print(f"{stdout.decode()}") - print("-" * 100) - print("STDERR:\n") - print(f"stderr: {stderr.decode()}") + def _process_output(self, line, buffer, is_stderr=False): + buffer.write(line.encode(self._encoding)) + text = line.strip() + self.task.log( + text, + system_msg=False, + timestamp=datetime.now(), + ) class NativeRuntime(object): From f4df5db66a4757802a28bd1ddd76756839038d52 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 15 Jan 2025 14:03:11 -0800 Subject: [PATCH 24/28] dummy commit --- metaflow/datastore/spin_datastore/__init__.py | 0 .../spin_datastore/inputs_datastore.py | 88 +++++++++++++++++ .../spin_datastore/step_datastore.py | 95 +++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 metaflow/datastore/spin_datastore/__init__.py create mode 100644 metaflow/datastore/spin_datastore/inputs_datastore.py create mode 100644 metaflow/datastore/spin_datastore/step_datastore.py diff --git a/metaflow/datastore/spin_datastore/__init__.py b/metaflow/datastore/spin_datastore/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/datastore/spin_datastore/inputs_datastore.py b/metaflow/datastore/spin_datastore/inputs_datastore.py new file mode 100644 index 00000000000..898048869a7 --- /dev/null +++ b/metaflow/datastore/spin_datastore/inputs_datastore.py @@ -0,0 +1,88 @@ +from . import SpinDatastore + + +class SpinInput(object): + def __init__(self, artifacts, task): + self.artifacts = artifacts + self.task = task + + def __getattr__(self, name): + # We always look for any artifacts provided by the user first + if self.artifacts is not None and name in self.artifacts: + return self.artifacts[name] + + try: + return getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + +class StaticSpinInputsDatastore(SpinDatastore): + def __init__(self, spin_parser_validator): + super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator) + self._previous_tasks = {} + + def __getattr__(self, name): + if name not in self.previous_steps: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + input_step = SpinInput( + self.spin_parser_validator.artifacts["join"][name], + self.get_previous_tasks[name], + ) + setattr(self, name, input_step) + return input_step + + def __iter__(self): + for prev_step_name in self.previous_steps: + yield self[prev_step_name] + + def __len__(self): + return len(self.get_previous_tasks) + + @property + def get_previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + for prev_step_name in self.previous_steps: + previous_task = self.get_all_previous_tasks(prev_step_name) + self._previous_tasks[prev_step_name] = previous_task + return self._previous_tasks + + +class SpinInputsDatastore(SpinDatastore): + def __init__(self, spin_parser_validator): + super(SpinInputsDatastore, self).__init__(spin_parser_validator) + self._previous_tasks = None + + def __len__(self): + return len(self.get_previous_tasks) + + def __getitem__(self, idx): + _item_task = self.get_previous_tasks[idx] + _item_artifacts = self.spin_parser_validator.artifacts + # _item_artifacts = self.spin_parser_validator.artifacts[idx] + return SpinInput(_item_artifacts, _item_task) + + def __iter__(self): + for idx in range(len(self.get_previous_tasks)): + yield self[idx] + + @property + def get_previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + # This a join step for a foreach split, so only has one previous step + prev_step_name = self.previous_steps[0] + self._previous_tasks = self.get_all_previous_tasks(prev_step_name) + # Sort the tasks by index + self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) + return self._previous_tasks diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py new file mode 100644 index 00000000000..a33bbfca2cb --- /dev/null +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -0,0 +1,95 @@ +class LinearStepDatastore(object): + def __init__(self, task_pathspec): + from metaflow import Task + + self._task_pathspec = task_pathspec + self._task = Task(task_pathspec, _namespace_check=False) + self._previous_task = None + self._data = {} + + # Set them to empty dictionaries in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} + + def __contains__(self, name): + try: + _ = self.__getattr__(name) + except AttributeError: + return False + return True + + def __getitem__(self, name): + return self.__getattr__(name) + + def __setitem__(self, name, value): + self._data[name] = value + + def __getattr__(self, name): + # Check internal data first + if name in self._data: + return self._data[name] + + # We always look for any artifacts provided by the user first + if name in self.artifacts: + return self.artifacts[name] + + if self.run_id is None: + raise AttributeError( + f"Attribute '{name}' not provided by the user and no `run_id` was provided. " + ) + + # If the linear step is part of a foreach step, we need to set the input attribute + # and the index attribute + if name == "input": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + ) + + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + foreach_var = foreach_stack[-1].var + + # Fetch the artifact corresponding to the foreach var and index from the previous task + input_val = self.previous_task[foreach_var].data[foreach_index] + setattr(self, name, input_val) + return input_val + + # If the linear step is part of a foreach step, we need to set the index attribute + if name == "index": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + ) + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + setattr(self, name, foreach_index) + return foreach_index + + # If the user has not provided the artifact, we look for it in the + # task using the client API + try: + return getattr(self.previous_task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @property + def previous_task(self): + # This is a linear step, so we only have one immediate ancestor + if self._previous_task: + return self._previous_task + + prev_task_pathspecs = self._task.immediate_ancestors + prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0] + self._previous_task = Task(prev_task_pathspec, _namespace_check=False) + return self._previous_task + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except AttributeError: + return default From 30f170ca1a8f7d09cd14e7e448713ee942602505 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 24 Jan 2025 16:49:34 -0800 Subject: [PATCH 25/28] Dummy commit --- metaflow/cli_components/step_cmd.py | 2 + .../spin_datastore/step_datastore.py | 2 +- metaflow/task.py | 40 +++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 035922452ff..131a40b5e01 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -253,4 +253,6 @@ def spin_internal( print("Task is: ", task) print("I am here 3") print("sys.executable: ", sys.executable) + + task.run_spin_step() # pass diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py index a33bbfca2cb..971ee88ab98 100644 --- a/metaflow/datastore/spin_datastore/step_datastore.py +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -79,10 +79,10 @@ def __getattr__(self, name): @property def previous_task(self): - # This is a linear step, so we only have one immediate ancestor if self._previous_task: return self._previous_task + # This is a linear step, so we only have one immediate ancestor prev_task_pathspecs = self._task.immediate_ancestors prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0] self._previous_task = Task(prev_task_pathspec, _namespace_check=False) diff --git a/metaflow/task.py b/metaflow/task.py index 3e9b553d03b..1b5ff760481 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -390,6 +390,46 @@ def _finalize_control_task(self): ) ) + def run_spin_step( + self, + step_name, + task_pathspec, + new_run_id, + new_task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + ): + step_func = getattr(self.flow, step_name) + decorators = step_func.decorators + + # initialize output datastore + output = self.flow_datastore.get_task_datastore( + new_run_id, step_name, new_task_id, 0, mode="w" + ) + + output.init_task() + + # How we access the input and index attributes depends on the execution context. + # If spin is set to True, we short-circuit attribute access to getattr directly + # Also set the other attributes that are needed for the task to execute + self.flow._spin = True + self.flow._current_step = step_name + self.flow._success = False + self.flow._task_ok = None + self.flow._exception = None + + # Set inputs + if spin_parser_validator.step_type == "join": + self.flow._set_datastore(output) + else: + # Set inputs + self.flow._set_datastore(step_datastore) + if step_datastore.is_foreach_step: + setattr(self.flow, "_spin_input", step_datastore.input) + setattr(self.flow, "_spin_index", step_datastore.index) + def run_step( self, step_name, From 2b821d06343f04f5bab888bf7ab3695cb5816023 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 26 Jan 2025 01:34:10 -0800 Subject: [PATCH 26/28] Working implementation --- metaflow/cli.py | 6 - metaflow/cli_components/run_cmds.py | 10 +- metaflow/cli_components/step_cmd.py | 34 +++- metaflow/datastore/__init__.py | 3 + .../spin_datastore/inputs_datastore.py | 99 +++++++---- .../spin_datastore/step_datastore.py | 43 +++-- metaflow/flowspec.py | 12 ++ metaflow/metaflow_current.py | 2 + metaflow/plugins/pypi/conda_decorator.py | 1 - metaflow/runtime.py | 34 ++-- metaflow/task.py | 154 +++++++++++++++++- 11 files changed, 306 insertions(+), 92 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index cdce10deca2..3c6612227b3 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -509,12 +509,6 @@ def start( ctx.obj.event_logger, ctx.obj.monitor, ) - echo( - "Using local metadata provider, datastore, environment, and null event logger and monitor for spin." - ) - print(f"Using metadata provider: {ctx.obj.metadata}") - echo(f"Using Datastore root: {datastore_root}") - echo(f"Using Flow Datastore: {ctx.obj.flow_datastore}") # Start event logger and monitor ctx.obj.event_logger.start() diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index fc873452d47..b280435d559 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -385,12 +385,20 @@ def run( help="Task ID to use when spinning up the step. The spinned up step will use the artifacts" "corresponding to this task ID. If not provided, an arbitrary task ID from the latest run will be used.", ) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=True, + show_default=True, + help="Skip decorators attached to the step.", +) @common_runner_options @click.pass_obj def spin( obj, step_name, task_pathspec=None, + skip_decorators=False, run_id_file=None, runner_attribute_file=None, **kwargs @@ -418,10 +426,10 @@ def spin( obj.monitor, step_func, task_pathspec, + skip_decorators, ) # write_latest_run_id(obj, runtime.run_id) # write_file(run_id_file, runtime.run_id) spin_runtime.execute() - pass diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 131a40b5e01..00b957a8e2a 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -190,6 +190,12 @@ def step( required=True, help="Task ID for the step that's about to be spun", ) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) @click.option( "--input-paths", help="A comma-separated list of pathspecs specifying inputs for this step.", @@ -223,11 +229,13 @@ def spin_internal( step_name, run_id=None, task_id=None, + task_pathspec=None, input_paths=None, split_index=None, retry_count=None, max_user_code_retries=None, namespace=None, + skip_decorators=False, ): import sys @@ -235,9 +243,7 @@ def spin_internal( echo = echo_dev_null else: echo = echo_always - print("I am here 1") - print("I am here 2") - # echo("Spinning a task, *%s*" % step_name, fg="magenta", bold=False) + echo("Spinning a task, *%s*" % step_name, fg="magenta", bold=False) task = MetaflowTask( ctx.obj.flow, @@ -249,10 +255,22 @@ def spin_internal( ctx.obj.monitor, # null monitor None, # no unbounded foreach context ) - # echo("Task is: ", task) - print("Task is: ", task) - print("I am here 3") + # print("Task is: ", task) + # print("I am here 3") print("sys.executable: ", sys.executable) + import time - task.run_spin_step() - # pass + start = time.time() + task.run_spin_step( + step_name, + task_pathspec, + run_id, + task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + namespace, + skip_decorators, + ) + print("Time taken for the whole thing: ", time.time() - start) diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..8560b5627cd 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,6 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore.step_datastore import LinearStepDatastore +from .spin_datastore.inputs_datastore import SpinInputsDatastore +from .spin_datastore.inputs_datastore import StaticSpinInputsDatastore diff --git a/metaflow/datastore/spin_datastore/inputs_datastore.py b/metaflow/datastore/spin_datastore/inputs_datastore.py index 898048869a7..0fd8ad4b768 100644 --- a/metaflow/datastore/spin_datastore/inputs_datastore.py +++ b/metaflow/datastore/spin_datastore/inputs_datastore.py @@ -1,4 +1,4 @@ -from . import SpinDatastore +from itertools import chain class SpinInput(object): @@ -16,73 +16,104 @@ def __getattr__(self, name): except AttributeError: raise AttributeError( f"Attribute '{name}' not found in the previous execution of the task for " - f"`{self.step_name}`." + f"`{self.task.parent.id}`." ) -class StaticSpinInputsDatastore(SpinDatastore): - def __init__(self, spin_parser_validator): - super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator) - self._previous_tasks = {} +class StaticSpinInputsDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self.task = task + self.immediate_ancestors = immediate_ancestors + self.artifacts = artifacts + + self._previous_tasks = None + self._previous_steps = None + + @property + def previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + from metaflow import Task + + prev_task_pathspecs = self.immediate_ancestors + # Static Join step, so each previous step only has one task + self._previous_tasks = { + step_name: Task(prev_task_pathspec[0], _namespace_check=False) + for step_name, prev_task_pathspec in prev_task_pathspecs.items() + } + return self._previous_tasks + + @property + def previous_steps(self): + if self._previous_steps: + return self._previous_steps + self._previous_steps = self.task.metadata_dict.get("previous-steps") + return self._previous_steps def __getattr__(self, name): if name not in self.previous_steps: raise AttributeError( - f"Attribute '{name}' not found in the previous execution of the task for " - f"`{self.step_name}`." + f"Step '{self.task.parent.id}' does not have a previous step with name '{name}'." ) input_step = SpinInput( - self.spin_parser_validator.artifacts["join"][name], - self.get_previous_tasks[name], + self.artifacts.get( + name, {} + ), # Get the artifacts corresponding to the previous step + self.previous_tasks.get( + name + ), # Get the task corresponding to the previous step ) setattr(self, name, input_step) return input_step def __iter__(self): for prev_step_name in self.previous_steps: - yield self[prev_step_name] + yield getattr(self, prev_step_name) def __len__(self): - return len(self.get_previous_tasks) + return len(self.previous_steps) - @property - def get_previous_tasks(self): - if self._previous_tasks: - return self._previous_tasks - - for prev_step_name in self.previous_steps: - previous_task = self.get_all_previous_tasks(prev_step_name) - self._previous_tasks[prev_step_name] = previous_task - return self._previous_tasks +class SpinInputsDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self.task = task + self.immediate_ancestors = immediate_ancestors + self.artifacts = artifacts -class SpinInputsDatastore(SpinDatastore): - def __init__(self, spin_parser_validator): - super(SpinInputsDatastore, self).__init__(spin_parser_validator) self._previous_tasks = None def __len__(self): - return len(self.get_previous_tasks) + return len(self.previous_tasks) def __getitem__(self, idx): - _item_task = self.get_previous_tasks[idx] - _item_artifacts = self.spin_parser_validator.artifacts - # _item_artifacts = self.spin_parser_validator.artifacts[idx] + _item_task = self.previous_tasks[idx] + _item_artifacts = self.artifacts.get(self.previous_step, {}).get(idx, {}) return SpinInput(_item_artifacts, _item_task) def __iter__(self): - for idx in range(len(self.get_previous_tasks)): + for idx in range(len(self.previous_tasks)): yield self[idx] @property - def get_previous_tasks(self): + def previous_step(self): + return self.task.metadata_dict.get("previous-steps")[0] + + @property + def previous_tasks(self): if self._previous_tasks: return self._previous_tasks - # This a join step for a foreach split, so only has one previous step - prev_step_name = self.previous_steps[0] - self._previous_tasks = self.get_all_previous_tasks(prev_step_name) - # Sort the tasks by index + # Foreach Join step so we have one previous step with multiple tasks + from metaflow import Task + + prev_task_pathspecs = list( + chain.from_iterable(self.immediate_ancestors.values()) + ) + self._previous_tasks = [ + Task(prev_task_pathspec, _namespace_check=False) + for prev_task_pathspec in prev_task_pathspecs + ] self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) return self._previous_tasks diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py index 971ee88ab98..7d84fb6273e 100644 --- a/metaflow/datastore/spin_datastore/step_datastore.py +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -1,9 +1,11 @@ -class LinearStepDatastore(object): - def __init__(self, task_pathspec): - from metaflow import Task +from itertools import chain + - self._task_pathspec = task_pathspec - self._task = Task(task_pathspec, _namespace_check=False) +class LinearStepDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self._task = task + self._immediate_ancestors = immediate_ancestors + self._artifacts = artifacts self._previous_task = None self._data = {} @@ -31,21 +33,26 @@ def __getattr__(self, name): return self._data[name] # We always look for any artifacts provided by the user first - if name in self.artifacts: - return self.artifacts[name] - - if self.run_id is None: - raise AttributeError( - f"Attribute '{name}' not provided by the user and no `run_id` was provided. " - ) + if name in self._artifacts: + return self._artifacts[name] # If the linear step is part of a foreach step, we need to set the input attribute # and the index attribute if name == "input": if not self._task.index: raise AttributeError( - f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + f"Attribute '{name}' does not exist for step `{self._task.parent.id}` as it is not part of " + f"a foreach step." ) + # input only exists for steps immediately after a foreach split + # we check for that by comparing the length of the foreach-step-names + # attribute of the task and its immediate ancestors + foreach_step_names = self._task.metadata_dict.get("foreach-step-names") + prev_task_foreach_step_names = self.previous_task.metadata_dict.get( + "foreach-step-names" + ) + if len(foreach_step_names) <= len(prev_task_foreach_step_names): + return None # input does not exist, so we return None foreach_stack = self._task["_foreach_stack"].data foreach_index = foreach_stack[-1].index @@ -60,7 +67,8 @@ def __getattr__(self, name): if name == "index": if not self._task.index: raise AttributeError( - f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a " + f"foreach step." ) foreach_stack = self._task["_foreach_stack"].data foreach_index = foreach_stack[-1].index @@ -83,8 +91,11 @@ def previous_task(self): return self._previous_task # This is a linear step, so we only have one immediate ancestor - prev_task_pathspecs = self._task.immediate_ancestors - prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0] + from metaflow import Task + + prev_task_pathspec = list( + chain.from_iterable(self._immediate_ancestors.values()) + )[0] self._previous_task = Task(prev_task_pathspec, _namespace_check=False) return self._previous_task diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index 888d89f4bc6..69b82c133a3 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -184,6 +184,12 @@ def __init__(self, use_cli=True): self._transition = None self._cached_input = {} + # Spin state + self._spin = False + self._spin_index = None + self._spin_input = None + self._spin_foreach_stack = None + if use_cli: with parameters.flow_context(self.__class__) as _: from . import cli @@ -459,6 +465,8 @@ def index(self) -> Optional[int]: int, optional Index of the task in a foreach step. """ + if self._spin: + return self._spin_index if self._foreach_stack: return self._foreach_stack[-1].index @@ -479,6 +487,8 @@ def input(self) -> Optional[Any]: object, optional Input passed to the foreach task. """ + if self._spin: + return self._datastore["input"] return self._find_input() def foreach_stack(self) -> Optional[List[Tuple[int, int, Any]]]: @@ -528,6 +538,8 @@ def nest_2(self): List[Tuple[int, int, Any]] An array describing the current stack of foreach steps. """ + if self._spin: + return self._spin_foreach_stack return [ (frame.index, frame.num_splits, self._find_input(stack_index=i)) for i, frame in enumerate(self._foreach_stack) diff --git a/metaflow/metaflow_current.py b/metaflow/metaflow_current.py index 8443c1d75ab..73c9f8359d8 100644 --- a/metaflow/metaflow_current.py +++ b/metaflow/metaflow_current.py @@ -45,6 +45,7 @@ def _set_env( username=None, metadata_str=None, is_running=True, + is_spin=False, tags=None, ): if flow is not None: @@ -60,6 +61,7 @@ def _set_env( self._username = username self._metadata_str = metadata_str self._is_running = is_running + self._is_spin = is_spin self._tags = tags def _update_env(self, env): diff --git a/metaflow/plugins/pypi/conda_decorator.py b/metaflow/plugins/pypi/conda_decorator.py index f43a68425dc..b1b7ee833d9 100644 --- a/metaflow/plugins/pypi/conda_decorator.py +++ b/metaflow/plugins/pypi/conda_decorator.py @@ -287,7 +287,6 @@ def task_pre_step( def runtime_step_cli( self, cli_args, retry_count, max_user_code_retries, ubf_context ): - print("Let's go - I am here") if self.disabled: return # Ensure local installation of Metaflow is visible to user code diff --git a/metaflow/runtime.py b/metaflow/runtime.py index f04703e64d7..1d282dbe278 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -91,6 +91,7 @@ def __init__( monitor, step_func, task_pathspec, + skip_decorators=False, max_log_size=MAX_LOG_SIZE, ): from metaflow import Task @@ -108,11 +109,12 @@ def __init__( self._step_func = step_func self._task_pathspec = task_pathspec - self._prev_task = Task(self._task_pathspec, _namespace_check=False) + self._task = Task(self._task_pathspec, _namespace_check=False) self._input_paths = None self._split_index = None self._whitelist_decorators = None self._config_file_name = None + self._skip_decorators = skip_decorators self._max_log_size = max_log_size self._encoding = sys.stdout.encoding or "UTF-8" @@ -122,26 +124,15 @@ def __init__( f"New run_id for spin task: {self._run_id} and step func: {self._step_func.name}" ) - print( - f"Decorators for {self._step_func.name}: {list(self._step_func.decorators)}" - ) - - for deco in self._step_func.decorators: - print( - f"Running runtime_init for {deco.__class__.__name__} at {self._step_func.name}" - ) + for deco in self.whitelist_decorators: print("-" * 100) deco.runtime_init(flow, graph, package, self._run_id) - if hasattr(deco, "_metaflow_home"): - print(f"Metaflow home is {deco._metaflow_home}") - - print(f"Input paths: {self.input_paths}") @property def split_index(self): if self._split_index: return self._split_index - foreach_indices = self._prev_task.metadata_dict.get("foreach-indices", []) + foreach_indices = self._task.metadata_dict.get("foreach-indices", []) self._split_index = foreach_indices[-1] if foreach_indices else None return self._split_index @@ -167,7 +158,7 @@ def _format_input_paths(task_id): ) self._input_paths = [f"{run_id}/_parameters/{task.id}"] else: - ancestors = self._prev_task.immediate_ancestors + ancestors = self._task.immediate_ancestors self._input_paths = [ _format_input_paths(ancestor) for i, ancestor in enumerate(chain.from_iterable(ancestors.values())) @@ -176,6 +167,8 @@ def _format_input_paths(task_id): @property def whitelist_decorators(self): + if self._skip_decorators: + return [] if self._whitelist_decorators: return self._whitelist_decorators self._whitelist_decorators = [ @@ -244,7 +237,7 @@ def execute(self): deco.runtime_finished(exception) def _launch_and_monitor_task(self): - args = CLIArgs(self.task, spin=True) + args = CLIArgs(self.task, spin=True, prev_task_pathspec=self._task_pathspec) env = dict(os.environ) for deco in self.task.decos: @@ -1767,12 +1760,13 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task, spin=False): + def __init__(self, task, spin=False, prev_task_pathspec=None): self.task = task self.spin = spin + self.prev_task_pathspec = prev_task_pathspec self.entrypoint = list(task.entrypoint) self.top_level_options = { - "quiet": True, + "quiet": True if spin else False, "metadata": self.task.metadata_type, "environment": self.task.environment_type, "datastore": self.task.datastore_type, @@ -1830,8 +1824,12 @@ def spin_args(self): self.command_options = { "run-id": self.task.run_id, "task-id": self.task.task_id, + "task-pathspec": self.prev_task_pathspec, "input-paths": self.task.input_paths, "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "namespace": get_namespace() or "", } self.env = {} diff --git a/metaflow/task.py b/metaflow/task.py index 1b5ff760481..06434d906e2 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -9,6 +9,7 @@ import traceback from types import MethodType, FunctionType +from itertools import chain from metaflow.sidecar import Message, MessageTypes from metaflow.datastore.exceptions import DataException @@ -16,7 +17,13 @@ from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum from .mflog import TASK_LOG_SOURCE -from .datastore import Inputs, TaskDataStoreSet +from .datastore import ( + Inputs, + TaskDataStoreSet, + LinearStepDatastore, + SpinInputsDatastore, + StaticSpinInputsDatastore, +) from .exception import ( MetaflowInternalError, MetaflowDataMissing, @@ -29,6 +36,7 @@ from metaflow.system import _system_logger, _system_monitor from metaflow.tracing import get_trace_id from metaflow.tuple_util import ForeachFrame +from .metaflow_config import SPIN_ALLOWED_DECORATORS # Maximum number of characters of the foreach path that we store in the metadata. MAX_FOREACH_PATH_LENGTH = 256 @@ -400,9 +408,25 @@ def run_spin_step( split_index, retry_count, max_user_code_retries, + namespace, + skip_decorators, ): + def is_join_step(immediate_ancestors): + prev_task_pathspecs = set(chain.from_iterable(immediate_ancestors.values())) + return len(prev_task_pathspecs) > 1 + step_func = getattr(self.flow, step_name) - decorators = step_func.decorators + whitelisted_decorators = ( + [] + if skip_decorators + else [ + deco + for deco in step_func.decorators + if any( + deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS + ) + ] + ) # initialize output datastore output = self.flow_datastore.get_task_datastore( @@ -414,21 +438,135 @@ def run_spin_step( # How we access the input and index attributes depends on the execution context. # If spin is set to True, we short-circuit attribute access to getattr directly # Also set the other attributes that are needed for the task to execute + from metaflow import Task + + self.task = Task(task_pathspec, _namespace_check=False) + immediate_ancestors = self.task.immediate_ancestors self.flow._spin = True + self.flow._spin_foreach_stack = self.task["_foreach_stack"].data + self.flow._spin_index = split_index self.flow._current_step = step_name self.flow._success = False self.flow._task_ok = None self.flow._exception = None # Set inputs - if spin_parser_validator.step_type == "join": + inp_datastore = None + is_join = is_join_step(immediate_ancestors) + if is_join: + # Join step + if len(self.task.metadata_dict.get("previous-steps")) > 1: + # Static join step + inp_datastore = StaticSpinInputsDatastore( + self.task, immediate_ancestors, artifacts={} + ) + else: + # Foreach join step + inp_datastore = SpinInputsDatastore( + self.task, immediate_ancestors, artifacts={} + ) self.flow._set_datastore(output) else: - # Set inputs - self.flow._set_datastore(step_datastore) - if step_datastore.is_foreach_step: - setattr(self.flow, "_spin_input", step_datastore.input) - setattr(self.flow, "_spin_index", step_datastore.index) + # Linear step + self.flow._set_datastore( + LinearStepDatastore(self.task, immediate_ancestors, artifacts={}) + ) + + current._set_env( + flow=self.flow, + run_id=new_run_id, + step_name=step_name, + task_id=new_task_id, + retry_count=retry_count, + namespace=resolve_identity(), + username=get_username(), + metadata_str="%s@%s" + % (self.metadata.__class__.TYPE, self.metadata.__class__.INFO), + is_running=True, + is_spin=True, + ) + + # task_pre_step decorator hooks + for deco in whitelisted_decorators: + deco.task_pre_step( + step_name=step_name, + task_datastore=output, + metadata=self.metadata, + run_id=new_run_id, + task_id=new_task_id, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + inputs=inp_datastore, + ) + + # task_decorate decorator hooks + for deco in whitelisted_decorators: + step_func = deco.task_decorate( + step_func=step_func, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + ) + + # Execute the step function + try: + if is_join: + # Join step + self._exec_step_function(step_func, input_obj=inp_datastore) + else: + self._exec_step_function(step_func) + + # task_post_step decorator hooks + for deco in whitelisted_decorators: + deco.task_post_step( + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + + self.flow._task_ok = True + self.flow._success = True + except Exception as ex: + exception_handled = False + for deco in whitelisted_decorators: + res = deco.task_exception( + ex, + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + exception_handled = bool(res) or exception_handled + + if exception_handled: + self.flow._task_ok = True + else: + self.flow._task_ok = False + self.flow._exception = MetaflowExceptionWrapper(ex) + print("%s failed:" % self.flow, file=sys.stderr) + raise + finally: + output.persist(self.flow) + output.done() + + # task_finish decorator hooks + for deco in whitelisted_decorators: + deco.task_finished( + step_name, + self.flow, + self.flow._graph, + self.flow._task_ok, + retry_count, + max_user_code_retries, + ) def run_step( self, From 6f57368e97d22d09a420fa65874f639e8b0937de Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 26 Jan 2025 02:40:12 -0800 Subject: [PATCH 27/28] Use polling for spin logging --- metaflow/runtime.py | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 1d282dbe278..05ea8123b79 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -120,10 +120,6 @@ def __init__( # Create a new run_id for the spin task self._run_id = self._metadata.new_run_id() - print( - f"New run_id for spin task: {self._run_id} and step func: {self._step_func.name}" - ) - for deco in self.whitelist_decorators: print("-" * 100) deco.runtime_init(flow, graph, package, self._run_id) @@ -266,7 +262,7 @@ def _launch_and_monitor_task(self): process = subprocess.Popen( cmdline, env=env, - bufsize=1, + bufsize=1, # Line buffering stdin=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, @@ -275,23 +271,42 @@ def _launch_and_monitor_task(self): except Exception as e: raise TaskFailed(self.task, f"Failed to launch task: {str(e)}") + poll = procpoll.make_poll() + poll.add(process.stdout.fileno()) + poll.add(process.stderr.fileno()) + + # Map file descriptors to their respective streams and buffers + fd_map = { + process.stdout.fileno(): (process.stdout, stdout_buffer, False), + process.stderr.fileno(): (process.stderr, stderr_buffer, True), + } + while True: - stdout_line = process.stdout.readline() - if stdout_line: - self._process_output(stdout_line, stdout_buffer) + # Poll for events with a timeout + events = poll.poll(POLL_TIMEOUT) + + if not events: + if process.poll() is not None: + break + continue + + for event in events: + if event.can_read: + stream, buffer, is_stderr = fd_map[event.fd] + line = stream.readline() + if line: + self._process_output(line, buffer, is_stderr) - stderr_line = process.stderr.readline() - if stderr_line: - self._process_output(stderr_line, stderr_buffer, is_stderr=True) + if event.is_terminated: + poll.remove(event.fd) if process.poll() is not None: break # Process any remaining output - for line in process.stdout: - self._process_output(line, stdout_buffer) - for line in process.stderr: - self._process_output(line, stderr_buffer, is_stderr=True) + for stream, buffer, is_stderr in fd_map.values(): + for line in stream: + self._process_output(line, buffer, is_stderr) returncode = process.wait() From 86b872e5a21fbf9ac4c509fb710b9cda721c6e0e Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 26 Jan 2025 22:51:24 -0800 Subject: [PATCH 28/28] Make spin work with runner API --- metaflow/cli_components/run_cmds.py | 21 ++- metaflow/cli_components/step_cmd.py | 10 +- metaflow/runner/metaflow_runner.py | 194 ++++++++++++++++++++++++---- metaflow/runtime.py | 11 +- 4 files changed, 203 insertions(+), 33 deletions(-) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index b280435d559..8227a027273 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -388,7 +388,7 @@ def run( @click.option( "--skip-decorators/--no-skip-decorators", is_flag=True, - default=True, + default=False, show_default=True, help="Skip decorators attached to the step.", ) @@ -429,7 +429,20 @@ def spin( skip_decorators, ) - # write_latest_run_id(obj, runtime.run_id) - # write_file(run_id_file, runtime.run_id) - + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) spin_runtime.execute() + + local_metadata = f"{obj.metadata.__class__.TYPE}@{obj.metadata.__class__.INFO}" + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + "metadata": local_metadata, + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 00b957a8e2a..fd3bfcdbd5a 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -223,6 +223,13 @@ def step( default=None, help="Change namespace from the default (your username) to the specified tag.", ) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) @click.pass_context def spin_internal( ctx, @@ -255,9 +262,6 @@ def spin_internal( ctx.obj.monitor, # null monitor None, # no unbounded foreach context ) - # print("Task is: ", task) - # print("I am here 3") - print("sys.executable: ", sys.executable) import time start = time.time() diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 4759a13b191..38a0525b0fb 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -4,7 +4,6 @@ import json from typing import Dict, Iterator, Optional, Tuple - from metaflow import Run from metaflow.plugins import get_runner_cli @@ -17,27 +16,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -48,12 +53,9 @@ def __init__( Parent runner for this run. command_obj : CommandManager CommandManager containing the subprocess executing this run. - run_obj : Run - Run object corresponding to this run. """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -72,11 +74,10 @@ async def wait( Parameters ---------- - timeout : float, optional, default None - The maximum time, in seconds, to wait for the run to finish. - If the timeout is reached, the run is terminated. If not specified, wait - forever. - stream : str, optional, default None + timeout : Optional[float], default None + The maximum time to wait for the run to finish. + If the timeout is reached, the run is terminated + stream : Optional[str], default None If specified, the specified stream is printed to stdout. `stream` can be one of `stdout` or `stderr`. @@ -173,7 +174,7 @@ async def stream_log( ---------- stream : str The stream to stream logs from. Can be one of `stdout` or `stderr`. - position : int, optional, default None + position : Optional[int], default None The position in the log file to start streaming from. If None, it starts from the beginning of the log file. This allows resuming streaming from a previously known position @@ -189,6 +190,83 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -322,6 +400,23 @@ def __get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + async def __async_get_executing_run(self, attribute_file_fd, command_obj): content = await async_handle_timeout( attribute_file_fd, command_obj, self.file_read_timeout @@ -337,6 +432,23 @@ async def __async_get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + def run(self, **kwargs) -> ExecutingRun: """ Blocking execution of the run. This method will wait until @@ -399,6 +511,44 @@ def resume(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def spin(self, step_name, task_pathspec, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + + Parameters + ---------- + step_name : str + The name of the step to spin. + task_pathspec : str, optional, default None + The task pathspec to be used in the spun task. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + task_pathspec=task_pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + async def async_run(self, **kwargs) -> ExecutingRun: """ Non-blocking execution of the run. This method will return as soon as the diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 05ea8123b79..7623bafd103 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -119,10 +119,10 @@ def __init__( self._encoding = sys.stdout.encoding or "UTF-8" # Create a new run_id for the spin task - self._run_id = self._metadata.new_run_id() + self.run_id = self._metadata.new_run_id() for deco in self.whitelist_decorators: print("-" * 100) - deco.runtime_init(flow, graph, package, self._run_id) + deco.runtime_init(flow, graph, package, self.run_id) @property def split_index(self): @@ -179,7 +179,7 @@ def _new_task(self, step, input_paths=None, **kwargs): self._flow_datastore, self._flow, step, - self._run_id, + self.run_id, self._metadata, self._environment, self._entrypoint, @@ -205,7 +205,7 @@ def execute(self): self.task = self._new_task(self._step_func.name, {}) _ds = self._flow_datastore.get_task_datastore( - self._run_id, + self.run_id, self._step_func.name, self.task.task_id, attempt=0, @@ -249,6 +249,9 @@ def _launch_and_monitor_task(self): if self._config_file_name: args.top_level_options["local-config-file"] = self._config_file_name + # Add the skip-decorators flag to the command options + args.command_options.update({"skip-decorators": self._skip_decorators}) + env.update(args.get_env()) env["PYTHONUNBUFFERED"] = "x"