Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve commons module typing #1219

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 41.5.4 [#1219](https://github.com/openfisca/openfisca-core/pull/1219)

#### Technical changes

- Fix doc & type definitions in the commons module

### 41.5.3 [#1218](https://github.com/openfisca/openfisca-core/pull/1218)

#### Technical changes
Expand Down
32 changes: 17 additions & 15 deletions openfisca_core/commons/formulas.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Any, Dict, Sequence, TypeVar

from openfisca_core.types import Array, ArrayLike
from collections.abc import Mapping
from typing import Union

import numpy

T = TypeVar("T")
from openfisca_core import types as t


def apply_thresholds(
input: Array[float],
thresholds: ArrayLike[float],
choices: ArrayLike[float],
) -> Array[float]:
input: t.Array[numpy.float_],
thresholds: t.ArrayLike[float],
choices: t.ArrayLike[float],
) -> t.Array[numpy.float_]:
"""Makes a choice based on an input and thresholds.

From a list of ``choices``, this function selects one of these values
Expand Down Expand Up @@ -40,7 +39,7 @@ def apply_thresholds(

"""

condlist: Sequence[Array[bool]]
condlist: list[Union[t.Array[numpy.bool_], bool]]
condlist = [input <= threshold for threshold in thresholds]

if len(condlist) == len(choices) - 1:
Expand All @@ -58,15 +57,18 @@ def apply_thresholds(
return numpy.select(condlist, choices)


def concat(this: ArrayLike[str], that: ArrayLike[str]) -> Array[str]:
def concat(
this: Union[t.Array[numpy.str_], t.ArrayLike[str]],
that: Union[t.Array[numpy.str_], t.ArrayLike[str]],
) -> t.Array[numpy.str_]:
"""Concatenates the values of two arrays.

Args:
this: An array to concatenate.
that: Another array to concatenate.

Returns:
:obj:`numpy.ndarray` of :obj:`float`:
:obj:`numpy.ndarray` of :obj:`numpy.str_`:
An array with the concatenated values.

Examples:
Expand All @@ -87,9 +89,9 @@ def concat(this: ArrayLike[str], that: ArrayLike[str]) -> Array[str]:


def switch(
conditions: Array[Any],
value_by_condition: Dict[float, T],
) -> Array[T]:
conditions: t.Array[numpy.float_],
value_by_condition: Mapping[float, float],
) -> t.Array[numpy.float_]:
"""Mimicks a switch statement.

Given an array of conditions, returns an array of the same size,
Expand Down Expand Up @@ -120,4 +122,4 @@ def switch(

condlist = [conditions == condition for condition in value_by_condition.keys()]

return numpy.select(condlist, value_by_condition.values())
return numpy.select(condlist, tuple(value_by_condition.values()))
8 changes: 5 additions & 3 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TypeVar
from typing import Optional, TypeVar

from openfisca_core.types import Array
import numpy

from openfisca_core import types as t

T = TypeVar("T")

Expand Down Expand Up @@ -43,7 +45,7 @@ def empty_clone(original: T) -> T:
return new


def stringify_array(array: Array) -> str:
def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str:
"""Generates a clean string representation of a numpy array.

Args:
Expand Down
Empty file added openfisca_core/commons/py.typed
Empty file.
14 changes: 7 additions & 7 deletions openfisca_core/commons/rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


def average_rate(
target: Array[float],
target: Array[numpy.float_],
varying: ArrayLike[float],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
) -> Array[numpy.float_]:
"""Computes the average rate of a target net income.

Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -41,7 +41,7 @@ def average_rate(

"""

average_rate: Array[float]
average_rate: Array[numpy.float_]

average_rate = 1 - target / varying

Expand All @@ -62,10 +62,10 @@ def average_rate(


def marginal_rate(
target: Array[float],
varying: Array[float],
target: Array[numpy.float_],
varying: Array[numpy.float_],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
) -> Array[numpy.float_]:
"""Computes the marginal rate of a target net income.

Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -97,7 +97,7 @@ def marginal_rate(

"""

marginal_rate: Array[float]
marginal_rate: Array[numpy.float_]

marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:])

Expand Down
130 changes: 76 additions & 54 deletions openfisca_core/types/_domain.py → openfisca_core/types.py
Original file line number Diff line number Diff line change
@@ -1,146 +1,168 @@
from __future__ import annotations

import typing_extensions
from typing import Any, Optional
from typing_extensions import Protocol
from collections.abc import Sequence
from numpy.typing import NDArray
from typing import Any, TypeVar
from typing_extensions import Protocol, TypeAlias

import abc

import numpy

N = TypeVar("N", bound=numpy.generic, covariant=True)

class Entity(Protocol):
"""Entity protocol."""
#: Type representing an numpy array.
Array: TypeAlias = NDArray[N]

L = TypeVar("L")

#: Type representing an array-like object.
ArrayLike: TypeAlias = Sequence[L]

#: Type variable representing an error.
E = TypeVar("E", covariant=True)

#: Type variable representing a value.
A = TypeVar("A", covariant=True)


# Entities


class Entity(Protocol):
key: Any
plural: Any

@abc.abstractmethod
def check_role_validity(self, role: Any) -> None:
"""Abstract method."""
...

@abc.abstractmethod
def check_variable_defined_for_entity(self, variable_name: Any) -> None:
"""Abstract method."""
...

@abc.abstractmethod
def get_variable(
self,
variable_name: Any,
check_existence: Any = ...,
) -> Optional[Any]:
"""Abstract method."""
) -> Any | None:
...


class Formula(Protocol):
"""Formula protocol."""
class Role(Protocol):
entity: Any
subroles: Any

@abc.abstractmethod
def __call__(
self,
population: Population,
instant: Instant,
params: Params,
) -> numpy.ndarray:
"""Abstract method."""

# Holders

class Holder(Protocol):
"""Holder protocol."""

class Holder(Protocol):
@abc.abstractmethod
def clone(self, population: Any) -> Holder:
"""Abstract method."""
...

@abc.abstractmethod
def get_memory_usage(self) -> Any:
"""Abstract method."""
...


class Instant(Protocol):
"""Instant protocol."""
# Parameters


@typing_extensions.runtime_checkable
class ParameterNodeAtInstant(Protocol):
"""ParameterNodeAtInstant protocol."""
...


class Params(Protocol):
"""Params protocol."""
# Periods

@abc.abstractmethod
def __call__(self, instant: Instant) -> ParameterNodeAtInstant:
"""Abstract method."""

class Instant(Protocol):
...


@typing_extensions.runtime_checkable
class Period(Protocol):
"""Period protocol."""

@property
@abc.abstractmethod
def start(self) -> Any:
"""Abstract method."""
...

@property
@abc.abstractmethod
def unit(self) -> Any:
"""Abstract method."""
...


class Population(Protocol):
"""Population protocol."""
# Populations


class Population(Protocol):
entity: Any

@abc.abstractmethod
def get_holder(self, variable_name: Any) -> Any:
"""Abstract method."""

...

class Role(Protocol):
"""Role protocol."""

entity: Any
subroles: Any
# Simulations


class Simulation(Protocol):
"""Simulation protocol."""

@abc.abstractmethod
def calculate(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def calculate_add(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def calculate_divide(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def get_population(self, plural: Optional[Any]) -> Any:
"""Abstract method."""
def get_population(self, plural: Any | None) -> Any:
...


class TaxBenefitSystem(Protocol):
"""TaxBenefitSystem protocol."""
# Tax-Benefit systems


class TaxBenefitSystem(Protocol):
person_entity: Any

@abc.abstractmethod
def get_variable(
self,
variable_name: Any,
check_existence: Any = ...,
) -> Optional[Any]:
) -> Any | None:
"""Abstract method."""


class Variable(Protocol):
"""Variable protocol."""
# Variables


class Variable(Protocol):
entity: Any


class Formula(Protocol):
@abc.abstractmethod
def __call__(
self,
population: Population,
instant: Instant,
params: Params,
) -> Array[Any]:
...


class Params(Protocol):
@abc.abstractmethod
def __call__(self, instant: Instant) -> ParameterNodeAtInstant:
...
Loading
Loading