Skip to content

Commit

Permalink
Test registering example custom derived quantity
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Sep 18, 2024
1 parent 62fca31 commit 05fc719
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions tests/test_derived_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from pandas.testing import assert_frame_equal

from mikeio1d import Res1D

NODE_DERIVED_QUANTITIES = [
"NodeFlooding",
"NodeWaterDepth",
Expand Down Expand Up @@ -110,11 +112,36 @@ def test_read_derived_quantities_single_location_reach(res1d_network, reach_deri

def set_multiindex_level_values(df, level, value):
"""Set all values of a MultiIndex level to a specific value."""
num_columns = df.shape[1]
df.columns = df.columns.set_levels([value for _ in range(num_columns)], level=level)
df.columns = df.columns.set_levels([value], level=level)
return df


def test_example_custom_derived_quantity(res1d_network):
from mikeio1d.quantities.derived.derived_quantity_example import ExampleDerivedQuantity
from mikeio1d.res1d import derived_quantity_manager as dqm

dqm.register(ExampleDerivedQuantity)
res1d_network = Res1D(res1d_network.file_path)

assert "WaterLevelPlusOne" in res1d_network.derived_quantities
assert "WaterLevelPlusOne" in res1d_network.nodes.derived_quantities
assert "WaterLevelPlusOne" in res1d_network.reaches.derived_quantities

df_nodes = res1d_network.nodes.WaterLevelPlusOne.read()
df_nodes_expected = res1d_network.nodes.WaterLevel.read(column_mode="compact") + 1
df_nodes_expected = set_multiindex_level_values(
df_nodes_expected, "quantity", "WaterLevelPlusOne"
)
assert_frame_equal(df_nodes, df_nodes_expected)

df_reaches = res1d_network.reaches.WaterLevelPlusOne.read()
df_reaches_expected = res1d_network.reaches.WaterLevel.read(column_mode="compact") + 1
df_reaches_expected = set_multiindex_level_values(
df_reaches_expected, "quantity", "WaterLevelPlusOne"
)
assert_frame_equal(df_reaches, df_reaches_expected)


class TestSpecificDerivedQuantities:
def test_node_flooding(self, res1d_network):
df = res1d_network.nodes["1"].NodeFlooding.read()
Expand Down

0 comments on commit 05fc719

Please sign in to comment.