From 82dc61bf455a971e0bcc713284b10513aa85804e Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 3 Jan 2025 16:40:45 -0800 Subject: [PATCH] Support querying ancestors and successors in local metadata provider --- metaflow/metadata_provider/metadata.py | 29 +++--- metaflow/plugins/metadata_providers/local.py | 97 ++++++++++++++++++++ 2 files changed, 115 insertions(+), 11 deletions(-) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index ac713505099..7409fa371c0 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -672,21 +672,28 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) - @classmethod - def _filter_tasks_by_metadata( - cls, flow_id, run_id, query_step, field_name, field_value - ): - raise NotImplementedError() - @classmethod def filter_tasks_by_metadata( cls, flow_id, run_id, query_step, field_name, field_value ): - # TODO: Do we need to do anything wrt to task attempt? - task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, query_step, field_name, field_value - ) - return task_ids + """ + Filter tasks by metadata field and value, and returns the list of task_ids + that satisfy the query. + + Parameters + ---------- + flow_id : str + Flow id, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + query_step: str + Step name to query tasks from + field_name: str + Metadata field name to query + field_value: str + Metadata field value to query + """ + raise NotImplementedError() @staticmethod def _apply_filter(elts, filters): diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index ea7754cac5f..344f5d912c0 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -202,6 +202,103 @@ def _optimistically_mutate(): "Tagging failed due to too many conflicting updates from other processes" ) + @classmethod + def filter_tasks_by_metadata(cls, flow_id: str, run_id: str, query_step: str, + field_name: str, field_value: str) -> list: + """ + Filter tasks by metadata field and value, returning task IDs that match criteria. + + Parameters + ---------- + flow_id : str + Identifier for the flow + run_id : str + Identifier for the run + query_step : str + Name of the step to query tasks from + field_name : str + Name of metadata field to query + field_value : str + Value to match in metadata field + + Returns + ------- + list + List of task IDs that match the query criteria + + Raises + ------ + JSONDecodeError + If metadata file is corrupted or empty + FileNotFoundError + If metadata file is not found + """ + + def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: + """Find the most recent metadata file for the given field prefix.""" + # The metadata is saved as files with the format: sysmeta__.json + # and the artifact files are saved as: _artifact__.json + # We loop over all the JSON files in the directory and find the latest one + # that matches the field prefix. + json_files = glob.glob(os.path.join(path, "*.json")) + matching_files = [] + + for file_path in json_files: + filename = os.path.basename(file_path) + name, timestamp = filename.rsplit('_', 1) + timestamp = timestamp.split('.')[0] + + if name == field_prefix: + matching_files.append((file_path, int(timestamp))) + + if not matching_files: + return None + + return max(matching_files, key=lambda x: x[1]) + + def _read_metadata_value(file_path: str) -> dict: + """Read and parse metadata from JSON file.""" + try: + with open(file_path, 'r') as f: + return json.load(f) + except json.JSONDecodeError: + raise json.JSONDecodeError("Failed to decode metadata JSON file - may be corrupted or empty") + except Exception as e: + raise Exception(f"Error reading metadata file: {str(e)}") + + try: + # Get all tasks for the given step + tasks = LocalMetadataProvider.get_object( + "step", "task", {}, None, flow_id, run_id, query_step + ) + + filtered_task_ids = [] + field_name_prefix = f"sysmeta_{field_name}" + + # Filter tasks based on metadata + for task in tasks: + task_id = task.get("task_id") + if not task_id: + continue + + meta_path = LocalMetadataProvider._get_metadir( + flow_id, run_id, query_step, task_id + ) + + latest_file = _get_latest_metadata_file(meta_path, field_name_prefix) + if not latest_file: + continue + + # Read metadata and check value + metadata = _read_metadata_value(latest_file[0]) + if metadata.get("value") == field_value: + filtered_task_ids.append(task_id) + + return filtered_task_ids + + except Exception as e: + raise Exception(f"Failed to filter tasks: {str(e)}") + @classmethod def _get_object_internal( cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args