diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index fada6bc9f..236e2068d 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,16 +1,19 @@ -from collections.abc import Iterable +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array import copy import dpath.util import numpy -from openfisca_core import entities, errors, periods, populations, variables +from openfisca_core import errors, periods from . import helpers from ._axis import _Axis from .simulation import Simulation -from .typing import AxisParams +from .typing import AxisParams, Entity, Population, Role class SimulationBuilder: @@ -23,26 +26,24 @@ def __init__(self): ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: dict[ - variables.Variable.name, dict[str(periods.period), numpy.array] - ] = {} - self.populations: dict[entities.Entity.key, populations.Population] = {} + self.input_buffer: dict[str, dict[str, Array]] = {} + self.populations: dict[str, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: dict[entities.Entity.plural, int] = {} + self.entity_counts: dict[str, int] = {} # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: dict[entities.Entity.plural, list[int]] = {} + self.entity_ids: dict[str, list[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: dict[entities.Entity.plural, list[int]] = {} - self.roles: dict[entities.Entity.plural, list[int]] = {} + self.memberships: dict[str, list[int]] = {} + self.roles: dict[str, list[int]] = {} - self.variable_entities: dict[variables.Variable.name, entities.Entity] = {} + self.variable_entities: dict[str, Entity] = {} self.axes = [[]] - self.axes_entity_counts: dict[entities.Entity.plural, int] = {} - self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {} - self.axes_memberships: dict[entities.Entity.plural, list[int]] = {} - self.axes_roles: dict[entities.Entity.plural, list[int]] = {} + self.axes_entity_counts: dict[str, int] = {} + self.axes_entity_ids: dict[str, list[str]] = {} + self.axes_memberships: dict[str, list[int]] = {} + self.axes_roles: dict[str, list[int]] = {} def build_from_dict(self, tax_benefit_system, input_dict): """ @@ -395,9 +396,10 @@ def set_default_period(self, period_str): if period_str: self.default_period = str(periods.period(period_str)) - def get_input(self, variable, period_str): + def get_input(self, variable: str, period_str: str) -> Array | None: if variable not in self.input_buffer: self.input_buffer[variable] = {} + return self.input_buffer[variable].get(period_str) def check_persons_to_allocate( @@ -535,11 +537,11 @@ def raise_period_mismatch(self, entity, json, e): raise errors.SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes - def get_count(self, entity_name): + def get_count(self, entity_name: str) -> int: return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes - def get_ids(self, entity_name): + def get_ids(self, entity_name: str) -> list[str]: return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes @@ -550,7 +552,7 @@ def get_memberships(self, entity_name): ) # Returns the roles of individuals in this entity, including when there is replication along axes - def get_roles(self, entity_name): + def get_roles(self, entity_name: str) -> Sequence[Role]: # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) @@ -563,14 +565,14 @@ def add_perpendicular_axis(self, axis: AxisParams) -> None: # This adds an axis perpendicular to all previous dimensions self.axes.append([_Axis(**axis)]) - def expand_axes(self): + def expand_axes(self) -> None: # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + perpendicular_dimensions: list[list[_Axis]] = self.axes + cell_count: int = 1 - cell_count = 1 for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis.count + first_axis: _Axis = parallel_axes[0] + axis_count: int = first_axis.count cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times @@ -580,10 +582,16 @@ def expand_axes(self): self.get_count(entity_name) * cell_count ) # Adjust ids - original_ids = self.get_ids(entity_name) * cell_count - indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) - adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] + original_ids: list[str] = self.get_ids(entity_name) * cell_count + indices: Array[numpy.int_] = numpy.arange( + 0, cell_count * self.entity_counts[entity_name] + ) + adjusted_ids: list[str] = [ + original_id + str(index) + for original_id, index in zip(original_ids, indices) + ] self.axes_entity_ids[entity_name] = adjusted_ids + # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count @@ -659,8 +667,8 @@ def expand_axes(self): ) * (axis.max - axis.min) / (axis_count - 1) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name: str): + def get_variable_entity(self, variable_name: str) -> Entity: return self.variable_entities[variable_name] - def register_variable(self, variable_name: str, entity): + def register_variable(self, variable_name: str, entity: Entity) -> None: self.variable_entities[variable_name] = entity diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py index 42debeff3..7bd35e713 100644 --- a/openfisca_core/simulations/typing.py +++ b/openfisca_core/simulations/typing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TypedDict +from numpy.typing import NDArray as Array +from typing import Protocol, TypedDict class AxisParams(TypedDict, total=False): @@ -10,3 +11,27 @@ class AxisParams(TypedDict, total=False): max: float period: str | int index: int + + +class Entity(Protocol): + plural: str | None + + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> Variable | None: + ... + + +class Population(Protocol): + ... + + +class Role(Protocol): + ... + + +class Variable(Protocol): + def default_array(self, array_size: int) -> Array: + ... diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index 2693a3121..c3118b55d 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -1,5 +1,6 @@ from __future__ import annotations +from numpy.typing import NDArray as Array from openfisca_core.types import Formula, Instant from typing import Optional, Union @@ -467,7 +468,7 @@ def check_set_value(self, value): return value - def default_array(self, array_size): + def default_array(self, array_size: int) -> Array: array = numpy.empty(array_size, dtype=self.dtype) if self.value_type == Enum: array.fill(self.default_value.index)