Skip to content

Commit

Permalink
fix #1364: conditional posterior shape and device bugs. (#1373)
Browse files Browse the repository at this point in the history
* fix shape and device bugs.

* fix: do not test batched and iid x

* fix coverage and testing bugs
  • Loading branch information
janfb authored Jan 22, 2025
1 parent 80740b2 commit 448cef2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 13 deletions.
16 changes: 13 additions & 3 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,14 @@ def condition_on_theta(
def conditioned_potential(
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert len(dims_global_theta) == theta.shape[1], (
assert len(dims_global_theta) == theta.shape[-1], (
"dims_global_theta must match the number of parameters to sample."
)
if theta.dim() > 2:
assert theta.shape[0] == 1, (
"condition_on_theta does not support sample shape for theta."
)
theta = theta.squeeze(0)
global_theta = theta[:, dims_global_theta]
x_o = x_o if x_o is not None else self.x_o
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
Expand All @@ -155,7 +160,7 @@ def conditioned_potential(
)

return _log_likelihood_over_iid_trials_and_local_theta(
x=x_o,
x=x_o.to(self.device),
global_theta=global_theta,
local_theta=local_theta,
estimator=self.likelihood_estimator,
Expand Down Expand Up @@ -266,6 +271,10 @@ def _log_likelihood_over_iid_trials_and_local_theta(
assert local_theta.shape[0] == num_trials, (
"Condition batch size must match the number of iid trials in x."
)
if num_xs > 1:
raise NotImplementedError(
"Batched sampling for multiple `x` is not supported for iid conditions."
)

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
Expand All @@ -289,7 +298,8 @@ def _log_likelihood_over_iid_trials_and_local_theta(
num_xs, num_trials, num_thetas
).sum(1)

return log_likelihood_trial_sum
# remove xs batch dimension
return log_likelihood_trial_sum.squeeze(0)


def mixed_likelihood_estimator_based_potential(
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ def gradient_ascent(
)
best_theta_iter = optimize_inits[ # type: ignore
torch.argmax(log_probs_of_optimized)
].view(1, -1)
].unsqueeze(0) # add batch dim
best_log_prob_iter = potential_fn(
theta_transform.inv(best_theta_iter)
)
Expand Down
53 changes: 52 additions & 1 deletion tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ratio_estimator_based_potential,
)
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.neural_nets.factory import (
Expand All @@ -33,7 +34,11 @@
posterior_nn,
)
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.torchutils import (
BoxUniform,
gpu_available,
process_device,
)
from sbi.utils.user_input_checks import (
validate_theta_and_x,
)
Expand Down Expand Up @@ -465,3 +470,49 @@ def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device:
proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim))
theta = proposal.sample((num_simulations,))
x = simulator(theta)


@pytest.mark.gpu
@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_conditioned_posterior_on_gpu(device: str, mcmc_params_fast: dict):
device = process_device(device)
num_dims = 3

proposal = BoxUniform(
low=-torch.ones(num_dims, device=device),
high=torch.ones(num_dims, device=device),
)

inference = NPE_C(device=device, show_progress_bars=False)

num_simulations = 100
theta = proposal.sample((num_simulations,))
x = torch.randn_like(theta)
x_o = torch.zeros(1, num_dims).to(device)
inference = inference.append_simulations(theta, x)

estimator = inference.train(max_num_epochs=2)

# condition on one dim of theta
condition_o = torch.ones(1, 1).to(device)
prior = BoxUniform(
low=-torch.ones(num_dims - 1, device=device),
high=torch.ones(num_dims - 1, device=device),
)
prior_transform = utils.mcmc_transform(prior, device=device)

potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
conditioned_potential_fn = potential_fn.condition_on_theta(
condition_o, dims_global_theta=[0, 1]
)

conditional_posterior = MCMCPosterior(
potential_fn=conditioned_potential_fn,
theta_transform=prior_transform,
proposal=prior,
device=device,
**mcmc_params_fast,
).set_default_x(x_o)
samples = conditional_posterior.sample((1,), x=x_o)
conditional_posterior.potential_fn(samples)
conditional_posterior.map()
36 changes: 28 additions & 8 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.
return torch.cat((rts, choices), dim=1)


def wrapped_simulator(
def mixed_simulator_with_conditions(
theta_and_condition: Tensor, last_idx_parameters: int = 2
) -> Tensor:
"""Simulator for mixed data with experimental conditions."""
# simulate with experiment conditions
theta = theta_and_condition[:, :last_idx_parameters]
condition = theta_and_condition[:, last_idx_parameters:]
Expand Down Expand Up @@ -278,7 +279,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
)

theta = proposal.sample((num_simulations,))
x = wrapped_simulator(theta)
x = mixed_simulator_with_conditions(theta)
assert x.shape == (num_simulations, 2)

num_trials = 10
Expand All @@ -289,7 +290,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
condition_o = theta_and_condition[:, 2:]
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)

x_o = wrapped_simulator(theta_and_conditions_o)
x_o = mixed_simulator_with_conditions(theta_and_conditions_o)

mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
Expand All @@ -313,6 +314,9 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
],
validate_args=False,
)
# test theta with sample shape.
conditioned_potential_fn(prior.sample((10,)).unsqueeze(0))

prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
BinomialGammaPotential(
Expand All @@ -339,14 +343,28 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):

@pytest.mark.parametrize("num_thetas", [1, 10])
@pytest.mark.parametrize("num_trials", [1, 5])
@pytest.mark.parametrize("num_xs", [1, 3])
@pytest.mark.parametrize(
"num_xs",
[
1,
pytest.param(
2,
marks=pytest.mark.xfail(
reason="Batched x not supported for iid trials.",
raises=NotImplementedError,
),
),
],
)
@pytest.mark.parametrize(
"num_conditions",
[
1,
pytest.param(
2,
marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"),
marks=pytest.mark.xfail(
reason="Batched theta_condition is not supported",
),
),
],
)
Expand Down Expand Up @@ -376,7 +394,7 @@ def test_log_likelihood_over_local_iid_theta(

num_simulations = 100
theta = proposal.sample((num_simulations,))
x = wrapped_simulator(theta)
x = mixed_simulator_with_conditions(theta)
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)

# condition on multiple conditions
Expand Down Expand Up @@ -407,8 +425,10 @@ def test_log_likelihood_over_local_iid_theta(
)
x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
ll_single = torch.stack(ll_single).sum(0) # sum over trials
ll_single = (
torch.stack(ll_single).sum(0).squeeze(0)
) # sum over trials, squeeze x batch.

assert ll_batched.shape == torch.Size([num_xs, num_thetas])
assert ll_batched.shape == torch.Size([num_thetas])
assert ll_batched.shape == ll_single.shape
assert torch.allclose(ll_batched, ll_single, atol=1e-5)

0 comments on commit 448cef2

Please sign in to comment.