From f7f165874108c4d7bafa922e0e49a614edc8f3f6 Mon Sep 17 00:00:00 2001 From: Jover Lee Date: Fri, 6 Dec 2024 10:50:09 -0800 Subject: [PATCH] Replace private jax modules with public API Replacing type and function imported from a private jax modules with the public API to avoid potential future breakages due to changes in the private modules similar to https://github.com/blab/evofr/issues/43. - `jax._src.interpreters.batching.Array` => `jax.Array` - `jax._src.nn.functions.softmax` => `jax.nn.softmax` --- evofr/models/mlr_hierarchical_gp.py | 3 +-- evofr/models/mlr_hierarchical_time_varying.py | 3 +-- evofr/models/mlr_innovation.py | 2 +- evofr/models/renewal_model/renewal_regression.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/evofr/models/mlr_hierarchical_gp.py b/evofr/models/mlr_hierarchical_gp.py index a37e852..2997843 100644 --- a/evofr/models/mlr_hierarchical_gp.py +++ b/evofr/models/mlr_hierarchical_gp.py @@ -5,8 +5,7 @@ import numpy as np import numpyro import numpyro.distributions as dist -from jax import jit, vmap -from jax._src.interpreters.batching import Array +from jax import jit, vmap, Array from jax.nn import softmax from jax.scipy.special import gammaln from numpyro.infer.reparam import TransformReparam diff --git a/evofr/models/mlr_hierarchical_time_varying.py b/evofr/models/mlr_hierarchical_time_varying.py index 6d84604..d576d7f 100644 --- a/evofr/models/mlr_hierarchical_time_varying.py +++ b/evofr/models/mlr_hierarchical_time_varying.py @@ -5,8 +5,7 @@ import numpy as np import numpyro import numpyro.distributions as dist -from jax import vmap -from jax._src.interpreters.batching import Array +from jax import vmap, Array from jax.nn import softmax from numpyro.infer.reparam import TransformReparam diff --git a/evofr/models/mlr_innovation.py b/evofr/models/mlr_innovation.py index 4b46a43..c45473f 100644 --- a/evofr/models/mlr_innovation.py +++ b/evofr/models/mlr_innovation.py @@ -8,7 +8,7 @@ import numpyro.distributions as dist import pandas as pd from jax import vmap -from jax._src.nn.functions import softmax +from jax.nn import softmax from evofr.data.data_helpers import prep_dates, prep_sequence_counts from evofr.data.data_spec import DataSpec diff --git a/evofr/models/renewal_model/renewal_regression.py b/evofr/models/renewal_model/renewal_regression.py index 0dcc440..88bb935 100644 --- a/evofr/models/renewal_model/renewal_regression.py +++ b/evofr/models/renewal_model/renewal_regression.py @@ -5,7 +5,7 @@ import numpyro import numpyro.distributions as dist from jax import jit, lax -from jax._src.nn.functions import softmax +from jax.nn import softmax from evofr.models.model_spec import ModelSpec from evofr.models.renewal_model.basis_functions.basis_fns import BasisFunction