Skip to content

Commit

Permalink
expose progress bar in IS posterior.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Apr 8, 2024
1 parent a126b1d commit 69677f4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
10 changes: 10 additions & 0 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,18 @@ def sample(
oversampling_factor: int = 32,
max_sampling_batch_size: int = 10_000,
sample_with: Optional[str] = None,
show_progress_bars: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Return samples from the approximate posterior distribution.
Args:
sample_shape: _description_
x: _description_
oversampling_factor: Number of proposed samples from which only one is
selected based on its importance weight.
max_sampling_batch_size: The batch size of samples being drawn from the
proposal at every iteration.
show_progress_bars: Whether to show a progressbar during sampling.
"""
if sample_with is not None:
raise ValueError(
Expand All @@ -181,6 +187,7 @@ def sample(
sample_shape,
oversampling_factor=oversampling_factor,
max_sampling_batch_size=max_sampling_batch_size,
show_progress_bars=show_progress_bars,
)
elif self.method == "importance":
return self._importance_sample(sample_shape)
Expand All @@ -190,13 +197,15 @@ def sample(
def _importance_sample(
self,
sample_shape: Shape = torch.Size(),
show_progress_bars: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Returns samples from the proposal and log of their importance weights.
Args:
sample_shape: Desired shape of samples that are drawn from posterior.
sample_with: This argument only exists to keep backward-compatibility with
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
show_progress_bars: Whether to show sampling progress monitor.
Returns:
Samples and logarithm of corresponding importance weights.
Expand All @@ -206,6 +215,7 @@ def _importance_sample(
self.potential_fn,
proposal=self.proposal,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
)

samples = samples.reshape((*sample_shape, -1)).to(self._device)
Expand Down
7 changes: 6 additions & 1 deletion sbi/samplers/importance/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def importance_sample(
potential_fn,
proposal,
num_samples: int = 1,
show_progress_bars: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Returns samples from proposal and log(importance weights).
Expand All @@ -20,7 +21,11 @@ def importance_sample(
Returns:
Samples and logarithm of importance weights.
"""
samples = proposal.sample((num_samples,))
# Use progress bars when available (e.g., for multi-round proposals)
try:
samples = proposal.sample((num_samples,), show_progress_bar=show_progress_bars)
except TypeError:
samples = proposal.sample((num_samples,))

potential_logprobs = potential_fn(samples)
proposal_logprobs = proposal.log_prob(samples)
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,4 +502,4 @@ def test_api_snl_sampling_methods(
)
posterior.train(max_num_iters=10)

posterior.sample(sample_shape=(num_samples,))
posterior.sample(sample_shape=(num_samples,), show_progress_bars=False)
2 changes: 1 addition & 1 deletion tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,4 +419,4 @@ def test_api_sre_sampling_methods(
)
posterior.train(max_num_iters=10)

posterior.sample(sample_shape=(num_samples,))
posterior.sample(sample_shape=(num_samples,), show_progress_bars=False)

0 comments on commit 69677f4

Please sign in to comment.