diff --git a/evofr/models/mlr_hierarchical_gp.py b/evofr/models/mlr_hierarchical_gp.py index a37e852..2997843 100644 --- a/evofr/models/mlr_hierarchical_gp.py +++ b/evofr/models/mlr_hierarchical_gp.py @@ -5,8 +5,7 @@ import numpy as np import numpyro import numpyro.distributions as dist -from jax import jit, vmap -from jax._src.interpreters.batching import Array +from jax import jit, vmap, Array from jax.nn import softmax from jax.scipy.special import gammaln from numpyro.infer.reparam import TransformReparam diff --git a/evofr/models/mlr_hierarchical_time_varying.py b/evofr/models/mlr_hierarchical_time_varying.py index 6d84604..d576d7f 100644 --- a/evofr/models/mlr_hierarchical_time_varying.py +++ b/evofr/models/mlr_hierarchical_time_varying.py @@ -5,8 +5,7 @@ import numpy as np import numpyro import numpyro.distributions as dist -from jax import vmap -from jax._src.interpreters.batching import Array +from jax import vmap, Array from jax.nn import softmax from numpyro.infer.reparam import TransformReparam diff --git a/evofr/models/mlr_innovation.py b/evofr/models/mlr_innovation.py index 4b46a43..c45473f 100644 --- a/evofr/models/mlr_innovation.py +++ b/evofr/models/mlr_innovation.py @@ -8,7 +8,7 @@ import numpyro.distributions as dist import pandas as pd from jax import vmap -from jax._src.nn.functions import softmax +from jax.nn import softmax from evofr.data.data_helpers import prep_dates, prep_sequence_counts from evofr.data.data_spec import DataSpec diff --git a/evofr/models/renewal_model/renewal_regression.py b/evofr/models/renewal_model/renewal_regression.py index 0dcc440..88bb935 100644 --- a/evofr/models/renewal_model/renewal_regression.py +++ b/evofr/models/renewal_model/renewal_regression.py @@ -5,7 +5,7 @@ import numpyro import numpyro.distributions as dist from jax import jit, lax -from jax._src.nn.functions import softmax +from jax.nn import softmax from evofr.models.model_spec import ModelSpec from evofr.models.renewal_model.basis_functions.basis_fns import BasisFunction