Skip to content

Commit

Permalink
Support querying ancestors and successors in local metadata provider
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 4, 2025
1 parent c6fb9ac commit 82dc61b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 11 deletions.
29 changes: 18 additions & 11 deletions metaflow/metadata_provider/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,21 +672,28 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt):
if metadata:
self.register_metadata(run_id, step_name, task_id, metadata)

@classmethod
def _filter_tasks_by_metadata(
cls, flow_id, run_id, query_step, field_name, field_value
):
raise NotImplementedError()

@classmethod
def filter_tasks_by_metadata(
cls, flow_id, run_id, query_step, field_name, field_value
):
# TODO: Do we need to do anything wrt to task attempt?
task_ids = cls._filter_tasks_by_metadata(
flow_id, run_id, query_step, field_name, field_value
)
return task_ids
"""
Filter tasks by metadata field and value, and returns the list of task_ids
that satisfy the query.
Parameters
----------
flow_id : str
Flow id, that the run belongs to.
run_id: str
Run id, together with flow_id, that identifies the specific Run whose tasks to query
query_step: str
Step name to query tasks from
field_name: str
Metadata field name to query
field_value: str
Metadata field value to query
"""
raise NotImplementedError()

@staticmethod
def _apply_filter(elts, filters):
Expand Down
97 changes: 97 additions & 0 deletions metaflow/plugins/metadata_providers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,103 @@ def _optimistically_mutate():
"Tagging failed due to too many conflicting updates from other processes"
)

@classmethod
def filter_tasks_by_metadata(cls, flow_id: str, run_id: str, query_step: str,
field_name: str, field_value: str) -> list:
"""
Filter tasks by metadata field and value, returning task IDs that match criteria.
Parameters
----------
flow_id : str
Identifier for the flow
run_id : str
Identifier for the run
query_step : str
Name of the step to query tasks from
field_name : str
Name of metadata field to query
field_value : str
Value to match in metadata field
Returns
-------
list
List of task IDs that match the query criteria
Raises
------
JSONDecodeError
If metadata file is corrupted or empty
FileNotFoundError
If metadata file is not found
"""

def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple:
"""Find the most recent metadata file for the given field prefix."""
# The metadata is saved as files with the format: sysmeta_<field_name>_<timestamp>.json
# and the artifact files are saved as: <attempt>_artifact__<artifact_name>.json
# We loop over all the JSON files in the directory and find the latest one
# that matches the field prefix.
json_files = glob.glob(os.path.join(path, "*.json"))
matching_files = []

for file_path in json_files:
filename = os.path.basename(file_path)
name, timestamp = filename.rsplit('_', 1)
timestamp = timestamp.split('.')[0]

if name == field_prefix:
matching_files.append((file_path, int(timestamp)))

if not matching_files:
return None

return max(matching_files, key=lambda x: x[1])

def _read_metadata_value(file_path: str) -> dict:
"""Read and parse metadata from JSON file."""
try:
with open(file_path, 'r') as f:
return json.load(f)
except json.JSONDecodeError:
raise json.JSONDecodeError("Failed to decode metadata JSON file - may be corrupted or empty")
except Exception as e:
raise Exception(f"Error reading metadata file: {str(e)}")

try:
# Get all tasks for the given step
tasks = LocalMetadataProvider.get_object(
"step", "task", {}, None, flow_id, run_id, query_step
)

filtered_task_ids = []
field_name_prefix = f"sysmeta_{field_name}"

# Filter tasks based on metadata
for task in tasks:
task_id = task.get("task_id")
if not task_id:
continue

meta_path = LocalMetadataProvider._get_metadir(
flow_id, run_id, query_step, task_id
)

latest_file = _get_latest_metadata_file(meta_path, field_name_prefix)
if not latest_file:
continue

# Read metadata and check value
metadata = _read_metadata_value(latest_file[0])
if metadata.get("value") == field_value:
filtered_task_ids.append(task_id)

return filtered_task_ids

except Exception as e:
raise Exception(f"Failed to filter tasks: {str(e)}")

@classmethod
def _get_object_internal(
cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args
Expand Down

0 comments on commit 82dc61b

Please sign in to comment.