Skip to content

Implements equality methods for impf and impfset #1027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 28, 2025
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Removed:
`plot_intensity`, `plot_fraction`, `_event_plot` to mask plotting when regions are too far from data points [#1047](https://github.com/CLIMADA-project/climada_python/pull/1047). To recreate previous plots (no masking), the parameter can be set to None.
- Added instructions to install Climada petals on Euler cluster in `doc.guide.Guide_Euler.ipynb` [#1029](https://github.com/CLIMADA-project/climada_python/pull/1029)

- `ImpactFunc` and `ImpactFuncSet` now support equality comparisons via `==` [#1027](https://github.com/CLIMADA-project/climada_python/pull/1027)

### Changed

- `Hazard.local_exceedance_intensity`, `Hazard.local_return_period` and `Impact.local_exceedance_impact`, `Impact.local_return_period`, using the `climada.util.interpolation` module: New default (no binning), binning on decimals, and faster implementation [#1012](https://github.com/CLIMADA-project/climada_python/pull/1012)
Expand Down
23 changes: 18 additions & 5 deletions climada/entity/impact_funcs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@
self.mdd = mdd if mdd is not None else np.array([])
self.paa = paa if paa is not None else np.array([])

def __eq__(self, value: object, /) -> bool:
if isinstance(value, ImpactFunc):
return (
self.haz_type == value.haz_type
and self.id == value.id
and self.name == value.name
and self.intensity_unit == value.intensity_unit
and np.array_equal(self.intensity, value.intensity)
and np.array_equal(self.mdd, value.mdd)
and np.array_equal(self.paa, value.paa)
)
return False

def calc_mdr(self, inten: Union[float, np.ndarray]) -> np.ndarray:
"""Interpolate impact function to a given intensity.

Expand Down Expand Up @@ -177,7 +190,7 @@
mdd: tuple[float, float] = (0, 1),
paa: tuple[float, float] = (1, 1),
impf_id: int = 1,
**kwargs
**kwargs,
):
"""Step function type impact function.

Expand Down Expand Up @@ -218,7 +231,7 @@
intensity=intensity,
mdd=mdd,
paa=paa,
**kwargs
**kwargs,
)

def set_step_impf(self, *args, **kwargs):
Expand All @@ -235,10 +248,10 @@
intensity: tuple[float, float, float],
L: float,
k: float,
x0: float,

Check warning on line 251 in climada/entity/impact_funcs/base.py

View check run for this annotation

Jenkins - WCR / Pylint

invalid-name

LOW: Argument name "x0" doesn't conform to '(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$' pattern
Raw output
Used when the name doesn't match the regular expression associated to its type(constant, variable, class...).
haz_type: str,
impf_id: int = 1,
**kwargs
**kwargs,
):
r"""Sigmoid type impact function hinging on three parameter.

Expand Down Expand Up @@ -287,7 +300,7 @@
intensity=intensity,
paa=paa,
mdd=mdd,
**kwargs
**kwargs,
)

def set_sigmoid_impf(self, *args, **kwargs):
Expand All @@ -308,7 +321,7 @@
exponent: float,
haz_type: str,
impf_id: int = 1,
**kwargs
**kwargs,
):
r"""S-shape polynomial impact function hinging on four parameter.

Expand Down
6 changes: 6 additions & 0 deletions climada/entity/impact_funcs/impact_func_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def __init__(self, impact_funcs: Optional[Iterable[ImpactFunc]] = None):
for impf in impact_funcs:
self.append(impf)

def __eq__(self, value: object, /) -> bool:
if isinstance(value, ImpactFuncSet):
return self._data == value._data

return False

def clear(self):
"""Reinitialize attributes."""
self._data = dict() # {hazard_type : {id:ImpactFunc}}
Expand Down
75 changes: 73 additions & 2 deletions climada/entity/impact_funcs/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,74 @@
from climada.entity.impact_funcs.base import ImpactFunc


class TestEquality(unittest.TestCase):
"""Test equality method"""

def setUp(self):
self.impf1 = ImpactFunc(
haz_type="TC",
id=1,
intensity=np.array([1, 2, 3]),
mdd=np.array([0.1, 0.2, 0.3]),
paa=np.array([0.4, 0.5, 0.6]),
intensity_unit="m/s",
name="Test Impact",
)
self.impf2 = ImpactFunc(
haz_type="TC",
id=1,
intensity=np.array([1, 2, 3]),
mdd=np.array([0.1, 0.2, 0.3]),
paa=np.array([0.4, 0.5, 0.6]),
intensity_unit="m/s",
name="Test Impact",
)
self.impf3 = ImpactFunc(
haz_type="FL",
id=2,
intensity=np.array([4, 5, 6]),
mdd=np.array([0.7, 0.8, 0.9]),
paa=np.array([0.1, 0.2, 0.3]),
intensity_unit="m",
name="Another Impact",
)

def test_reflexivity(self):
self.assertEqual(self.impf1, self.impf1)

def test_symmetry(self):
self.assertEqual(self.impf1, self.impf2)
self.assertEqual(self.impf2, self.impf1)

def test_transitivity(self):
impf4 = ImpactFunc(
haz_type="TC",
id=1,
intensity=np.array([1, 2, 3]),
mdd=np.array([0.1, 0.2, 0.3]),
paa=np.array([0.4, 0.5, 0.6]),
intensity_unit="m/s",
name="Test Impact",
)
self.assertEqual(self.impf1, self.impf2)
self.assertEqual(self.impf2, impf4)
self.assertEqual(self.impf1, impf4)

def test_consistency(self):
self.assertEqual(self.impf1, self.impf2)
self.assertEqual(self.impf1, self.impf2)

def test_comparison_with_none(self):
self.assertNotEqual(self.impf1, None)

def test_different_types(self):
self.assertNotEqual(self.impf1, "Not an ImpactFunc")

def test_inequality(self):
self.assertNotEqual(self.impf1, self.impf3)
self.assertTrue(self.impf1 != self.impf3)


class TestInterpolation(unittest.TestCase):
"""Impact function interpolation test"""

Expand Down Expand Up @@ -139,5 +207,8 @@ def test_aux_vars(impf):

# Execute Tests
if __name__ == "__main__":
TESTS = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation)
unittest.TextTestRunner(verbosity=2).run(TESTS)
equality_tests = unittest.TestLoader().loadTestsFromTestCase(TestEquality)
interpolation_tests = unittest.TestLoader().loadTestsFromTestCase(TestInterpolation)
unittest.TextTestRunner(verbosity=2).run(
unittest.TestSuite([equality_tests, interpolation_tests])
)
51 changes: 51 additions & 0 deletions climada/entity/impact_funcs/test/test_imp_fun_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import unittest
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -288,6 +289,55 @@ def test_remove_add_pass(self):
self.assertEqual([1], imp_fun.get_ids("TC"))


class TestEquality(unittest.TestCase):
"""Test equality method for ImpactFuncSet"""

def setUp(self):
intensity = np.array([0, 20])
paa = np.array([0, 1])
mdd = np.array([0, 0.5])

fun_1 = ImpactFunc("TC", 3, intensity, mdd, paa)
fun_2 = ImpactFunc("TC", 3, deepcopy(intensity), deepcopy(mdd), deepcopy(paa))
fun_3 = ImpactFunc("TC", 4, intensity + 1, mdd, paa)

self.impact_set1 = ImpactFuncSet([fun_1])
self.impact_set2 = ImpactFuncSet([fun_2])
self.impact_set3 = ImpactFuncSet([fun_3])
self.impact_set4 = ImpactFuncSet([fun_1, fun_3])

def test_reflexivity(self):
self.assertEqual(self.impact_set1, self.impact_set1)

def test_symmetry(self):
self.assertEqual(self.impact_set1, self.impact_set2)
self.assertEqual(self.impact_set2, self.impact_set1)

def test_transitivity(self):
impact_set5 = ImpactFuncSet([self.impact_set1._data["TC"][3]])
self.assertEqual(self.impact_set1, self.impact_set2)
self.assertEqual(self.impact_set2, impact_set5)
self.assertEqual(self.impact_set1, impact_set5)

def test_consistency(self):
self.assertEqual(self.impact_set1, self.impact_set2)
self.assertEqual(self.impact_set1, self.impact_set2)

def test_comparison_with_none(self):
self.assertNotEqual(self.impact_set1, None)

def test_different_types(self):
self.assertNotEqual(self.impact_set1, "Not an ImpactFuncSet")

def test_field_comparison(self):
self.assertNotEqual(self.impact_set1, self.impact_set3)
self.assertNotEqual(self.impact_set1, self.impact_set4)

def test_inequality(self):
self.assertNotEqual(self.impact_set1, self.impact_set3)
self.assertTrue(self.impact_set1 != self.impact_set3)


class TestChecker(unittest.TestCase):
"""Test loading funcions from the ImpactFuncSet class"""

Expand Down Expand Up @@ -592,6 +642,7 @@ def test_write_read_pass(self):
# Execute Tests
if __name__ == "__main__":
TESTS = unittest.TestLoader().loadTestsFromTestCase(TestContainer)
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestEquality))
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestChecker))
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestExtend))
TESTS.addTests(unittest.TestLoader().loadTestsFromTestCase(TestReaderExcel))
Expand Down
Loading