From 69677f4c589a60dfe0d8d5362c51df7e3f50c26f Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Wed, 3 Apr 2024 15:49:04 +0200 Subject: [PATCH] expose progress bar in IS posterior. --- sbi/inference/posteriors/importance_posterior.py | 10 ++++++++++ sbi/samplers/importance/importance_sampling.py | 7 ++++++- tests/linearGaussian_snle_test.py | 2 +- tests/linearGaussian_snre_test.py | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index ccad0e64d..ace31ba2f 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -160,12 +160,18 @@ def sample( oversampling_factor: int = 32, max_sampling_batch_size: int = 10_000, sample_with: Optional[str] = None, + show_progress_bars: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Return samples from the approximate posterior distribution. Args: sample_shape: _description_ x: _description_ + oversampling_factor: Number of proposed samples from which only one is + selected based on its importance weight. + max_sampling_batch_size: The batch size of samples being drawn from the + proposal at every iteration. + show_progress_bars: Whether to show a progressbar during sampling. """ if sample_with is not None: raise ValueError( @@ -181,6 +187,7 @@ def sample( sample_shape, oversampling_factor=oversampling_factor, max_sampling_batch_size=max_sampling_batch_size, + show_progress_bars=show_progress_bars, ) elif self.method == "importance": return self._importance_sample(sample_shape) @@ -190,6 +197,7 @@ def sample( def _importance_sample( self, sample_shape: Shape = torch.Size(), + show_progress_bars: bool = False, ) -> Tuple[Tensor, Tensor]: """Returns samples from the proposal and log of their importance weights. @@ -197,6 +205,7 @@ def _importance_sample( sample_shape: Desired shape of samples that are drawn from posterior. sample_with: This argument only exists to keep backward-compatibility with `sbi` v0.17.2 or older. If it is set, we instantly raise an error. + show_progress_bars: Whether to show sampling progress monitor. Returns: Samples and logarithm of corresponding importance weights. @@ -206,6 +215,7 @@ def _importance_sample( self.potential_fn, proposal=self.proposal, num_samples=num_samples, + show_progress_bars=show_progress_bars, ) samples = samples.reshape((*sample_shape, -1)).to(self._device) diff --git a/sbi/samplers/importance/importance_sampling.py b/sbi/samplers/importance/importance_sampling.py index 20199b68b..2660ab310 100644 --- a/sbi/samplers/importance/importance_sampling.py +++ b/sbi/samplers/importance/importance_sampling.py @@ -9,6 +9,7 @@ def importance_sample( potential_fn, proposal, num_samples: int = 1, + show_progress_bars: bool = False, ) -> Tuple[Tensor, Tensor]: """Returns samples from proposal and log(importance weights). @@ -20,7 +21,11 @@ def importance_sample( Returns: Samples and logarithm of importance weights. """ - samples = proposal.sample((num_samples,)) + # Use progress bars when available (e.g., for multi-round proposals) + try: + samples = proposal.sample((num_samples,), show_progress_bar=show_progress_bars) + except TypeError: + samples = proposal.sample((num_samples,)) potential_logprobs = potential_fn(samples) proposal_logprobs = proposal.log_prob(samples) diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 44ec03824..aae1370b7 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -502,4 +502,4 @@ def test_api_snl_sampling_methods( ) posterior.train(max_num_iters=10) - posterior.sample(sample_shape=(num_samples,)) + posterior.sample(sample_shape=(num_samples,), show_progress_bars=False) diff --git a/tests/linearGaussian_snre_test.py b/tests/linearGaussian_snre_test.py index 086c7ce73..9380f8890 100644 --- a/tests/linearGaussian_snre_test.py +++ b/tests/linearGaussian_snre_test.py @@ -419,4 +419,4 @@ def test_api_sre_sampling_methods( ) posterior.train(max_num_iters=10) - posterior.sample(sample_shape=(num_samples,)) + posterior.sample(sample_shape=(num_samples,), show_progress_bars=False)