Skip to content

Commit

Permalink
Refactoring and allow multiindex entity indices.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Jan 30, 2024
1 parent 30e3341 commit 28a8223
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 28 deletions.
99 changes: 75 additions & 24 deletions mikeio1d/pandas_extension/result_frame_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class ResultFrameAggregator:
"""
Aggregates a MIKE IO 1D results DataFrame into scalar values associate with entities.
Aggregates a MIKE IO 1D result DataFrame into scalar values associate with entities.
Parameters
----------
Expand Down Expand Up @@ -50,6 +50,9 @@ class ResultFrameAggregator:
# Always take the first chainage value, but take the max of time
>>> agg = ResultFrameAggregator('max', chainage='first')
>>> agg.aggregate(df)
# Same result as above, but with explicit argument names.
>>> agg = ResultFrameAggregator(time='max', chainage='first')
"""

def __init__(self, agg: str | Callable = None, override_name: str = None, **kwargs):
Expand All @@ -60,10 +63,13 @@ def __init__(self, agg: str | Callable = None, override_name: str = None, **kwar
agg = kwargs["time"]

self._override_name = override_name
self._agg_level_name = "agg"

self._entity_levels = ("quantity", "group", "name", "tag", "derived")
self._entity_levels = ("group", "name", "tag")
self._agg_levels = ("duplicate", "chainage", "time")
self._quantity_levels = ("quantity", "derived")
self._agg_strategies = self._init_agg_strategies(agg, kwargs)

self._validate()

def _init_agg_strategies(self, agg, agg_kwargs: Dict) -> Dict[str, Any]:
Expand All @@ -86,31 +92,39 @@ def _validate_levels(self):
"""
entity_levels = set(self._entity_levels)
agg_levels = set(self._agg_levels)

if "time" not in agg_levels:
raise ValueError("Agg levels must contain 'time'.")

if self._agg_levels[-1] != "time":
raise ValueError("Agg levels must end with 'time'.")
quantity_levels = set(self._quantity_levels)

if len(entity_levels) != len(self._entity_levels):
raise ValueError("Entity levels must be unique.")

if len(agg_levels) != len(self._agg_levels):
raise ValueError("Agg levels must be unique.")

if not entity_levels.isdisjoint(agg_levels):
raise ValueError("Entity levels and agg levels must be mutually exclusive sets.")
if len(quantity_levels) != len(self._quantity_levels):
raise ValueError("Quantity levels must be unique.")

if (
not entity_levels.isdisjoint(agg_levels)
and not agg_levels.isdisjoint(quantity_levels)
and not quantity_levels.isdisjoint(entity_levels)
):
raise ValueError("Entity, quantity, and agg levels must be mutually exclusive sets.")

timeseries_id_fields = set(f.name for f in fields(TimeSeriesId))

agg_levels.remove("time") # time is not a field in TimeSeriesId
agg_levels.discard("time") # time is not a field in TimeSeriesId

if not (entity_levels | agg_levels) == timeseries_id_fields:
if not (entity_levels | agg_levels | quantity_levels) == timeseries_id_fields:
raise ValueError(
"Either entity_levels or agg_levels is missing a field from TimeSeriesId."
"Either entity_levels, quantity_levels, or agg_levels is missing a field from TimeSeriesId."
)

if self._agg_levels[-1] != "time":
raise ValueError("Agg levels must end with 'time'.")

if self._agg_levels[0] != "duplicate":
raise ValueError("Agg levels should start with 'duplicate'.")

def _validate_agg_strategies(self):
"""
Validate that the agg strategies are complete and valid.
Expand Down Expand Up @@ -141,8 +155,8 @@ def aggregate(self, df: pd.DataFrame) -> pd.DataFrame:
"""
self._validate_df(df)

df = compact_dataframe(df)
df = self._remove_group_level(df)
# df = compact_dataframe(df)
# df = self._remove_group_level(df)

for agg_level in self._agg_levels[:-1]:
agg = self.get_agg_strategy(agg_level)
Expand All @@ -169,6 +183,13 @@ def agg_levels(self) -> List[str]:
"""
return self._agg_levels

@property
def quantity_levels(self) -> List[str]:
"""
Quantity levels are the levels which uniquely identify a quantity.
"""
return self._quantity_levels

@property
def agg_strategies(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -253,20 +274,50 @@ def _aggregate_along_time(self, df: pd.DataFrame, agg: Any) -> pd.DataFrame:
agg = make_list_if_not_iterable(agg)
return df.agg(agg)

def _finalize_quantity_index(self, quantity_index: pd.Index) -> pd.Index:
"""
Finalize format of quantity_index.
"""
if len(self._quantity_levels) == 1:
return quantity_index

levels_to_keep = ["quantity", self._agg_level_name]
for level in self._quantity_levels:
if level not in levels_to_keep:
quantity_index = quantity_index.droplevel(level)

quantity_index = quantity_index.map("_".join)

return quantity_index

def _finalize_entity_index(self, entity_index: pd.Index) -> pd.Index:
"""
Finalize format of entity_index.
"""
if len(self._entity_levels) == 1:
return entity_index

levels_to_keep = ["name"]
for level in self._entity_levels:
if level in levels_to_keep:
continue

is_singular = entity_index.get_level_values(level).nunique() == 1
if is_singular:
entity_index = entity_index.droplevel(level)

return entity_index

def _finalize_df_post_aggregate(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Finalize the DataFrame formatting after aggregation.
"""

# Set the name of the overall aggregation
if self._override_name is not None:
df.index = (self._override_name,)

# Join the overall aggregation name with individual quantities
df = df.stack("quantity")
df.index = df.index.map("_".join)
df = df.rename_axis(self._agg_level_name)
df = df.stack(self._quantity_levels).T

# Transform the DataFrame such that names are rows
df = df.T
df.columns = self._finalize_quantity_index(df.columns)
df.index = self._finalize_entity_index(df.index)

df = df.sort_index()
return df
9 changes: 5 additions & 4 deletions tests/test_result_frame_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,13 @@ def test_validate_agg_strategy(self):
with pytest.raises(ValueError):
rfa._validate_agg_strategy("time", None)

def test_aggregate(self):
pass

def test_entity_levels(self):
rfa = ResultFrameAggregator("max")
assert rfa.entity_levels == ("quantity", "group", "name", "tag", "derived")
assert rfa.entity_levels == ("group", "name", "tag")

def test_quantity_levels(self):
rfa = ResultFrameAggregator("max")
assert rfa.quantity_levels == ("quantity", "derived")

def test_agg_levels(self):
rfa = ResultFrameAggregator("max")
Expand Down

0 comments on commit 28a8223

Please sign in to comment.