Skip to content

Commit

Permalink
fix(enums): preserve order in encode (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 10, 2024
1 parent c54d3a5 commit 62c29b7
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion openfisca_core/data_storage/on_disk_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _decode_file(self, file: str) -> t.Array[t.DTypeGeneric]:
... storage = data_storage.OnDiskStorage(directory)
... storage.put(value, period)
... storage._decode_file(storage._files[period])
EnumArray(Housing.TENANT)
EnumArray([Housing.TENANT])
"""
enum = self._enums.get(file)
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class EnumType(t.EnumType):
rec.array([(0, 'OWNER', Housing.OWNER), (1, 'TENANT', Housing.TENAN...)
>>> Housing.indices
array([0, 1], dtype=int16)
array([0, 1], dtype=uint8)
>>> Housing.names
array(['OWNER', 'TENANT'], dtype='<U6')
Expand Down
20 changes: 11 additions & 9 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ def __hash__(self) -> int:
def encode(
cls,
array: (
EnumArray
t.EnumArray
| t.IntArray
| t.StrArray
| t.ObjArray
| t.ArrayLike[int]
| t.ArrayLike[str]
| t.ArrayLike[t.Enum]
),
) -> EnumArray:
) -> t.EnumArray:
"""Encode an encodable array into an :class:`.EnumArray`.
Args:
Expand All @@ -161,7 +161,7 @@ def encode(
>>> array = numpy.array([1])
>>> enum_array = enum.EnumArray(array, Housing)
>>> Housing.encode(enum_array)
EnumArray(Housing.TENANT)
EnumArray([Housing.TENANT])
# Array of Enum
Expand Down Expand Up @@ -213,13 +213,14 @@ def encode(

# Integer array
if _is_int_array(array):
indices = numpy.array(array[array < len(cls.items)], dtype=t.EnumDType)
return EnumArray(indices, cls)
indices = numpy.array(array[array < cls.indices.size])
return EnumArray(indices.astype(t.EnumDType), cls)

# String array
if _is_str_array(array): # type: ignore[unreachable]
indices = cls.items[numpy.isin(cls.names, array)].index
return EnumArray(indices, cls)
names = array[numpy.isin(array, cls.names)]
indices = numpy.array([cls[name].index for name in names])
return EnumArray(indices.astype(t.EnumDType), cls)

# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
Expand All @@ -233,8 +234,9 @@ def encode(
# name to check that the values in the array, if non-empty, are of
# the right type.
if cls.__name__ is array[0].__class__.__name__:
indices = cls.items[numpy.isin(cls.enums, array)].index
return EnumArray(indices, cls)
enums = array[numpy.isin(array, cls.enums)]
indices = numpy.array([enum.index for enum in enums])
return EnumArray(indices.astype(t.EnumDType), cls)

msg = (
f"Failed to encode \"{array}\" of type '{array[0].__class__.__name__}', "
Expand Down
9 changes: 3 additions & 6 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class EnumArray(t.EnumArray):
"<class 'openfisca_core.indexed_enums.enum_array.EnumArray'>"
>>> repr(enum_array)
'EnumArray(Housing.TENANT)'
'EnumArray([Housing.TENANT])'
>>> str(enum_array)
"['TENANT']"
Expand All @@ -62,7 +62,7 @@ class EnumArray(t.EnumArray):
... possible_values = Housing
>>> enum.EnumArray(array, OccupancyStatus.possible_values)
EnumArray(Housing.TENANT)
EnumArray([Housing.TENANT])
.. _Subclassing ndarray:
https://numpy.org/doc/stable/user/basics.subclassing.html
Expand Down Expand Up @@ -270,7 +270,6 @@ def decode(self) -> t.ObjArray:
"""
result: t.ObjArray

if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
Expand Down Expand Up @@ -307,7 +306,6 @@ def decode_to_str(self) -> t.StrArray:
"""
result: t.StrArray

if self.possible_values is None:
msg = (
f"The possible values of the {self.__class__.__name__} are "
Expand All @@ -320,8 +318,7 @@ def decode_to_str(self) -> t.StrArray:
return result

def __repr__(self) -> str:
items = ", ".join(str(item) for item in self.decode())
return f"{self.__class__.__name__}({items})"
return f"{self.__class__.__name__}({self.decode()!s})"

def __str__(self) -> str:
return str(self.decode_to_str())
Expand Down
13 changes: 7 additions & 6 deletions openfisca_core/indexed_enums/tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy
import pytest
from numpy.testing import assert_array_equal

from openfisca_core import indexed_enums as enum

Expand All @@ -20,9 +21,9 @@ class Colour(enum.Enum):

def test_enum_encode_with_array_of_enum():
"""Does encode when called with an array of enums."""
array = numpy.array([Animal.DOG])
array = numpy.array([Animal.DOG, Animal.DOG, Animal.CAT, Colour.AMARANTH])
enum_array = Animal.encode(array)
assert enum_array == Animal.DOG
assert_array_equal(enum_array, numpy.array([1, 1, 0]))


def test_enum_encode_with_enum_sequence():
Expand Down Expand Up @@ -51,9 +52,9 @@ def test_enum_encode_with_enum_with_bad_value():

def test_enum_encode_with_array_of_int():
"""Does encode when called with an array of int."""
array = numpy.array([1])
array = numpy.array([1, 1, 0, 2])
enum_array = Animal.encode(array)
assert enum_array == Animal.DOG
assert_array_equal(enum_array, numpy.array([1, 1, 0]))


def test_enum_encode_with_int_sequence():
Expand Down Expand Up @@ -82,9 +83,9 @@ def test_enum_encode_with_int_with_bad_value():

def test_enum_encode_with_array_of_string():
"""Does encode when called with an array of string."""
array = numpy.array(["DOG"])
array = numpy.array(["DOG", "DOG", "CAT", "AMARANTH"])
enum_array = Animal.encode(array)
assert enum_array == Animal.DOG
assert_array_equal(enum_array, numpy.array([1, 1, 0]))


def test_enum_encode_with_str_sequence():
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from numpy import (
bool_ as BoolDType,
generic as AnyDType,
int16 as EnumDType,
int32 as IntDType,
object_ as ObjDType,
str_ as StrDType,
uint8 as EnumDType,
)

#: Type for the non-vectorised list of enum items.
Expand Down
2 changes: 1 addition & 1 deletion openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
DTypeBytes: TypeAlias = numpy.bytes_

#: Type for Enum arrays.
DTypeEnum: TypeAlias = numpy.int16
DTypeEnum: TypeAlias = numpy.uint8

#: Type for date arrays.
DTypeDate: TypeAlias = numpy.datetime64
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"PyYAML >=6.0, <7.0",
"StrEnum >=0.4.8, <0.5.0", # 3.11.x backport
"dpath >=2.1.4, <3.0",
"numexpr >=2.8.4, <3.0",
"numpy >=1.24.2, <2.0",
"numexpr >=2.10.0, <2.10.1",
"numpy >=1.24.2, <1.26.4",
"pendulum >=3.0.0, <4.0.0",
"psutil >=5.9.4, <6.0",
"pytest >=8.3.3, <9.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def test_log_aggregate_with_enum(tracer) -> None:

assert (
lines[0]
== " A<2017> >> {'avg': EnumArray(HousingOccupancyStatus.tenant), 'max': EnumArray(HousingOccupancyStatus.tenant), 'min': EnumArray(HousingOccupancyStatus.tenant)}"
== " A<2017> >> {'avg': EnumArray([HousingOccupancyStatus.tenant]), 'max': EnumArray([HousingOccupancyStatus.tenant]), 'min': EnumArray([HousingOccupancyStatus.tenant])}"
)


Expand Down
2 changes: 1 addition & 1 deletion tests/core/tools/test_assert_near.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def test_enum_2(tax_benefit_system) -> None:
"housing_occupancy_status"
].possible_values
value = possible_values.encode(numpy.array(["tenant", "owner"]))
expected_value = ["owner", "tenant"]
expected_value = ["tenant", "owner"]
assert_near(value, expected_value)

0 comments on commit 62c29b7

Please sign in to comment.