Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable JAX backend #812

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,901 changes: 3,944 additions & 5,957 deletions pixi.lock

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,35 @@ jaxlib = ">=0.4.20"
jax = { version = ">=0.4.20", extras = ["cpu"] }
jaxlib = ">=0.4.20"

[tool.pixi.feature.cuda]
platforms = ["linux-64"]
system-requirements = {cuda = "12"}

[tool.pixi.feature.cuda.target.linux-64.dependencies]
cuda-nvcc = ">=12"
jax = ">=0.4.34"
jaxlib = { version = ">=0.4.34", build = "cuda12*" }

# Tasks
# --------------------------------------------------------------------------------------

[tool.pixi.feature.test.tasks]
tests = "pytest"

[tool.pixi.feature.jax.tasks]
tests = "pytest --use-jax"

[tool.pixi.feature.cuda.tasks]
tests = "pytest --use-jax"

# Environments
# --------------------------------------------------------------------------------------

[tool.pixi.environments]
py311 = ["test", "py311"]
py312 = ["test", "py312"]
py312-jax = ["py312", "jax"]
py312-jax-cuda = ["py312", "cuda"]


# ======================================================================================
Expand Down
20 changes: 9 additions & 11 deletions src/_gettsim/functions/policy_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import functools
import inspect
from collections.abc import Callable
from datetime import date
from typing import Any, TypeVar

import numpy
from _gettsim.config import USE_JAX
from _gettsim.vectorization import make_vectorizable

T = TypeVar("T")

Expand Down Expand Up @@ -109,17 +109,15 @@ def is_active_at_date(self, date: date) -> bool:


def _vectorize_func(func: Callable) -> Callable:
# What should work once that Jax backend is fully supported
signature = inspect.signature(func)
func_vec = numpy.vectorize(func)
# If the function is already vectorized, return it as is
if hasattr(func, "__info__") and func.__info__.get("skip_vectorization", False):
return func

@functools.wraps(func)
def wrapper_vectorize_func(*args, **kwargs):
return func_vec(*args, **kwargs)
if isinstance(func, PolicyFunction):
return func

wrapper_vectorize_func.__signature__ = signature

return wrapper_vectorize_func
backend = "jax" if USE_JAX else "numpy"
return make_vectorizable(func, backend=backend)


def _first_not_none(*values: T) -> T:
Expand Down
6 changes: 6 additions & 0 deletions src/_gettsim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def compute_taxes_and_transfers( # noqa: PLR0913
check_minimal_specification="ignore",
rounding=True,
debug=False,
jit=False,
):
"""Compute taxes and transfers.

Expand Down Expand Up @@ -127,6 +128,11 @@ def compute_taxes_and_transfers( # noqa: PLR0913
enforce_signature=True,
)

if jit:
from jax import jit

tax_transfer_function = jit(tax_transfer_function)

results = tax_transfer_function(**input_data)

# Prepare results.
Expand Down
25 changes: 1 addition & 24 deletions src/_gettsim/policy_environment_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import inspect
from typing import TYPE_CHECKING

import numpy

from _gettsim.aggregation import (
all_by_p_id,
any_by_p_id,
Expand All @@ -27,7 +25,6 @@
TYPES_INPUT_VARIABLES,
)
from _gettsim.functions.derived_function import DerivedFunction
from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.groupings import create_groupings
from _gettsim.shared import (
format_list_linewise,
Expand All @@ -39,6 +36,7 @@
if TYPE_CHECKING:
from collections.abc import Callable

from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.policy_environment import PolicyEnvironment


Expand Down Expand Up @@ -580,27 +578,6 @@ def aggregate_by_p_id_func(column, p_id_to_aggregate_by, p_id_to_store_by):
)


def _vectorize_func(func):
# If the function is already vectorized, return it as is
if hasattr(func, "__info__") and func.__info__.get("skip_vectorization", False):
return func

if isinstance(func, PolicyFunction):
return func

# What should work once that Jax backend is fully supported
signature = inspect.signature(func)
func_vec = numpy.vectorize(func)

@functools.wraps(func)
def wrapper_vectorize_func(*args, **kwargs):
return func_vec(*args, **kwargs)

wrapper_vectorize_func.__signature__ = signature

return wrapper_vectorize_func


def _fail_if_targets_are_not_among_functions(functions, targets):
"""Fail if some target is not among functions.

Expand Down
33 changes: 9 additions & 24 deletions src/_gettsim_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,22 @@
import importlib

import pytest

import _gettsim


def test_default_backend():
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "numpy"
def test_conftest_set_array_backend_updates_use_jax(request):
expected = request.config.option.USE_JAX
from _gettsim.config import USE_JAX

assert expected == USE_JAX

def test_set_backend():
is_jax_installed = importlib.util.find_spec("jax") is not None

# expect default backend
def test_conftest_set_array_backend_updates_backend(request):
use_jax = request.config.option.USE_JAX
expected = "jax.numpy" if use_jax else "numpy"
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "numpy"

if is_jax_installed:
# set jax backend
_gettsim.config.set_array_backend("jax")
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "jax.numpy"

from _gettsim.config import USE_JAX

assert USE_JAX
else:
with pytest.raises(AssertionError):
_gettsim.config.set_array_backend("jax")
got = numpy_or_jax.__name__
assert expected == got


@pytest.mark.parametrize("backend", ["dask", "jax.numpy"])
Expand Down
23 changes: 23 additions & 0 deletions src/_gettsim_tests/test_full_taxes_and_transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ def test_full_taxes_and_transfers(
)


@pytest.mark.parametrize(
"test_data",
data.test_data[:1],
ids=str,
)
def test_full_taxes_and_transfers_jitted(
test_data: PolicyTestData,
):
df = test_data.input_df
environment = cached_set_up_policy_environment(date=test_data.date)

out = OUT_COLS.copy()
if test_data.date.year <= 2008:
out.remove("abgelt_st_y_sn")

compute_taxes_and_transfers(
data=df,
environment=environment,
targets=out,
jit=True,
)


@pytest.mark.parametrize(
"test_data",
data.test_data,
Expand Down
3 changes: 1 addition & 2 deletions src/_gettsim_tests/test_functions_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from _gettsim.functions.loader import (
_load_functions,
)
from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.functions.policy_function import PolicyFunction, _vectorize_func
from _gettsim.policy_environment import PolicyEnvironment
from _gettsim.policy_environment_postprocessor import (
_create_derived_functions,
_vectorize_func,
)
from _gettsim.shared import policy_info

Expand Down
22 changes: 17 additions & 5 deletions src/_gettsim_tests/test_policy_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,21 @@
from _gettsim_tests import TEST_DIR


def return_one():
return 1


def return_two():
return 2


def return_three():
return 3


class TestPolicyEnvironment:
def test_get_function_by_name_exists(self):
function = PolicyFunction(lambda: 1, function_name="foo")
function = PolicyFunction(return_one, function_name="foo")
environment = PolicyEnvironment([function])

assert environment.get_function_by_name("foo") == function
Expand All @@ -33,19 +45,19 @@ def test_get_function_by_name_does_not_exist(self):
PolicyEnvironment([], {}),
PolicyEnvironment(
[
PolicyFunction(lambda: 1, function_name="foo"),
PolicyFunction(return_one, function_name="foo"),
]
),
PolicyEnvironment(
[
PolicyFunction(lambda: 1, function_name="foo"),
PolicyFunction(lambda: 2, function_name="bar"),
PolicyFunction(return_one, function_name="foo"),
PolicyFunction(return_two, function_name="bar"),
]
),
],
)
def test_upsert_functions(self, environment: PolicyEnvironment):
new_function = PolicyFunction(lambda: 3, function_name="foo")
new_function = PolicyFunction(return_three, function_name="foo")
new_environment = environment.upsert_functions(new_function)

assert new_environment.get_function_by_name("foo") == new_function
Expand Down
Loading