From 5a9a74753e50108afa5e3168a8ffd4397eadeb8e Mon Sep 17 00:00:00 2001 From: Ryan Murray <74630349+rywm-dhi@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:30:50 +0100 Subject: [PATCH] Test ResultFrameAggregator for all hierarchical column indices. --- .../result_frame_aggregator.py | 19 +++++++++++++------ tests/test_result_frame_aggregator.py | 5 +++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/mikeio1d/pandas_extension/result_frame_aggregator.py b/mikeio1d/pandas_extension/result_frame_aggregator.py index 26849ae8..48f99804 100644 --- a/mikeio1d/pandas_extension/result_frame_aggregator.py +++ b/mikeio1d/pandas_extension/result_frame_aggregator.py @@ -155,9 +155,6 @@ def aggregate(self, df: pd.DataFrame) -> pd.DataFrame: """ self._validate_df(df) - # df = compact_dataframe(df) - # df = self._remove_group_level(df) - for agg_level in self._agg_levels[:-1]: agg = self.get_agg_strategy(agg_level) df = self._aggregate_along_level(df, agg_level, agg) @@ -283,8 +280,11 @@ def _finalize_quantity_index(self, quantity_index: pd.Index) -> pd.Index: levels_to_keep = ["quantity", self._agg_level_name] for level in self._quantity_levels: - if level not in levels_to_keep: - quantity_index = quantity_index.droplevel(level) + if level in levels_to_keep: + continue + if level not in quantity_index.names: + continue + quantity_index = quantity_index.droplevel(level) quantity_index = quantity_index.map("_".join) @@ -301,6 +301,8 @@ def _finalize_entity_index(self, entity_index: pd.Index) -> pd.Index: for level in self._entity_levels: if level in levels_to_keep: continue + if level not in entity_index.names: + continue is_singular = entity_index.get_level_values(level).nunique() == 1 if is_singular: @@ -314,7 +316,12 @@ def _finalize_df_post_aggregate(self, df: pd.DataFrame) -> pd.DataFrame: """ df = df.rename_axis(self._agg_level_name) - df = df.stack(self._quantity_levels).T + for level in self._quantity_levels: + if level not in df.columns.names: + continue + df = df.stack(level) + + df = df.T df.columns = self._finalize_quantity_index(df.columns) df.index = self._finalize_entity_index(df.index) diff --git a/tests/test_result_frame_aggregator.py b/tests/test_result_frame_aggregator.py index 29c18f1f..03de0474 100644 --- a/tests/test_result_frame_aggregator.py +++ b/tests/test_result_frame_aggregator.py @@ -338,8 +338,9 @@ def test_finalize_df_post_aggregate(self, df_dummy): class TestResultFrameAggregator: - def test_max_aggregation(self, res1d_river_network): - df_reaches = res1d_river_network.reaches.read(column_mode="all") + @pytest.mark.parametrize("column_mode", ["all", "compact"]) + def test_max_aggregation(self, res1d_river_network, column_mode): + df_reaches = res1d_river_network.reaches.read(column_mode=column_mode) rfa = ResultFrameAggregator("max") df_agg = rfa.aggregate(df_reaches)