Skip to content

Commit

Permalink
Tests for runner (#1859)
Browse files Browse the repository at this point in the history
* synchronous run with resume and async resume

* add nbrun

* remove runtime error if not in jupyter env

* fix env vars update

* suggested changes to nbrun

* tests using lower level chaining API

* use higher level API

* add runner import inside try

* set the correct metadata while getting the Run object

* pass USER env in tox

* reset old env vars when returning a Run object

* change name of metadata_pathspec_file to runner_attribute_file

* add comments for why reset of env and metadata is needed

* minor fix

* Update test.yml

* formatting

* fix

---------

Co-authored-by: Madhur Tandon <[email protected]>
  • Loading branch information
savingoyal and madhur-ob authored May 25, 2024
1 parent 125a018 commit 8cd344c
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 152 deletions.
72 changes: 36 additions & 36 deletions metaflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,16 @@
import traceback
from datetime import datetime
from functools import wraps
import metaflow.tracing as tracing

import metaflow.tracing as tracing
from metaflow._vendor import click

from . import lint
from . import plugins
from . import parameters
from . import decorators
from . import metaflow_version
from . import namespace
from .metaflow_current import current
from . import decorators, lint, metaflow_version, namespace, parameters, plugins
from .cli_args import cli_args
from .tagging_util import validate_tags
from .util import (
resolve_identity,
decompress_list,
write_latest_run_id,
get_latest_run_id,
)
from .task import MetaflowTask
from .client.core import get_metadata
from .datastore import FlowDataStore, TaskDataStore, TaskDataStoreSet
from .exception import CommandException, MetaflowException
from .graph import FlowGraph
from .datastore import FlowDataStore, TaskDataStoreSet, TaskDataStore

from .runtime import NativeRuntime
from .package import MetaflowPackage
from .plugins import (
DATASTORES,
ENVIRONMENTS,
LOGGING_SIDECARS,
METADATA_PROVIDERS,
MONITOR_SIDECARS,
)
from .metaflow_config import (
DEFAULT_DATASTORE,
DEFAULT_ENVIRONMENT,
Expand All @@ -44,12 +21,29 @@
DEFAULT_MONITOR,
DEFAULT_PACKAGE_SUFFIXES,
)
from .metaflow_current import current
from .metaflow_environment import MetaflowEnvironment
from .mflog import LOG_SOURCES, mflog
from .package import MetaflowPackage
from .plugins import (
DATASTORES,
ENVIRONMENTS,
LOGGING_SIDECARS,
METADATA_PROVIDERS,
MONITOR_SIDECARS,
)
from .pylint_wrapper import PyLint
from .R import use_r, metaflow_r_version
from .mflog import mflog, LOG_SOURCES
from .R import metaflow_r_version, use_r
from .runtime import NativeRuntime
from .tagging_util import validate_tags
from .task import MetaflowTask
from .unbounded_foreach import UBF_CONTROL, UBF_TASK

from .util import (
decompress_list,
get_latest_run_id,
resolve_identity,
write_latest_run_id,
)

ERASE_TO_EOL = "\033[K"
HIGHLIGHT = "red"
Expand Down Expand Up @@ -558,11 +552,11 @@ def common_run_options(func):
help="Write the ID of this run to the file specified.",
)
@click.option(
"--pathspec-file",
"--runner-attribute-file",
default=None,
show_default=True,
type=str,
help="Write the pathspec of this run to the file specified.",
help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.",
)
@wraps(func)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -622,7 +616,7 @@ def resume(
decospecs=None,
run_id_file=None,
resume_identifier=None,
pathspec_file=None,
runner_attribute_file=None,
):
before_run(obj, tags, decospecs + obj.environment.decospecs())

Expand Down Expand Up @@ -679,10 +673,13 @@ def resume(
resume_identifier=resume_identifier,
)
write_file(run_id_file, runtime.run_id)
write_file(pathspec_file, "/".join((obj.flow.name, runtime.run_id)))
runtime.print_workflow_info()

runtime.persist_constants()
write_file(
runner_attribute_file,
"%s:%s" % (get_metadata(), "/".join((obj.flow.name, runtime.run_id))),
)
if clone_only:
runtime.clone_original_run()
else:
Expand Down Expand Up @@ -713,7 +710,7 @@ def run(
max_log_size=None,
decospecs=None,
run_id_file=None,
pathspec_file=None,
runner_attribute_file=None,
user_namespace=None,
**kwargs
):
Expand All @@ -738,11 +735,14 @@ def run(
)
write_latest_run_id(obj, runtime.run_id)
write_file(run_id_file, runtime.run_id)
write_file(pathspec_file, "/".join((obj.flow.name, runtime.run_id)))

obj.flow._set_constants(obj.graph, kwargs)
runtime.print_workflow_info()
runtime.persist_constants()
write_file(
runner_attribute_file,
"%s:%s" % (get_metadata(), "/".join((obj.flow.name, runtime.run_id))),
)
runtime.execute()


Expand Down
40 changes: 18 additions & 22 deletions metaflow/runner/click_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,34 @@
"""
)

import inspect
import datetime
import importlib
import inspect
import itertools
import uuid
from collections import OrderedDict
import uuid, datetime
from typing import (
Optional,
List,
OrderedDict as TOrderedDict,
Any,
Union,
Dict,
Callable,
)
from typing import Any, Callable, Dict, List, Optional
from typing import OrderedDict as TOrderedDict
from typing import Union

from metaflow import FlowSpec, Parameter
from metaflow.cli import start
from metaflow._vendor import click
from metaflow.parameters import JSONTypeClass
from metaflow.includefile import FilePathClass
from metaflow._vendor.typeguard import check_type, TypeCheckError
from metaflow._vendor.click.types import (
StringParamType,
IntParamType,
FloatParamType,
BoolParamType,
UUIDParameterType,
Path,
DateTime,
Tuple,
Choice,
DateTime,
File,
FloatParamType,
IntParamType,
Path,
StringParamType,
Tuple,
UUIDParameterType,
)
from metaflow._vendor.typeguard import TypeCheckError, check_type
from metaflow.cli import start
from metaflow.includefile import FilePathClass
from metaflow.parameters import JSONTypeClass

click_to_python_types = {
StringParamType: str,
Expand Down
70 changes: 50 additions & 20 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
import sys
import time
import tempfile
import time
from typing import Dict, Iterator, Optional, Tuple
from metaflow import Run
from .subprocess_manager import SubprocessManager, CommandManager

from metaflow import Run, metadata

from .subprocess_manager import CommandManager, SubprocessManager


def clear_and_set_os_environ(env: Dict):
os.environ.clear()
os.environ.update(env)


def read_from_file_when_ready(file_path: str, timeout: float = 5):
Expand Down Expand Up @@ -227,7 +234,8 @@ def __init__(
from metaflow.runner.click_api import MetaflowAPI

self.flow_file = flow_file
self.env_vars = os.environ.copy()
self.old_env = os.environ.copy()
self.env_vars = self.old_env.copy()
self.env_vars.update(env or {})
if profile:
self.env_vars["METAFLOW_PROFILE"] = profile
Expand All @@ -241,9 +249,21 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __get_executing_run(self, tfp_pathspec, command_obj):
def __get_executing_run(self, tfp_runner_attribute, command_obj):
# When two 'Runner' executions are done sequentially i.e. one after the other
# the 2nd run kinda uses the 1st run's previously set metadata and
# environment variables.

# It is thus necessary to set them to correct values before we return
# the Run object.
try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=10)
# Set the environment variables to what they were before the run executed.
clear_and_set_os_environ(self.old_env)

# Set the correct metadata from the runner_attribute file corresponding to this run.
content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=10)
metadata_for_flow, pathspec = content.split(":", maxsplit=1)
metadata(metadata_for_flow)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
Expand Down Expand Up @@ -280,17 +300,19 @@ def run(self, show_output: bool = False, **kwargs) -> ExecutingRun:
ExecutingRun object for this run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
runner_attribute_file=tfp_runner_attribute.name, **kwargs
)

pid = self.spm.run_command(
[sys.executable, *command], env=self.env_vars, show_output=show_output
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)
return self.__get_executing_run(tfp_runner_attribute, command_obj)

def resume(self, show_output: bool = False, **kwargs):
"""
Expand All @@ -315,17 +337,19 @@ def resume(self, show_output: bool = False, **kwargs):
ExecutingRun object for this resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
runner_attribute_file=tfp_runner_attribute.name, **kwargs
)

pid = self.spm.run_command(
[sys.executable, *command], env=self.env_vars, show_output=show_output
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)
return self.__get_executing_run(tfp_runner_attribute, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -344,17 +368,20 @@ async def async_run(self, **kwargs) -> ExecutingRun:
ExecutingRun object for this run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
runner_attribute_file=tfp_runner_attribute.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
[sys.executable, *command],
env=self.env_vars,
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)
return self.__get_executing_run(tfp_runner_attribute, command_obj)

async def async_resume(self, **kwargs):
"""
Expand All @@ -373,17 +400,20 @@ async def async_resume(self, **kwargs):
ExecutingRun object for this resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
runner_attribute_file=tfp_runner_attribute.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
[sys.executable, *command],
env=self.env_vars,
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)
return self.__get_executing_run(tfp_runner_attribute, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down
9 changes: 5 additions & 4 deletions metaflow/runner/nbrun.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import ast
import os
import shutil
import tempfile
from typing import Optional, Dict
from typing import Dict, Optional

from metaflow import Runner

try:
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
profile: Optional[str] = None,
env: Optional[Dict] = None,
base_dir: Optional[str] = None,
**kwargs
**kwargs,
):
self.cell = get_current_cell()
self.flow = flow
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(
flow_file=self.tmp_flow_file.name,
profile=profile,
env=self.env_vars,
**kwargs
**kwargs,
)

def nbrun(self, **kwargs):
Expand Down
12 changes: 6 additions & 6 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
import os
import sys
import time
import signal
import shutil
import asyncio
import signal
import subprocess
import sys
import tempfile
import threading
import subprocess
from typing import List, Dict, Optional, Callable, Iterator, Tuple
import time
from typing import Callable, Dict, Iterator, List, Optional, Tuple


def kill_process_and_descendants(pid, termination_timeout):
Expand Down
Loading

0 comments on commit 8cd344c

Please sign in to comment.