Skip to content

Commit

Permalink
Merge pull request #113 from ihmeuw-msca/feat/run_subset_of_stages
Browse files Browse the repository at this point in the history
Evaluating a subset of OneMod Pipeline Stages
  • Loading branch information
blsmxiu47 authored Nov 12, 2024
2 parents b9df575 + 56712b9 commit d6bd6ae
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 289 deletions.
21 changes: 10 additions & 11 deletions examples/pipeline_example.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Example OneMod pipeline."""

import fire
from custom_stage import CustomStage

from onemod import Pipeline
from onemod.stage import PreprocessingStage, RoverStage, SpxmodStage, KregStage

from custom_stage import CustomStage
from onemod.stage import KregStage, PreprocessingStage, RoverStage, SpxmodStage


def create_pipeline(directory: str, data: str):
# Create stages
# TODO: Does stage-specific validation info go here or in class definitions?
# Stage-specific validation specifications go here.
# Stage classes may also implement default validation specifications.
preprocessing = PreprocessingStage(name="preprocessing", config={})
covariate_selection = RoverStage(
name="covariate_selection",
Expand Down Expand Up @@ -100,18 +100,17 @@ def create_pipeline(directory: str, data: str):
# Serialize pipeline
example_pipeline.to_json()

# TODO: Validate and serialize
# User could call this method themself, but run/fit/predict should
# probably also call it in case updates have been made to the
# User could call this method themself, but evaluate() also
# calls it in case updates have been made to the
# pipeline (e.g., someone is experimenting with a pipeline in a
# a notebook)
# example_pipeline.build()
example_pipeline.build()

# Run (fit and predict) entire pipeline
# example_pipeline.run()
example_pipeline.evaluate(method="run")

# TODO: Fit specific stages
# example_pipeline.fit(stages=["preprocessing", "covariate_selection"])
# Fit specific stages
example_pipeline.fit(stages=["preprocessing", "covariate_selection"])

# TODO: Predict for specific locations
# example_pipeline.predict(id_subsets={"location_id": [1, 2, 3]})
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ markers =
integration: Integration tests.
unit: Unit tests.
requires_data: Tests that require external datasets and are excluded from CI.
requires_jobmon: Tests that require jobmon and are excluded from CI.
29 changes: 26 additions & 3 deletions src/onemod/backend/jobmon_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,18 @@ def get_upstream_tasks(
method: Literal["run", "fit", "predict"],
stages: dict[str, Stage],
task_dict: dict[str, list[Task]],
specified_stages: set[str] | None = None,
) -> list[Task]:
"""Get upstream stage tasks."""
upstream_tasks = []

for upstream_name in stage.dependencies:
if (
specified_stages is not None
and upstream_name not in specified_stages
):
continue

upstream = stages[upstream_name]
if method not in upstream.skip:
if (
Expand All @@ -182,6 +190,7 @@ def get_upstream_tasks(
upstream_tasks.append(task_dict[upstream_name][-1])
else:
upstream_tasks.extend(task_dict[upstream_name])

return upstream_tasks


Expand All @@ -192,6 +201,7 @@ def evaluate_with_jobmon(
resources: Path | str,
python: Path | str | None = None,
method: Literal["run", "fit", "predict", "collect"] = "run",
stages: list[str] | None = None,
) -> None:
"""Evaluate pipeline or stage method with Jobmon.
Expand All @@ -208,6 +218,8 @@ def evaluate_with_jobmon(
Default is None.
method : str, optional
Name of method to evalaute. Default is 'run'.
stages : set of str or None, optional
Set of stage names to evaluate. Default is None.
TODO: Optional stage-specific Python environments
TODO: User-defined max_attempts
Expand All @@ -217,25 +229,36 @@ def evaluate_with_jobmon(
# Get tool
tool = get_tool(model.name, cluster, resources)

# Create tasks
# Set config
if isinstance(model, Stage):
model_config = model.dataif.config
elif isinstance(model, Pipeline):
model_config = model.config

task_args: dict[str, str] = {
"python": str(python or sys.executable),
"config": str(model_config),
}

# Create tasks
if isinstance(model, Pipeline):
tasks = []
task_dict: dict[str, list[Task]] = {}
for stage_name in model.get_execution_order():

if stages is None:
stages = model.get_execution_order()

for stage_name in stages:
stage = model.stages[stage_name]
if (
method not in stage.skip and method != "collect"
): # TODO: handle collect
upstream_tasks = get_upstream_tasks(
stage, method, model.stages, task_dict
stage,
method,
model.stages,
task_dict,
specified_stages=set(stages),
)
task_dict[stage_name] = get_tasks(
tool, resources, stage, method, task_args, upstream_tasks
Expand Down
8 changes: 7 additions & 1 deletion src/onemod/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def evaluate_local(
model: Pipeline | Stage,
method: Literal["run", "fit", "predict"] = "run",
stages: list[str] | None = None,
**kwargs,
) -> None:
"""Evaluate pipeline or stage method locally.
Expand All @@ -30,10 +31,15 @@ def evaluate_local(
Submodel data subset ID. Only used for model stages.
param_id : int, optional
Submodel parameter set ID. Only used for model stages.
stages : list of str or None, optional
List of stage names to evaluate. Default is None.
"""
if isinstance(model, Pipeline):
for stage_name in model.get_execution_order():
if stages is None:
stages = model.get_execution_order()

for stage_name in stages:
stage = model.stages[stage_name]
if method not in stage.skip:
_evaluate_stage(stage, method)
Expand Down
50 changes: 47 additions & 3 deletions src/onemod/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def add_stage(self, stage: Stage) -> None:
stage.config.inherit(self.config)
self._stages[stage.name] = stage

def check_upstream_outputs_exist(
self, stage_name: str, upstream_name: str
) -> bool:
"""Check if outputs from the specified upstream dependency exist for the inputs of a given stage."""
stage = self.stages[stage_name]

for input_name, input_data in stage.input.items.items():
if input_data.stage == upstream_name:
upstream_output_path = input_data.path
if not upstream_output_path.exists():
return False
return True

def get_execution_order(self) -> list[str]:
"""
Return topologically sorted order of stages, ensuring no cycles.
Expand Down Expand Up @@ -262,6 +275,7 @@ def evaluate(
method: Literal["run", "fit", "predict"] = "run",
backend: Literal["local", "jobmon"] = "local",
build: bool = True,
stages: set[str] | None = None,
**kwargs,
) -> None:
"""Evaluate pipeline method.
Expand All @@ -272,6 +286,10 @@ def evaluate(
Name of method to evaluate. Default is 'run'.
backend : str, optional
How to evaluate the method. Default is 'local'.
build : bool, optional
Whether to build the pipeline before evaluation. Default is True.
stages : set of str, optional
Stages to evaluate. Default is None.
Other Parameters
----------------
Expand All @@ -281,20 +299,46 @@ def evaluate(
Path to resources yaml file. Required if `backend` is
'jobmon'.
TODO: Add options to run subset of stages
TODO: Add options to run subset of IDs
"""
if build:
self.build()

if stages is not None:
for stage_name in stages:
if stage_name not in self.stages:
raise ValueError(
f"Stage '{stage_name}' not found in pipeline."
)

for stage_name in stages:
stage: Stage = self.stages.get(stage_name)
for dep in stage.dependencies:
if dep not in stages:
if not self.check_upstream_outputs_exist(
stage_name, dep
):
raise ValueError(
f"Required input to stage '{stage_name}' is missing. Missing output from upstream dependency '{dep}'."
)

ordered_stages = (
[stage for stage in self.get_execution_order() if stage in stages]
if stages is not None
else self.get_execution_order()
)

if backend == "jobmon":
from onemod.backend import evaluate_with_jobmon

evaluate_with_jobmon(model=self, method=method, **kwargs)
evaluate_with_jobmon(
model=self, method=method, stages=ordered_stages, **kwargs
)
else:
from onemod.backend import evaluate_local

evaluate_local(model=self, method=method)
evaluate_local(model=self, method=method, stages=ordered_stages)

def run(
self,
Expand Down
31 changes: 31 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Generator

import pytest
Expand All @@ -23,6 +24,36 @@ def test_assets_dir():
return test_dir


@pytest.fixture
def small_input_data(request, test_assets_dir):
"""Fixture providing path to test input data for tests marked with requires_data."""
if request.node.get_closest_marker("requires_data") is None:
pytest.skip("Skipping test because it requires data assets.")

small_input_data_path = Path(
test_assets_dir, "e2e", "example1", "data", "small_data.parquet"
)
return small_input_data_path


@pytest.fixture
def dummy_resources(request, test_assets_dir):
"""Fixture providing path to test resources for tests marked with requires_data."""
if request.node.get_closest_marker("requires_data") is None:
pytest.skip("Skipping test because it requires data assets.")

dummy_resources_path = Path(
test_assets_dir, "e2e", "example1", "config", "jobmon", "resources.yaml"
)
return dummy_resources_path


@pytest.fixture
def test_base_dir(tmp_path_factory):
test_base_dir = tmp_path_factory.mktemp("test_base_dir")
return test_base_dir


@pytest.fixture(scope="session")
def validation_collector() -> Generator[ValidationErrorCollector, None, None]:
"""Fixture that manages the validation context for tests."""
Expand Down
Loading

0 comments on commit d6bd6ae

Please sign in to comment.