From a19f1bafb7a8c41b8b48a50536975d22a1f5f684 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Fri, 20 Sep 2024 15:36:17 -0500 Subject: [PATCH] use init_to_sample as init_strategy --- pyrenew/metaclass.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index eb7119eb..3bfdba81 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -10,7 +10,7 @@ import numpy as np import polars as pl from jax.typing import ArrayLike -from numpyro.infer import MCMC, NUTS, Predictive +from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample from pyrenew.mcmcutils import plot_posterior, spread_draws @@ -176,6 +176,9 @@ def _init_model( if "find_heuristic_step_size" not in nuts_args: nuts_args["find_heuristic_step_size"] = True + if "init_strategy" not in nuts_args: + nuts_args["init_strategy"] = init_to_sample + if mcmc_args is None: mcmc_args = dict()