Skip to content

Commit

Permalink
refactor(enums): make __eq__ faster (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 11, 2024
1 parent 67f14dc commit 573eb2f
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 125 deletions.
3 changes: 2 additions & 1 deletion openfisca_core/indexed_enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from . import types
from ._enum_type import EnumType
from ._errors import EnumMemberNotFoundError
from ._errors import EnumEncodingError, EnumMemberNotFoundError
from .config import ENUM_ARRAY_DTYPE
from .enum import Enum
from .enum_array import EnumArray
Expand All @@ -11,6 +11,7 @@
"ENUM_ARRAY_DTYPE",
"Enum",
"EnumArray",
"EnumEncodingError",
"EnumMemberNotFoundError",
"EnumType",
"types",
Expand Down
24 changes: 0 additions & 24 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ class EnumType(t.EnumType):
#: The items of the indexed enum class.
items: t.RecArray

#: The names as if they were sorted.
_sorted_names_: t.StrArray

#: The enums as if they were sorted.
_sorted_enums_: t.ObjArray

#: The indices that would sort the names.
_sorted_names_index_: t.IndexArray

#: The indices that would sort the enums.
_sorted_enums_index_: t.IndexArray

@property
def indices(cls) -> t.IndexArray:
"""Return the indices of the indexed enum class."""
Expand Down Expand Up @@ -121,18 +109,6 @@ def __new__(
# Add the items attribute to the enum class.
cls.items = _item_array(cls)

# Add the indices that would sort the names.
cls._sorted_names_index_ = numpy.argsort(cls.names).astype(t.EnumDType)

# Add the indices that would sort the enums.
cls._sorted_enums_index_ = numpy.argsort(cls.enums).astype(t.EnumDType)

# Add the names as if they were sorted.
cls._sorted_names_ = cls.names[cls._sorted_names_index_]

# Add the enums as if they were sorted.
cls._sorted_enums_ = cls.enums[cls._sorted_enums_index_]

# Return the modified enum class.
return cls

Expand Down
14 changes: 13 additions & 1 deletion openfisca_core/indexed_enums/_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from . import types as t


class EnumEncodingError(TypeError):
"""Raised when an enum is encoded with an unsupported type."""

def __init__(self, enum_class: type[t.Enum], value: t.VarArray) -> None:
msg = (
f"Failed to encode \"{value}\" of type '{value[0].__class__.__name__}', "
"as it is not supported. Please, try again with an array of "
f"'{int.__name__}', '{str.__name__}', or '{enum_class.__name__}'."
)
super().__init__(msg)


class EnumMemberNotFoundError(IndexError):
"""Raised when a member is not found in an enum."""

Expand All @@ -15,4 +27,4 @@ def __init__(self, enum_class: type[t.Enum], value: str) -> None:
super().__init__(msg)


__all__ = ["EnumMemberNotFoundError"]
__all__ = ["EnumEncodingError", "EnumMemberNotFoundError"]
83 changes: 20 additions & 63 deletions openfisca_core/indexed_enums/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ def _enum_to_index(enum_class: type[t.Enum], value: t.ObjArray) -> t.IndexArray:
Returns:
The index array.
Raises:
EnumMemberNotFoundError: If one value is not in the enum class.
Examples:
>>> import numpy
Expand All @@ -37,43 +34,24 @@ def _enum_to_index(enum_class: type[t.Enum], value: t.ObjArray) -> t.IndexArray:
>>> class Rogue(enum.Enum):
... BOULEVARD = "More like a shady impasse, to be honest."
# >>> _enum_to_index(Road, numpy.array(Road.AVENUE))
# array([1], dtype=uint8)
#
# >>> _enum_to_index(Road, numpy.array([Road.AVENUE]))
# array([1], dtype=uint8)
#
# >>> value = numpy.array([Road.STREET, Road.AVENUE, Road.STREET])
# >>> _enum_to_index(Road, value)
# array([0, 1, 0], dtype=uint8)
>>> value = numpy.array([Road.AVENUE, Road.AVENUE, Rogue.BOULEVARD])
>>> _enum_to_index(Road, value)
>>> _enum_to_index(Road, numpy.array(Road.AVENUE))
Traceback (most recent call last):
EnumMemberNotFoundError: Member BOULEVARD not found in enum 'Road'...
"""
# Create a mask to determine which values are in the enum class.
mask = numpy.isin(value, enum_class.enums)

# Get the values that are not in the enum class.
ko = value[~mask]

# If there are values that are not in the enum class, raise an error.
if ko.size > 0:
raise EnumMemberNotFoundError(enum_class, ko[0].name)
TypeError: iteration over a 0-d array
# In case we're dealing with a scalar, we need to convert it to an array.
ok = value[mask]
>>> _enum_to_index(Road, numpy.array([Road.AVENUE]))
array([1], dtype=uint8)
# Get the index positions of the enums in the sorted enums.
index_where = numpy.searchsorted(enum_class._sorted_enums_, ok)
>>> value = numpy.array([Road.STREET, Road.AVENUE, Road.STREET])
>>> _enum_to_index(Road, value)
array([0, 1, 0], dtype=uint8)
# Get the actual index of the enums in the enum class.
index = enum_class._sorted_enums_index_[index_where]
>>> value = numpy.array([Road.AVENUE, Road.AVENUE, Rogue.BOULEVARD])
>>> _enum_to_index(Road, value)
array([1, 1, 0], dtype=uint8)
# Finally, return the index array.
return numpy.array(index, dtype=t.EnumDType)
"""
index = [member.index for member in value]
return _int_to_index(enum_class, numpy.array(index))


def _int_to_index(enum_class: type[t.Enum], value: t.IndexArray) -> t.IndexArray:
Expand Down Expand Up @@ -121,7 +99,7 @@ def _int_to_index(enum_class: type[t.Enum], value: t.IndexArray) -> t.IndexArray
"""
# Create a mask to determine which values are in the enum class.
mask = numpy.isin(value, enum_class.indices)
mask = value < enum_class.items.size

# Get the values that are not in the enum class.
ko = value[~mask]
Expand All @@ -144,9 +122,6 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray:
Returns:
The index array.
Raises:
EnumMemberNotFoundError: If one value is not in the enum class.
Examples:
>>> import numpy
Expand All @@ -165,7 +140,8 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray:
... )
>>> _str_to_index(Road, numpy.array("AVENUE"))
array([1], dtype=uint8)
Traceback (most recent call last):
TypeError: iteration over a 0-d array
>>> _str_to_index(Road, numpy.array(["AVENUE"]))
array([1], dtype=uint8)
Expand All @@ -174,31 +150,12 @@ def _str_to_index(enum_class: type[t.Enum], value: t.StrArray) -> t.IndexArray:
array([0, 1, 0], dtype=uint8)
>>> _str_to_index(Road, numpy.array(["AVENUE", "AVENUE", "BOULEVARD"]))
Traceback (most recent call last):
EnumMemberNotFoundError: Member BOULEVARD not found in enum 'Road'...
array([1, 1, 0], dtype=uint8)
"""
# Create a mask to determine which values are in the enum class.
mask = numpy.isin(value, enum_class.names)

# Get the values that are not in the enum class.
ko = value[~mask]

# If there are values that are not in the enum class, raise an error.
if ko.size > 0:
raise EnumMemberNotFoundError(enum_class, ko[0])

# In case we're dealing with a scalar, we need to convert it to an array.
ok = value[mask]

# Get the index positions of the names in the sorted names.
index_where = numpy.searchsorted(enum_class._sorted_names_, ok)

# Get the actual index of the names in the enum class.
index = enum_class._sorted_names_index_[index_where]

# Finally, return the index array.
return numpy.array(index, dtype=t.EnumDType)
names = enum_class.names
index = [enum_class[name].index if name in names else 0 for name in value]
return _int_to_index(enum_class, numpy.array(index))


__all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"]
68 changes: 41 additions & 27 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import types as t
from ._enum_type import EnumType
from ._errors import EnumEncodingError
from ._guards import _is_int_array, _is_obj_array, _is_str_array
from ._utils import _enum_to_index, _int_to_index, _str_to_index
from .enum_array import EnumArray
Expand Down Expand Up @@ -111,37 +112,55 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"

def __hash__(self) -> int:
return hash(self.__class__.__name__) ^ hash(self.index)
return object.__hash__(self)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return hash(self) ^ hash(other) == 0
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index == other.index
return NotImplemented

def __ne__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return hash(self) ^ hash(other) != 0
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index != other.index
return NotImplemented

def __lt__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return self.index < other.index
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index < other.index
return NotImplemented

def __le__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return self.index <= other.index
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index <= other.index
return NotImplemented

def __gt__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return self.index > other.index
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index > other.index
return NotImplemented

def __ge__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented
return self.index >= other.index
if (
isinstance(other, Enum)
and self.__class__.__name__ == other.__class__.__name__
):
return self.index >= other.index
return NotImplemented

@classmethod
def encode(
Expand All @@ -167,7 +186,7 @@ def encode(
Raises:
NotImplementedError: If ``array`` is a scalar :class:`~numpy.ndarray`.
TypeError: If ``array`` is of a diffent :class:`.Enum` type.
EnumEncodingError: If ``array`` is of a diffent :class:`.Enum` type.
Examples:
>>> import numpy
Expand Down Expand Up @@ -211,7 +230,7 @@ def encode(
>>> array = numpy.array([b"TENANT"])
>>> enum_array = Housing.encode(array)
Traceback (most recent call last):
TypeError: Failed to encode "[b'TENANT']" of type 'bytes_', as i...
EnumEncodingError: Failed to encode "[b'TENANT']" of type 'bytes...
.. seealso::
:meth:`.EnumArray.decode` for decoding.
Expand Down Expand Up @@ -259,12 +278,7 @@ def encode(
if _is_obj_array(array) and cls.__name__ is array[0].__class__.__name__:
return EnumArray(_enum_to_index(cls, array), cls)

msg = (
f"Failed to encode \"{array}\" of type '{array[0].__class__.__name__}', "
"as it is not supported. Please, try again with an array of "
f"'{int.__name__}', '{str.__name__}', or '{cls.__name__}'."
)
raise TypeError(msg)
raise EnumEncodingError(cls, array)


__all__ = ["Enum"]
6 changes: 3 additions & 3 deletions openfisca_core/indexed_enums/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def test_enum_encode_with_str_scalar_array():


def test_enum_encode_with_str_with_bad_value():
"""Does not encode when called with a value not in an Enum."""
"""Encode encode when called with a value not in an Enum."""
array = numpy.array(["JAIBA"])
with pytest.raises(IndexError):
Animal.encode(array)
enum_array = Animal.encode(array)
assert Animal.CAT in enum_array


# Unsupported encodings
Expand Down
4 changes: 2 additions & 2 deletions openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy
from numpy import (
bool_ as BoolDType,
generic as AnyDType,
generic as VarDType,
int32 as IntDType,
object_ as ObjDType,
str_ as StrDType,
Expand Down Expand Up @@ -44,7 +44,7 @@
ObjArray: TypeAlias = Array[ObjDType]

#: Type for generic arrays.
AnyArray: TypeAlias = Array[AnyDType]
VarArray: TypeAlias = Array[VarDType]

__all__ = [
"ArrayLike",
Expand Down
4 changes: 0 additions & 4 deletions openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def plural(self, /) -> None | RolePlural: ...

class EnumType(enum.EnumMeta):
items: RecArray
_sorted_names_: Array[DTypeStr]
_sorted_enums_: Array[DTypeObject]
_sorted_names_index_: Array[DTypeEnum]
_sorted_enums_index_: Array[DTypeEnum]

@property
@abc.abstractmethod
Expand Down

0 comments on commit 573eb2f

Please sign in to comment.