-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #188 from MaxGhenis/MaxGhenis/issue100
- Loading branch information
Showing
2 changed files
with
50 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
- bump: patch | ||
changes: | ||
fixed: | ||
- FutureWarning issue with bools and enums. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,79 @@ | ||
from __future__ import annotations | ||
|
||
import enum | ||
from typing import Union | ||
|
||
import numpy | ||
|
||
import numpy as np | ||
from .config import ENUM_ARRAY_DTYPE | ||
from .enum_array import EnumArray | ||
|
||
import warnings | ||
|
||
warnings.simplefilter("ignore", category=FutureWarning) | ||
|
||
|
||
class Enum(enum.Enum): | ||
""" | ||
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items | ||
have an index. | ||
""" | ||
|
||
# Tweak enums to add an index attribute to each enum item | ||
def __init__(self, name: str) -> None: | ||
# When the enum item is initialized, self._member_names_ contains the | ||
# names of the previously initialized items, so its length is the index | ||
# of this item. | ||
""" | ||
Initialize an Enum item with a name and an index. | ||
The index is automatically assigned based on the order of the Enum items. | ||
""" | ||
self.index = len(self._member_names_) | ||
|
||
# Bypass the slow Enum.__eq__ | ||
__eq__ = object.__eq__ | ||
|
||
# In Python 3, __hash__ must be defined if __eq__ is defined to stay | ||
# hashable. | ||
__hash__ = object.__hash__ | ||
|
||
@classmethod | ||
def encode( | ||
cls, | ||
array: Union[ | ||
EnumArray, | ||
numpy.int_, | ||
numpy.float_, | ||
numpy.object_, | ||
], | ||
) -> EnumArray: | ||
def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: | ||
""" | ||
Encode a string numpy array, an enum item numpy array, or an int numpy | ||
array into an :any:`EnumArray`. See :any:`EnumArray.decode` for | ||
decoding. | ||
:param numpy.ndarray array: Array of string identifiers, or of enum | ||
items, to encode. | ||
Encode an array of enum items or string identifiers into an EnumArray. | ||
:returns: An :any:`EnumArray` encoding the input array values. | ||
:rtype: :any:`EnumArray` | ||
Args: | ||
array: The input array to encode. Can be an EnumArray, a NumPy array | ||
of enum items, or a NumPy array of string identifiers. | ||
For instance: | ||
Returns: | ||
An EnumArray containing the encoded values. | ||
>>> string_identifier_array = asarray(['free_lodger', 'owner']) | ||
>>> encoded_array = HousingOccupancyStatus.encode(string_identifier_array) | ||
>>> encoded_array[0] | ||
2 # Encoded value | ||
Examples: | ||
>>> string_array = np.array(["ITEM_1", "ITEM_2", "ITEM_3"]) | ||
>>> encoded_array = MyEnum.encode(string_array) | ||
>>> encoded_array | ||
EnumArray([1, 2, 3], dtype=int8) | ||
>>> free_lodger = HousingOccupancyStatus.free_lodger | ||
>>> owner = HousingOccupancyStatus.owner | ||
>>> enum_item_array = asarray([free_lodger, owner]) | ||
>>> encoded_array = HousingOccupancyStatus.encode(enum_item_array) | ||
>>> encoded_array[0] | ||
2 # Encoded value | ||
>>> item_array = np.array([MyEnum.ITEM_1, MyEnum.ITEM_2, MyEnum.ITEM_3]) | ||
>>> encoded_array = MyEnum.encode(item_array) | ||
>>> encoded_array | ||
EnumArray([1, 2, 3], dtype=int8) | ||
""" | ||
if isinstance(array, EnumArray): | ||
return array | ||
|
||
if isinstance(array == 0, bool): | ||
if array.dtype.kind == "b": | ||
# Convert boolean array to string array | ||
array = array.astype(str) | ||
|
||
# String array | ||
if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}: | ||
array = numpy.select( | ||
if array.dtype.kind in {"U", "S"}: | ||
# String array | ||
indices = np.select( | ||
[array == item.name for item in cls], | ||
[item.index for item in cls], | ||
).astype(ENUM_ARRAY_DTYPE) | ||
|
||
# Enum items arrays | ||
elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O": | ||
# Ensure we are comparing the comparable. The problem this fixes: | ||
# On entering this method "cls" will generally come from | ||
# variable.possible_values, while the array values may come from | ||
# directly importing a module containing an Enum class. However, | ||
# variables (and hence their possible_values) are loaded by a call | ||
# to load_module, which gives them a different identity from the | ||
# ones imported in the usual way. | ||
# | ||
# So, instead of relying on the "cls" passed in, we use only its | ||
# name to check that the values in the array, if non-empty, are of | ||
# the right type. | ||
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__: | ||
cls = array[0].__class__ | ||
if array[0].__class__.__name__ != "bytes": | ||
array = numpy.select( | ||
[array == item for item in cls], | ||
[item.index for item in cls], | ||
).astype(ENUM_ARRAY_DTYPE) | ||
else: | ||
array = numpy.select( | ||
[array.astype(str) == item.name for item in cls], | ||
[item.index for item in cls], | ||
).astype(ENUM_ARRAY_DTYPE) | ||
|
||
return EnumArray(array, cls) | ||
) | ||
elif array.dtype.kind == "O": | ||
# Enum items array | ||
if len(array) > 0: | ||
first_item = array[0] | ||
if cls.__name__ == type(first_item).__name__: | ||
# Use the same Enum class as the array items | ||
cls = type(first_item) | ||
indices = np.select( | ||
[array == item for item in cls], | ||
[item.index for item in cls], | ||
) | ||
elif array.dtype.kind in {"i", "u"}: | ||
# Integer array | ||
indices = array | ||
else: | ||
raise ValueError(f"Unsupported array dtype: {array.dtype}") | ||
|
||
return EnumArray(indices.astype(ENUM_ARRAY_DTYPE), cls) |