Skip to content

Commit

Permalink
Refactor and simplify client code
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 4, 2025
1 parent 82dc61b commit 68808c9
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __iter__(self) -> Iterator["MetaflowObject"]:
_CLASSES[self._CHILD_CLASS]._NAME,
query_filter,
self._attempt,
*self.path_components
*self.path_components,
)
unfiltered_children = unfiltered_children if unfiltered_children else []
children = filter(
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step):
Returns a Task object corresponding to the queried step.
If the queried step has several tasks, the first task is returned.
"""
# Find any previous task for current step
# Find any task corresponding to the queried step
step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
Expand Down Expand Up @@ -1184,28 +1184,30 @@ def _get_filter_query_value(
field_value = self.metadata_dict.get("foreach-indices-truncated")
return field_name, field_value

def _get_related_tasks(
self, steps_key: str, relation_type: str
) -> Dict[str, List[str]]:
def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
flow_id, run_id, _, _ = self.path_components
query_steps = self.metadata_dict.get(steps_key)
steps = (
self.metadata_dict.get("previous_steps")
if relation_type == "ancestor"
else self.metadata_dict.get("successor_steps")
)

if not query_steps:
if not steps:
return {}

field_name, field_value = self._get_filter_query_value(
flow_id,
run_id,
len(self.metadata_dict.get("foreach-stack", [])),
query_steps,
steps,
relation_type,
)

return {
query_step: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, query_step, field_name, field_value
step: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step, field_name, field_value
)
for query_step in query_steps
for step in steps
}

@property
Expand All @@ -1221,7 +1223,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]:
names of the ancestors steps and the values are the corresponding
task ids of the ancestors.
"""
return self._get_related_tasks("previous_steps", "ancestor")
return self._get_related_tasks("ancestor")

@property
def immediate_successors(self) -> Dict[str, List[str]]:
Expand All @@ -1236,7 +1238,7 @@ def immediate_successors(self) -> Dict[str, List[str]]:
names of the successors steps and the values are the corresponding
task ids of the successors.
"""
return self._get_related_tasks("successor_steps", "successor")
return self._get_related_tasks("successor")

@property
def immediate_siblings(self) -> Dict[str, List[str]]:
Expand Down

0 comments on commit 68808c9

Please sign in to comment.