Skip to content

Commit

Permalink
Test ResultFrameAggregator for all hierarchical column indices.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Jan 30, 2024
1 parent 28a8223 commit 8203f21
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
19 changes: 13 additions & 6 deletions mikeio1d/pandas_extension/result_frame_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def aggregate(self, df: pd.DataFrame) -> pd.DataFrame:
"""
self._validate_df(df)

# df = compact_dataframe(df)
# df = self._remove_group_level(df)

for agg_level in self._agg_levels[:-1]:
agg = self.get_agg_strategy(agg_level)
df = self._aggregate_along_level(df, agg_level, agg)
Expand Down Expand Up @@ -283,8 +280,11 @@ def _finalize_quantity_index(self, quantity_index: pd.Index) -> pd.Index:

levels_to_keep = ["quantity", self._agg_level_name]
for level in self._quantity_levels:
if level not in levels_to_keep:
quantity_index = quantity_index.droplevel(level)
if level in levels_to_keep:
continue
if level not in quantity_index.names:
continue
quantity_index = quantity_index.droplevel(level)

quantity_index = quantity_index.map("_".join)

Expand All @@ -301,6 +301,8 @@ def _finalize_entity_index(self, entity_index: pd.Index) -> pd.Index:
for level in self._entity_levels:
if level in levels_to_keep:
continue
if level not in entity_index.names:
continue

is_singular = entity_index.get_level_values(level).nunique() == 1
if is_singular:
Expand All @@ -314,7 +316,12 @@ def _finalize_df_post_aggregate(self, df: pd.DataFrame) -> pd.DataFrame:
"""

df = df.rename_axis(self._agg_level_name)
df = df.stack(self._quantity_levels).T
for level in self._quantity_levels:
if level not in df.columns.names:
continue
df = df.stack(level)

df = df.T

df.columns = self._finalize_quantity_index(df.columns)
df.index = self._finalize_entity_index(df.index)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_result_frame_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,9 @@ def test_finalize_df_post_aggregate(self, df_dummy):


class TestResultFrameAggregator:
def test_max_aggregation(self, res1d_river_network):
df_reaches = res1d_river_network.reaches.read(column_mode="all")
@pytest.mark.parametrize("column_mode", ["all", "compact"])
def test_max_aggregation(self, res1d_river_network, column_mode):
df_reaches = res1d_river_network.reaches.read(column_mode=column_mode)
rfa = ResultFrameAggregator("max")
df_agg = rfa.aggregate(df_reaches)

Expand Down

0 comments on commit 8203f21

Please sign in to comment.