Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance enhancements for init_strategy may lead to unexpected behavior. #1970

Open
tillahoffmann opened this issue Feb 4, 2025 · 0 comments
Labels
enhancement New feature or request refactor

Comments

@tillahoffmann
Copy link
Contributor

There is some dedicated logic to enhance the performance for specific init_strategys. This logic requires that init_strategy is a partial or a function that returns a partial. E.g., I expected, maybe naively, that the following init strategy would work.

>>> import jax
>>> import numpyro
>>> 
>>> 
>>> def model():
...     numpyro.sample("x", numpyro.distributions.Normal())
...     numpyro.sample("y", numpyro.distributions.Normal())
>>> 
>>> 
>>> def init_and_get_auto_loc(init_strategy):
...     guide = numpyro.infer.autoguide.AutoDiagonalNormal(model, init_loc_fn=init_strategy)
...     
...     svi = numpyro.infer.SVI(model, guide, numpyro.optim.Adam(0.1), numpyro.infer.Trace_ELBO())
...     state = svi.init(jax.random.key(9))
...     return svi.get_params(state)["auto_loc"]
>>>
>>>
>>> init_strategy = lambda site: 3.0 if site["name"] == "x" else 7.0
>>> init_and_get_auto_loc(init_strategy)
TypeError: <lambda>() missing 1 required positional argument: 'site'

But wrapping in a partial works.

>>> from functools import partial
>>>
>>> init_and_get_auto_loc(partial(init_strategy))
Array([3., 7.], dtype=float32)

I came across this while trying to write an init strategy where some sites were initialized by value but the remainder initialized to uniform although with a different radius than the default of 2. Is this the intended behavior?

The relevant logic is here.

init_strategy = (
init_strategy if isinstance(init_strategy, partial) else init_strategy()
)
# handle those init strategies differently to save computation
if init_strategy.func is init_to_uniform:
radius = init_strategy.keywords.get("radius")
init_values = {}
elif init_strategy.func is _init_to_unconstrained_value:
radius = 2
init_values = init_strategy.keywords.get("values")
else:
radius = None

init_strategy = (
init_strategy if isinstance(init_strategy, partial) else init_strategy()
)
if (init_strategy.func is init_to_value) and not replay_model:
init_values = init_strategy.keywords.get("values")
unconstrained_values = transform_fn(inv_transforms, init_values, invert=True)
init_strategy = _init_to_unconstrained_value(values=unconstrained_values)

@fehiepsi fehiepsi added enhancement New feature or request refactor labels Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactor
Projects
None yet
Development

No branches or pull requests

2 participants