Skip to content

Commit

Permalink
Merge pull request #188 from MaxGhenis/MaxGhenis/issue100
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGhenis authored Apr 29, 2024
2 parents 4c67661 + 6803f6e commit a1495e3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 77 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- FutureWarning issue with bools and enums.
123 changes: 46 additions & 77 deletions policyengine_core/enums/enum.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,79 @@
from __future__ import annotations

import enum
from typing import Union

import numpy

import numpy as np
from .config import ENUM_ARRAY_DTYPE
from .enum_array import EnumArray

import warnings

warnings.simplefilter("ignore", category=FutureWarning)


class Enum(enum.Enum):
"""
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items
have an index.
"""

# Tweak enums to add an index attribute to each enum item
def __init__(self, name: str) -> None:
# When the enum item is initialized, self._member_names_ contains the
# names of the previously initialized items, so its length is the index
# of this item.
"""
Initialize an Enum item with a name and an index.
The index is automatically assigned based on the order of the Enum items.
"""
self.index = len(self._member_names_)

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__

# In Python 3, __hash__ must be defined if __eq__ is defined to stay
# hashable.
__hash__ = object.__hash__

@classmethod
def encode(
cls,
array: Union[
EnumArray,
numpy.int_,
numpy.float_,
numpy.object_,
],
) -> EnumArray:
def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
"""
Encode a string numpy array, an enum item numpy array, or an int numpy
array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
decoding.
:param numpy.ndarray array: Array of string identifiers, or of enum
items, to encode.
Encode an array of enum items or string identifiers into an EnumArray.
:returns: An :any:`EnumArray` encoding the input array values.
:rtype: :any:`EnumArray`
Args:
array: The input array to encode. Can be an EnumArray, a NumPy array
of enum items, or a NumPy array of string identifiers.
For instance:
Returns:
An EnumArray containing the encoded values.
>>> string_identifier_array = asarray(['free_lodger', 'owner'])
>>> encoded_array = HousingOccupancyStatus.encode(string_identifier_array)
>>> encoded_array[0]
2 # Encoded value
Examples:
>>> string_array = np.array(["ITEM_1", "ITEM_2", "ITEM_3"])
>>> encoded_array = MyEnum.encode(string_array)
>>> encoded_array
EnumArray([1, 2, 3], dtype=int8)
>>> free_lodger = HousingOccupancyStatus.free_lodger
>>> owner = HousingOccupancyStatus.owner
>>> enum_item_array = asarray([free_lodger, owner])
>>> encoded_array = HousingOccupancyStatus.encode(enum_item_array)
>>> encoded_array[0]
2 # Encoded value
>>> item_array = np.array([MyEnum.ITEM_1, MyEnum.ITEM_2, MyEnum.ITEM_3])
>>> encoded_array = MyEnum.encode(item_array)
>>> encoded_array
EnumArray([1, 2, 3], dtype=int8)
"""
if isinstance(array, EnumArray):
return array

if isinstance(array == 0, bool):
if array.dtype.kind == "b":
# Convert boolean array to string array
array = array.astype(str)

# String array
if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}:
array = numpy.select(
if array.dtype.kind in {"U", "S"}:
# String array
indices = np.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)

# Enum items arrays
elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O":
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
if array[0].__class__.__name__ != "bytes":
array = numpy.select(
[array == item for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)
else:
array = numpy.select(
[array.astype(str) == item.name for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)

return EnumArray(array, cls)
)
elif array.dtype.kind == "O":
# Enum items array
if len(array) > 0:
first_item = array[0]
if cls.__name__ == type(first_item).__name__:
# Use the same Enum class as the array items
cls = type(first_item)
indices = np.select(
[array == item for item in cls],
[item.index for item in cls],
)
elif array.dtype.kind in {"i", "u"}:
# Integer array
indices = array
else:
raise ValueError(f"Unsupported array dtype: {array.dtype}")

return EnumArray(indices.astype(ENUM_ARRAY_DTYPE), cls)

0 comments on commit a1495e3

Please sign in to comment.