diff --git a/CHANGELOG.md b/CHANGELOG.md index 07d6e2869..e2fb2a4d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ Code freeze date: YYYY-MM-DD ### Added +- `ImpactFunc` and `ImpactFuncSet` now support equality comparisons via `==` + ### Changed - `Hazard.local_exceedance_intensity`, `Hazard.local_return_period` and `Impact.local_exceedance_impact`, `Impact.local_return_period`, using the `climada.util.interpolation` module: New default (no binning), binning on decimals, and faster implementation [#1012](https://github.com/CLIMADA-project/climada_python/pull/1012) ### Fixed diff --git a/climada/entity/impact_funcs/base.py b/climada/entity/impact_funcs/base.py index 287391a79..c51540d57 100644 --- a/climada/entity/impact_funcs/base.py +++ b/climada/entity/impact_funcs/base.py @@ -97,6 +97,19 @@ def __init__( self.mdd = mdd if mdd is not None else np.array([]) self.paa = paa if paa is not None else np.array([]) + def __eq__(self, value: object, /) -> bool: + if isinstance(value, ImpactFunc): + return ( + self.haz_type == value.haz_type + and self.id == value.id + and self.name == value.name + and self.intensity_unit == value.intensity_unit + and np.array_equal(self.intensity, value.intensity) + and np.array_equal(self.mdd, value.mdd) + and np.array_equal(self.paa, value.paa) + ) + return False + def calc_mdr(self, inten: Union[float, np.ndarray]) -> np.ndarray: """Interpolate impact function to a given intensity. @@ -177,7 +190,7 @@ def from_step_impf( mdd: tuple[float, float] = (0, 1), paa: tuple[float, float] = (1, 1), impf_id: int = 1, - **kwargs + **kwargs, ): """Step function type impact function. @@ -218,7 +231,7 @@ def from_step_impf( intensity=intensity, mdd=mdd, paa=paa, - **kwargs + **kwargs, ) def set_step_impf(self, *args, **kwargs): @@ -238,7 +251,7 @@ def from_sigmoid_impf( x0: float, haz_type: str, impf_id: int = 1, - **kwargs + **kwargs, ): r"""Sigmoid type impact function hinging on three parameter. @@ -287,7 +300,7 @@ def from_sigmoid_impf( intensity=intensity, paa=paa, mdd=mdd, - **kwargs + **kwargs, ) def set_sigmoid_impf(self, *args, **kwargs): @@ -308,7 +321,7 @@ def from_poly_s_shape( exponent: float, haz_type: str, impf_id: int = 1, - **kwargs + **kwargs, ): r"""S-shape polynomial impact function hinging on four parameter. diff --git a/climada/entity/impact_funcs/impact_func_set.py b/climada/entity/impact_funcs/impact_func_set.py index e94ff8b82..030f73f2b 100755 --- a/climada/entity/impact_funcs/impact_func_set.py +++ b/climada/entity/impact_funcs/impact_func_set.py @@ -109,6 +109,12 @@ def __init__(self, impact_funcs: Optional[Iterable[ImpactFunc]] = None): for impf in impact_funcs: self.append(impf) + def __eq__(self, value: object, /) -> bool: + if isinstance(value, ImpactFuncSet): + return self._data == value._data + + return False + def clear(self): """Reinitialize attributes.""" self._data = dict() # {hazard_type : {id:ImpactFunc}} diff --git a/climada/entity/impact_funcs/test/test_base.py b/climada/entity/impact_funcs/test/test_base.py index b0652a1be..59fc5a676 100644 --- a/climada/entity/impact_funcs/test/test_base.py +++ b/climada/entity/impact_funcs/test/test_base.py @@ -26,6 +26,74 @@ from climada.entity.impact_funcs.base import ImpactFunc +class TestEquality(unittest.TestCase): + """Test equality method""" + + def setUp(self): + self.impf1 = ImpactFunc( + haz_type="TC", + id=1, + intensity=np.array([1, 2, 3]), + mdd=np.array([0.1, 0.2, 0.3]), + paa=np.array([0.4, 0.5, 0.6]), + intensity_unit="m/s", + name="Test Impact", + ) + self.impf2 = ImpactFunc( + haz_type="TC", + id=1, + intensity=np.array([1, 2, 3]), + mdd=np.array([0.1, 0.2, 0.3]), + paa=np.array([0.4, 0.5, 0.6]), + intensity_unit="m/s", + name="Test Impact", + ) + self.impf3 = ImpactFunc( + haz_type="FL", + id=2, + intensity=np.array([4, 5, 6]), + mdd=np.array([0.7, 0.8, 0.9]), + paa=np.array([0.1, 0.2, 0.3]), + intensity_unit="m", + name="Another Impact", + ) + + def test_reflexivity(self): + self.assertEqual(self.impf1, self.impf1) + + def test_symmetry(self): + self.assertEqual(self.impf1, self.impf2) + self.assertEqual(self.impf2, self.impf1) + + def test_transitivity(self): + impf4 = ImpactFunc( + haz_type="TC", + id=1, + intensity=np.array([1, 2, 3]), + mdd=np.array([0.1, 0.2, 0.3]), + paa=np.array([0.4, 0.5, 0.6]), + intensity_unit="m/s", + name="Test Impact", + ) + self.assertEqual(self.impf1, self.impf2) + self.assertEqual(self.impf2, impf4) + self.assertEqual(self.impf1, impf4) + + def test_consistency(self): + self.assertEqual(self.impf1, self.impf2) + self.assertEqual(self.impf1, self.impf2) + + def test_comparison_with_none(self): + self.assertNotEqual(self.impf1, None) + + def test_different_types(self): + self.assertNotEqual(self.impf1, "Not an ImpactFunc") + + def test_inequality(self): + self.assertNotEqual(self.impf1, self.impf3) + self.assertTrue(self.impf1 != self.impf3) + + class TestInterpolation(unittest.TestCase): """Impact function interpolation test""" @@ -139,5 +207,8 @@ def test_aux_vars(impf): # Execute Tests if __name__ == "__main__": - TESTS = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation) - unittest.TextTestRunner(verbosity=2).run(TESTS) + equality_tests = unittest.TestLoader().loadTestsFromTestCase(TestEquality) + interpolation_tests = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation) + unittest.TextTestRunner(verbosity=2).run( + unittest.TestSuite([equality_tests, interpolation_tests]) + ) diff --git a/climada/entity/impact_funcs/test/test_imp_fun_set.py b/climada/entity/impact_funcs/test/test_imp_fun_set.py index 3bc60559b..cad52d46b 100644 --- a/climada/entity/impact_funcs/test/test_imp_fun_set.py +++ b/climada/entity/impact_funcs/test/test_imp_fun_set.py @@ -288,6 +288,55 @@ def test_remove_add_pass(self): self.assertEqual([1], imp_fun.get_ids("TC")) +class TestEquality(unittest.TestCase): + """Test equality method for ImpactFuncSet""" + + def setUp(self): + intensity = np.array([0, 20]) + paa = np.array([0, 1]) + mdd = np.array([0, 0.5]) + + fun_1 = ImpactFunc("TC", 3, intensity, mdd, paa) + fun_2 = ImpactFunc("TC", 3, deepcopy(intensity), deepcopy(mdd), deepcopy(paa)) + fun_3 = ImpactFunc("TC", 4, intensity + 1, mdd, paa) + + self.impact_set1 = ImpactFuncSet([fun_1]) + self.impact_set2 = ImpactFuncSet([fun_2]) + self.impact_set3 = ImpactFuncSet([fun_3]) + self.impact_set4 = ImpactFuncSet([fun_1, fun_3]) + + def test_reflexivity(self): + self.assertEqual(self.impact_set1, self.impact_set1) + + def test_symmetry(self): + self.assertEqual(self.impact_set1, self.impact_set2) + self.assertEqual(self.impact_set2, self.impact_set1) + + def test_transitivity(self): + impact_set5 = ImpactFuncSet([self.impact_set1._data["TC"][3]]) + self.assertEqual(self.impact_set1, self.impact_set2) + self.assertEqual(self.impact_set2, impact_set5) + self.assertEqual(self.impact_set1, impact_set5) + + def test_consistency(self): + self.assertEqual(self.impact_set1, self.impact_set2) + self.assertEqual(self.impact_set1, self.impact_set2) + + def test_comparison_with_none(self): + self.assertNotEqual(self.impact_set1, None) + + def test_different_types(self): + self.assertNotEqual(self.impact_set1, "Not an ImpactFuncSet") + + def test_field_comparison(self): + self.assertNotEqual(self.impact_set1, self.impact_set3) + self.assertNotEqual(self.impact_set1, self.impact_set4) + + def test_inequality(self): + self.assertNotEqual(self.impact_set1, self.impact_set3) + self.assertTrue(self.impact_set1 != self.impact_set3) + + class TestChecker(unittest.TestCase): """Test loading funcions from the ImpactFuncSet class""" @@ -592,6 +641,7 @@ def test_write_read_pass(self): # Execute Tests if __name__ == "__main__": TESTS = unittest.TestLoader().loadTestsFromTestCase(TestContainer) + TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestEquality)) TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestChecker)) TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestExtend)) TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestReaderExcel))