Skip to content

Commit

Permalink
refactor(enums): fix linter warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 5, 2024
1 parent 4515480 commit 6480a2c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
14 changes: 9 additions & 5 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def encode(
return array

# String array
if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}:
if array.dtype.kind in {"U", "S"}:
array = numpy.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(numpy.int16)

# Enum items arrays
elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O":
elif array.dtype.kind == "O":
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
Expand All @@ -84,14 +84,18 @@ def encode(
# name to check that the values in the array, if non-empty, are of
# the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
klass = array[0].__class__

else:
klass = cls

array = numpy.select(
[array == item for item in cls],
[item.index for item in cls],
[array == item for item in klass],
[item.index for item in klass],
).astype(numpy.int16)

array = numpy.asarray(array, dtype=numpy.int16)
return EnumArray(array, cls)


__all__ = ["Enum"]
18 changes: 13 additions & 5 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from typing import NoReturn, overload
from typing_extensions import TypeGuard, Self
from typing_extensions import Self, TypeGuard

import numpy

Expand Down Expand Up @@ -149,18 +149,26 @@ def _is_an_enum(self, other: object) -> TypeGuard[t.Enum]:
if self.possible_values is None:
raise NotImplementedError

if other is None:
raise NotImplementedError

return (
not hasattr(other, "__name__")
and other.__class__.__name__ is self.possible_values.__name__
)

def _is_an_enum_type(self, other: object) -> TypeGuard[type[t.Enum]]:
name: None | str

if self.possible_values is None:
raise NotImplementedError

return (
hasattr(other, "__name__")
and other.__name__ is self.possible_values.__name__
)
if other is None:
raise NotImplementedError

name = getattr(other, "__name__", None)

return isinstance(name, str) and name is self.possible_values.__name__


__all__ = ["EnumArray"]

0 comments on commit 6480a2c

Please sign in to comment.