Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SteinVI: Recompute Score Function* For Each Particle Interaction in the Attractive Force #1947

Merged
merged 3 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/hmcecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def run_hmc(mcmc_key, args, data, obs, kernel):


def main(args):
assert (
11_000_000 >= args.num_datapoints
), "11,000,000 data points in the Higgs dataset"
assert 11_000_000 >= args.num_datapoints, (
"11,000,000 data points in the Higgs dataset"
)
# full dataset takes hours for plain hmc!
if args.dataset == "higgs":
_, fetch = load_dataset(
Expand Down
6 changes: 3 additions & 3 deletions examples/ssbvm_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def main(args):
parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".')

args = parser.parse_args()
assert all(
aa in AMINO_ACIDS for aa in args.amino_acids
), f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids."
assert all(aa in AMINO_ACIDS for aa in args.amino_acids), (
f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids."
)
main(args)
5 changes: 2 additions & 3 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import init_to_uniform
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adagrad

Expand Down Expand Up @@ -121,12 +120,12 @@ def main(args):
rng_key, inf_key = random.split(inf_key)

# We find that SteinVI benefits from a small radius when inferring BNNs.
guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))
guide = AutoNormal(model)

stein = SteinVI(
model,
guide,
Adagrad(0.5),
Adagrad(1.0),
RBFKernel(),
repulsion_temperature=args.repulsion,
num_stein_particles=args.num_stein_particles,
Expand Down
10 changes: 5 additions & 5 deletions notebooks/source/lotka_volterra_multiple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@
"source": [
"print(f\"The dataset has the shape {data.shape}, (n_datasets, n_points, n_observables)\")\n",
"print(f\"The time matrix has the shape {ts.shape}, (n_datasets, n_timepoints)\")\n",
"print(f\"The time matrix has different spacing between timepoints: \\n {ts[:,:5]}\")\n",
"print(f\"The final timepoints are: {jnp.nanmax(ts,1)} years.\")\n",
"print(f\"The time matrix has different spacing between timepoints: \\n {ts[:, :5]}\")\n",
"print(f\"The final timepoints are: {jnp.nanmax(ts, 1)} years.\")\n",
"print(\n",
" f\"The dataset has {jnp.sum(jnp.isnan(data))/jnp.size(data):.0%} missing observations\"\n",
" f\"The dataset has {jnp.sum(jnp.isnan(data)) / jnp.size(data):.0%} missing observations\"\n",
")\n",
"print(f\"True params mean: {sample['theta'][0]}\")"
]
Expand Down Expand Up @@ -550,7 +550,7 @@
"mcmc.print_summary()\n",
"\n",
"print(f\"True params mean: {sample['theta'][0]}\")\n",
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")"
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")"
]
},
{
Expand Down Expand Up @@ -591,7 +591,7 @@
"\n",
"\n",
"print(f\"True params mean: {sample['theta'][0]}\")\n",
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")"
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"print(df.Y.value_counts())\n",
"\n",
"for i in range(nclasses):\n",
" print(f\"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}\")"
" print(f\"mean(X) for Y == {i}: {X[np.where(Y == i)].mean():.3f}\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _subs_wrapper(subs_map, i, length, site):
)
else:
raise RuntimeError(
f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1},"
f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim + 1},"
f" but got {value_ndim}. This might happen when you use nested scan,"
" which is currently not supported. Please report the issue to us!"
)
Expand Down
6 changes: 2 additions & 4 deletions numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@

TaylorTwoProxyState = namedtuple(
"TaylorProxyState",
"ref_subsample_log_liks,"
"ref_subsample_log_lik_grads,"
"ref_subsample_log_lik_hessians",
"ref_subsample_log_liks,ref_subsample_log_lik_grads,ref_subsample_log_lik_hessians",
)

TaylorOneProxyState = namedtuple(
"TaylorOneProxyState", "ref_subsample_log_liks," "ref_subsample_log_lik_grads,"
"TaylorOneProxyState", "ref_subsample_log_liks,ref_subsample_log_lik_grads,"
)


Expand Down
140 changes: 66 additions & 74 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from jax import grad, numpy as jnp, random, tree, vmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan

from numpyro import handlers
from numpyro.contrib.einstein.stein_loss import SteinLoss
Expand All @@ -33,7 +34,6 @@ def _numel(shape):
class SteinVI:
"""Variational inference with Stein mixtures inference.


**Example:**

.. doctest::
Expand Down Expand Up @@ -138,9 +138,9 @@ def __init__(
if isinstance(guide.init_loc_fn, partial):
init_fn_name = guide.init_loc_fn.func.__name__
if init_fn_name == "init_to_uniform":
assert (
guide.init_loc_fn.keywords.get("radius", None) != 0.0
), init_loc_error_message
assert guide.init_loc_fn.keywords.get("radius", None) != 0.0, (
init_loc_error_message
)
else:
init_fn_name = guide.init_loc_fn.__name__
assert init_fn_name not in [
Expand Down Expand Up @@ -230,104 +230,84 @@ def local_trace(key):
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
non_mixture_uparams = { # Includes any marked guide parameters and all model parameters
# Separate model and guide parameters, since only guide parameters are updated using Stein
# Split parameters into model and guide components - only unflagged guide parameters are
# optimized via Stein forces.
nonmix_uparams = { # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}

stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
p: v for p, v in unconstr_params.items() if p not in nonmix_uparams
}

# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
# Collect guide parameters into a monolithic particle.
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
stein_uparams, nbatch_dims=1
)

# Kernel behavior varies based on particle site locations. The particle_info dictionary
# maps site names to their corresponding dimensional ranges as (start, end) tuples.
particle_info, _ = self._calc_particle_info(
stein_uparams, stein_particles.shape[0]
)
attractive_key, classic_key = random.split(rng_key)

def particle_transform_fn(particle):
params = unravel_pytree(particle)
ctparams = self.constrain_fn(self.particle_transform_fn(params))
ctparticle, _ = ravel_pytree(ctparams)
return ctparticle

# 2. Calculate gradients for each particle
def kernel_particles_loss_fn(rng_key, particles):
particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
grads = vmap(
lambda i: grad(
lambda particle: self.stein_loss.particle_loss(
rng_key=particle_keys[i],
model=handlers.scale(
self._inference_model, self.loss_temperature
),
guide=self.guide,
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=vmap(particle_transform_fn)(particles),
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
)
)(particles[i])
)(jnp.arange(self.stein_loss.stein_num_particles))

return grads

# 2.1 Compute particle gradients (for attractive force)
particle_ljp_grads = kernel_particles_loss_fn(attractive_key, stein_particles)

# 2.3 Lift particles to constraint space
ctstein_particles = vmap(particle_transform_fn)(stein_particles)

# 2.4 Compute non-mixture parameter gradients
non_mixture_param_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
handlers.scale(self._inference_model, self.loss_temperature),
self.guide,
unravel_pytree_batched(ctstein_particles),
*args,
**kwargs,
)
)(non_mixture_uparams)
model = handlers.scale(self._inference_model, self.loss_temperature)

# 3. Calculate kernel of particles
def loss_fn(particle, i):
def stein_loss_fn(key, particle, particle_idx):
return self.stein_loss.particle_loss(
rng_key=rng_key,
model=handlers.scale(self._inference_model, self.loss_temperature),
rng_key=key,
model=model,
guide=self.guide,
# Stein particles evolve in unconstrained space, but gradient computations must account
# for the transformation to constrained space
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=ctstein_particles,
select_index=i,
flat_particles=vmap(particle_transform_fn)(stein_particles),
select_index=particle_idx,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
param_map=self.constrain_fn(nonmix_uparams),
)

kernel = self.kernel_fn.compute(
rng_key, stein_particles, particle_info, loss_fn
rng_key, stein_particles, particle_info, stein_loss_fn
)
attractive_key, classic_key = random.split(rng_key)

# 4. Calculate the attractive force and repulsive force on the particles
attractive_force = vmap(
lambda y: jnp.sum(
vmap(
lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
)(stein_particles, particle_ljp_grads),
axis=0,
)
)(stein_particles)
def compute_attr_force(rng_key, particles):
# Second term of eq. 9 from https://arxiv.org/pdf/2410.22948.
def body(attr_force, state, y):
key, x, i = state
x_grad = grad(stein_loss_fn, argnums=1)(key, x, i)
attr_force = attr_force + self._apply_kernel(kernel, x, y, x_grad)
return attr_force, None

particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
init = jnp.zeros_like(particles[0])
idxs = jnp.arange(self.num_stein_particles)

attr_force, _ = vmap(
lambda y, key: scan(
partial(body, y=y),
init,
(random.split(key, self.num_stein_particles), particles, idxs),
)
)(particles, particle_keys)

return attr_force

attractive_force = compute_attr_force(attractive_key, stein_particles)

# Third term of eq. 9 from https://arxiv.org/pdf/2410.22948.
repulsive_force = vmap(
lambda y: jnp.mean(
vmap(
Expand All @@ -338,15 +318,27 @@ def loss_fn(particle, i):
)
)(stein_particles)

# 6. Compute the stein force
particle_grads = attractive_force + repulsive_force

# 7. Decompose the monolithic particle forces back to concrete parameter values
stein_param_grads = unravel_pytree_batched(particle_grads)
# Compute non-mixture parameter gradients.
nonmix_uparam_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
model,
self.guide,
unravel_pytree_batched(vmap(particle_transform_fn)(stein_particles)),
*args,
**kwargs,
)
)(nonmix_uparams)

# Decompose the monolithic particle forces back to concrete parameter values.
stein_uparam_grads = unravel_pytree_batched(particle_grads)

# 8. Return loss and gradients (based on parameter forces)
# Return loss and gradients (based on parameter forces).
res_grads = tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
lambda x: -x, {**nonmix_uparam_grads, **stein_uparam_grads}
)
return jnp.linalg.norm(particle_grads), res_grads

Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ class BaseEnumMessenger(NamedMessenger):
"""

def __init__(self, fn=None, first_available_dim=None):
assert (
first_available_dim is None or first_available_dim < 0
), first_available_dim
assert first_available_dim is None or first_available_dim < 0, (
first_available_dim
)
self.first_available_dim = first_available_dim
super().__init__(fn)

Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def log_prob_fn(x):
class _TFPKernelMeta(ABCMeta):
def __getitem__(cls, kernel_class):
assert issubclass(kernel_class, tfp.mcmc.TransitionKernel)
assert (
"target_log_prob_fn" in inspect.getfullargspec(kernel_class).args
), f"the first argument of {kernel_class} must be `target_log_prob_fn`"
assert "target_log_prob_fn" in inspect.getfullargspec(kernel_class).args, (
f"the first argument of {kernel_class} must be `target_log_prob_fn`"
)

_PyroKernel = type(kernel_class.__name__, (TFPKernel,), {})
_PyroKernel.kernel_class = kernel_class
Expand Down
Loading
Loading