From 85e2aaa74be729d0b3ccf6a8c06fd470e2e61f00 Mon Sep 17 00:00:00 2001 From: Wes Warriner Date: Sat, 21 Dec 2024 02:24:29 -0700 Subject: [PATCH 1/3] feat: make data io, loader, interface, and stage util functions use, return pandas df by default instead of polars data objects; update tests accordingly and add a few additional distinct dataif tests for pandas vs polars return types --- src/onemod/fsutils/data_loader.py | 43 ++++++++++------- src/onemod/fsutils/interface.py | 8 ++-- src/onemod/fsutils/io.py | 6 +-- src/onemod/stage/base.py | 8 ++-- src/onemod/stage/model_stages/rover_stage.py | 7 ++- src/onemod/utils/parameters.py | 25 +++++----- src/onemod/utils/subsets.py | 29 ++++++------ tests/helpers/dummy_stages.py | 3 +- .../test_integration_pipeline_evaluate.py | 10 ++-- .../integration/test_integration_stage_io.py | 31 +++++++++--- tests/unit/fsutils/test_data_interface.py | 47 +++++++++++++++++-- tests/unit/fsutils/test_io.py | 34 ++++++++++---- 12 files changed, 165 insertions(+), 86 deletions(-) diff --git a/src/onemod/fsutils/data_loader.py b/src/onemod/fsutils/data_loader.py index e1987a6e..89ee7ad0 100644 --- a/src/onemod/fsutils/data_loader.py +++ b/src/onemod/fsutils/data_loader.py @@ -17,33 +17,44 @@ def load( self, path: Path, return_type: Literal[ - "polars_dataframe", "polars_lazyframe", "pandas_dataframe" - ] = "polars_dataframe", + "pandas_dataframe", "polars_dataframe", "polars_lazyframe" + ] = "pandas_dataframe", columns: list[str] | None = None, id_subsets: dict[str, list] | None = None, **options, - ) -> pl.DataFrame | pl.LazyFrame | pd.DataFrame: + ) -> pd.DataFrame | pl.DataFrame | pl.LazyFrame: """Load data with lazy loading and subset filtering. Polars and Pandas options available for the type of the returned data object.""" if path.suffix not in self.io_dict: raise ValueError(f"Unsupported data format for '{path.suffix}'") - polars_lf = self.io_dict[path.suffix].load_lazy(path, **options) + if return_type == "pandas_dataframe": + pandas_df = self.io_dict[path.suffix].load_eager(path, **options) - if columns: - polars_lf = polars_lf.select(columns) + if columns: + pandas_df = pandas_df[columns] - if id_subsets: - for col, values in id_subsets.items(): - polars_lf = polars_lf.filter(pl.col(col).is_in(values)) + if id_subsets: + for col, values in id_subsets.items(): + pandas_df = pandas_df[pandas_df[col].isin(values)] + pandas_df.reset_index(drop=True, inplace=True) - if return_type == "polars_dataframe": - return polars_lf.collect() - elif return_type == "polars_lazyframe": - return polars_lf - elif return_type == "pandas_dataframe": - return polars_lf.collect().to_pandas() + return pandas_df + elif return_type in ["polars_dataframe", "polars_lazyframe"]: + polars_lf = self.io_dict[path.suffix].load_lazy(path, **options) + + if columns: + polars_lf = polars_lf.select(columns) + + if id_subsets: + for col, values in id_subsets.items(): + polars_lf = polars_lf.filter(pl.col(col).is_in(values)) + + if return_type == "polars_dataframe": + return polars_lf.collect() + elif return_type == "polars_lazyframe": + return polars_lf else: raise ValueError( "Return type must be one of 'polars_dataframe', 'polars_lazyframe', or 'pandas_dataframe'" @@ -51,7 +62,7 @@ def load( def dump( self, - obj: pl.DataFrame | pl.LazyFrame | pd.DataFrame, + obj: pd.DataFrame | pl.DataFrame | pl.LazyFrame, path: Path, **options, ) -> None: diff --git a/src/onemod/fsutils/interface.py b/src/onemod/fsutils/interface.py index ed22c8a7..016275ae 100644 --- a/src/onemod/fsutils/interface.py +++ b/src/onemod/fsutils/interface.py @@ -22,8 +22,8 @@ def load( *fparts: str, key: str, return_type: Literal[ - "polars_dataframe", "polars_lazyframe", "pandas_dataframe" - ] = "polars_dataframe", + "pandas_dataframe", "polars_dataframe", "polars_lazyframe" + ] = "pandas_dataframe", columns: list[str] | None = None, id_subsets: dict[str, list] | None = None, **options, @@ -32,7 +32,7 @@ def load( Parameters ---------- - return_type : {'polars_dataframe', 'polars_lazyframe', 'pandas_dataframe'}, optional + return_type : {'pandas_dataframe', 'polars_dataframe', 'polars_lazyframe'}, optional Return type of loaded data object, applicable only for data files. columns : list of str, optional Specific columns to load, applicable only for data files. @@ -63,7 +63,7 @@ def load( def dump(self, obj: Any, *fparts: str, key: str, **options) -> None: """Dump data or config files based on object type and key.""" path = self.get_full_path(*fparts, key=key) - if isinstance(obj, (pl.DataFrame, pl.LazyFrame, pd.DataFrame)): + if isinstance(obj, (pd.DataFrame, pl.DataFrame, pl.LazyFrame)): return self.data_loader.dump(obj, path, **options) else: return self.config_loader.dump(obj, path, **options) diff --git a/src/onemod/fsutils/io.py b/src/onemod/fsutils/io.py index 06d4c1f5..f868e247 100644 --- a/src/onemod/fsutils/io.py +++ b/src/onemod/fsutils/io.py @@ -15,14 +15,14 @@ class DataIO(ABC): """Bridge class that unifies the I/O for different data file types.""" fextns: tuple[str, ...] = ("",) - dtypes: tuple[Type, ...] = (pl.DataFrame, pl.LazyFrame, pd.DataFrame) + dtypes: tuple[Type, ...] = (pd.DataFrame, pl.DataFrame, pl.LazyFrame) def load_eager( self, fpath: Path | str, - backend: Literal["polars", "pandas"] = "polars", + backend: Literal["pandas", "polars"] = "pandas", **options, - ) -> pl.DataFrame | pd.DataFrame: + ) -> pd.DataFrame | pl.DataFrame: """Load data from given path.""" fpath = Path(fpath) if fpath.suffix not in self.fextns: diff --git a/src/onemod/stage/base.py b/src/onemod/stage/base.py index 30acad31..81285fd6 100644 --- a/src/onemod/stage/base.py +++ b/src/onemod/stage/base.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Literal -from polars import DataFrame +from pandas import DataFrame from pydantic import BaseModel, ConfigDict, Field, validate_call import onemod.stage as onemod_stages @@ -474,14 +474,14 @@ def create_stage_subsets( f"{self.name} does not have a groupby attribute" ) - lf = self.dataif.load( + df = self.dataif.load( key=data_key, columns=list(self.groupby), id_subsets=id_subsets, - return_type="polars_lazyframe", + return_type="pandas_dataframe", ) - subsets_df = create_subsets(self.groupby, lf.collect().to_pandas()) + subsets_df = create_subsets(self.groupby, df) self._subset_ids = set(subsets_df["subset_id"].to_list()) self.dataif.dump(subsets_df, "subsets.csv", key="output") diff --git a/src/onemod/stage/model_stages/rover_stage.py b/src/onemod/stage/model_stages/rover_stage.py index 47bc0e9b..c21454a7 100644 --- a/src/onemod/stage/model_stages/rover_stage.py +++ b/src/onemod/stage/model_stages/rover_stage.py @@ -16,7 +16,6 @@ import warnings import pandas as pd -import polars as pl from loguru import logger from modrover.api import Rover @@ -52,8 +51,8 @@ def fit(self, subset_id: int, *args, **kwargs) -> None: """ # Load data and filter by subset logger.info(f"Loading {self.name} data subset {subset_id}") - data = self.get_stage_subset(subset_id).filter( - pl.col(self.config["test_column"]) == 0 + data = self.get_stage_subset(subset_id).query( + f"{self.config['test_column']} == 0" ) if len(data) > 0: @@ -71,7 +70,7 @@ def fit(self, subset_id: int, *args, **kwargs) -> None: # Fit submodel submodel.fit( - data=data.to_pandas(), + data=data, strategies=list(self.config.strategies), top_pct_score=self.config.top_pct_score, top_pct_learner=self.config.top_pct_learner, diff --git a/src/onemod/utils/parameters.py b/src/onemod/utils/parameters.py index e22367fd..86163d9c 100644 --- a/src/onemod/utils/parameters.py +++ b/src/onemod/utils/parameters.py @@ -3,12 +3,12 @@ from itertools import product from typing import Any -import polars as pl +from pandas import DataFrame from onemod.config import ModelConfig -def create_params(config: ModelConfig) -> pl.DataFrame | None: +def create_params(config: ModelConfig) -> DataFrame | None: """Create parameter sets from crossby.""" param_dict = { param_name: param_values @@ -18,17 +18,18 @@ def create_params(config: ModelConfig) -> pl.DataFrame | None: if len(param_dict) == 0: return None - crossby = list(param_dict.keys()) - params = pl.DataFrame( - [list(param_set) for param_set in product(*param_dict.values())], - schema=crossby, - orient="row", + params = DataFrame( + [param_set for param_set in product(*param_dict.values())], + columns=(crossby := param_dict.keys()), ) + params["param_id"] = params.index - params = params.with_row_index(name="param_id") - return params.select(["param_id", *crossby]) + return params[["param_id", *crossby]] -def get_params(params: pl.DataFrame, param_id: int) -> dict[str, Any]: - params = params.filter(pl.col("param_id") == param_id).drop("param_id") - return {str(col): params[col][0] for col in params.columns} +def get_params(params: DataFrame, param_id: int) -> dict[str, Any]: + params = params.query("param_id == @param_id").drop(columns=["param_id"]) + return { + param_name: param_value.item() + for param_name, param_value in params.items() + } diff --git a/src/onemod/utils/subsets.py b/src/onemod/utils/subsets.py index 38037123..9933a32b 100644 --- a/src/onemod/utils/subsets.py +++ b/src/onemod/utils/subsets.py @@ -1,7 +1,6 @@ """Functions for working with groupby and subsets.""" import pandas as pd -import polars as pl def create_subsets(groupby: set[str], data: pd.DataFrame) -> pd.DataFrame: @@ -17,11 +16,11 @@ def create_subsets(groupby: set[str], data: pd.DataFrame) -> pd.DataFrame: def get_subset( - data: pl.DataFrame, - subsets: pl.DataFrame, + data: pd.DataFrame, + subsets: pd.DataFrame, subset_id: int, id_names: list[str] | None = None, -) -> pl.DataFrame: +) -> pd.DataFrame: """Get data subset by subset_id.""" id_subsets = get_id_subsets(subsets, subset_id) if id_names is not None: @@ -31,21 +30,21 @@ def get_subset( return filter_data(data, id_subsets) -def get_id_subsets(subsets: pl.DataFrame, subset_id: int) -> dict: +def get_id_subsets(subsets: pd.DataFrame, subset_id: int) -> dict: """Get ID names and values that define a data subset.""" return ( - subsets.filter(pl.col("subset_id") == subset_id) - .drop("subset_id") - .to_dict(as_series=False) + subsets.query("subset_id == @subset_id") + .drop(columns=["subset_id"]) + .to_dict(orient="list") ) def filter_data( - data: pl.DataFrame, id_subsets: dict[str, set[int]] -) -> pl.DataFrame: + data: pd.DataFrame, id_subsets: dict[str, set[int]] +) -> pd.DataFrame: """Filter data by ID subsets.""" - filter_expr = pl.lit(True) - for key, value in id_subsets.items(): - filter_expr &= pl.col(key).is_in(value) - - return data.filter(filter_expr) + return data.query( + " & ".join( + [f"{key}.isin({value})" for key, value in id_subsets.items()] + ) + ).reset_index(drop=True) diff --git a/tests/helpers/dummy_stages.py b/tests/helpers/dummy_stages.py index 557856ad..36a7d06b 100644 --- a/tests/helpers/dummy_stages.py +++ b/tests/helpers/dummy_stages.py @@ -1,4 +1,3 @@ -import polars as pl from pydantic import Field from onemod.config import ( @@ -249,7 +248,7 @@ class MultiplyByTwoStage(ModelStage): def run(self, subset_id: int, *args, **kwargs) -> None: """Run MultiplyByTwoStage.""" df = self.get_stage_subset(subset_id) - df = df.with_columns((pl.col("value") * 2).alias("value")) + df["value"] = df["value"] * 2 self.dataif.dump(df, "data.parquet", key="output") def fit(self) -> None: diff --git a/tests/integration/test_integration_pipeline_evaluate.py b/tests/integration/test_integration_pipeline_evaluate.py index 2566b066..3b88349c 100644 --- a/tests/integration/test_integration_pipeline_evaluate.py +++ b/tests/integration/test_integration_pipeline_evaluate.py @@ -1,6 +1,6 @@ from unittest.mock import patch -import polars as pl +import pandas as pd import pytest from tests.helpers.dummy_pipeline import get_expected_args, setup_dummy_pipeline from tests.helpers.dummy_stages import MultiplyByTwoStage, assert_stage_logs @@ -150,8 +150,8 @@ def test_invalid_id_subsets_keys(small_input_data, test_base_dir, method): def test_evaluate_with_id_subsets(test_base_dir, sample_data): """Test that Pipeline.evaluate() correctly evaluates single stage with id_subsets.""" sample_input_data = test_base_dir / "test_input_data.parquet" - df = pl.DataFrame(sample_data) - df.write_parquet(sample_input_data) + df = pd.DataFrame(sample_data) + df.to_parquet(sample_input_data) test_pipeline = Pipeline( name="dummy_pipeline", @@ -171,14 +171,14 @@ def test_evaluate_with_id_subsets(test_base_dir, sample_data): # Ensure input data is as expected for the test assert sample_input_data.exists() - input_df = pl.read_parquet(sample_input_data) + input_df = pd.read_parquet(sample_input_data) assert input_df.shape == (4, 4) test_pipeline.evaluate(method="run", id_subsets={"age_group_id": [1]}) # Verify that output only contains rows with specified subset(s) for age_group_id output_df = test_stage.dataif.load("data.parquet", key="output") - assert output_df.select(pl.col("age_group_id")).n_unique() == 1 + assert output_df["age_group_id"].nunique() == 1 assert output_df.shape == (1, 4) diff --git a/tests/integration/test_integration_stage_io.py b/tests/integration/test_integration_stage_io.py index 3563815f..5b1c6fb5 100644 --- a/tests/integration/test_integration_stage_io.py +++ b/tests/integration/test_integration_stage_io.py @@ -52,12 +52,29 @@ def test_input(stage_1): @pytest.mark.integration def test_output(stage_1): - assert stage_1.output == Output( - stage=stage_1.name, - items={ - "predictions": Data(stage=stage_1.name, path="predictions.parquet"), - "model": Data(stage=stage_1.name, path="model.pkl"), - }, + # print(stage_1.output) + # print(Output( + # stage=stage_1.name, + # items={ + # "model": Data(stage=stage_1.name, path="model.pkl"), + # "predictions": Data(stage=stage_1.name, path="predictions.parquet"), + # }, + # )) + assert ( + stage_1.output + == Output( + stage=stage_1.name, + items={ + "predictions": Data( + stage=stage_1.name, + path="predictions.parquet", + format="parquet", + ), + "model": Data( + stage=stage_1.name, path="model.pkl", format="pkl" + ), # FIXME: implicit format pending update of Data class with new version of DataInterface + }, + ) ) @@ -81,7 +98,7 @@ def test_input_with_missing(): with pytest.raises(KeyError) as error: stage_3(priors="/path/to/priors.pkl") observed = str(error.value).strip('"') - expected = f"{stage_3.name} missing required input: " + expected = f"Stage '{stage_3.name}' missing required input: " assert ( observed == expected + "['data', 'covariates']" or observed == expected + "['covariates', 'data']" diff --git a/tests/unit/fsutils/test_data_interface.py b/tests/unit/fsutils/test_data_interface.py index b182d778..dd205608 100644 --- a/tests/unit/fsutils/test_data_interface.py +++ b/tests/unit/fsutils/test_data_interface.py @@ -1,6 +1,7 @@ import numpy as np +import pandas as pd +import polars as pl import pytest -from polars import DataFrame from onemod.fsutils.interface import DataInterface @@ -27,10 +28,10 @@ def sample_config(): @pytest.mark.unit @pytest.mark.parametrize("extension", [".csv", ".parquet"]) -def test_data_interface(sample_data1, extension, tmp_path): +def test_data_interface_eager_polars(sample_data1, extension, tmp_path): dataif = DataInterface(tmp=tmp_path) - df = DataFrame(sample_data1) + df = pl.DataFrame(sample_data1) dataif.dump(df, "data" + extension, key="tmp") @@ -40,6 +41,44 @@ def test_data_interface(sample_data1, extension, tmp_path): assert np.allclose(sample_data1[key], loaded_data[key]) +@pytest.mark.unit +@pytest.mark.parametrize("extension", [".csv", ".parquet"]) +def test_data_interface_lazy_polars(sample_data1, extension, tmp_path): + dataif = DataInterface(tmp=tmp_path) + + df = pl.LazyFrame(sample_data1) + + dataif.dump(df, "data" + extension, key="tmp") + + lf = dataif.load( + "data" + extension, key="tmp", return_type="polars_lazyframe" + ) + + assert type(lf) is pl.LazyFrame + + loaded_data = lf.collect() + + for key in ["a", "b"]: + assert np.allclose(sample_data1[key], loaded_data[key]) + + +@pytest.mark.unit +@pytest.mark.parametrize("extension", [".csv", ".parquet"]) +def test_data_interface_eager_pandas(sample_data1, extension, tmp_path): + dataif = DataInterface(tmp=tmp_path) + + df = pd.DataFrame(sample_data1) + + dataif.dump(df, "data" + extension, key="tmp") + + loaded_data = dataif.load( + "data" + extension, key="tmp", return_type="pandas_dataframe" + ) + + for key in ["a", "b"]: + assert np.allclose(sample_data1[key], loaded_data[key]) + + @pytest.mark.unit def test_add_path(tmp_path): dataif = DataInterface() @@ -77,7 +116,7 @@ def test_remove_path(tmp_path): @pytest.fixture def data_files(sample_data2, tmp_path): """Create small CSV and Parquet files for testing.""" - data = DataFrame(sample_data2) + data = pl.DataFrame(sample_data2) csv_path = tmp_path / "data.csv" parquet_path = tmp_path / "data.parquet" diff --git a/tests/unit/fsutils/test_io.py b/tests/unit/fsutils/test_io.py index c98dcaa7..5d0f251d 100644 --- a/tests/unit/fsutils/test_io.py +++ b/tests/unit/fsutils/test_io.py @@ -1,6 +1,7 @@ import numpy as np +import pandas as pd +import polars as pl import pytest -from polars import DataFrame, LazyFrame from onemod.fsutils.io import CSVIO, JSONIO, TOMLIO, YAMLIO, ParquetIO, PickleIO @@ -12,12 +13,12 @@ def data(): @pytest.mark.unit def test_csvio_eager(data, tmp_path): - data = DataFrame(data) + data = pd.DataFrame(data) port = CSVIO() port.dump(data, tmp_path / "file.csv") loaded_data = port.load_eager(tmp_path / "file.csv") - assert type(loaded_data) is DataFrame + assert type(loaded_data) is pd.DataFrame for key in ["a", "b"]: assert np.allclose(data[key], loaded_data[key]) @@ -25,12 +26,12 @@ def test_csvio_eager(data, tmp_path): @pytest.mark.unit def test_csvio_lazy(data, tmp_path): - data = DataFrame(data) + data = pd.DataFrame(data) port = CSVIO() port.dump(data, tmp_path / "file.csv") lazy_loaded_data = port.load_lazy(tmp_path / "file.csv") - assert type(lazy_loaded_data) is LazyFrame + assert type(lazy_loaded_data) is pl.LazyFrame loaded_data = lazy_loaded_data.collect() @@ -59,13 +60,26 @@ def test_yamlio(data, tmp_path): @pytest.mark.unit -def test_parquetio_eager(data, tmp_path): - data = DataFrame(data) +def test_parquetio_eager_pandas(data, tmp_path): + data = pd.DataFrame(data) port = ParquetIO() port.dump(data, tmp_path / "file.parquet") loaded_data = port.load_eager(tmp_path / "file.parquet") - assert type(loaded_data) is DataFrame + assert type(loaded_data) is pd.DataFrame + + for key in ["a", "b"]: + assert np.allclose(data[key], loaded_data[key]) + + +@pytest.mark.unit +def test_parquetio_eager_polars(data, tmp_path): + data = pl.DataFrame(data) + port = ParquetIO() + port.dump(data, tmp_path / "file.parquet") + loaded_data = port.load_eager(tmp_path / "file.parquet", backend="polars") + + assert type(loaded_data) is pl.DataFrame for key in ["a", "b"]: assert np.allclose(data[key], loaded_data[key]) @@ -73,12 +87,12 @@ def test_parquetio_eager(data, tmp_path): @pytest.mark.unit def test_parquetio_lazy(data, tmp_path): - data = DataFrame(data) + data = pd.DataFrame(data) port = ParquetIO() port.dump(data, tmp_path / "file.parquet") lazy_loaded_data = port.load_lazy(tmp_path / "file.parquet") - assert type(lazy_loaded_data) is LazyFrame + assert type(lazy_loaded_data) is pl.LazyFrame loaded_data = lazy_loaded_data.collect() From 6c2d70510b124848c50a926cdcd85b7617928f83 Mon Sep 17 00:00:00 2001 From: Wes Warriner Date: Sat, 21 Dec 2024 02:58:31 -0700 Subject: [PATCH 2/3] cleanup: fix mypy issues --- src/onemod/fsutils/data_loader.py | 3 +++ src/onemod/stage/model_stages/spxmod_stage.py | 2 +- src/onemod/utils/parameters.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/onemod/fsutils/data_loader.py b/src/onemod/fsutils/data_loader.py index 89ee7ad0..f60fa1d7 100644 --- a/src/onemod/fsutils/data_loader.py +++ b/src/onemod/fsutils/data_loader.py @@ -31,6 +31,9 @@ def load( if return_type == "pandas_dataframe": pandas_df = self.io_dict[path.suffix].load_eager(path, **options) + assert isinstance( + pandas_df, pd.DataFrame + ), "Expected a pandas DataFrame" if columns: pandas_df = pandas_df[columns] diff --git a/src/onemod/stage/model_stages/spxmod_stage.py b/src/onemod/stage/model_stages/spxmod_stage.py index 9cf69459..2909a91e 100644 --- a/src/onemod/stage/model_stages/spxmod_stage.py +++ b/src/onemod/stage/model_stages/spxmod_stage.py @@ -210,7 +210,7 @@ def _get_submodel_data( """Load submodel data.""" # Load data and filter by subset logger.info(f"Loading {self.name} data subset {subset_id}") - data = self.get_stage_subset(subset_id).to_pandas() + data = self.get_stage_subset(subset_id) # Add spline basis to data spline_vars = [] diff --git a/src/onemod/utils/parameters.py b/src/onemod/utils/parameters.py index 86163d9c..5370fac0 100644 --- a/src/onemod/utils/parameters.py +++ b/src/onemod/utils/parameters.py @@ -20,7 +20,7 @@ def create_params(config: ModelConfig) -> DataFrame | None: params = DataFrame( [param_set for param_set in product(*param_dict.values())], - columns=(crossby := param_dict.keys()), + columns=(crossby := list(param_dict.keys())), ) params["param_id"] = params.index @@ -30,6 +30,6 @@ def create_params(config: ModelConfig) -> DataFrame | None: def get_params(params: DataFrame, param_id: int) -> dict[str, Any]: params = params.query("param_id == @param_id").drop(columns=["param_id"]) return { - param_name: param_value.item() + str(param_name): param_value.item() for param_name, param_value in params.items() } From 8eb0364afc4e7ec0bc71acfaa21c538b64707ee6 Mon Sep 17 00:00:00 2001 From: Wes Warriner Date: Sat, 21 Dec 2024 22:26:46 -0800 Subject: [PATCH 3/3] tdd: minor fixes to resolve cluster requires_data pipeline evaluation tests --- src/onemod/stage/base.py | 5 ++++- src/onemod/utils/parameters.py | 7 +++---- tests/integration/test_integration_pipeline_evaluate.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/onemod/stage/base.py b/src/onemod/stage/base.py index 81285fd6..757acbb0 100644 --- a/src/onemod/stage/base.py +++ b/src/onemod/stage/base.py @@ -500,7 +500,10 @@ def create_stage_params(self) -> None: """Create stage parameter sets from config.""" params = create_params(self.config) if params is not None: - self._crossby = set(params.drop("param_id").columns) + if "param_id" not in params.columns: + raise KeyError("Parameter set ID column 'param_id' not found") + + self._crossby = set(params.columns) - {"param_id"} self._param_ids = set(params["param_id"]) self.dataif.dump(params, "parameters.csv", key="output") diff --git a/src/onemod/utils/parameters.py b/src/onemod/utils/parameters.py index 5370fac0..3c878396 100644 --- a/src/onemod/utils/parameters.py +++ b/src/onemod/utils/parameters.py @@ -15,16 +15,15 @@ def create_params(config: ModelConfig) -> DataFrame | None: for param_name in config.crossable_params if isinstance(param_values := config[param_name], set) } - if len(param_dict) == 0: + if not param_dict: return None params = DataFrame( - [param_set for param_set in product(*param_dict.values())], - columns=(crossby := list(param_dict.keys())), + list(product(*param_dict.values())), columns=list(param_dict.keys()) ) params["param_id"] = params.index - return params[["param_id", *crossby]] + return params[["param_id", *param_dict.keys()]] def get_params(params: DataFrame, param_id: int) -> dict[str, Any]: diff --git a/tests/integration/test_integration_pipeline_evaluate.py b/tests/integration/test_integration_pipeline_evaluate.py index 3b88349c..3b733799 100644 --- a/tests/integration/test_integration_pipeline_evaluate.py +++ b/tests/integration/test_integration_pipeline_evaluate.py @@ -118,8 +118,8 @@ def test_missing_dependency_error(small_input_data, test_base_dir, method): subset_stage_names = {"covariate_selection"} with pytest.raises( - ValueError, - match="Required input to stage 'covariate_selection' is missing. Missing output from upstream dependency 'preprocessing'.", + FileNotFoundError, + match=f"Stage covariate_selection input items do not exist: {{'data': '{test_base_dir}/preprocessing/data.parquet'}}", ): dummy_pipeline.evaluate(method=method, stages=subset_stage_names)