You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fromsbi.inferenceimportNPE, DirectPosteriortrainer=NPE()
net=trainer.append_simulations(theta, x).train()
posterior=DirectPosterior(net, prior) # Or use `build_posterior()`
For SNPE_A, it is:
fromsbi.inferenceimportNPE, DirectPosterior, snpe_a_correctionforrinrange(3):
theta=proposal.sample((1000,))
x=simulator(theta)
trainer=NPE(density_estimator="Gaussian"ifr<2else"mdn")
net=trainer.append_simulations(theta, x).train()
proposal_posterior=DirectPosterior(net, prior) # Or use `build_posterior()`corrected_posterior=snpe_a_correction(proposal_posterior, proposal)
proposal=corrected_posterior
For SNPE_C (atomic), it is:
fromsbi.inferenceimportNPE, DirectPosterior, snpe_c_atomic_loss# First round is standard NPE.theta, x=simulate_for_sbi(prior, simulator)
trainer=NPE()
net=trainer.append_simulations(theta, x).train()
proposal=DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`# Later rounds use the APT loss.for_inrange(1, 3):
theta, x=simulate_for_sbi(proposal, simulator)
net=trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
proposal=DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`
For SNPE_C (non-atomic), the only difference is that one would also pass proposal=proposal to append_simulations(), and one has to use MDNs.
The text was updated successfully, but these errors were encountered:
-- inference
----- trainers
--------- npe
------------- npe.py
------------- snpe_a_correction.py
------------- snpe_c_loss.py
Then, the API for
NPE
(amortized) is:For
SNPE_A
, it is:For
SNPE_C
(atomic), it is:For
SNPE_C
(non-atomic), the only difference is that one would also passproposal=proposal
toappend_simulations()
, and one has to useMDN
s.The text was updated successfully, but these errors were encountered: