Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New organization of SNPE methods #1241

Open
michaeldeistler opened this issue Aug 28, 2024 · 0 comments
Open

New organization of SNPE methods #1241

michaeldeistler opened this issue Aug 28, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Aug 28, 2024

-- inference
----- trainers
--------- npe
------------- npe.py
------------- snpe_a_correction.py
------------- snpe_c_loss.py

Then, the API for NPE (amortized) is:

from sbi.inference import NPE, DirectPosterior

trainer = NPE()
net = trainer.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`

For SNPE_A, it is:

from sbi.inference import NPE, DirectPosterior, snpe_a_correction

for r in range(3):
    theta = proposal.sample((1000,))
    x = simulator(theta)

    trainer = NPE(density_estimator="Gaussian" if r < 2 else "mdn")
    net = trainer.append_simulations(theta, x).train()
    proposal_posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`
    corrected_posterior = snpe_a_correction(proposal_posterior, proposal)
    proposal = corrected_posterior

For SNPE_C (atomic), it is:

from sbi.inference import NPE, DirectPosterior, snpe_c_atomic_loss

# First round is standard NPE.
theta, x = simulate_for_sbi(prior, simulator)
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

# Later rounds use the APT loss.
for _ in range(1, 3):
    theta, x = simulate_for_sbi(proposal, simulator)
    net = trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
    proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

For SNPE_C (non-atomic), the only difference is that one would also pass proposal=proposal to append_simulations(), and one has to use MDNs.

@michaeldeistler michaeldeistler added the enhancement New feature or request label Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant