Skip to content

Commit

Permalink
Merge pull request #119 from PolicyEngine/automatic-period-adjustment
Browse files Browse the repository at this point in the history
Automatic period adjustment
  • Loading branch information
nikhilwoodruff authored Oct 4, 2023
2 parents 98576a6 + 3079c6e commit 87cd38a
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 118 deletions.
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
added:
- Automatic period adjustment helper functionality.
changed:
- Default error threshold for tests widened to 1e-3.
29 changes: 4 additions & 25 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
if self.variable.set_input:
if (
self.variable.set_input
and period.unit != self.variable.definition_period
):
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)

Expand Down Expand Up @@ -285,30 +288,6 @@ def _set(
raise ValueError(
"A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition."
)
if (
self.variable.definition_period != period.unit
or period.size > 1
):
name = self.variable.name
period_size_adj = (
f"{period.unit}"
if (period.size == 1)
else f"{period.size}-{period.unit}s"
)
error_message = os.linesep.join(
[
f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".',
f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.',
f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.',
]
)

raise PeriodMismatchError(
self.variable.name,
period,
self.variable.definition_period,
error_message,
)

should_store_on_disk = (
self._on_disk_storable
Expand Down
39 changes: 30 additions & 9 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from policyengine_core.errors import CycleError, SpiralError
from policyengine_core.holders.holder import Holder
from policyengine_core.periods import Period
from policyengine_core.periods.config import ETERNITY
from policyengine_core.periods.config import ETERNITY, MONTH, YEAR
from policyengine_core.periods.helpers import period
from policyengine_core.tracers import (
FullTracer,
Expand All @@ -27,7 +27,7 @@
from policyengine_core.experimental import MemoryConfig
from policyengine_core.populations import Population
from policyengine_core.tracers import SimpleTracer
from policyengine_core.variables import Variable
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter

Expand Down Expand Up @@ -454,13 +454,14 @@ def _calculate(
variable_name, check_existence=True
)

self._check_period_consistency(period, variable)

# Check if we've neutralized via parameters.
try:
if self.tax_benefit_system.parameters(period).gov.abolitions[
variable.name
]:
if (
variable.is_neutralized
or self.tax_benefit_system.parameters(period).gov.abolitions[
variable.name
]
):
return holder.default_array()
except Exception as e:
pass
Expand All @@ -470,6 +471,20 @@ def _calculate(
if cached_array is not None:
return cached_array

if variable.definition_period == MONTH and period.unit == YEAR:
if variable.quantity_type == QuantityType.STOCK:
contained_months = period.get_subperiods(MONTH)
return self.calculate(variable_name, contained_months[-1])
else:
return self.calculate_add(variable_name, period)
elif variable.definition_period == YEAR and period.unit == MONTH:
if variable.quantity_type == QuantityType.STOCK:
return self.calculate(variable_name, period.this_year)
else:
return self.calculate_divide(variable_name, period)

self._check_period_consistency(period, variable)

if variable.defined_for is not None:
mask = (
self.calculate(
Expand Down Expand Up @@ -607,10 +622,13 @@ def calculate_add(
)
)

return sum(
result = sum(
self.calculate(variable_name, sub_period)
for sub_period in period.get_subperiods(variable.definition_period)
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result

def calculate_divide(
self,
Expand Down Expand Up @@ -640,9 +658,12 @@ def calculate_divide(

if period.unit == periods.MONTH:
computation_period = period.this_year
return (
result = (
self.calculate(variable_name, period=computation_period) / 12.0
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result
elif period.unit == periods.YEAR:
return self.calculate(variable_name, period)

Expand Down
3 changes: 2 additions & 1 deletion policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from policyengine_core.scripts import build_tax_benefit_system
from policyengine_core.reforms import Reform, set_parameter
from policyengine_core.populations import ADD, DIVIDE

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -464,7 +465,7 @@ def assert_near(
import numpy as np

if absolute_error_margin is None and relative_error_margin is None:
absolute_error_margin = 0
absolute_error_margin = 1e-3
if not isinstance(value, np.ndarray):
value = np.array(value)
if isinstance(value, EnumArray):
Expand Down
35 changes: 27 additions & 8 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from policyengine_core.entities import Entity
from policyengine_core.enums import Enum, EnumArray
from policyengine_core.periods import Period
from policyengine_core.holders import (
set_input_dispatch_by_period,
set_input_divide_by_period,
)
from policyengine_core.periods import DAY, ETERNITY

from . import config, helpers

Expand Down Expand Up @@ -176,13 +181,6 @@ def __init__(self, baseline_variable=None):
periods.ETERNITY,
),
)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.FLOW,
)
self.label = self.set(
attr, "label", allowed_type=str, setter=self.set_label
)
Expand All @@ -192,13 +190,34 @@ def __init__(self, baseline_variable=None):
attr, "cerfa_field", allowed_type=(str, dict)
)
self.unit = self.set(attr, "unit", allowed_type=str)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.STOCK
if (
self.value_type in (bool, int, Enum, str, datetime.date)
or self.unit == "/1"
)
else QuantityType.FLOW,
)
self.documentation = self.set(
attr,
"documentation",
allowed_type=str,
setter=self.set_documentation,
)
self.set_input = self.set_set_input(attr.pop("set_input", None))
self.set_input = self.set_set_input(
attr.pop(
"set_input",
set_input_dispatch_by_period
if self.quantity_type == QuantityType.STOCK
else set_input_divide_by_period,
)
)
if self.definition_period in (DAY, ETERNITY):
self.set_input = None
self.calculate_output = self.set_calculate_output(
attr.pop("calculate_output", None)
)
Expand Down
5 changes: 0 additions & 5 deletions tests/core/test_calculate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def simulation(tax_benefit_system):
)


def test_calculate_output_default(simulation):
with pytest.raises(ValueError):
simulation.calculate_output("simple_variable", 2017)


def test_calculate_output_add(simulation):
simulation.set_input("variable_with_calculate_output_add", "2017-01", [10])
simulation.set_input("variable_with_calculate_output_add", "2017-05", [20])
Expand Down
26 changes: 0 additions & 26 deletions tests/core/test_countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def test_non_existing_variable(simulation):
simulation.calculate("non_existent_variable", PERIOD)


@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
def test_calculate_variable_with_wrong_definition_period(simulation):
year = str(PERIOD.this_year)

with pytest.raises(ValueError) as error:
simulation.calculate("basic_income", year)

error_message = str(error.value)
expected_words = ["period", year, "month", "basic_income", "ADD"]

for word in expected_words:
assert (
word in error_message
), f"Expected '{word}' in error message '{error_message}'"


@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
def test_divide_option_on_month_defined_variable(simulation):
with pytest.raises(ValueError):
Expand All @@ -107,16 +91,6 @@ def test_divide_option_with_complex_period(simulation):
), f"Expected '{word}' in error message '{error_message}'"


def test_input_with_wrong_period(tax_benefit_system):
year = str(PERIOD.this_year)
variables = {"basic_income": {year: 12000}}
simulation_builder = SimulationBuilder()
simulation_builder.set_default_period(PERIOD)

with pytest.raises(ValueError):
simulation_builder.build_from_variables(tax_benefit_system, variables)


def test_variable_with_reference(make_simulation, isolated_tax_benefit_system):
variables = {"salary": 4000}
simulation = make_simulation(
Expand Down
31 changes: 2 additions & 29 deletions tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ def test_set_input_enum_item(couple):
assert result == housing.HousingOccupancyStatus.free_lodger


def test_yearly_input_month_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("rent", 2019, 3000)
assert (
'Unable to set a value for variable "rent" for year-long period'
in error.value.message
)


def test_3_months_input_month_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("rent", "month:2019-01:3", 3000)
assert (
'Unable to set a value for variable "rent" for 3-months-long period'
in error.value.message
)


def test_month_input_year_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("housing_tax", "2019-01", 3000)
assert (
'Unable to set a value for variable "housing_tax" for month-long period'
in error.value.message
)


def test_enum_dtype(couple):
simulation = couple
status_occupancy = numpy.asarray([2], dtype=numpy.int16)
Expand Down Expand Up @@ -157,8 +130,8 @@ def test_get_memory_usage_with_trace(single):
memory_usage = salary_holder.get_memory_usage()
assert memory_usage["nb_requests"] == 15
assert (
memory_usage["nb_requests_by_array"] == 1.25
) # 15 calculations / 12 arrays
memory_usage["nb_requests_by_array"] == 15 / 13
) # 15 calculations / 13 arrays (12 months plus the year is cached too)


def test_set_input_dispatch_by_period(single):
Expand Down
15 changes: 0 additions & 15 deletions tests/core/test_reforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,6 @@ def apply(self):
assert_near(goes_to_school, [True], absolute_error_margin=0)


def test_neutralization_optimization(make_simulation, tax_benefit_system):
reform = WithBasicIncomeNeutralized(tax_benefit_system)

period = "2017-01"
simulation = make_simulation(reform, {}, period)
simulation.debug = True

simulation.calculate("basic_income", period="2013-01")
simulation.calculate_add("basic_income", period="2013")

# As basic_income is neutralized, it should not be cached
basic_income_holder = simulation.persons.get_holder("basic_income")
assert basic_income_holder.get_known_periods() == []


def test_input_variable_neutralization(make_simulation, tax_benefit_system):
class test_salary_neutralization(Reform):
def apply(self):
Expand Down

0 comments on commit 87cd38a

Please sign in to comment.