diff --git a/paramdb/_param_data/_primitives.py b/paramdb/_param_data/_primitives.py index 8e92309..f5c4683 100644 --- a/paramdb/_param_data/_primitives.py +++ b/paramdb/_param_data/_primitives.py @@ -178,7 +178,10 @@ def __bool__(self) -> bool: return False def __eq__(self, other: object) -> bool: - return isinstance(other, ParamNone) + return other is None or isinstance(other, ParamNone) + + def __hash__(self) -> int: + return hash(self.value) def __repr__(self) -> str: # Show empty parentheses diff --git a/tests/_param_data/test_primitives.py b/tests/_param_data/test_primitives.py index ec7c6c3..be2ce5c 100644 --- a/tests/_param_data/test_primitives.py +++ b/tests/_param_data/test_primitives.py @@ -6,6 +6,7 @@ import pytest from paramdb import ParamInt, ParamFloat, ParamBool, ParamStr, ParamNone from tests.helpers import ( + SimpleParam, CustomParamInt, CustomParamFloat, CustomParamBool, @@ -174,19 +175,38 @@ def test_param_primitive_eq( ) -> None: """ Parameter primitive objects are equal to themselves, their vaues, and custom - parameter primitive objects, and are not equal to other objects. + parameter primitive objects. """ # pylint: disable=comparison-with-itself assert param_primitive == param_primitive + assert param_primitive == deepcopy(param_primitive) assert param_primitive == custom_param_primitive - assert custom_param_primitive == custom_param_primitive - assert custom_param_primitive == param_primitive - if isinstance(param_primitive, ParamNone): - assert param_primitive != param_primitive.value - assert custom_param_primitive != custom_param_primitive.value - else: - assert param_primitive == param_primitive.value - assert custom_param_primitive == custom_param_primitive.value + assert param_primitive == param_primitive.value + + +def test_param_primitive_ne( + simple_param: SimpleParam, + param_primitive: ParamPrimitive, + custom_param_primitive: CustomParamPrimitive, +) -> None: + """ + Parameter primitive objects are not equal to other objects or parameter primitives + with different values. + """ + assert param_primitive != simple_param + assert custom_param_primitive != simple_param + if not isinstance(param_primitive, ParamNone): + assert param_primitive != type(param_primitive)() + assert custom_param_primitive != type(custom_param_primitive)() + + +def test_param_primitive_hash( + param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive +) -> None: + """Parameter primitive objects has the same hash as objects they are equal to.""" + assert hash(param_primitive) == hash(deepcopy(param_primitive)) + assert hash(param_primitive) == hash(custom_param_primitive) + assert hash(param_primitive) == hash(param_primitive.value) def test_param_int_methods_return_int(param_int: ParamInt) -> None: