Skip to content

Commit

Permalink
Improve commons module typing (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko authored Sep 18, 2024
2 parents bbc402f + 5aaca86 commit 77cc91f
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 224 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 41.5.4 [#1219](https://github.com/openfisca/openfisca-core/pull/1219)

#### Technical changes

- Fix doc & type definitions in the commons module

### 41.5.3 [#1218](https://github.com/openfisca/openfisca-core/pull/1218)

#### Technical changes
Expand Down
32 changes: 17 additions & 15 deletions openfisca_core/commons/formulas.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Any, Dict, Sequence, TypeVar

from openfisca_core.types import Array, ArrayLike
from collections.abc import Mapping
from typing import Union

import numpy

T = TypeVar("T")
from openfisca_core import types as t


def apply_thresholds(
input: Array[float],
thresholds: ArrayLike[float],
choices: ArrayLike[float],
) -> Array[float]:
input: t.Array[numpy.float_],
thresholds: t.ArrayLike[float],
choices: t.ArrayLike[float],
) -> t.Array[numpy.float_]:
"""Makes a choice based on an input and thresholds.
From a list of ``choices``, this function selects one of these values
Expand Down Expand Up @@ -40,7 +39,7 @@ def apply_thresholds(
"""

condlist: Sequence[Array[bool]]
condlist: list[Union[t.Array[numpy.bool_], bool]]
condlist = [input <= threshold for threshold in thresholds]

if len(condlist) == len(choices) - 1:
Expand All @@ -58,15 +57,18 @@ def apply_thresholds(
return numpy.select(condlist, choices)


def concat(this: ArrayLike[str], that: ArrayLike[str]) -> Array[str]:
def concat(
this: Union[t.Array[numpy.str_], t.ArrayLike[str]],
that: Union[t.Array[numpy.str_], t.ArrayLike[str]],
) -> t.Array[numpy.str_]:
"""Concatenates the values of two arrays.
Args:
this: An array to concatenate.
that: Another array to concatenate.
Returns:
:obj:`numpy.ndarray` of :obj:`float`:
:obj:`numpy.ndarray` of :obj:`numpy.str_`:
An array with the concatenated values.
Examples:
Expand All @@ -87,9 +89,9 @@ def concat(this: ArrayLike[str], that: ArrayLike[str]) -> Array[str]:


def switch(
conditions: Array[Any],
value_by_condition: Dict[float, T],
) -> Array[T]:
conditions: t.Array[numpy.float_],
value_by_condition: Mapping[float, float],
) -> t.Array[numpy.float_]:
"""Mimicks a switch statement.
Given an array of conditions, returns an array of the same size,
Expand Down Expand Up @@ -120,4 +122,4 @@ def switch(

condlist = [conditions == condition for condition in value_by_condition.keys()]

return numpy.select(condlist, value_by_condition.values())
return numpy.select(condlist, tuple(value_by_condition.values()))
8 changes: 5 additions & 3 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TypeVar
from typing import Optional, TypeVar

from openfisca_core.types import Array
import numpy

from openfisca_core import types as t

T = TypeVar("T")

Expand Down Expand Up @@ -43,7 +45,7 @@ def empty_clone(original: T) -> T:
return new


def stringify_array(array: Array) -> str:
def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str:
"""Generates a clean string representation of a numpy array.
Args:
Expand Down
Empty file added openfisca_core/commons/py.typed
Empty file.
14 changes: 7 additions & 7 deletions openfisca_core/commons/rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


def average_rate(
target: Array[float],
target: Array[numpy.float_],
varying: ArrayLike[float],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
) -> Array[numpy.float_]:
"""Computes the average rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -41,7 +41,7 @@ def average_rate(
"""

average_rate: Array[float]
average_rate: Array[numpy.float_]

average_rate = 1 - target / varying

Expand All @@ -62,10 +62,10 @@ def average_rate(


def marginal_rate(
target: Array[float],
varying: Array[float],
target: Array[numpy.float_],
varying: Array[numpy.float_],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
) -> Array[numpy.float_]:
"""Computes the marginal rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -97,7 +97,7 @@ def marginal_rate(
"""

marginal_rate: Array[float]
marginal_rate: Array[numpy.float_]

marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:])

Expand Down
130 changes: 76 additions & 54 deletions openfisca_core/types/_domain.py → openfisca_core/types.py
Original file line number Diff line number Diff line change
@@ -1,146 +1,168 @@
from __future__ import annotations

import typing_extensions
from typing import Any, Optional
from typing_extensions import Protocol
from collections.abc import Sequence
from numpy.typing import NDArray
from typing import Any, TypeVar
from typing_extensions import Protocol, TypeAlias

import abc

import numpy

N = TypeVar("N", bound=numpy.generic, covariant=True)

class Entity(Protocol):
"""Entity protocol."""
#: Type representing an numpy array.
Array: TypeAlias = NDArray[N]

L = TypeVar("L")

#: Type representing an array-like object.
ArrayLike: TypeAlias = Sequence[L]

#: Type variable representing an error.
E = TypeVar("E", covariant=True)

#: Type variable representing a value.
A = TypeVar("A", covariant=True)


# Entities


class Entity(Protocol):
key: Any
plural: Any

@abc.abstractmethod
def check_role_validity(self, role: Any) -> None:
"""Abstract method."""
...

@abc.abstractmethod
def check_variable_defined_for_entity(self, variable_name: Any) -> None:
"""Abstract method."""
...

@abc.abstractmethod
def get_variable(
self,
variable_name: Any,
check_existence: Any = ...,
) -> Optional[Any]:
"""Abstract method."""
) -> Any | None:
...


class Formula(Protocol):
"""Formula protocol."""
class Role(Protocol):
entity: Any
subroles: Any

@abc.abstractmethod
def __call__(
self,
population: Population,
instant: Instant,
params: Params,
) -> numpy.ndarray:
"""Abstract method."""

# Holders

class Holder(Protocol):
"""Holder protocol."""

class Holder(Protocol):
@abc.abstractmethod
def clone(self, population: Any) -> Holder:
"""Abstract method."""
...

@abc.abstractmethod
def get_memory_usage(self) -> Any:
"""Abstract method."""
...


class Instant(Protocol):
"""Instant protocol."""
# Parameters


@typing_extensions.runtime_checkable
class ParameterNodeAtInstant(Protocol):
"""ParameterNodeAtInstant protocol."""
...


class Params(Protocol):
"""Params protocol."""
# Periods

@abc.abstractmethod
def __call__(self, instant: Instant) -> ParameterNodeAtInstant:
"""Abstract method."""

class Instant(Protocol):
...


@typing_extensions.runtime_checkable
class Period(Protocol):
"""Period protocol."""

@property
@abc.abstractmethod
def start(self) -> Any:
"""Abstract method."""
...

@property
@abc.abstractmethod
def unit(self) -> Any:
"""Abstract method."""
...


class Population(Protocol):
"""Population protocol."""
# Populations


class Population(Protocol):
entity: Any

@abc.abstractmethod
def get_holder(self, variable_name: Any) -> Any:
"""Abstract method."""

...

class Role(Protocol):
"""Role protocol."""

entity: Any
subroles: Any
# Simulations


class Simulation(Protocol):
"""Simulation protocol."""

@abc.abstractmethod
def calculate(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def calculate_add(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def calculate_divide(self, variable_name: Any, period: Any) -> Any:
"""Abstract method."""
...

@abc.abstractmethod
def get_population(self, plural: Optional[Any]) -> Any:
"""Abstract method."""
def get_population(self, plural: Any | None) -> Any:
...


class TaxBenefitSystem(Protocol):
"""TaxBenefitSystem protocol."""
# Tax-Benefit systems


class TaxBenefitSystem(Protocol):
person_entity: Any

@abc.abstractmethod
def get_variable(
self,
variable_name: Any,
check_existence: Any = ...,
) -> Optional[Any]:
) -> Any | None:
"""Abstract method."""


class Variable(Protocol):
"""Variable protocol."""
# Variables


class Variable(Protocol):
entity: Any


class Formula(Protocol):
@abc.abstractmethod
def __call__(
self,
population: Population,
instant: Instant,
params: Params,
) -> Array[Any]:
...


class Params(Protocol):
@abc.abstractmethod
def __call__(self, instant: Instant) -> ParameterNodeAtInstant:
...
Loading

0 comments on commit 77cc91f

Please sign in to comment.