Skip to content

Commit

Permalink
Move policy_function decorator to policy function module.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Jan 25, 2025
1 parent ca25a3a commit da4aa6a
Show file tree
Hide file tree
Showing 36 changed files with 122 additions and 124 deletions.
89 changes: 85 additions & 4 deletions src/_gettsim/functions/policy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import functools
import inspect
import re
from collections.abc import Callable
from typing import TYPE_CHECKING, TypeVar
from datetime import date
from typing import TypeVar

import numpy

if TYPE_CHECKING:
from datetime import date

T = TypeVar("T")


Expand Down Expand Up @@ -83,6 +82,88 @@ def is_active_at_date(self, date: date) -> bool:
return self.start_date <= date <= self.end_date


def policy_function(
*,
start_date: str = "0001-01-01",
end_date: str = "9999-12-31",
leaf_name: str | None = None,
params_key_for_rounding: str | None = None,
skip_vectorization: bool = False,
) -> PolicyFunction:
"""
Decorator that wraps a callable into a `PolicyFunction`.
**Dates active (start_date, end_date, leaf_name):**
Specifies that a PolicyFunction is only active between two dates, `start` and `end`.
By using the `leaf_name` argument, you can specify a different name for the
PolicyFunction in the functions tree.
Note that even if you use this decorator with the `leaf_name` argument, you must
ensure that the function name is unique in the file where it is defined. Otherwise,
the function would be overwritten by the last function with the same name.
**Rounding spec (params_key_for_rounding):**
Adds the location of the rounding specification to a PolicyFunction.
Parameters
----------
start_date
The start date (inclusive) in the format YYYY-MM-DD (part of ISO 8601).
end_date
The end date (inclusive) in the format YYYY-MM-DD (part of ISO 8601).
leaf_name
The name that should be used as the PolicyFunction's leaf name in the DAG. If
omitted, we use the name of the function as defined.
params_key_for_rounding
Key of the parameters dictionary where rounding specifications are found. For
functions that are not user-written this is just the name of the respective
.yaml file.
skip_vectorization
Whether the function is already vectorized and, thus, should not be vectorized
again.
Returns
-------
PolicyFunction
A PolicyFunction object.
"""

_validate_dashed_iso_date(start_date)
_validate_dashed_iso_date(end_date)

start_date = date.fromisoformat(start_date)
end_date = date.fromisoformat(end_date)

_validate_date_range(start_date, end_date)

def inner(func: Callable) -> PolicyFunction:
return PolicyFunction(
func,
leaf_name=leaf_name if leaf_name else func.__name__,
start_date=start_date,
end_date=end_date,
params_key_for_rounding=params_key_for_rounding,
skip_vectorization=skip_vectorization,
)

return inner


_dashed_iso_date = re.compile(r"\d{4}-\d{2}-\d{2}")


def _validate_dashed_iso_date(date_str: str):
if not _dashed_iso_date.match(date_str):
raise ValueError(f"Date {date_str} does not match the format YYYY-MM-DD.")


def _validate_date_range(start: date, end: date):
if start > end:
raise ValueError(f"The start date {start} must be before the end date {end}.")


def _vectorize_func(func: Callable) -> Callable:
# What should work once that Jax backend is fully supported
signature = inspect.signature(func)
Expand Down
83 changes: 0 additions & 83 deletions src/_gettsim/shared.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools
import inspect
import operator
import re
import textwrap
from collections.abc import Callable
from datetime import date
Expand All @@ -28,88 +27,6 @@ def __repr__(self):
TIME_DEPENDENT_FUNCTIONS: dict[str, list[Callable]] = {}


def policy_function(
*,
start_date: str = "0001-01-01",
end_date: str = "9999-12-31",
leaf_name: str | None = None,
params_key_for_rounding: str | None = None,
skip_vectorization: bool = False,
) -> PolicyFunction:
"""
Decorator that wraps a callable into a `PolicyFunction`.
**Dates active (start_date, end_date, leaf_name):**
Specifies that a PolicyFunction is only active between two dates, `start` and `end`.
By using the `leaf_name` argument, you can specify a different name for the
PolicyFunction in the functions tree.
Note that even if you use this decorator with the `leaf_name` argument, you must
ensure that the function name is unique in the file where it is defined. Otherwise,
the function would be overwritten by the last function with the same name.
**Rounding spec (params_key_for_rounding):**
Adds the location of the rounding specification to a PolicyFunction.
Parameters
----------
start_date
The start date (inclusive) in the format YYYY-MM-DD (part of ISO 8601).
end_date
The end date (inclusive) in the format YYYY-MM-DD (part of ISO 8601).
leaf_name
The name that should be used as the PolicyFunction's leaf name in the DAG. If
omitted, we use the name of the function as defined.
params_key_for_rounding
Key of the parameters dictionary where rounding specifications are found. For
functions that are not user-written this is just the name of the respective
.yaml file.
skip_vectorization
Whether the function is already vectorized and, thus, should not be vectorized
again.
Returns
-------
PolicyFunction
A PolicyFunction object.
"""

_validate_dashed_iso_date(start_date)
_validate_dashed_iso_date(end_date)

start_date = date.fromisoformat(start_date)
end_date = date.fromisoformat(end_date)

_validate_date_range(start_date, end_date)

def inner(func: Callable) -> PolicyFunction:
return PolicyFunction(
func,
leaf_name=leaf_name if leaf_name else func.__name__,
start_date=start_date,
end_date=end_date,
params_key_for_rounding=params_key_for_rounding,
skip_vectorization=skip_vectorization,
)

return inner


_dashed_iso_date = re.compile(r"\d{4}-\d{2}-\d{2}")


def _validate_dashed_iso_date(date_str: str):
if not _dashed_iso_date.match(date_str):
raise ValueError(f"Date {date_str} does not match the format YYYY-MM-DD.")


def _validate_date_range(start: date, end: date):
if start > end:
raise ValueError(f"The start date {start} must be before the end date {end}.")


def _check_for_conflicts_in_time_dependent_functions(
dag_key: str, function_name: str, start: date, end: date
):
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/social_insurance_contributions/arbeitsl_v.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for modeling unemployment and pension insurance."""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def sozialv_beitr_arbeitnehmer_m(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(end_date="2003-03-31", leaf_name="ges_krankenv_beitr_arbeitnehmer_m")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/social_insurance_contributions/ges_pflegev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(start_date="2005-01-01")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/social_insurance_contributions/ges_rentenv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(end_date="2003-03-31", leaf_name="ges_rentenv_beitr_arbeitnehmer_m")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/abgelt_st.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(start_date="2009-01-01")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/eink_st.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function

aggregate_by_p_id_eink_st = {
"eink_st_rel_kindergeld_anz_ansprüche_1": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/lohnst.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function
from _gettsim.taxes.eink_st import _eink_st_tarif


Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/soli_st.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function


@policy_function(end_date="2008-12-31", leaf_name="soli_st_y_sn")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/zu_verst_eink/eink.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function


def eink_selbst_y(eink_selbst_m: float) -> float:
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/zu_verst_eink/freibetraege.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.config import numpy_or_jax as np
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function

aggregate_by_p_id_freibeträge = {
"_eink_st_kinderfreib_anz_anspruch_1": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/zu_verst_eink/vorsorgeaufw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/taxes/zu_verst_eink/zu_verst_eink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
applying the tax schedule.
"""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def freibeträge_ind_y(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function


def arbeitsl_geld_2_eink_m(
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/arbeitsl_geld_2/bedarf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions to calculate basic needs according to SGB II
(i.e., where Arbeitslosengeld 2 is defined)."""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def arbeitsl_geld_2_regelbedarf_m(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy

from _gettsim.shared import join_numpy, policy_function
from _gettsim.functions.policy_function import policy_function
from _gettsim.shared import join_numpy

aggregate_by_p_id_kindergeldübertrag = {
"kindergeldübertrag_m": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/arbeitsl_geld_2/kost_unterk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(end_date="2022-12-31", leaf_name="arbeitsl_geld_2_kost_unterk_m")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/benefit_checks/vermoegens_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def _kinderzuschl_nach_vermög_check_m_bg(
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/elterngeld.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module provides functions to compute parental leave benefits (Elterngeld)."""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function

aggregate_by_group_elterngeld = {
"kind_anspruchsberechtigt_fg": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/erwerbsm_rente.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


@policy_function(start_date="2001-01-01")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/erziehungsgeld.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions to compute parental leave benefits (Erziehungsgeld, -2007)."""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function

aggregate_by_p_id_erziehungsgeld = {
"erziehungsgeld_eltern_m": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/grundrente.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function


@policy_function(params_key_for_rounding="ges_rente")
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/grunds_im_alter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from _gettsim.functions.policy_function import policy_function
from _gettsim.piecewise_functions import piecewise_polynomial
from _gettsim.shared import policy_function


def grunds_im_alter_m_eg( # noqa: PLR0913
Expand Down
3 changes: 2 additions & 1 deletion src/_gettsim/transfers/kindergeld.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy

from _gettsim.shared import join_numpy, policy_function
from _gettsim.functions.policy_function import policy_function
from _gettsim.shared import join_numpy

aggregate_by_group_kindergeld = {
"anz_kinder_mit_kindergeld_fg": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/kinderzuschl/kinderzuschl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Kinderzuschlag / Additional Child Benefit."""

from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def kinderzuschl_m_bg(
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/kinderzuschl/kinderzuschl_eink.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function

aggregate_by_group_kinderzuschl_eink = {
"_kinderzuschl_anz_kinder_anspruch_bg": {
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim/transfers/rente.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _gettsim.shared import policy_function
from _gettsim.functions.policy_function import policy_function


def sum_ges_rente_priv_rente_m(priv_rente_m: float, ges_rente_m: float) -> float:
Expand Down
Loading

0 comments on commit da4aa6a

Please sign in to comment.