diff --git a/examples/hmcecs.py b/examples/hmcecs.py index 010898009..5e0abc696 100644 --- a/examples/hmcecs.py +++ b/examples/hmcecs.py @@ -75,9 +75,9 @@ def run_hmc(mcmc_key, args, data, obs, kernel): def main(args): - assert ( - 11_000_000 >= args.num_datapoints - ), "11,000,000 data points in the Higgs dataset" + assert 11_000_000 >= args.num_datapoints, ( + "11,000,000 data points in the Higgs dataset" + ) # full dataset takes hours for plain hmc! if args.dataset == "higgs": _, fetch = load_dataset( diff --git a/examples/ssbvm_mixture.py b/examples/ssbvm_mixture.py index bc6891346..71b7243f2 100644 --- a/examples/ssbvm_mixture.py +++ b/examples/ssbvm_mixture.py @@ -293,7 +293,7 @@ def main(args): parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".') args = parser.parse_args() - assert all( - aa in AMINO_ACIDS for aa in args.amino_acids - ), f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids." + assert all(aa in AMINO_ACIDS for aa in args.amino_acids), ( + f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids." + ) main(args) diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py index 155eb3fd3..2f012e71e 100644 --- a/examples/stein_bnn.py +++ b/examples/stein_bnn.py @@ -25,7 +25,6 @@ from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI from numpyro.distributions import Gamma, Normal from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset -from numpyro.infer import init_to_uniform from numpyro.infer.autoguide import AutoNormal from numpyro.optim import Adagrad @@ -121,12 +120,12 @@ def main(args): rng_key, inf_key = random.split(inf_key) # We find that SteinVI benefits from a small radius when inferring BNNs. - guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1)) + guide = AutoNormal(model) stein = SteinVI( model, guide, - Adagrad(0.5), + Adagrad(1.0), RBFKernel(), repulsion_temperature=args.repulsion, num_stein_particles=args.num_stein_particles, diff --git a/notebooks/source/lotka_volterra_multiple.ipynb b/notebooks/source/lotka_volterra_multiple.ipynb index 331e82e33..46c21df7e 100644 --- a/notebooks/source/lotka_volterra_multiple.ipynb +++ b/notebooks/source/lotka_volterra_multiple.ipynb @@ -404,10 +404,10 @@ "source": [ "print(f\"The dataset has the shape {data.shape}, (n_datasets, n_points, n_observables)\")\n", "print(f\"The time matrix has the shape {ts.shape}, (n_datasets, n_timepoints)\")\n", - "print(f\"The time matrix has different spacing between timepoints: \\n {ts[:,:5]}\")\n", - "print(f\"The final timepoints are: {jnp.nanmax(ts,1)} years.\")\n", + "print(f\"The time matrix has different spacing between timepoints: \\n {ts[:, :5]}\")\n", + "print(f\"The final timepoints are: {jnp.nanmax(ts, 1)} years.\")\n", "print(\n", - " f\"The dataset has {jnp.sum(jnp.isnan(data))/jnp.size(data):.0%} missing observations\"\n", + " f\"The dataset has {jnp.sum(jnp.isnan(data)) / jnp.size(data):.0%} missing observations\"\n", ")\n", "print(f\"True params mean: {sample['theta'][0]}\")" ] @@ -550,7 +550,7 @@ "mcmc.print_summary()\n", "\n", "print(f\"True params mean: {sample['theta'][0]}\")\n", - "print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")" + "print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")" ] }, { @@ -591,7 +591,7 @@ "\n", "\n", "print(f\"True params mean: {sample['theta'][0]}\")\n", - "print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")" + "print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")" ] }, { diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index e7fbe0c05..c2f00d181 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -104,7 +104,7 @@ "print(df.Y.value_counts())\n", "\n", "for i in range(nclasses):\n", - " print(f\"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}\")" + " print(f\"mean(X) for Y == {i}: {X[np.where(Y == i)].mean():.3f}\")" ] }, { diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 6c0711142..74a1ebc5b 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -82,7 +82,7 @@ def _subs_wrapper(subs_map, i, length, site): ) else: raise RuntimeError( - f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1}," + f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim + 1}," f" but got {value_ndim}. This might happen when you use nested scan," " which is currently not supported. Please report the issue to us!" ) diff --git a/numpyro/contrib/ecs_proxies.py b/numpyro/contrib/ecs_proxies.py index c17b2d167..65c125abe 100644 --- a/numpyro/contrib/ecs_proxies.py +++ b/numpyro/contrib/ecs_proxies.py @@ -12,13 +12,11 @@ TaylorTwoProxyState = namedtuple( "TaylorProxyState", - "ref_subsample_log_liks," - "ref_subsample_log_lik_grads," - "ref_subsample_log_lik_hessians", + "ref_subsample_log_liks,ref_subsample_log_lik_grads,ref_subsample_log_lik_hessians", ) TaylorOneProxyState = namedtuple( - "TaylorOneProxyState", "ref_subsample_log_liks," "ref_subsample_log_lik_grads," + "TaylorOneProxyState", "ref_subsample_log_liks,ref_subsample_log_lik_grads," ) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index c4d69ed16..873502c3a 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -10,6 +10,7 @@ from jax import grad, numpy as jnp, random, tree, vmap from jax.flatten_util import ravel_pytree +from jax.lax import scan from numpyro import handlers from numpyro.contrib.einstein.stein_loss import SteinLoss @@ -33,7 +34,6 @@ def _numel(shape): class SteinVI: """Variational inference with Stein mixtures inference. - **Example:** .. doctest:: @@ -138,9 +138,9 @@ def __init__( if isinstance(guide.init_loc_fn, partial): init_fn_name = guide.init_loc_fn.func.__name__ if init_fn_name == "init_to_uniform": - assert ( - guide.init_loc_fn.keywords.get("radius", None) != 0.0 - ), init_loc_error_message + assert guide.init_loc_fn.keywords.get("radius", None) != 0.0, ( + init_loc_error_message + ) else: init_fn_name = guide.init_loc_fn.__name__ assert init_fn_name not in [ @@ -230,25 +230,29 @@ def local_trace(key): return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles)) def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): - # 0. Separate model and guide parameters, since only guide parameters are updated using Stein - non_mixture_uparams = { # Includes any marked guide parameters and all model parameters + # Separate model and guide parameters, since only guide parameters are updated using Stein + # Split parameters into model and guide components - only unflagged guide parameters are + # optimized via Stein forces. + nonmix_uparams = { # Includes any marked guide parameters and all model parameters p: v for p, v in unconstr_params.items() if p not in self.guide_sites or self.non_mixture_params_fn(p) } + stein_uparams = { - p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams + p: v for p, v in unconstr_params.items() if p not in nonmix_uparams } - # 1. Collect each guide parameter into monolithic particles that capture correlations - # between parameter values across each individual particle + # Collect guide parameters into a monolithic particle. stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree( stein_uparams, nbatch_dims=1 ) + + # Kernel behavior varies based on particle site locations. The particle_info dictionary + # maps site names to their corresponding dimensional ranges as (start, end) tuples. particle_info, _ = self._calc_particle_info( stein_uparams, stein_particles.shape[0] ) - attractive_key, classic_key = random.split(rng_key) def particle_transform_fn(particle): params = unravel_pytree(particle) @@ -256,78 +260,54 @@ def particle_transform_fn(particle): ctparticle, _ = ravel_pytree(ctparams) return ctparticle - # 2. Calculate gradients for each particle - def kernel_particles_loss_fn(rng_key, particles): - particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles) - grads = vmap( - lambda i: grad( - lambda particle: self.stein_loss.particle_loss( - rng_key=particle_keys[i], - model=handlers.scale( - self._inference_model, self.loss_temperature - ), - guide=self.guide, - selected_particle=self.constrain_fn(unravel_pytree(particle)), - unravel_pytree=unravel_pytree, - flat_particles=vmap(particle_transform_fn)(particles), - select_index=i, - model_args=args, - model_kwargs=kwargs, - param_map=self.constrain_fn(non_mixture_uparams), - ) - )(particles[i]) - )(jnp.arange(self.stein_loss.stein_num_particles)) - - return grads - - # 2.1 Compute particle gradients (for attractive force) - particle_ljp_grads = kernel_particles_loss_fn(attractive_key, stein_particles) - - # 2.3 Lift particles to constraint space - ctstein_particles = vmap(particle_transform_fn)(stein_particles) - - # 2.4 Compute non-mixture parameter gradients - non_mixture_param_grads = grad( - lambda cps: -self.stein_loss.loss( - classic_key, - self.constrain_fn(cps), - handlers.scale(self._inference_model, self.loss_temperature), - self.guide, - unravel_pytree_batched(ctstein_particles), - *args, - **kwargs, - ) - )(non_mixture_uparams) + model = handlers.scale(self._inference_model, self.loss_temperature) - # 3. Calculate kernel of particles - def loss_fn(particle, i): + def stein_loss_fn(key, particle, particle_idx): return self.stein_loss.particle_loss( - rng_key=rng_key, - model=handlers.scale(self._inference_model, self.loss_temperature), + rng_key=key, + model=model, guide=self.guide, + # Stein particles evolve in unconstrained space, but gradient computations must account + # for the transformation to constrained space selected_particle=self.constrain_fn(unravel_pytree(particle)), unravel_pytree=unravel_pytree, - flat_particles=ctstein_particles, - select_index=i, + flat_particles=vmap(particle_transform_fn)(stein_particles), + select_index=particle_idx, model_args=args, model_kwargs=kwargs, - param_map=self.constrain_fn(non_mixture_uparams), + param_map=self.constrain_fn(nonmix_uparams), ) kernel = self.kernel_fn.compute( - rng_key, stein_particles, particle_info, loss_fn + rng_key, stein_particles, particle_info, stein_loss_fn ) + attractive_key, classic_key = random.split(rng_key) - # 4. Calculate the attractive force and repulsive force on the particles - attractive_force = vmap( - lambda y: jnp.sum( - vmap( - lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad) - )(stein_particles, particle_ljp_grads), - axis=0, - ) - )(stein_particles) + def compute_attr_force(rng_key, particles): + # Second term of eq. 9 from https://arxiv.org/pdf/2410.22948. + def body(attr_force, state, y): + key, x, i = state + x_grad = grad(stein_loss_fn, argnums=1)(key, x, i) + attr_force = attr_force + self._apply_kernel(kernel, x, y, x_grad) + return attr_force, None + + particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles) + init = jnp.zeros_like(particles[0]) + idxs = jnp.arange(self.num_stein_particles) + + attr_force, _ = vmap( + lambda y, key: scan( + partial(body, y=y), + init, + (random.split(key, self.num_stein_particles), particles, idxs), + ) + )(particles, particle_keys) + + return attr_force + attractive_force = compute_attr_force(attractive_key, stein_particles) + + # Third term of eq. 9 from https://arxiv.org/pdf/2410.22948. repulsive_force = vmap( lambda y: jnp.mean( vmap( @@ -338,15 +318,27 @@ def loss_fn(particle, i): ) )(stein_particles) - # 6. Compute the stein force particle_grads = attractive_force + repulsive_force - # 7. Decompose the monolithic particle forces back to concrete parameter values - stein_param_grads = unravel_pytree_batched(particle_grads) + # Compute non-mixture parameter gradients. + nonmix_uparam_grads = grad( + lambda cps: -self.stein_loss.loss( + classic_key, + self.constrain_fn(cps), + model, + self.guide, + unravel_pytree_batched(vmap(particle_transform_fn)(stein_particles)), + *args, + **kwargs, + ) + )(nonmix_uparams) + + # Decompose the monolithic particle forces back to concrete parameter values. + stein_uparam_grads = unravel_pytree_batched(particle_grads) - # 8. Return loss and gradients (based on parameter forces) + # Return loss and gradients (based on parameter forces). res_grads = tree.map( - lambda x: -x, {**non_mixture_param_grads, **stein_param_grads} + lambda x: -x, {**nonmix_uparam_grads, **stein_uparam_grads} ) return jnp.linalg.norm(particle_grads), res_grads diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index cf6d4ebd4..025d4931d 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -436,9 +436,9 @@ class BaseEnumMessenger(NamedMessenger): """ def __init__(self, fn=None, first_available_dim=None): - assert ( - first_available_dim is None or first_available_dim < 0 - ), first_available_dim + assert first_available_dim is None or first_available_dim < 0, ( + first_available_dim + ) self.first_available_dim = first_available_dim super().__init__(fn) diff --git a/numpyro/contrib/tfp/mcmc.py b/numpyro/contrib/tfp/mcmc.py index 7a2312e70..416b6806d 100644 --- a/numpyro/contrib/tfp/mcmc.py +++ b/numpyro/contrib/tfp/mcmc.py @@ -56,9 +56,9 @@ def log_prob_fn(x): class _TFPKernelMeta(ABCMeta): def __getitem__(cls, kernel_class): assert issubclass(kernel_class, tfp.mcmc.TransitionKernel) - assert ( - "target_log_prob_fn" in inspect.getfullargspec(kernel_class).args - ), f"the first argument of {kernel_class} must be `target_log_prob_fn`" + assert "target_log_prob_fn" in inspect.getfullargspec(kernel_class).args, ( + f"the first argument of {kernel_class} must be `target_log_prob_fn`" + ) _PyroKernel = type(kernel_class.__name__, (TFPKernel,), {}) _PyroKernel.kernel_class = kernel_class diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index af17c4794..c4b5e1ea1 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -592,13 +592,13 @@ def __init__( *, validate_args=None, ): - assert ( - isinstance(num_steps, int) and num_steps > 0 - ), "`num_steps` argument should be an positive integer." + assert isinstance(num_steps, int) and num_steps > 0, ( + "`num_steps` argument should be an positive integer." + ) self.num_steps = num_steps - assert ( - transition_matrix.ndim == 2 - ), "`transition_matrix` argument should be a square matrix" + assert transition_matrix.ndim == 2, ( + "`transition_matrix` argument should be a square matrix" + ) self.transition_matrix = transition_matrix # Expand the covariance/precision/scale matrices to the right number of steps. args = { @@ -661,9 +661,9 @@ class GaussianRandomWalk(Distribution): pytree_aux_fields = ("num_steps",) def __init__(self, scale=1.0, num_steps=1, *, validate_args=None): - assert ( - isinstance(num_steps, int) and num_steps > 0 - ), "`num_steps` argument should be an positive integer." + assert isinstance(num_steps, int) and num_steps > 0, ( + "`num_steps` argument should be an positive integer." + ) self.scale = scale self.num_steps = num_steps batch_shape, event_shape = jnp.shape(scale), (num_steps,) @@ -1762,9 +1762,9 @@ def __init__( # TODO: look into future jax sparse csr functionality and other developments self.adj_matrix = _to_sparse(adj_matrix) else: - assert not _is_sparse( - adj_matrix - ), "adj_matrix is a sparse matrix so please specify `is_sparse=True`." + assert not _is_sparse(adj_matrix), ( + "adj_matrix is a sparse matrix so please specify `is_sparse=True`." + ) # TODO: look into static jax ndarray representation (self.adj_matrix,) = promote_shapes( adj_matrix, shape=batch_shape + adj_matrix.shape[-2:] @@ -1783,14 +1783,14 @@ def __init__( ) if self._validate_args and (isinstance(adj_matrix, np.ndarray) or is_sparse): - assert ( - self.adj_matrix.sum(axis=-1) > 0 - ).all() > 0, "all sites in adjacency matrix must have neighbours" + assert (self.adj_matrix.sum(axis=-1) > 0).all() > 0, ( + "all sites in adjacency matrix must have neighbours" + ) if self.is_sparse: - assert ( - self.adj_matrix != self.adj_matrix.T - ).nnz == 0, "adjacency matrix must be symmetric" + assert (self.adj_matrix != self.adj_matrix.T).nnz == 0, ( + "adjacency matrix must be symmetric" + ) else: assert np.array_equal( self.adj_matrix, np.swapaxes(self.adj_matrix, -2, -1) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 9bc19a4d1..70dc6bee6 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -219,9 +219,9 @@ def model(obs): support = constraints.independent(constraints.circular, 1) def __init__(self, base_dist: Distribution, skewness, *, validate_args=None): - assert ( - base_dist.event_shape == skewness.shape[-1:] - ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." + assert base_dist.event_shape == skewness.shape[-1:], ( + "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." + ) batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] @@ -600,8 +600,7 @@ def log_prob(self, value): event_shape = value.shape[-1:] if event_shape != self.event_shape: raise ValueError( - f"Expected event shape {self.event_shape}, " - f"but got {event_shape}" + f"Expected event shape {self.event_shape}, but got {event_shape}" ) self._validate_sample(value) dim = int(self.concentration.shape[-1]) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index e56d97e6b..4f0fea56a 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -228,7 +228,7 @@ def enumerate_support(self, expand=True): # NB: the error can't be raised if inhomogeneous issue happens when tracing if np.amin(self.total_count) != total_count: raise NotImplementedError( - "Inhomogeneous total count not supported" " by `enumerate_support`." + "Inhomogeneous total count not supported by `enumerate_support`." ) else: total_count = jnp.amax(self.total_count) diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index 19a821c89..d8447b35a 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -359,9 +359,9 @@ def __init__( "All component distributions must have the same support." ) else: - assert isinstance( - support, constraints.Constraint - ), "support must be a Constraint object" + assert isinstance(support, constraints.Constraint), ( + "support must be a Constraint object" + ) self._mixing_distribution = mixing_distribution self._component_distributions = component_distributions diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 78db31d6a..d9b86a465 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -34,9 +34,9 @@ class LeftTruncatedDistribution(Distribution): def __init__(self, base_dist, low=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) - assert ( - base_dist.support is constraints.real - ), "The base distribution should be univariate and have real support." + assert base_dist.support is constraints.real, ( + "The base distribution should be univariate and have real support." + ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low)) self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist @@ -117,9 +117,9 @@ class RightTruncatedDistribution(Distribution): def __init__(self, base_dist, high=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) - assert ( - base_dist.support is constraints.real - ), "The base distribution should be univariate and have real support." + assert base_dist.support is constraints.real, ( + "The base distribution should be univariate and have real support." + ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high)) self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist @@ -188,9 +188,9 @@ class TwoSidedTruncatedDistribution(Distribution): def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) - assert ( - base_dist.support is constraints.real - ), "The base distribution should be univariate and have real support." + assert base_dist.support is constraints.real, ( + "The base distribution should be univariate and have real support." + ) batch_shape = lax.broadcast_shapes( base_dist.batch_shape, jnp.shape(low), jnp.shape(high) ) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 326ec8f32..1f739d5b6 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -434,9 +434,7 @@ def __init__( self.condition_fn = condition_fn self.data = data if sum((x is not None for x in (data, condition_fn))) != 1: - raise ValueError( - "Only one of `data` or `condition_fn` " "should be provided." - ) + raise ValueError("Only one of `data` or `condition_fn` should be provided.") super(condition, self).__init__(fn) def process_message(self, msg): @@ -874,7 +872,7 @@ def __init__( self.data = data if sum((x is not None for x in (data, substitute_fn))) != 1: raise ValueError( - "Only one of `data` or `substitute_fn` " "should be provided." + "Only one of `data` or `substitute_fn` should be provided." ) super(substitute, self).__init__(fn) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 6df31055a..694f7c8b3 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -104,9 +104,9 @@ def _create_plates(self, *args, **kwargs): plates = self.create_plates(*args, **kwargs) if isinstance(plates, numpyro.plate): plates = [plates] - assert all( - isinstance(p, numpyro.plate) for p in plates - ), "create_plates() returned a non-plate" + assert all(isinstance(p, numpyro.plate) for p in plates), ( + "create_plates() returned a non-plate" + ) self.plates = {p.name: p for p in plates} for name, frame in sorted(self._prototype_frames.items()): if name not in self.plates: @@ -181,9 +181,9 @@ def _setup_prototype(self, *args, **kwargs): biject_to(site["fn"].support) for frame in site["cond_indep_stack"]: if frame.name in self._prototype_frames: - assert ( - frame == self._prototype_frames[frame.name] - ), f"The plate {frame.name} has inconsistent dim or size. Please check your model again." + assert frame == self._prototype_frames[frame.name], ( + f"The plate {frame.name} has inconsistent dim or size. Please check your model again." + ) else: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": @@ -786,9 +786,9 @@ def get_transform(self, params): :rtype: :class:`~numpyro.distributions.transforms.Transform` """ posterior = handlers.substitute(self._get_posterior, params)() - assert isinstance( - posterior, dist.TransformedDistribution - ), "posterior is not a transformed distribution" + assert isinstance(posterior, dist.TransformedDistribution), ( + "posterior is not a transformed distribution" + ) if len(posterior.transforms) > 0: return ComposeTransform(posterior.transforms) else: @@ -1355,9 +1355,9 @@ def _setup_prototype(self, *args, **kwargs): if site["type"] == "plate" } num_plates = len(subsample_plates) - assert ( - num_plates == 1 - ), f"AutoSemiDAIS assumes that the model contains exactly 1 plate with data subsampling but got {num_plates}." + assert num_plates == 1, ( + f"AutoSemiDAIS assumes that the model contains exactly 1 plate with data subsampling but got {num_plates}." + ) plate_name = list(subsample_plates.keys())[0] local_vars = [] subsample_axes = {} diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py index 52babdf11..c7f978e7e 100644 --- a/numpyro/infer/barker.py +++ b/numpyro/infer/barker.py @@ -181,7 +181,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ) if self._potential_fn and init_params is None: raise ValueError( - "Valid value of `init_params` must be provided with" " `potential_fn`." + "Valid value of `init_params` must be provided with `potential_fn`." ) pe, grad = jax.value_and_grad(self._potential_fn)(init_params) diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 64a2870ad..d12689c84 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -289,12 +289,12 @@ def __init__( moves ) - assert all( - [hasattr(move, "__call__") for move in self._moves] - ), "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove())." - assert jnp.all( - self._weights >= 0 - ), "Each specified move must have probability >= 0" + assert all([hasattr(move, "__call__") for move in self._moves]), ( + "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove())." + ) + assert jnp.all(self._weights >= 0), ( + "Each specified move must have probability >= 0" + ) super().__init__( model, @@ -515,9 +515,9 @@ def __init__( "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)" ) - assert jnp.all( - self._weights >= 0 - ), "Each specified move must have probability >= 0" + assert jnp.all(self._weights >= 0), ( + "Each specified move must have probability >= 0" + ) assert init_mu > 0, "Scale factor should be strictly positive" self._max_steps = max_steps # max number of stepping out steps diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index a299dc89c..50a7cbd53 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -751,7 +751,7 @@ def init( ) if self._potential_fn and init_params is None: raise ValueError( - "Valid value of `init_params` must be provided with" " `potential_fn`." + "Valid value of `init_params` must be provided with `potential_fn`." ) # change dense_mass to a structural form diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f6b95389b..4395c7251 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -89,9 +89,9 @@ def __init__(self, inner_kernel, gibbs_fn, gibbs_sites): raise ValueError("inner_kernel must be a HMC or NUTS sampler.") if not callable(gibbs_fn): raise ValueError("gibbs_fn must be a callable") - assert ( - inner_kernel.model is not None - ), "HMCGibbs does not support models specified via a potential function." + assert inner_kernel.model is not None, ( + "HMCGibbs does not support models specified via a potential function." + ) self.inner_kernel = copy.copy(inner_kernel) self.inner_kernel._model = partial(_wrap_model, inner_kernel.model) @@ -442,9 +442,9 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and not site["is_observed"] and site["infer"].get("enumerate", "") != "parallel" ] - assert ( - self._gibbs_sites - ), "Cannot detect any discrete latent variables in the model." + assert self._gibbs_sites, ( + "Cannot detect any discrete latent variables in the model." + ) return super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) def sample(self, state, model_args, model_kwargs): diff --git a/numpyro/infer/sa.py b/numpyro/infer/sa.py index 91f963f41..e8d893f11 100644 --- a/numpyro/infer/sa.py +++ b/numpyro/infer/sa.py @@ -345,7 +345,7 @@ def init( ) if self._potential_fn and init_params is None: raise ValueError( - "Valid value of `init_params` must be provided with" " `potential_fn`." + "Valid value of `init_params` must be provided with `potential_fn`." ) # NB: init args is different from HMC diff --git a/numpyro/optim.py b/numpyro/optim.py index 8d1717b16..8ba11dfe4 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -241,13 +241,11 @@ class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"]) ) -def _minimize_wrapper() -> ( - tuple[ - Callable[[_Params], _MinimizeState], - Callable[[Any, Any, _MinimizeState], _MinimizeState], - Callable[[_MinimizeState], _Params], - ] -): +def _minimize_wrapper() -> tuple[ + Callable[[_Params], _MinimizeState], + Callable[[Any, Any, _MinimizeState], _MinimizeState], + Callable[[_MinimizeState], _Params], +]: def init_fn(params: _Params) -> _MinimizeState: flat_params, unravel_fn = ravel_pytree(params) return _MinimizeState(flat_params, unravel_fn) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index dc02e8210..157cf069d 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -178,9 +178,9 @@ def sample( argument is not intended to be used with MCMC. :return: sample from the stochastic `fn`. """ - assert isinstance( - sample_shape, tuple - ), "sample_shape needs to be a tuple of integers" + assert isinstance(sample_shape, tuple), ( + "sample_shape needs to be a tuple of integers" + ) if not isinstance(fn, numpyro.distributions.Distribution): type_error = TypeError( "It looks like you tried to use a fn that isn't an instance of " @@ -280,9 +280,9 @@ def param( """ # if there are no active Messengers, we just draw a sample and return it as expected: if not _PYRO_STACK: - assert not callable( - init_value - ), "A callable init_value needs to be put inside a numpyro.handlers.seed handler." + assert not callable(init_value), ( + "A callable init_value needs to be put inside a numpyro.handlers.seed handler." + ) return init_value if callable(init_value): diff --git a/test/test_distributions.py b/test/test_distributions.py index 03c5813d6..04cd4d7a0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3446,9 +3446,9 @@ def _assert_not_jax_issue_19885( if block_until_ready: result = block_until_ready() _, err = capfd.readouterr() - assert ( - "MatMul reference implementation being executed" not in err - ), f"jit: {jit}" + assert "MatMul reference implementation being executed" not in err, ( + f"jit: {jit}" + ) return result diff --git a/test/test_distributions_mixture.py b/test/test_distributions_mixture.py index 82064f67b..c1beaf2fd 100644 --- a/test/test_distributions_mixture.py +++ b/test/test_distributions_mixture.py @@ -129,18 +129,18 @@ def _test_mixture(mixing_distribution, component_distribution): mixing_distribution=mixing_distribution, component_distributions=component_distribution, ) - assert ( - mixture.mixture_size == mixing_distribution.probs.shape[-1] - ), "Mixture size needs to be the size of the probability vector" + assert mixture.mixture_size == mixing_distribution.probs.shape[-1], ( + "Mixture size needs to be the size of the probability vector" + ) if isinstance(component_distribution, dist.Distribution): - assert ( - mixture.batch_shape == component_distribution.batch_shape[:-1] - ), "Mixture batch shape needs to be the component batch shape without the mixture dimension." + assert mixture.batch_shape == component_distribution.batch_shape[:-1], ( + "Mixture batch shape needs to be the component batch shape without the mixture dimension." + ) else: - assert ( - mixture.batch_shape == component_distribution[0].batch_shape - ), "Mixture batch shape needs to be the component batch shape." + assert mixture.batch_shape == component_distribution[0].batch_shape, ( + "Mixture batch shape needs to be the component batch shape." + ) # Test samples sample_shape = (11,) # Samples from component distribution(s)