Skip to content

Commit

Permalink
refactor(enums): improve performance enum array (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 14, 2024
1 parent 02c0576 commit ef85e1f
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 84 deletions.
2 changes: 2 additions & 0 deletions openfisca_core/indexed_enums/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Enumerations for variables with a limited set of possible values."""

from . import types
from ._enum_type import EnumType
from ._errors import EnumEncodingError, EnumMemberNotFoundError
from .config import ENUM_ARRAY_DTYPE
from .enum import Enum
Expand All @@ -12,5 +13,6 @@
"EnumArray",
"EnumEncodingError",
"EnumMemberNotFoundError",
"EnumType",
"types",
]
70 changes: 70 additions & 0 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import final

import numpy

from . import types as t


@final
class EnumType(t.EnumType):
"""Meta class for creating an indexed :class:`.Enum`.
Examples:
>>> from openfisca_core import indexed_enums as enum
>>> class Enum(enum.Enum, metaclass=enum.EnumType):
... pass
>>> Enum.items
Traceback (most recent call last):
AttributeError: ...
>>> class Housing(Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> Housing.indices
array([0, 1], dtype=uint8)
>>> Housing.names
array(['OWNER', 'TENANT'], dtype='<U6')
>>> Housing.enums
array([Housing.OWNER, Housing.TENANT], dtype=object)
"""

def __new__(
metacls,
name: str,
bases: tuple[type, ...],
classdict: t.EnumDict,
**kwds: object,
) -> t.EnumType:
"""Create a new indexed enum class."""
# Create the enum class.
cls = super().__new__(metacls, name, bases, classdict, **kwds)

# If the enum class has no members, return it as is.
if not cls.__members__:
return cls

# Add the indices attribute to the enum class.
cls.indices = numpy.arange(len(cls), dtype=t.EnumDType)

# Add the names attribute to the enum class.
cls.names = numpy.array(cls._member_names_, dtype=t.StrDType)

# Add the enums attribute to the enum class.
cls.enums = numpy.array(cls, dtype=t.ObjDType)

# Return the modified enum class.
return cls

def __dir__(cls) -> list[str]:
return sorted({"indices", "names", "enums", *super().__dir__()})


__all__ = ["EnumType"]
12 changes: 9 additions & 3 deletions openfisca_core/indexed_enums/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _is_enum_array(array: t.VarArray) -> TypeIs[t.ObjArray]:
return array.dtype.type in objs


def _is_enum_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[t.Enum]]:
def _is_enum_array_like(
array: t.VarArray | t.ArrayLike[object],
) -> TypeIs[t.ArrayLike[t.Enum]]:
"""Narrow the type of a given array-like to an sequence of :class:`.Enum`.
Args:
Expand Down Expand Up @@ -109,7 +111,9 @@ def _is_int_array(array: t.VarArray) -> TypeIs[t.IndexArray]:
return array.dtype.type in ints


def _is_int_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[int]]:
def _is_int_array_like(
array: t.VarArray | t.ArrayLike[object],
) -> TypeIs[t.ArrayLike[int]]:
"""Narrow the type of a given array-like to a sequence of :obj:`int`.
Args:
Expand Down Expand Up @@ -165,7 +169,9 @@ def _is_str_array(array: t.VarArray) -> TypeIs[t.StrArray]:
return array.dtype.type in strs


def _is_str_array_like(array: t.ArrayLike[object]) -> TypeIs[t.ArrayLike[str]]:
def _is_str_array_like(
array: t.VarArray | t.ArrayLike[object],
) -> TypeIs[t.ArrayLike[str]]:
"""Narrow the type of a given array-like to an sequence of :obj:`str`.
Args:
Expand Down
92 changes: 29 additions & 63 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy

from . import types as t
from ._enum_type import EnumType
from ._errors import EnumEncodingError, EnumMemberNotFoundError
from ._guards import (
_is_enum_array,
Expand All @@ -18,7 +19,7 @@
from .enum_array import EnumArray


class Enum(t.Enum):
class Enum(t.Enum, metaclass=EnumType):
"""Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_.
Its items have an :class:`int` index, useful and performant when running
Expand Down Expand Up @@ -148,11 +149,6 @@ def encode(cls, array: t.VarArray | t.ArrayLike[object]) -> t.EnumArray:
Returns:
EnumArray: An :class:`.EnumArray` with the encoded input values.
Raises:
EnumEncodingError: If ``array`` is of diffent :class:`.Enum` type.
EnumMemberNotFoundError: If members are not found in :class:`.Enum`.
NotImplementedError: If ``array`` is a scalar :class:`~numpy.ndarray`.
Examples:
>>> import numpy
Expand Down Expand Up @@ -201,70 +197,40 @@ def encode(cls, array: t.VarArray | t.ArrayLike[object]) -> t.EnumArray:
:meth:`.EnumArray.decode` for decoding.
"""
# Array of indices
indices: t.IndexArray

if isinstance(array, EnumArray):
return array

# Array-like
if len(array) == 0:
return EnumArray(numpy.asarray(array, t.EnumDType), cls)
if isinstance(array, Sequence):
if len(array) == 0:
indices = numpy.array([], t.EnumDType)

elif _is_int_array_like(array):
indices = _int_to_index(cls, array)

elif _is_str_array_like(array):
indices = _str_to_index(cls, array)

elif _is_enum_array_like(array):
indices = _enum_to_index(array)

else:
raise EnumEncodingError(cls, array)
return cls._encode_array_like(array)
return cls._encode_array(array)

@classmethod
def _encode_array(cls, value: t.VarArray) -> t.EnumArray:
if _is_int_array(value):
indices = _int_to_index(cls, value)
elif _is_str_array(value): # type: ignore[unreachable]
indices = _str_to_index(cls, value)
elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__:
indices = _enum_to_index(value)
else:
# Scalar arrays are not supported.
if array.ndim == 0:
msg = (
"Scalar arrays are not supported: expecting a vector array, "
f"instead. Please try again with `numpy.array([{array}])`."
)
raise NotImplementedError(msg)

# Empty arrays are returned as is.
if array.size == 0:
indices = numpy.array([], t.EnumDType)

# Index arrays.
elif _is_int_array(array):
indices = _int_to_index(cls, array)

# String arrays.
elif _is_str_array(array): # type: ignore[unreachable]
indices = _str_to_index(cls, array)

# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
elif _is_enum_array(array) and cls.__name__ is array[0].__class__.__name__:
indices = _enum_to_index(array)

else:
raise EnumEncodingError(cls, array)

if indices.size != len(array):
raise EnumEncodingError(cls, value)
if indices.size != len(value):
raise EnumMemberNotFoundError(cls)
return EnumArray(indices, cls)

@classmethod
def _encode_array_like(cls, value: t.ArrayLike[object]) -> t.EnumArray:
if _is_int_array_like(value):
indices = _int_to_index(cls, value)
elif _is_str_array_like(value): # type: ignore[unreachable]
indices = _str_to_index(cls, value)
elif _is_enum_array_like(value):
indices = _enum_to_index(value)
else:
raise EnumEncodingError(cls, value)
if indices.size != len(value):
raise EnumMemberNotFoundError(cls)
return EnumArray(indices, cls)


Expand Down
26 changes: 15 additions & 11 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class EnumArray(t.EnumArray):
"""

#: Enum type of the array items.
possible_values: None | type[t.Enum] = None
possible_values: None | type[t.Enum]

def __new__(
cls,
Expand Down Expand Up @@ -157,8 +157,12 @@ def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override]
isinstance(other, type(t.Enum))
and other.__name__ is self.possible_values.__name__
):
index = numpy.array([enum.index for enum in self.possible_values])
result = self.view(numpy.ndarray) == index[index <= max(self)]
result = (
self.view(numpy.ndarray)
== self.possible_values.indices[
self.possible_values.indices <= max(self)
]
)
return result
if (
isinstance(other, t.Enum)
Expand Down Expand Up @@ -265,16 +269,16 @@ def decode(self) -> t.ObjArray:
array([Housing.TENANT], dtype=object)
"""
result: t.ObjArray
if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
f"not defined."
)
raise TypeError(msg)
return numpy.select(
[self == item.index for item in self.possible_values],
list(self.possible_values), # pyright: ignore[reportArgumentType]
)
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
result = self.possible_values.enums[array]
return result

def decode_to_str(self) -> t.StrArray:
"""Decode itself to an array of strings.
Expand All @@ -300,16 +304,16 @@ def decode_to_str(self) -> t.StrArray:
array(['TENANT'], dtype='<U6')
"""
result: t.StrArray
if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
f"not defined."
)
raise TypeError(msg)
return numpy.select(
[self == item.index for item in self.possible_values],
[item.name for item in self.possible_values],
)
array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self
result = self.possible_values.names[array]
return result

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.decode()!s})"
Expand Down
8 changes: 4 additions & 4 deletions openfisca_core/indexed_enums/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_enum_encode_with_enum_sequence():
def test_enum_encode_with_enum_scalar_array():
"""Does not encode when called with an enum scalar array."""
array = numpy.array(Animal.DOG)
with pytest.raises(NotImplementedError):
with pytest.raises(TypeError):
Animal.encode(array)


Expand Down Expand Up @@ -67,7 +67,7 @@ def test_enum_encode_with_int_sequence():
def test_enum_encode_with_int_scalar_array():
"""Does not encode when called with an int scalar array."""
array = numpy.array(1)
with pytest.raises(NotImplementedError):
with pytest.raises(TypeError):
Animal.encode(array)


Expand Down Expand Up @@ -98,7 +98,7 @@ def test_enum_encode_with_str_sequence():
def test_enum_encode_with_str_scalar_array():
"""Does not encode when called with a str scalar array."""
array = numpy.array("DOG")
with pytest.raises(NotImplementedError):
with pytest.raises(TypeError):
Animal.encode(array)


Expand All @@ -124,7 +124,7 @@ def test_enum_encode_with_any_scalar_array():
"""Does not encode when called with unsupported types."""
value = 1.5
array = numpy.array(value)
with pytest.raises(NotImplementedError):
with pytest.raises(TypeError):
Animal.encode(array)


Expand Down
6 changes: 5 additions & 1 deletion openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing_extensions import TypeAlias

from openfisca_core.types import Array, ArrayLike, DTypeLike, Enum, EnumArray
from openfisca_core.types import Array, ArrayLike, DTypeLike, Enum, EnumArray, EnumType

from enum import _EnumDict as EnumDict # noqa: PLC2701

from numpy import (
bool_ as BoolDType,
Expand Down Expand Up @@ -34,4 +36,6 @@
"DTypeLike",
"Enum",
"EnumArray",
"EnumDict",
"EnumType",
]
10 changes: 8 additions & 2 deletions openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ def plural(self, /) -> None | RolePlural: ...
# Indexed enums


class Enum(enum.Enum, metaclass=enum.EnumMeta):
class EnumType(enum.EnumMeta):
indices: Array[DTypeEnum]
names: Array[DTypeStr]
enums: Array[DTypeObject]


class Enum(enum.Enum, metaclass=EnumType):
index: int
_member_names_: list[str]

Expand All @@ -118,7 +124,7 @@ class EnumArray(Array[DTypeEnum], metaclass=abc.ABCMeta):

@abc.abstractmethod
def __new__(
cls, input_array: Array[DTypeEnum], possible_values: None | type[Enum] = ...
cls, input_array: Array[DTypeEnum], possible_values: type[Enum]
) -> Self: ...


Expand Down

0 comments on commit ef85e1f

Please sign in to comment.