Skip to content

Commit

Permalink
Update logic for siblings, make it work for static splits as well
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 14, 2025
1 parent 253a2f5 commit 17a4489
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,41 +1278,58 @@ def immediate_successors(self) -> Dict[str, List[str]]:
return self._get_related_tasks(is_ancestor=False)

@property
def immediate_siblings(self) -> Dict[str, List[str]]:
def siblings(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of closest sibling task pathspecs of this task for each
sibling step.
Returns a dictionary of sibling task pathspecs of this task. Siblings of a task have
the same common parent task.
Returns
-------
Dict[str, List[str]]
Dictionary of closest siblings of this task. The keys are the
Dictionary of siblings task pathspecs of this task. The keys are the
names of the current step and the values are the corresponding
task pathspecs of the siblings.
"""
flow_id, run_id, step_name, _ = self.path_components

foreach_stack = self.metadata_dict.get("foreach-stack", [])
foreach_step_names = self.metadata_dict.get("foreach-step-names", [])
if len(foreach_stack) == 0:
raise MetaflowInternalError("Task is not part of any foreach split")
if step_name != foreach_step_names[-1]:
raise MetaflowInternalError(
f"Step {step_name} does not have any direct siblings since it is not part "
f"of a new foreach split."
)
ancestor_steps = self.metadata_dict.get("previous-steps")
cur_foreach_stack_len = len(self.metadata_dict.get("foreach-indices", []))
if len(ancestor_steps) > 1 or step_name in ("start", "end"):
# This is a static join, or a start/end step. The current task will have no siblings.
return {
step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"],
}

field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices-truncated")
# We find all tasks of the same step that have the same foreach-indices-truncated value
return {
step_name: [
f"{flow_id}/{run_id}/{step_name}/{task_id}"
for task_id in self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, field_name, field_value
)
]
}
# This can be a linear step, a foreach split, a foreach join, or a static split.
query_task = self._get_task_for_queried_step(flow_id, run_id, ancestor_steps[0])
query_foreach_stack_len = len(
query_task.metadata_dict.get("foreach-indices", [])
)
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach join, there will be no siblings
return {
step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"],
}
elif query_foreach_stack_len < cur_foreach_stack_len:
# This is a foreach split, there will be multiple siblings
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices-truncated")
# We find all tasks of the same step that have the same foreach-indices-truncated value
return {
step_name: [
f"{flow_id}/{run_id}/{step_name}/{task_id}"
for task_id in self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, field_name, field_value
)
]
}

# Logic for static splits, and linear steps
# To find siblings, we first find the single ancestor task of the current task.
# And then we find all the successor tasks of this ancestor task.
ancestor_task_pathspecs = self.immediate_ancestors.get(ancestor_steps[0])
ancestor_task = Task(ancestor_task_pathspecs[0], _namespace_check=False)
return ancestor_task.immediate_successors

@property
def metadata(self) -> List[Metadata]:
Expand Down

0 comments on commit 17a4489

Please sign in to comment.