From 573eb2fa78a65a324bc2d908a482ca0eb63e690e Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Fri, 11 Oct 2024 19:09:27 +0200 Subject: [PATCH] refactor(enums): make __eq__ faster (#1233) --- openfisca_core/indexed_enums/__init__.py | 3 +- openfisca_core/indexed_enums/_enum_type.py | 24 ------ openfisca_core/indexed_enums/_errors.py | 14 +++- openfisca_core/indexed_enums/_utils.py | 83 +++++-------------- openfisca_core/indexed_enums/enum.py | 68 +++++++++------ .../indexed_enums/tests/test_enum.py | 6 +- openfisca_core/indexed_enums/types.py | 4 +- openfisca_core/types.py | 4 - 8 files changed, 81 insertions(+), 125 deletions(-) diff --git a/openfisca_core/indexed_enums/__init__.py b/openfisca_core/indexed_enums/__init__.py index 6268b8348..494601fc8 100644 --- a/openfisca_core/indexed_enums/__init__.py +++ b/openfisca_core/indexed_enums/__init__.py @@ -2,7 +2,7 @@ from . import types from ._enum_type import EnumType -from ._errors import EnumMemberNotFoundError +from ._errors import EnumEncodingError, EnumMemberNotFoundError from .config import ENUM_ARRAY_DTYPE from .enum import Enum from .enum_array import EnumArray @@ -11,6 +11,7 @@ "ENUM_ARRAY_DTYPE", "Enum", "EnumArray", + "EnumEncodingError", "EnumMemberNotFoundError", "EnumType", "types", diff --git a/openfisca_core/indexed_enums/_enum_type.py b/openfisca_core/indexed_enums/_enum_type.py index 6356984ab..a86969ccf 100644 --- a/openfisca_core/indexed_enums/_enum_type.py +++ b/openfisca_core/indexed_enums/_enum_type.py @@ -73,18 +73,6 @@ class EnumType(t.EnumType): #: The items of the indexed enum class. items: t.RecArray - #: The names as if they were sorted. - _sorted_names_: t.StrArray - - #: The enums as if they were sorted. - _sorted_enums_: t.ObjArray - - #: The indices that would sort the names. - _sorted_names_index_: t.IndexArray - - #: The indices that would sort the enums. - _sorted_enums_index_: t.IndexArray - @property def indices(cls) -> t.IndexArray: """Return the indices of the indexed enum class.""" @@ -121,18 +109,6 @@ def __new__( # Add the items attribute to the enum class. cls.items = _item_array(cls) - # Add the indices that would sort the names. - cls._sorted_names_index_ = numpy.argsort(cls.names).astype(t.EnumDType) - - # Add the indices that would sort the enums. - cls._sorted_enums_index_ = numpy.argsort(cls.enums).astype(t.EnumDType) - - # Add the names as if they were sorted. - cls._sorted_names_ = cls.names[cls._sorted_names_index_] - - # Add the enums as if they were sorted. - cls._sorted_enums_ = cls.enums[cls._sorted_enums_index_] - # Return the modified enum class. return cls diff --git a/openfisca_core/indexed_enums/_errors.py b/openfisca_core/indexed_enums/_errors.py index d16024cf2..7ec21eca4 100644 --- a/openfisca_core/indexed_enums/_errors.py +++ b/openfisca_core/indexed_enums/_errors.py @@ -1,6 +1,18 @@ from . import types as t +class EnumEncodingError(TypeError): + """Raised when an enum is encoded with an unsupported type.""" + + def __init__(self, enum_class: type[t.Enum], value: t.VarArray) -> None: + msg = ( + f"Failed to encode \"{value}\" of type '{value[0].__class__.__name__}', " + "as it is not supported. Please, try again with an array of " + f"'{int.__name__}', '{str.__name__}', or '{enum_class.__name__}'." + ) + super().__init__(msg) + + class EnumMemberNotFoundError(IndexError): """Raised when a member is not found in an enum.""" @@ -15,4 +27,4 @@ def __init__(self, enum_class: type[t.Enum], value: str) -> None: super().__init__(msg) -__all__ = ["EnumMemberNotFoundError"] +__all__ = ["EnumEncodingError", "EnumMemberNotFoundError"] diff --git a/openfisca_core/indexed_enums/_utils.py b/openfisca_core/indexed_enums/_utils.py index c95104b76..0a29ff961 100644 --- a/openfisca_core/indexed_enums/_utils.py +++ b/openfisca_core/indexed_enums/_utils.py @@ -14,9 +14,6 @@ def _enum_to_index(enum_class: type[t.Enum], value: t.ObjArray) -> t.IndexArray: Returns: The index array. - Raises: - EnumMemberNotFoundError: If one value is not in the enum class. - Examples: >>> import numpy @@ -37,43 +34,24 @@ def _enum_to_index(enum_class: type[t.Enum], value: t.ObjArray) -> t.IndexArray: >>> class Rogue(enum.Enum): ... BOULEVARD = "More like a shady impasse, to be honest." - # >>> _enum_to_index(Road, numpy.array(Road.AVENUE)) - # array([1], dtype=uint8) - # - # >>> _enum_to_index(Road, numpy.array([Road.AVENUE])) - # array([1], dtype=uint8) - # - # >>> value = numpy.array([Road.STREET, Road.AVENUE, Road.STREET]) - # >>> _enum_to_index(Road, value) - # array([0, 1, 0], dtype=uint8) - - >>> value = numpy.array([Road.AVENUE, Road.AVENUE, Rogue.BOULEVARD]) - >>> _enum_to_index(Road, value) + >>> _enum_to_index(Road, numpy.array(Road.AVENUE)) Traceback (most recent call last): - EnumMemberNotFoundError: Member BOULEVARD not found in enum 'Road'... - - """ - # Create a mask to determine which values are in the enum class. - mask = numpy.isin(value, enum_class.enums) - - # Get the values that are not in the enum class. - ko = value[~mask] - - # If there are values that are not in the enum class, raise an error. - if ko.size > 0: - raise EnumMemberNotFoundError(enum_class, ko[0].name) + TypeError: iteration over a 0-d array - # In case we're dealing with a scalar, we need to convert it to an array. - ok = value[mask] + >>> _enum_to_index(Road, numpy.array([Road.AVENUE])) + array([1], dtype=uint8) - # Get the index positions of the enums in the sorted enums. - index_where = numpy.searchsorted(enum_class._sorted_enums_, ok) + >>> value = numpy.array([Road.STREET, Road.AVENUE, Road.STREET]) + >>> _enum_to_index(Road, value) + array([0, 1, 0], dtype=uint8) - # Get the actual index of the enums in the enum class. - index = enum_class._sorted_enums_index_[index_where] + >>> value = numpy.array([Road.AVENUE, Road.AVENUE, Rogue.BOULEVARD]) + >>> _enum_to_index(Road, value) + array([1, 1, 0], dtype=uint8) - # Finally, return the index array. - return numpy.array(index, dtype=t.EnumDType) + """ + index = [member.index for member in value] + return _int_to_index(enum_class, numpy.array(index)) def _int_to_index(enum_class: type[t.Enum], value: t.IndexArray) -> t.IndexArray: @@ -121,7 +99,7 @@ def _int_to_index(enum_class: type[t.Enum], value: t.IndexArray) -> t.IndexArray """ # Create a mask to determine which values are in the enum class. - mask = numpy.isin(value, enum_class.indices) + mask = value < enum_class.items.size # Get the values that are not in the enum class. ko = value[~mask] @@ -144,9 +122,6 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray: Returns: The index array. - Raises: - EnumMemberNotFoundError: If one value is not in the enum class. - Examples: >>> import numpy @@ -165,7 +140,8 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray: ... ) >>> _str_to_index(Road, numpy.array("AVENUE")) - array([1], dtype=uint8) + Traceback (most recent call last): + TypeError: iteration over a 0-d array >>> _str_to_index(Road, numpy.array(["AVENUE"])) array([1], dtype=uint8) @@ -174,31 +150,12 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray: array([0, 1, 0], dtype=uint8) >>> _str_to_index(Road, numpy.array(["AVENUE", "AVENUE", "BOULEVARD"])) - Traceback (most recent call last): - EnumMemberNotFoundError: Member BOULEVARD not found in enum 'Road'... + array([1, 1, 0], dtype=uint8) """ - # Create a mask to determine which values are in the enum class. - mask = numpy.isin(value, enum_class.names) - - # Get the values that are not in the enum class. - ko = value[~mask] - - # If there are values that are not in the enum class, raise an error. - if ko.size > 0: - raise EnumMemberNotFoundError(enum_class, ko[0]) - - # In case we're dealing with a scalar, we need to convert it to an array. - ok = value[mask] - - # Get the index positions of the names in the sorted names. - index_where = numpy.searchsorted(enum_class._sorted_names_, ok) - - # Get the actual index of the names in the enum class. - index = enum_class._sorted_names_index_[index_where] - - # Finally, return the index array. - return numpy.array(index, dtype=t.EnumDType) + names = enum_class.names + index = [enum_class[name].index if name in names else 0 for name in value] + return _int_to_index(enum_class, numpy.array(index)) __all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"] diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index 7265e90bc..446d6cb74 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -4,6 +4,7 @@ from . import types as t from ._enum_type import EnumType +from ._errors import EnumEncodingError from ._guards import _is_int_array, _is_obj_array, _is_str_array from ._utils import _enum_to_index, _int_to_index, _str_to_index from .enum_array import EnumArray @@ -111,37 +112,55 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}.{self.name}" def __hash__(self) -> int: - return hash(self.__class__.__name__) ^ hash(self.index) + return object.__hash__(self) def __eq__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return hash(self) ^ hash(other) == 0 + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index == other.index + return NotImplemented def __ne__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return hash(self) ^ hash(other) != 0 + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index != other.index + return NotImplemented def __lt__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return self.index < other.index + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index < other.index + return NotImplemented def __le__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return self.index <= other.index + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index <= other.index + return NotImplemented def __gt__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return self.index > other.index + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index > other.index + return NotImplemented def __ge__(self, other: object) -> bool: - if not isinstance(other, Enum): - return NotImplemented - return self.index >= other.index + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index >= other.index + return NotImplemented @classmethod def encode( @@ -167,7 +186,7 @@ def encode( Raises: NotImplementedError: If ``array`` is a scalar :class:`~numpy.ndarray`. - TypeError: If ``array`` is of a diffent :class:`.Enum` type. + EnumEncodingError: If ``array`` is of a diffent :class:`.Enum` type. Examples: >>> import numpy @@ -211,7 +230,7 @@ def encode( >>> array = numpy.array([b"TENANT"]) >>> enum_array = Housing.encode(array) Traceback (most recent call last): - TypeError: Failed to encode "[b'TENANT']" of type 'bytes_', as i... + EnumEncodingError: Failed to encode "[b'TENANT']" of type 'bytes... .. seealso:: :meth:`.EnumArray.decode` for decoding. @@ -259,12 +278,7 @@ def encode( if _is_obj_array(array) and cls.__name__ is array[0].__class__.__name__: return EnumArray(_enum_to_index(cls, array), cls) - msg = ( - f"Failed to encode \"{array}\" of type '{array[0].__class__.__name__}', " - "as it is not supported. Please, try again with an array of " - f"'{int.__name__}', '{str.__name__}', or '{cls.__name__}'." - ) - raise TypeError(msg) + raise EnumEncodingError(cls, array) __all__ = ["Enum"] diff --git a/openfisca_core/indexed_enums/tests/test_enum.py b/openfisca_core/indexed_enums/tests/test_enum.py index 1f3e95a6f..8ffae5dd8 100644 --- a/openfisca_core/indexed_enums/tests/test_enum.py +++ b/openfisca_core/indexed_enums/tests/test_enum.py @@ -103,10 +103,10 @@ def test_enum_encode_with_str_scalar_array(): def test_enum_encode_with_str_with_bad_value(): - """Does not encode when called with a value not in an Enum.""" + """Encode encode when called with a value not in an Enum.""" array = numpy.array(["JAIBA"]) - with pytest.raises(IndexError): - Animal.encode(array) + enum_array = Animal.encode(array) + assert Animal.CAT in enum_array # Unsupported encodings diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py index b545d0bb6..784ae1e92 100644 --- a/openfisca_core/indexed_enums/types.py +++ b/openfisca_core/indexed_enums/types.py @@ -15,7 +15,7 @@ import numpy from numpy import ( bool_ as BoolDType, - generic as AnyDType, + generic as VarDType, int32 as IntDType, object_ as ObjDType, str_ as StrDType, @@ -44,7 +44,7 @@ ObjArray: TypeAlias = Array[ObjDType] #: Type for generic arrays. -AnyArray: TypeAlias = Array[AnyDType] +VarArray: TypeAlias = Array[VarDType] __all__ = [ "ArrayLike", diff --git a/openfisca_core/types.py b/openfisca_core/types.py index b1cb65970..702138e39 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -113,10 +113,6 @@ def plural(self, /) -> None | RolePlural: ... class EnumType(enum.EnumMeta): items: RecArray - _sorted_names_: Array[DTypeStr] - _sorted_enums_: Array[DTypeObject] - _sorted_names_index_: Array[DTypeEnum] - _sorted_enums_index_: Array[DTypeEnum] @property @abc.abstractmethod