diff --git a/metaflow/cli.py b/metaflow/cli.py index 3a8dc4ecaa9..3c6612227b3 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -134,6 +134,8 @@ def config_merge_cb(ctx, param, value): "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-internal": "metaflow.cli_components.step_cmd.spin_internal", }, ) def cli(ctx): @@ -384,7 +386,6 @@ def start( # second one processed will return the actual options. The order of processing # depends on what (and in what order) the user specifies on the command line. config_options = config_file or config_value - if ( hasattr(ctx, "saved_args") and ctx.saved_args @@ -462,14 +463,10 @@ def start( ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor @@ -485,6 +482,41 @@ def start( ctx.obj.config_options = config_options + # Override values for spin + if hasattr(ctx, "saved_args") and ctx.saved_args and ctx.saved_args[0] == "spin": + # For spin, we will only use the local metadata provider, datastore, environment + # and null event logger and monitor + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( + ctx.obj.echo + ) + ctx.obj.datastore_impl.datastore_root = datastore_root + + FlowDataStore.default_storage_impl = ctx.obj.datastore_impl + ctx.obj.flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, + ctx.obj.metadata, + ctx.obj.event_logger, + ctx.obj.monitor, + ) + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) + decorators._init(ctx.obj.flow) # It is important to initialize flow decorators early as some of the @@ -528,7 +560,7 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index bf77d16ad1f..8227a027273 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -9,11 +9,11 @@ from ..graph import FlowGraph from ..metaflow_current import current from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, get_latest_task_pathspec def before_run(obj, tags, decospecs): @@ -70,6 +70,28 @@ def write_file(file_path, content): f.write(str(content)) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def common_run_options(func): @click.option( "--tag", @@ -110,20 +132,6 @@ def common_run_options(func): "option multiple times to attach multiple decorators " "in steps.", ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -167,6 +175,7 @@ def wrapper(*args, **kwargs): @click.argument("step-to-rerun", required=False) @click.command(help="Resume execution of a previous run of this flow.") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -285,6 +294,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli_entrypoint("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -360,3 +370,79 @@ def run( f, ) runtime.execute() + + +@click.command(help="Spins up a step locally") +@click.argument( + "step-name", + required=True, + type=str, +) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task ID to use when spinning up the step. The spinned up step will use the artifacts" + "corresponding to this task ID. If not provided, an arbitrary task ID from the latest run will be used.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + step_name, + task_pathspec=None, + skip_decorators=False, + run_id_file=None, + runner_attribute_file=None, + **kwargs +): + before_run(obj, [], []) + if task_pathspec is None: + task_pathspec = get_latest_task_pathspec(obj.flow.name, step_name) + + obj.echo( + f"Spinning up step *{step_name}* locally with task pathspec *{task_pathspec}*" + ) + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name) + + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + task_pathspec, + skip_decorators, + ) + + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + spin_runtime.execute() + + local_metadata = f"{obj.metadata.__class__.TYPE}@{obj.metadata.__class__.INFO}" + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + "metadata": local_metadata, + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 4b40c9e5e54..fd3bfcdbd5a 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -174,3 +174,107 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.pass_context +def spin_internal( + ctx, + step_name, + run_id=None, + task_id=None, + task_pathspec=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + namespace=None, + skip_decorators=False, +): + import sys + + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + echo("Spinning a task, *%s*" % step_name, fg="magenta", bold=False) + + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, # local datastore + ctx.obj.metadata, # local metadata provider + ctx.obj.environment, # local environment + ctx.obj.echo, + ctx.obj.event_logger, # null logger + ctx.obj.monitor, # null monitor + None, # no unbounded foreach context + ) + import time + + start = time.time() + task.run_spin_step( + step_name, + task_pathspec, + run_id, + task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + namespace, + skip_decorators, + ) + print("Time taken for the whole thing: ", time.time() - start) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 87b6a88c37c..e5273db9e33 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1,5 +1,6 @@ from __future__ import print_function +import time import json import os import tarfile @@ -379,7 +380,7 @@ def __iter__(self) -> Iterator["MetaflowObject"]: _CLASSES[self._CHILD_CLASS]._NAME, query_filter, self._attempt, - *self.path_components + *self.path_components, ) unfiltered_children = unfiltered_children if unfiltered_children else [] children = filter( @@ -1118,11 +1119,218 @@ class Task(MetaflowObject): def __init__(self, *args, **kwargs): super(Task, self).__init__(*args, **kwargs) + # We want to cache metadata dictionary since it's used in many places + self._metadata_dict = None def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" + def _get_task_for_queried_step(self, flow_id, run_id, query_step): + """ + Returns a Task object corresponding to the queried step. + If the queried step has several tasks, the first task is returned. + """ + return Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False).task + + def _get_metadata_query_vals( + self, + flow_id: str, + run_id: str, + cur_foreach_stack_len: int, + steps: List[str], + is_ancestor: bool, + ): + """ + Returns the field name and field value to be used for querying metadata of successor or ancestor tasks. + + Parameters + ---------- + flow_id : str + Flow ID of the task + run_id : str + Run ID of the task + cur_foreach_stack_len : int + Length of the foreach stack of the current task + steps : List[str] + List of step names whose tasks will be returned. For static joins, and static splits, we can have + ancestors and successors across multiple steps. + is_ancestor : bool + If we are querying for ancestor tasks, set this to True. + """ + # For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated + # which help us in querying ancestor and successor tasks. + # `foreach-indices`: contains the indices of the foreach stack at the time of task execution. + # `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but + # truncated by 1 + # For example, a task that's nested 3 levels deep in a foreach stack may have the following values: + # foreach-indices = [0, 1, 2] + # foreach-indices-truncated = [0, 1] + + if len(steps) > 1: + # This is a static join or a static split. There will be no change in foreach stack length + query_foreach_stack_len = cur_foreach_stack_len + else: + # For linear steps, or foreach splits and joins, ancestor and successor tasks will all belong to + # the same step. + query_task = self._get_task_for_queried_step(flow_id, run_id, steps[0]) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-indices", []) + ) + + if query_foreach_stack_len == cur_foreach_stack_len: + # The successor or ancestor tasks belong to the same foreach stack level + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif is_ancestor: + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join + # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # Ancestor Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # We will compare the foreach-indices-truncated value of ancestor task with the + # foreach-indices value of current task + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach split + # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # We will compare the foreach-indices value of ancestor task with the + # foreach-indices-truncated value of current task + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + else: + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach split + # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # Successor Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of successor tasks + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach join + # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] + # Successor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] + # We will compare the foreach-indices-truncated value of current task with the + # foreach-indices value of successor tasks + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + return field_name, field_value + + def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]: + flow_id, run_id, _, _ = self.path_components + steps = ( + self.metadata_dict.get("previous-steps") + if is_ancestor + else self.metadata_dict.get("successor-steps") + ) + + if not steps: + return {} + + field_name, field_value = self._get_metadata_query_vals( + flow_id, + run_id, + len(self.metadata_dict.get("foreach-indices", [])), + steps, + is_ancestor=is_ancestor, + ) + + return { + step: [ + f"{flow_id}/{run_id}/{step}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step, field_name, field_value + ) + ] + for step in steps + } + + @property + def immediate_ancestors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate ancestors task pathspecs of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestor steps and the values are the corresponding + task pathspecs of the ancestors. + """ + return self._get_related_tasks(is_ancestor=True) + + @property + def immediate_successors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate successors task pathspecs of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate successors of this task. The keys are the + names of the successor steps and the values are the corresponding + task pathspecs of the successors. + """ + return self._get_related_tasks(is_ancestor=False) + + @property + def siblings(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of sibling task pathspecs of this task. Siblings of a task have + the same common parent task. + + Returns + ------- + Dict[str, List[str]] + Dictionary of siblings task pathspecs of this task. The keys are the + names of the current step and the values are the corresponding + task pathspecs of the siblings. + """ + flow_id, run_id, step_name, _ = self.path_components + + ancestor_steps = self.metadata_dict.get("previous-steps") + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-indices", [])) + if len(ancestor_steps) > 1 or step_name in ("start", "end"): + # This is a static join, or a start/end step. The current task will have no siblings. + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } + + # This can be a linear step, a foreach split, a foreach join, or a static split. + query_task = self._get_task_for_queried_step(flow_id, run_id, ancestor_steps[0]) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-indices", []) + ) + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join, there will be no siblings + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } + elif query_foreach_stack_len < cur_foreach_stack_len: + # This is a foreach split, there will be multiple siblings + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices-truncated") + # We find all tasks of the same step that have the same foreach-indices-truncated value + return { + step_name: [ + f"{flow_id}/{run_id}/{step_name}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, field_name, field_value + ) + ] + } + + # Logic for static splits, and linear steps + # To find siblings, we first find the single ancestor task of the current task. + # And then we find all the successor tasks of this ancestor task. + ancestor_task_pathspecs = self.immediate_ancestors.get(ancestor_steps[0]) + ancestor_task = Task(ancestor_task_pathspecs[0], _namespace_check=False) + return ancestor_task.immediate_successors + @property def metadata(self) -> List[Metadata]: """ @@ -1211,9 +1419,12 @@ def metadata_dict(self) -> Dict[str, str]: Dictionary mapping metadata name with value """ # use the newest version of each key, hence sorting - return { - m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at) - } + if self._metadata_dict is None: + self._metadata_dict = { + m.name: m.value + for m in sorted(self.metadata, key=lambda m: m.created_at) + } + return self._metadata_dict @property def index(self) -> Optional[int]: diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..8560b5627cd 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,6 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore.step_datastore import LinearStepDatastore +from .spin_datastore.inputs_datastore import SpinInputsDatastore +from .spin_datastore.inputs_datastore import StaticSpinInputsDatastore diff --git a/metaflow/datastore/spin_datastore/__init__.py b/metaflow/datastore/spin_datastore/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/datastore/spin_datastore/inputs_datastore.py b/metaflow/datastore/spin_datastore/inputs_datastore.py new file mode 100644 index 00000000000..0fd8ad4b768 --- /dev/null +++ b/metaflow/datastore/spin_datastore/inputs_datastore.py @@ -0,0 +1,119 @@ +from itertools import chain + + +class SpinInput(object): + def __init__(self, artifacts, task): + self.artifacts = artifacts + self.task = task + + def __getattr__(self, name): + # We always look for any artifacts provided by the user first + if self.artifacts is not None and name in self.artifacts: + return self.artifacts[name] + + try: + return getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.task.parent.id}`." + ) + + +class StaticSpinInputsDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self.task = task + self.immediate_ancestors = immediate_ancestors + self.artifacts = artifacts + + self._previous_tasks = None + self._previous_steps = None + + @property + def previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + from metaflow import Task + + prev_task_pathspecs = self.immediate_ancestors + # Static Join step, so each previous step only has one task + self._previous_tasks = { + step_name: Task(prev_task_pathspec[0], _namespace_check=False) + for step_name, prev_task_pathspec in prev_task_pathspecs.items() + } + return self._previous_tasks + + @property + def previous_steps(self): + if self._previous_steps: + return self._previous_steps + self._previous_steps = self.task.metadata_dict.get("previous-steps") + return self._previous_steps + + def __getattr__(self, name): + if name not in self.previous_steps: + raise AttributeError( + f"Step '{self.task.parent.id}' does not have a previous step with name '{name}'." + ) + + input_step = SpinInput( + self.artifacts.get( + name, {} + ), # Get the artifacts corresponding to the previous step + self.previous_tasks.get( + name + ), # Get the task corresponding to the previous step + ) + setattr(self, name, input_step) + return input_step + + def __iter__(self): + for prev_step_name in self.previous_steps: + yield getattr(self, prev_step_name) + + def __len__(self): + return len(self.previous_steps) + + +class SpinInputsDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self.task = task + self.immediate_ancestors = immediate_ancestors + self.artifacts = artifacts + + self._previous_tasks = None + + def __len__(self): + return len(self.previous_tasks) + + def __getitem__(self, idx): + _item_task = self.previous_tasks[idx] + _item_artifacts = self.artifacts.get(self.previous_step, {}).get(idx, {}) + return SpinInput(_item_artifacts, _item_task) + + def __iter__(self): + for idx in range(len(self.previous_tasks)): + yield self[idx] + + @property + def previous_step(self): + return self.task.metadata_dict.get("previous-steps")[0] + + @property + def previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + # Foreach Join step so we have one previous step with multiple tasks + from metaflow import Task + + prev_task_pathspecs = list( + chain.from_iterable(self.immediate_ancestors.values()) + ) + self._previous_tasks = [ + Task(prev_task_pathspec, _namespace_check=False) + for prev_task_pathspec in prev_task_pathspecs + ] + self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) + return self._previous_tasks diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py new file mode 100644 index 00000000000..7d84fb6273e --- /dev/null +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -0,0 +1,106 @@ +from itertools import chain + + +class LinearStepDatastore(object): + def __init__(self, task, immediate_ancestors, artifacts={}): + self._task = task + self._immediate_ancestors = immediate_ancestors + self._artifacts = artifacts + self._previous_task = None + self._data = {} + + # Set them to empty dictionaries in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} + + def __contains__(self, name): + try: + _ = self.__getattr__(name) + except AttributeError: + return False + return True + + def __getitem__(self, name): + return self.__getattr__(name) + + def __setitem__(self, name, value): + self._data[name] = value + + def __getattr__(self, name): + # Check internal data first + if name in self._data: + return self._data[name] + + # We always look for any artifacts provided by the user first + if name in self._artifacts: + return self._artifacts[name] + + # If the linear step is part of a foreach step, we need to set the input attribute + # and the index attribute + if name == "input": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self._task.parent.id}` as it is not part of " + f"a foreach step." + ) + # input only exists for steps immediately after a foreach split + # we check for that by comparing the length of the foreach-step-names + # attribute of the task and its immediate ancestors + foreach_step_names = self._task.metadata_dict.get("foreach-step-names") + prev_task_foreach_step_names = self.previous_task.metadata_dict.get( + "foreach-step-names" + ) + if len(foreach_step_names) <= len(prev_task_foreach_step_names): + return None # input does not exist, so we return None + + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + foreach_var = foreach_stack[-1].var + + # Fetch the artifact corresponding to the foreach var and index from the previous task + input_val = self.previous_task[foreach_var].data[foreach_index] + setattr(self, name, input_val) + return input_val + + # If the linear step is part of a foreach step, we need to set the index attribute + if name == "index": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a " + f"foreach step." + ) + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + setattr(self, name, foreach_index) + return foreach_index + + # If the user has not provided the artifact, we look for it in the + # task using the client API + try: + return getattr(self.previous_task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @property + def previous_task(self): + if self._previous_task: + return self._previous_task + + # This is a linear step, so we only have one immediate ancestor + from metaflow import Task + + prev_task_pathspec = list( + chain.from_iterable(self._immediate_ancestors.values()) + )[0] + self._previous_task = Task(prev_task_pathspec, _namespace_check=False) + return self._previous_task + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except AttributeError: + return default diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index 888d89f4bc6..69b82c133a3 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -184,6 +184,12 @@ def __init__(self, use_cli=True): self._transition = None self._cached_input = {} + # Spin state + self._spin = False + self._spin_index = None + self._spin_input = None + self._spin_foreach_stack = None + if use_cli: with parameters.flow_context(self.__class__) as _: from . import cli @@ -459,6 +465,8 @@ def index(self) -> Optional[int]: int, optional Index of the task in a foreach step. """ + if self._spin: + return self._spin_index if self._foreach_stack: return self._foreach_stack[-1].index @@ -479,6 +487,8 @@ def input(self) -> Optional[Any]: object, optional Input passed to the foreach task. """ + if self._spin: + return self._datastore["input"] return self._find_input() def foreach_stack(self) -> Optional[List[Tuple[int, int, Any]]]: @@ -528,6 +538,8 @@ def nest_2(self): List[Tuple[int, int, Any]] An array describing the current stack of foreach steps. """ + if self._spin: + return self._spin_foreach_stack return [ (frame.index, frame.num_splits, self._find_input(stack_index=i)) for i, frame in enumerate(self._foreach_stack) diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 11c3873a85e..5f66433bf3d 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -5,6 +5,7 @@ from collections import namedtuple from itertools import chain +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.tagging_util import validate_tag from metaflow.util import get_username, resolve_identity_as_tuple, is_stringish @@ -672,6 +673,39 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) + @classmethod + def filter_tasks_by_metadata( + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str, + ) -> List[str]: + """ + Filter tasks by metadata field and value, and returns the list of task_ids + that satisfy the query. + + Parameters + ---------- + flow_id : str + Flow id, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + query_step: str + Step name to query tasks from + field_name: str + Metadata field name to query + field_value: str + Metadata field value to query + + Returns + ------- + List[str] + List of task_ids that satisfy the query + """ + raise NotImplementedError() + @staticmethod def _apply_filter(elts, filters): if filters is None: diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 8dd0b76e7cd..56be74c9592 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,14 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + + ### # User configuration ### @@ -248,8 +256,7 @@ # Default container registry DEFAULT_CONTAINER_REGISTRY = from_conf("DEFAULT_CONTAINER_REGISTRY") # Controls whether to include foreach stack information in metadata. -# TODO(Darin, 05/01/24): Remove this flag once we are confident with this feature. -INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", False) +INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", True) # Maximum length of the foreach value string to be stored in each ForeachFrame. MAXIMUM_FOREACH_VALUE_CHARS = from_conf("MAXIMUM_FOREACH_VALUE_CHARS", 30) # The default runtime limit (In seconds) of jobs launched by any compute provider. Default of 5 days. diff --git a/metaflow/metaflow_current.py b/metaflow/metaflow_current.py index 8443c1d75ab..73c9f8359d8 100644 --- a/metaflow/metaflow_current.py +++ b/metaflow/metaflow_current.py @@ -45,6 +45,7 @@ def _set_env( username=None, metadata_str=None, is_running=True, + is_spin=False, tags=None, ): if flow is not None: @@ -60,6 +61,7 @@ def _set_env( self._username = username self._metadata_str = metadata_str self._is_running = is_running + self._is_spin = is_spin self._tags = tags def _update_env(self, env): diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index ea7754cac5f..995e589daf6 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -6,6 +6,7 @@ import tempfile import time from collections import namedtuple +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.metadata_provider.metadata import ObjectOrder @@ -202,6 +203,108 @@ def _optimistically_mutate(): "Tagging failed due to too many conflicting updates from other processes" ) + @classmethod + def filter_tasks_by_metadata( + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str, + ) -> List[str]: + """ + Filter tasks by metadata field and value, returning task IDs that match criteria. + + Parameters + ---------- + flow_id : str + Identifier for the flow + run_id : str + Identifier for the run + query_step : str + Name of the step to query tasks from + field_name : str + Name of metadata field to query + field_value : str + Value to match in metadata field + + Returns + ------- + List[str] + List of task IDs that match the query criteria + + Raises + ------ + JSONDecodeError + If metadata file is corrupted or empty + FileNotFoundError + If metadata file is not found + """ + + def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: + """Find the most recent metadata file for the given field prefix.""" + # The metadata is saved as files with the format: sysmeta__.json + # and the artifact files are saved as: _artifact__.json + # We loop over all the JSON files in the directory and find the latest one + # that matches the field prefix. + json_files = glob.glob(os.path.join(path, f"{field_prefix}*.json")) + matching_files = [] + + for file_path in json_files: + filename = os.path.basename(file_path) + name, timestamp = filename.rsplit("_", 1) + timestamp = timestamp.split(".")[0] + + if name == field_prefix: + matching_files.append((file_path, int(timestamp))) + + if not matching_files: + return None + + return max(matching_files, key=lambda x: x[1]) + + def _read_metadata_value(file_path: str) -> dict: + """Read and parse metadata from JSON file.""" + try: + with open(file_path, "r") as f: + return json.load(f) + except json.JSONDecodeError: + raise json.JSONDecodeError( + "Failed to decode metadata JSON file - may be corrupted or empty" + ) + except Exception as e: + raise Exception(f"Error reading metadata file: {str(e)}") + + try: + # Get all tasks for the given step + tasks = LocalMetadataProvider.get_object( + "step", "task", {}, None, flow_id, run_id, query_step + ) + + filtered_task_ids = [] + field_name_prefix = f"sysmeta_{field_name}" + + # Filter tasks based on metadata + for task in tasks: + task_id = task.get("task_id") + + meta_path = LocalMetadataProvider._get_metadir( + flow_id, run_id, query_step, task_id + ) + latest_file = _get_latest_metadata_file(meta_path, field_name_prefix) + if not latest_file: + continue + + # Read metadata and check value + metadata = _read_metadata_value(latest_file[0]) + if metadata.get("value") == field_value: + filtered_task_ids.append(task_id) + + return filtered_task_ids + + except Exception as e: + raise Exception(f"Failed to filter tasks: {str(e)}") + @classmethod def _get_object_internal( cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 2e69026deb1..2acf95a5955 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -4,6 +4,7 @@ import requests import time +from typing import List from metaflow.exception import ( MetaflowException, MetaflowTaggingError, @@ -17,7 +18,7 @@ from metaflow.metadata_provider import MetadataProvider from metaflow.metadata_provider.heartbeat import HB_URL_KEY from metaflow.sidecar import Message, MessageTypes, Sidecar - +from urllib.parse import urlencode from metaflow.util import version_parse @@ -304,6 +305,46 @@ def _new_task( self._register_system_metadata(run_id, step_name, task["task_id"], attempt) return task["task_id"], did_create + @classmethod + def filter_tasks_by_metadata( + cls, + flow_id: str, + run_id: str, + query_step: str, + field_name: str, + field_value: str, + ) -> List[str]: + """ + Filter tasks by metadata field and value, and returns the list of task_ids + that satisfy the query. + + Parameters + ---------- + flow_id : str + Flow id, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + query_step: str + Step name to query tasks from + field_name: str + Metadata field name to query + field_value: str + Metadata field value to query + + Returns + ------- + List[str] + List of task_ids that satisfy the query + """ + query_params = { + "field_name": field_name, + "field_value": field_value, + "query_step": query_step, + } + url = ServiceMetadataProvider._obj_path(flow_id, run_id, query_step) + url = f"{url}/tasks?{urlencode(query_params)}" + return cls._request(None, url, "GET") + @staticmethod def _obj_path( flow_name, diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 4759a13b191..38a0525b0fb 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -4,7 +4,6 @@ import json from typing import Dict, Iterator, Optional, Tuple - from metaflow import Run from metaflow.plugins import get_runner_cli @@ -17,27 +16,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -48,12 +53,9 @@ def __init__( Parent runner for this run. command_obj : CommandManager CommandManager containing the subprocess executing this run. - run_obj : Run - Run object corresponding to this run. """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -72,11 +74,10 @@ async def wait( Parameters ---------- - timeout : float, optional, default None - The maximum time, in seconds, to wait for the run to finish. - If the timeout is reached, the run is terminated. If not specified, wait - forever. - stream : str, optional, default None + timeout : Optional[float], default None + The maximum time to wait for the run to finish. + If the timeout is reached, the run is terminated + stream : Optional[str], default None If specified, the specified stream is printed to stdout. `stream` can be one of `stdout` or `stderr`. @@ -173,7 +174,7 @@ async def stream_log( ---------- stream : str The stream to stream logs from. Can be one of `stdout` or `stderr`. - position : int, optional, default None + position : Optional[int], default None The position in the log file to start streaming from. If None, it starts from the beginning of the log file. This allows resuming streaming from a previously known position @@ -189,6 +190,83 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -322,6 +400,23 @@ def __get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + async def __async_get_executing_run(self, attribute_file_fd, command_obj): content = await async_handle_timeout( attribute_file_fd, command_obj, self.file_read_timeout @@ -337,6 +432,23 @@ async def __async_get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + def run(self, **kwargs) -> ExecutingRun: """ Blocking execution of the run. This method will wait until @@ -399,6 +511,44 @@ def resume(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def spin(self, step_name, task_pathspec, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + + Parameters + ---------- + step_name : str + The name of the step to spin. + task_pathspec : str, optional, default None + The task pathspec to be used in the spun task. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + task_pathspec=task_pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + async def async_run(self, **kwargs) -> ExecutingRun: """ Non-blocking execution of the run. This method will return as soon as the diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 7e9269841fb..7623bafd103 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -19,12 +19,15 @@ from functools import partial from concurrent import futures + from metaflow.datastore.exceptions import DataException +from itertools import chain from contextlib import contextmanager from . import get_namespace from .metadata_provider import MetaDatum from .metaflow_config import MAX_ATTEMPTS, UI_URL +from .metaflow_config import SPIN_ALLOWED_DECORATORS from .exception import ( MetaflowException, MetaflowInternalError, @@ -73,6 +76,273 @@ # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + task_pathspec, + skip_decorators=False, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + self._step_func = step_func + self._task_pathspec = task_pathspec + self._task = Task(self._task_pathspec, _namespace_check=False) + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # Create a new run_id for the spin task + self.run_id = self._metadata.new_run_id() + for deco in self.whitelist_decorators: + print("-" * 100) + deco.runtime_init(flow, graph, package, self.run_id) + + @property + def split_index(self): + if self._split_index: + return self._split_index + foreach_indices = self._task.metadata_dict.get("foreach-indices", []) + self._split_index = foreach_indices[-1] if foreach_indices else None + return self._split_index + + @property + def input_paths(self): + def _format_input_paths(task_id): + _, run_id, step_name, task_id = task_id.split("/") + return f"{run_id}/{step_name}/{task_id}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._task_pathspec.split("/") + + step = Step(f"{flow_name}/{run_id}/_parameters", _namespace_check=False) + task = next(iter(step.tasks()), None) + if not task: + raise MetaflowException( + f"Task not found for {step} in the metadata store" + ) + self._input_paths = [f"{run_id}/_parameters/{task.id}"] + else: + ancestors = self._task.immediate_ancestors + self._input_paths = [ + _format_input_paths(ancestor) + for i, ancestor in enumerate(chain.from_iterable(ancestors.values())) + ] + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + return [] + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + self._flow_datastore, + self._flow, + step, + self.run_id, + self._metadata, + self._environment, + self._entrypoint, + self._event_logger, + self._monitor, + input_paths=self.input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + + self.task = self._new_task(self._step_func.name, {}) + _ds = self._flow_datastore.get_task_datastore( + self.run_id, + self._step_func.name, + self.task.task_id, + attempt=0, + mode="w", + ) + + for deco in self.whitelist_decorators: + deco.runtime_task_created( + _ds, + self.task.task_id, + self.split_index, + self.input_paths, + is_cloned=False, + ubf_context=None, + ) + + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + + def _launch_and_monitor_task(self): + args = CLIArgs(self.task, spin=True, prev_task_pathspec=self._task_pathspec) + env = dict(os.environ) + + for deco in self.task.decos: + deco.runtime_step_cli( + args, + self.task.retries, + self.task.user_code_retries, + self.task.ubf_context, + ) + + # Add user configurations using a file to avoid using up too much space on the + # command line + if self._config_file_name: + args.top_level_options["local-config-file"] = self._config_file_name + + # Add the skip-decorators flag to the command options + args.command_options.update({"skip-decorators": self._skip_decorators}) + + env.update(args.get_env()) + env["PYTHONUNBUFFERED"] = "x" + + stdout_buffer = TruncatedBuffer("stdout", self._max_log_size) + stderr_buffer = TruncatedBuffer("stderr", self._max_log_size) + + cmdline = args.get_args() + self._logger(f"Launching command: {' '.join(cmdline)}", system_msg=True) + + try: + process = subprocess.Popen( + cmdline, + env=env, + bufsize=1, # Line buffering + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + except Exception as e: + raise TaskFailed(self.task, f"Failed to launch task: {str(e)}") + + poll = procpoll.make_poll() + poll.add(process.stdout.fileno()) + poll.add(process.stderr.fileno()) + + # Map file descriptors to their respective streams and buffers + fd_map = { + process.stdout.fileno(): (process.stdout, stdout_buffer, False), + process.stderr.fileno(): (process.stderr, stderr_buffer, True), + } + + while True: + # Poll for events with a timeout + events = poll.poll(POLL_TIMEOUT) + + if not events: + if process.poll() is not None: + break + continue + + for event in events: + if event.can_read: + stream, buffer, is_stderr = fd_map[event.fd] + line = stream.readline() + if line: + self._process_output(line, buffer, is_stderr) + + if event.is_terminated: + poll.remove(event.fd) + + if process.poll() is not None: + break + + # Process any remaining output + for stream, buffer, is_stderr in fd_map.values(): + for line in stream: + self._process_output(line, buffer, is_stderr) + + returncode = process.wait() + + self.task.save_metadata( + "runtime", + { + "return_code": returncode, + "success": returncode == 0, + }, + ) + + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + self.task.save_logs( + { + "stdout": stdout_buffer.get_buffer(), + "stderr": stderr_buffer.get_buffer(), + } + ) + + def _process_output(self, line, buffer, is_stderr=False): + buffer.write(line.encode(self._encoding)) + text = line.strip() + self.task.log( + text, + system_msg=False, + timestamp=datetime.now(), + ) + + class NativeRuntime(object): def __init__( self, @@ -1508,11 +1778,13 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__(self, task, spin=False, prev_task_pathspec=None): self.task = task + self.spin = spin + self.prev_task_pathspec = prev_task_pathspec self.entrypoint = list(task.entrypoint) self.top_level_options = { - "quiet": True, + "quiet": True if spin else False, "metadata": self.task.metadata_type, "environment": self.task.environment_type, "datastore": self.task.datastore_type, @@ -1542,18 +1814,40 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, + "namespace": get_namespace() or "", + "ubf-context": self.task.ubf_context, + } + self.env = {} + + def spin_args(self): + self.commands = ["spin-internal"] + self.command_args = [self.task.step] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "task-pathspec": self.prev_task_pathspec, + "input-paths": self.task.input_paths, + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, } self.env = {} diff --git a/metaflow/task.py b/metaflow/task.py index 6b73302652b..06434d906e2 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -4,9 +4,12 @@ import sys import os import time +import json +import hashlib import traceback from types import MethodType, FunctionType +from itertools import chain from metaflow.sidecar import Message, MessageTypes from metaflow.datastore.exceptions import DataException @@ -14,7 +17,13 @@ from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum from .mflog import TASK_LOG_SOURCE -from .datastore import Inputs, TaskDataStoreSet +from .datastore import ( + Inputs, + TaskDataStoreSet, + LinearStepDatastore, + SpinInputsDatastore, + StaticSpinInputsDatastore, +) from .exception import ( MetaflowInternalError, MetaflowDataMissing, @@ -27,6 +36,7 @@ from metaflow.system import _system_logger, _system_monitor from metaflow.tracing import get_trace_id from metaflow.tuple_util import ForeachFrame +from .metaflow_config import SPIN_ALLOWED_DECORATORS # Maximum number of characters of the foreach path that we store in the metadata. MAX_FOREACH_PATH_LENGTH = 256 @@ -37,6 +47,22 @@ class MetaflowTask(object): MetaflowTask prepares a Flow instance for execution of a single step. """ + @staticmethod + def _dynamic_runtime_metadata(foreach_stack): + foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack] + foreach_indices_truncated = foreach_indices[:-1] + foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack] + return foreach_indices, foreach_indices_truncated, foreach_step_names + + def _static_runtime_metadata(self, graph_info, step_name): + prev_steps = [ + node_name + for node_name, attributes in graph_info["steps"].items() + if step_name in attributes["next"] + ] + succesor_steps = graph_info["steps"][step_name]["next"] + return prev_steps, succesor_steps + def __init__( self, flow, @@ -372,6 +398,176 @@ def _finalize_control_task(self): ) ) + def run_spin_step( + self, + step_name, + task_pathspec, + new_run_id, + new_task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + namespace, + skip_decorators, + ): + def is_join_step(immediate_ancestors): + prev_task_pathspecs = set(chain.from_iterable(immediate_ancestors.values())) + return len(prev_task_pathspecs) > 1 + + step_func = getattr(self.flow, step_name) + whitelisted_decorators = ( + [] + if skip_decorators + else [ + deco + for deco in step_func.decorators + if any( + deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS + ) + ] + ) + + # initialize output datastore + output = self.flow_datastore.get_task_datastore( + new_run_id, step_name, new_task_id, 0, mode="w" + ) + + output.init_task() + + # How we access the input and index attributes depends on the execution context. + # If spin is set to True, we short-circuit attribute access to getattr directly + # Also set the other attributes that are needed for the task to execute + from metaflow import Task + + self.task = Task(task_pathspec, _namespace_check=False) + immediate_ancestors = self.task.immediate_ancestors + self.flow._spin = True + self.flow._spin_foreach_stack = self.task["_foreach_stack"].data + self.flow._spin_index = split_index + self.flow._current_step = step_name + self.flow._success = False + self.flow._task_ok = None + self.flow._exception = None + + # Set inputs + inp_datastore = None + is_join = is_join_step(immediate_ancestors) + if is_join: + # Join step + if len(self.task.metadata_dict.get("previous-steps")) > 1: + # Static join step + inp_datastore = StaticSpinInputsDatastore( + self.task, immediate_ancestors, artifacts={} + ) + else: + # Foreach join step + inp_datastore = SpinInputsDatastore( + self.task, immediate_ancestors, artifacts={} + ) + self.flow._set_datastore(output) + else: + # Linear step + self.flow._set_datastore( + LinearStepDatastore(self.task, immediate_ancestors, artifacts={}) + ) + + current._set_env( + flow=self.flow, + run_id=new_run_id, + step_name=step_name, + task_id=new_task_id, + retry_count=retry_count, + namespace=resolve_identity(), + username=get_username(), + metadata_str="%s@%s" + % (self.metadata.__class__.TYPE, self.metadata.__class__.INFO), + is_running=True, + is_spin=True, + ) + + # task_pre_step decorator hooks + for deco in whitelisted_decorators: + deco.task_pre_step( + step_name=step_name, + task_datastore=output, + metadata=self.metadata, + run_id=new_run_id, + task_id=new_task_id, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + inputs=inp_datastore, + ) + + # task_decorate decorator hooks + for deco in whitelisted_decorators: + step_func = deco.task_decorate( + step_func=step_func, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + ) + + # Execute the step function + try: + if is_join: + # Join step + self._exec_step_function(step_func, input_obj=inp_datastore) + else: + self._exec_step_function(step_func) + + # task_post_step decorator hooks + for deco in whitelisted_decorators: + deco.task_post_step( + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + + self.flow._task_ok = True + self.flow._success = True + except Exception as ex: + exception_handled = False + for deco in whitelisted_decorators: + res = deco.task_exception( + ex, + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + exception_handled = bool(res) or exception_handled + + if exception_handled: + self.flow._task_ok = True + else: + self.flow._task_ok = False + self.flow._exception = MetaflowExceptionWrapper(ex) + print("%s failed:" % self.flow, file=sys.stderr) + raise + finally: + output.persist(self.flow) + output.done() + + # task_finish decorator hooks + for deco in whitelisted_decorators: + deco.task_finished( + step_name, + self.flow, + self.flow._graph, + self.flow._task_ok, + retry_count, + max_user_code_retries, + ) + def run_step( self, step_name, @@ -493,6 +689,36 @@ def run_step( ) ) + # Add runtime dag info - for a nested foreach this may look like: + # foreach_indices: [0, 1] + # foreach_indices_truncated: [0] + # foreach_step_names: ['step1', 'step2'] + foreach_indices, foreach_indices_truncated, foreach_step_names = ( + self._dynamic_runtime_metadata(foreach_stack) + ) + metadata.extend( + [ + MetaDatum( + field="foreach-indices", + value=foreach_indices, + type="foreach-indices", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-indices-truncated", + value=foreach_indices_truncated, + type="foreach-indices-truncated", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-step-names", + value=foreach_step_names, + type="foreach-step-names", + tags=metadata_tags, + ), + ] + ) + self.metadata.register_metadata( run_id, step_name, @@ -559,6 +785,7 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of @@ -616,7 +843,9 @@ def run_step( "graph_info": self.flow._graph_info, } ) - + previous_steps, successor_steps = self._static_runtime_metadata( + self.flow._graph_info, step_name + ) for deco in decorators: deco.task_pre_step( step_name, @@ -727,8 +956,20 @@ def run_step( field="attempt_ok", value=attempt_ok, type="internal_attempt_status", - tags=["attempt_id:{0}".format(retry_count)], - ) + tags=metadata_tags, + ), + MetaDatum( + field="previous-steps", + value=previous_steps, + type="previous-steps", + tags=metadata_tags, + ), + MetaDatum( + field="successor-steps", + value=successor_steps, + type="successor-steps", + tags=metadata_tags, + ), ], ) diff --git a/metaflow/util.py b/metaflow/util.py index cd3447d0e48..e9355e31ae1 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,6 +9,7 @@ from itertools import takewhile import re +from typing import Callable from metaflow.exception import MetaflowUnknownUser, MetaflowInternalError try: @@ -193,6 +194,49 @@ def get_latest_run_id(echo, flow_name): return None +def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: + """ + Returns a task pathspec from the latest run of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + + Returns + ------- + str + The task pathspec of the first task of the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + run = Flow(flow_name, _namespace_check=False).latest_run + + if run is None: + raise MetaflowNotFound(f"No run found for the flow {flow_name}") + + try: + step = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False) + except Exception: + raise MetaflowNotFound( + f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" + ) + + task = next(iter(step.tasks()), None) + if task: + return f"{flow_name}/{run.id}/{step_name}/{task.id}" + raise MetaflowNotFound(f"No task found for the queried step {query_step}") + + def write_latest_run_id(obj, run_id): from metaflow.plugins.datastores.local_storage import LocalStorage diff --git a/test/core/contexts.json b/test/core/contexts.json index 3c40ce3e607..9d1e264ca47 100644 --- a/test/core/contexts.json +++ b/test/core/contexts.json @@ -24,7 +24,7 @@ "--tag=\u523a\u8eab means sashimi", "--tag=multiple tags should be ok" ], - "checks": [ "python3-cli", "python3-metadata"], + "checks": ["python3-cli", "python3-metadata"], "disabled_tests": [ "LargeArtifactTest", "S3FailureTest", diff --git a/test/core/tests/card_default_editable_with_id.py b/test/core/tests/card_default_editable_with_id.py index c46c52248ab..483e156ce36 100644 --- a/test/core/tests/card_default_editable_with_id.py +++ b/test/core/tests/card_default_editable_with_id.py @@ -47,7 +47,6 @@ def check_results(self, flow, checker): cli_check_dict = checker.artifact_dict(step.name, "random_number") for task_pathspec in cli_check_dict: - task_id = task_pathspec.split("/")[-1] cards_info = checker.list_cards(step.name, task_id) number = cli_check_dict[task_pathspec]["random_number"] diff --git a/test/core/tests/catch_retry.py b/test/core/tests/catch_retry.py index 6d9231da460..a13bdf33750 100644 --- a/test/core/tests/catch_retry.py +++ b/test/core/tests/catch_retry.py @@ -60,7 +60,6 @@ def step_all(self): raise TestRetry() def check_results(self, flow, checker): - checker.assert_log( "start", "stdout", "stdout testing logs 3\n", exact_match=False ) @@ -69,7 +68,6 @@ def check_results(self, flow, checker): ) for step in flow: - if step.name == "start": checker.assert_artifact("start", "test_attempt", 3) try: diff --git a/test/core/tests/client_ancestors.py b/test/core/tests/client_ancestors.py new file mode 100644 index 00000000000..0f875185c72 --- /dev/null +++ b/test/core/tests/client_ancestors.py @@ -0,0 +1,77 @@ +from metaflow_test import MetaflowTest, ExpectationFailed, steps + + +class ImmediateAncestorTest(MetaflowTest): + """ + Test that immediate_ancestors API returns correct parent tasks + by comparing with parent task ids stored during execution. + """ + + PRIORITY = 1 + + @steps(0, ["start"]) + def step_start(self): + from metaflow import current + + self.step_name = current.step_name + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + self.parent_pathspecs = set() + + @steps(1, ["join"]) + def step_join(self): + from metaflow import current + + self.step_name = current.step_name + + # Store the parent task ids + # Store the task pathspec for all the parent tasks + self.parent_pathspecs = set(inp.task_pathspec for inp in inputs) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + @steps(2, ["all"]) + def step_all(self): + from metaflow import current + + self.step_name = current.step_name + # Store the parent task ids + # Task only has one parent, so we store the parent task id + self.parent_pathspecs = set([self.task_pathspec]) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + def check_results(self, flow, checker): + from itertools import chain + + run = checker.get_run() + + if run is None: + print("Run is None") + # very basic sanity check for CLI checker + for step in flow: + checker.assert_artifact(step.name, "step_name", step.name) + return + + # For each step in the flow + for step in run: + # For each task in the step + for task in step: + ancestors = task.immediate_ancestors + ancestor_pathspecs = set(chain.from_iterable(ancestors.values())) + + # Compare with stored parent_task_pathspecs + task_pathspec = task.data.task_pathspec + assert ancestor_pathspecs == task.data.parent_pathspecs, ( + f"Mismatch in ancestor task ids for task {task_pathspec}: Expected {task.data.parent_pathspecs}, " + f"got {ancestor_pathspecs}" + ) diff --git a/test/core/tests/client_successors.py b/test/core/tests/client_successors.py new file mode 100644 index 00000000000..c349192c2ee --- /dev/null +++ b/test/core/tests/client_successors.py @@ -0,0 +1,105 @@ +from metaflow_test import MetaflowTest, ExpectationFailed, steps + + +class ImmediateSuccessorTest(MetaflowTest): + """ + Test that immediate_successors API returns correct successor tasks + by comparing with parent task ids stored during execution. + """ + + PRIORITY = 1 + + @steps(0, ["start"]) + def step_start(self): + from metaflow import current + + self.step_name = current.step_name + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + self.parent_pathspecs = set() + + @steps(1, ["join"]) + def step_join(self): + from metaflow import current + + self.step_name = current.step_name + + # Store the parent task ids + # Store the task pathspec for all the parent tasks + self.parent_pathspecs = set(inp.task_pathspec for inp in inputs) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + @steps(2, ["all"]) + def step_all(self): + from metaflow import current + + self.step_name = current.step_name + # Store the parent task ids + # Task only has one parent, so we store the parent task id + self.parent_pathspecs = set([self.task_pathspec]) + + # Set the current task id + self.task_pathspec = f"{current.flow_name}/{current.run_id}/{current.step_name}/{current.task_id}" + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + def check_results(self, flow, checker): + from metaflow import Task + from itertools import chain + + run = checker.get_run() + + if run is None: + print("Run is None") + # very basic sanity check for CLI checker + for step in flow: + checker.assert_artifact(step.name, "step_name", step.name) + return + + # For each step in the flow + for step in run: + # For each task in the step + for task in step: + cur_task_pathspec = task.data.task_pathspec + successors = task.immediate_successors + actual_successors_pathspecs_set = set( + chain.from_iterable(successors.values()) + ) + expected_successor_pathspecs_set = set() + for successor_step_name, successor_pathspecs in successors.items(): + # Assert that the current task is in the parent_pathspecs of the successor tasks + for successor_pathspec in successor_pathspecs: + successor_task = Task( + successor_pathspec, _namespace_check=False + ) + print(f"Successor task: {successor_task}") + assert ( + task.data.task_pathspec + in successor_task.data.parent_pathspecs + ), ( + f"Task {task.data.task_pathspec} is not in the parent_pathspecs of the successor task " + f"{successor_task.data.task_pathspec}" + ) + + successor_step = run[successor_step_name] + for successor_task in successor_step: + if cur_task_pathspec in successor_task.data.parent_pathspecs: + expected_successor_pathspecs_set.add( + successor_task.data.task_pathspec + ) + + # Assert that None of the tasks in the successor steps have the current task in their + # parent_pathspecs + assert ( + actual_successors_pathspecs_set == expected_successor_pathspecs_set + ), ( + f"Expected successor pathspecs: {expected_successor_pathspecs_set}, got " + f"{actual_successors_pathspecs_set}" + ) diff --git a/test/core/tests/tag_catch.py b/test/core/tests/tag_catch.py index a8efb919560..5cdffe3fcd6 100644 --- a/test/core/tests/tag_catch.py +++ b/test/core/tests/tag_catch.py @@ -62,7 +62,6 @@ def step_all(self): os.kill(os.getpid(), signal.SIGKILL) def check_results(self, flow, checker): - checker.assert_log( "start", "stdout", "stdout testing logs 3\n", exact_match=False ) @@ -71,7 +70,6 @@ def check_results(self, flow, checker): ) for step in flow: - if step.name == "start": checker.assert_artifact("start", "test_attempt", 3) try: