From 11f2ae8e9f636dcd4f9d3774dbe1395b4738940f Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 24 Oct 2023 17:21:04 +0100 Subject: [PATCH 01/12] feat: add sample_latent_distribution method --- nessai/flows/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nessai/flows/base.py b/nessai/flows/base.py index 6a24fcf3..3bf4d516 100644 --- a/nessai/flows/base.py +++ b/nessai/flows/base.py @@ -76,6 +76,11 @@ def log_prob(self, x, context=None): """ raise NotImplementedError() + @abstractmethod + def sample_latent_distribution(self, n, context=None): + """Sample from the latent distribution.""" + raise NotImplementedError + @abstractmethod def base_distribution_log_prob(self, z, context=None): """ @@ -237,6 +242,11 @@ def log_prob(self, inputs, context=None): log_prob = self._distribution.log_prob(noise) return log_prob + logabsdet + def sample_latent_distribution(self, n, context=None): + if context is not None: + raise NotImplementedError + return self._distribution.sample(n) + def base_distribution_log_prob(self, z, context=None): """ Computes the log probability of samples in the latent for From 8e328265b188078d079f5474f78b66a22de35aec Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 24 Oct 2023 17:21:20 +0100 Subject: [PATCH 02/12] test: add tests for sample_latent_distribution --- tests/test_flows/test_base_flow.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_flows/test_base_flow.py b/tests/test_flows/test_base_flow.py index 1216f081..7b056230 100644 --- a/tests/test_flows/test_base_flow.py +++ b/tests/test_flows/test_base_flow.py @@ -3,6 +3,7 @@ Test the base flow class """ import pytest +import torch from unittest.mock import MagicMock, create_autospec, patch from nessai.flows.base import BaseFlow, NFlow @@ -140,3 +141,16 @@ def test_nflow_unfreeze(nflow): nflow._transform.requires_grad_ = MagicMock() NFlow.unfreeze_transform(nflow) nflow._transform.requires_grad_.assert_called_once_with(True) + + +def test_nflow_sample_latent_distribution(nflow): + n = 10 + NFlow.sample_latent_distribution(nflow, n) + nflow._distribution.sample.assert_called_once_with(n) + + +def test_nflow_sample_latent_distribution_context(nflow): + n = 10 + context = torch.randn(10) + with pytest.raises(NotImplementedError): + NFlow.sample_latent_distribution(nflow, n, context=context) From e186acd5b0f38a397939e0d00e3647eb80d56e56 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 24 Oct 2023 17:22:54 +0100 Subject: [PATCH 03/12] feat: add sample_latent_distribution method --- nessai/flowmodel/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/nessai/flowmodel/base.py b/nessai/flowmodel/base.py index c4e37518..65b944b1 100644 --- a/nessai/flowmodel/base.py +++ b/nessai/flowmodel/base.py @@ -775,6 +775,23 @@ def sample(self, n: int = 1) -> np.ndarray: x = self.model.sample(int(n)) return x.cpu().numpy().astype(np.float64) + def sample_latent_distribution(self, n: int = 1) -> np.ndarray: + """Sample from the latent distribution. + + Parameters + ---------- + n : int + Number of samples to draw + + Returns + ------- + numpy.ndarray + Array of samples + """ + with torch.inference_mode(): + z = self.model.sample_latent_distribution(n) + return z.cpu().numpy().astype(np.float64) + def sample_and_log_prob(self, N=1, z=None, alt_dist=None): """ Generate samples from samples drawn from the base distribution or From 9fa34d1d03edbc60e97ff1c9402c6908e824e8a6 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 24 Oct 2023 17:23:09 +0100 Subject: [PATCH 04/12] test: test for sample_latent_distribution --- tests/test_flowmodel/test_flowmodel_base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_flowmodel/test_flowmodel_base.py b/tests/test_flowmodel/test_flowmodel_base.py index 7aca6d50..a1678d48 100644 --- a/tests/test_flowmodel/test_flowmodel_base.py +++ b/tests/test_flowmodel/test_flowmodel_base.py @@ -423,6 +423,17 @@ def test_sample(model): np.testing.assert_array_equal(out, x.numpy()) +def test_sample_latent_distribution(model): + """Assert the correct method is called""" + n = 10 + z = torch.randn(n, 2) + model.model = MagicMock() + model.model.sample_latent_distribution = MagicMock(return_value=z) + out = FlowModel.sample_latent_distribution(model, n) + model.model.sample_latent_distribution.assert_called_once_with(n) + np.testing.assert_array_equal(out, z.numpy()) + + def test_move_to_update_default(model): """Ensure the stored device is updated""" model.device = "cuda" From e4b227fb1a3d07bc07fe00ed541db859bad1aa5e Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 24 Oct 2023 17:28:34 +0100 Subject: [PATCH 05/12] feat: rework truncation based on log_q --- nessai/proposal/flowproposal.py | 107 +++++++++++++++++++------------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 1ef1127a..b2192d98 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -34,6 +34,7 @@ save_live_points, ) from ..utils.sampling import NDimensionalTruncatedGaussian +from ..utils.structures import get_subset_arrays logger = logging.getLogger(__name__) @@ -100,9 +101,8 @@ class FlowProposal(RejectionProposal): Similar to ``fuzz`` but instead a scaling factor applied to the radius this specifies a rescaling for volume of the n-ball used to draw samples. This is translated to a value for ``fuzz``. - truncate : bool, optional - Truncate proposals using probability compute for worst point. - Not recommended. + truncate_log_q : bool, optional + Truncate proposals using minimum log-probability of the training data. rescale_parameters : list or bool, optional If True live points are rescaled to `rescale_bounds` before training. If an instance of `list` then must contain names of parameters to @@ -174,7 +174,7 @@ def __init__( fixed_radius=False, drawsize=None, check_acceptance=False, - truncate=False, + truncate_log_q=False, rescale_bounds=[-1, 1], expansion_fraction=4.0, boundary_inversion=False, @@ -245,7 +245,7 @@ def __init__( self.update_bounds = update_bounds self.check_acceptance = check_acceptance self.rescale_bounds = rescale_bounds - self.truncate = truncate + self.truncate_log_q = truncate_log_q self.boundary_inversion = boundary_inversion self.inversion_type = inversion_type self.flow_config = flow_config @@ -420,6 +420,8 @@ def configure_latent_prior(self): from ..utils import draw_nsphere self._draw_latent_prior = draw_nsphere + elif self.latent_prior == "flow": + self._draw_latent_prior = None else: raise RuntimeError( f"Unknown latent prior: {self.latent_prior}, choose from: " @@ -1088,7 +1090,9 @@ def forward_pass(self, x, rescale=True, compute_radius=True): return z, log_prob + log_J - def backward_pass(self, z, rescale=True): + def backward_pass( + self, z, rescale=True, discard_nans=True, return_z=False + ): """ A backwards pass from the model (latent -> real) @@ -1098,14 +1102,21 @@ def backward_pass(self, z, rescale=True): Structured array of points in the latent space rescale : bool, optional (True) Apply inverse rescaling function + discard_nan: bool + If True, samples with NaNs or Infs in log_q are removed. + return_z : bool + If True, return the array of latent samples, this may differ from + the input since samples can be discarded. Returns ------- x : array_like - Samples in the latent space + Samples in the data space log_prob : array_like Log probabilities corresponding to each sample (including the Jacobian) + z : array_like + Samples in the latent space, only returned if :code:`return_z=True` """ # Compute the log probability try: @@ -1115,8 +1126,9 @@ def backward_pass(self, z, rescale=True): except AssertionError: return np.array([]), np.array([]) - valid = np.isfinite(log_prob) - x, log_prob = x[valid], log_prob[valid] + if discard_nans: + valid = np.isfinite(log_prob) + x, log_prob = x[valid], log_prob[valid] x = numpy_array_to_live_points( x.astype(config.livepoints.default_float_dtype), self.rescaled_names, @@ -1126,10 +1138,13 @@ def backward_pass(self, z, rescale=True): x, log_J = self.inverse_rescale(x) # Include Jacobian for the rescaling log_prob -= log_J - x, log_prob = self.check_prior_bounds(x, log_prob) - return x, log_prob + x, z, log_prob = self.check_prior_bounds(x, z, log_prob) + if return_z: + return x, log_prob, z + else: + return x, log_prob - def radius(self, z, log_q=None): + def radius(self, z, *arrays): """ Calculate the radius of a latent point or set of latent points. If multiple points are parsed the maximum radius is returned. @@ -1138,22 +1153,18 @@ def radius(self, z, log_q=None): ---------- z : :obj:`np.ndarray` Array of points in the latent space - log_q : :obj:`np.ndarray`, optional (None) - Array of corresponding probabilities. If specified - then probability of the maximum radius is also returned. + *arrays : + Additional arrays to return the corresponding value Returns ------- tuple of arrays - Tuple of array with the maximum radius and corresponding log_q - if it was a specified input. + Tuple of array with the maximum radius and corresponding values + from any additional arrays that were passed. """ - if log_q is not None: - r = np.sqrt(np.sum(z**2.0, axis=-1)) - i = np.argmax(r) - return r[i], log_q[i] - else: - return np.nanmax(np.sqrt(np.sum(z**2.0, axis=-1))) + r = np.sqrt(np.sum(z**2.0, axis=-1)) + i = np.nanargmax(r) + return (r[i],) + (a[i] for a in arrays) def log_prior(self, x): """ @@ -1219,7 +1230,7 @@ def compute_weights(self, x, log_q): log_w -= np.max(log_w) return log_w - def rejection_sampling(self, z, worst_q=None): + def rejection_sampling(self, z, min_log_q=None): """ Perform rejection sampling. @@ -1230,9 +1241,9 @@ def rejection_sampling(self, z, worst_q=None): ---------- z : ndarray Samples from the latent space - worst_q : float, optional + min_log_q : float, optional Lower bound on the log-probability computed using the flow that - is used to truncate new samples. Not recommended. + is used to truncate new samples. Returns ------- @@ -1241,20 +1252,24 @@ def rejection_sampling(self, z, worst_q=None): array_like Array of accepted samples in the X space. """ - x, log_q = self.backward_pass(z, rescale=not self.use_x_prime_prior) + x, log_q, z = self.backward_pass( + z, + rescale=not self.use_x_prime_prior, + discard_nans=False, + return_z=True, + ) if not x.size: return np.array([]), x - if self.truncate: - if worst_q is None: - raise ValueError( - "`worst_q` is None but truncation is enabled." - ) - cut = log_q >= worst_q - x = x[cut] - z = z[cut] - log_q = log_q[cut] + if min_log_q: + above = log_q >= min_log_q + x = x[above] + z = z[above] + log_q = log_q[above] + else: + valid = np.isfinite(log_q) + x, z, log_q = get_subset_arrays(valid, x, z, log_q) # rescale given priors used initially, need for priors log_w = self.compute_weights(x, log_q) @@ -1311,6 +1326,8 @@ def prep_latent_prior(self): fuzz=self.fuzz, ) self._draw_func = self._populate_dist.sample + elif self.latent_prior == "flow": + self._draw_func = lambda N: self.flow.sample_latent_distribution(N) else: self._draw_func = partial( self._draw_latent_prior, @@ -1347,25 +1364,29 @@ def populate(self, worst_point, N=10000, plot=True, r=None): ) if r is not None: logger.debug(f"Using user inputs for radius {r}") - worst_q = None elif self.fixed_radius: r = self.fixed_radius - worst_q = None else: logger.debug(f"Populating with worst point: {worst_point}") if self.compute_radius_with_all: logger.debug("Using previous live points to compute radius") - worst_point = self.training_data - worst_z, worst_q = self.forward_pass( + worst_z = self.forward_pass( worst_point, rescale=True, compute_radius=True ) - r, worst_q = self.radius(worst_z, worst_q) + r = self.radius(worst_z) if self.max_radius and r > self.max_radius: r = self.max_radius if self.min_radius and r < self.min_radius: r = self.min_radius - logger.debug(f"Populating proposal with lantent radius: {r:.5}") + if self.truncate_log_q: + log_q_live_points = self.forward_pass(self.training_data)[1] + min_log_q = log_q_live_points.min() + logger.debug("Truncating with log_q={min_log_q:.3f}") + else: + min_log_q = None + + logger.debug(f"Populating proposal with latent radius: {r:.5}") self.r = r self.alt_dist = self.get_alt_distribution() @@ -1390,7 +1411,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): z = self.draw_latent_prior(self.drawsize) proposed += z.shape[0] - z, x = self.rejection_sampling(z, worst_q) + z, x = self.rejection_sampling(z, min_log_q=min_log_q) if not x.size: continue From ebff61d18ecca5b23e4f10bac44bcf75f90acb28 Mon Sep 17 00:00:00 2001 From: mj-will Date: Wed, 25 Oct 2023 09:54:18 +0100 Subject: [PATCH 06/12] fix: fix missing f-string --- nessai/proposal/flowproposal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index b2192d98..50363067 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1382,7 +1382,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): if self.truncate_log_q: log_q_live_points = self.forward_pass(self.training_data)[1] min_log_q = log_q_live_points.min() - logger.debug("Truncating with log_q={min_log_q:.3f}") + logger.debug(f"Truncating with log_q={min_log_q:.3f}") else: min_log_q = None From 697130c0fbcf905e26fe060d5ff3aebda2065c89 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 15:32:48 +0000 Subject: [PATCH 07/12] test: update existing flowproposal tests --- .../test_flowproposal_flow.py | 4 +- .../test_flowproposal_population.py | 44 ++++++++----------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py b/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py index 53fcb8fd..d3943bb7 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py @@ -86,7 +86,9 @@ def test_backward_pass(proposal, model, log_p): ) proposal.rescaled_names = model.names proposal.alt_dist = None - proposal.check_prior_bounds = MagicMock(side_effect=lambda a, b: (a, b)) + proposal.check_prior_bounds = MagicMock( + side_effect=lambda a, b, c: (a, b, c) + ) proposal.flow = MagicMock() proposal.flow.sample_and_log_prob = MagicMock(return_value=[x, log_p]) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index 6021796c..9ba8d7f3 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -108,7 +108,7 @@ def test_rejection_sampling(proposal, z, x, log_q): """Test rejection sampling method.""" proposal.use_x_prime_prior = False proposal.truncate = False - proposal.backward_pass = MagicMock(return_value=(x, log_q)) + proposal.backward_pass = MagicMock(return_value=(x, log_q, z)) log_w = np.log(np.array([0.5, 0.5])) proposal.compute_weights = MagicMock(return_value=log_w) @@ -129,7 +129,7 @@ def test_rejection_sampling_empty(proposal, z): proposal.use_x_prime_prior = False proposal.truncate = False proposal.backward_pass = MagicMock( - return_value=(np.array([]), np.array([])) + return_value=(np.array([]), np.array([]), np.array([])) ) z_out, x_out = FlowProposal.rejection_sampling(proposal, z) @@ -144,13 +144,15 @@ def test_rejection_sampling_truncate(proposal, z, x): proposal.use_x_prime_prior = False proposal.truncate = True log_q = np.array([0.0, 1.0]) - proposal.backward_pass = MagicMock(return_value=(x, log_q)) - worst_q = 0.5 + proposal.backward_pass = MagicMock(return_value=(x, log_q, z)) + min_log_q = 0.5 log_w = np.log(np.array([0.5])) proposal.compute_weights = MagicMock(return_value=log_w) z_out, x_out = FlowProposal.rejection_sampling( - proposal, z, worst_q=worst_q + proposal, + z, + min_log_q=min_log_q, ) assert proposal.backward_pass.called_once_with(x, True) @@ -161,18 +163,6 @@ def test_rejection_sampling_truncate(proposal, z, x): assert np.array_equal(z_out[0], z[1]) -def test_rejection_sampling_truncate_missing_q(proposal, z, x, log_q): - """Test rejection sampling method with truncation without without q""" - proposal.use_x_prime_prior = False - proposal.truncate = True - log_q = np.array([0.0, 1.0]) - proposal.backward_pass = MagicMock(return_value=(x, log_q)) - - with pytest.raises(ValueError) as excinfo: - FlowProposal.rejection_sampling(proposal, z, worst_q=None) - assert "`worst_q` is None but truncation is enabled" in str(excinfo.value) - - def test_compute_acceptance(proposal): """Test the compute_acceptance method""" proposal.samples = np.arange(1, 11, dtype=float).view([("logL", "f8")]) @@ -264,7 +254,7 @@ def test_radius(proposal): z = np.array([[1, 2, 3], [0, 1, 2]]) expected_r = np.sqrt(14) r = FlowProposal.radius(proposal, z) - assert r == expected_r + np.testing.assert_equal(r, expected_r) def test_radius_w_log_q(proposal): @@ -272,7 +262,7 @@ def test_radius_w_log_q(proposal): z = np.array([[1, 2, 3], [0, 1, 2]]) log_q = np.array([1, 2]) expected_r = np.sqrt(14) - r, log_q_r = FlowProposal.radius(proposal, z, log_q=log_q) + r, log_q_r = FlowProposal.radius(proposal, z, log_q) assert r == expected_r assert log_q_r == log_q[0] @@ -361,7 +351,6 @@ def test_populate( [[1, 2, 3]], dtype=[("x", "f8"), ("y", "f8"), ("logL", "f8")] ) worst_z = np.random.randn(1, n_dims) - worst_q = np.random.randn(1) if r is None else None z = [ np.random.randn(drawsize, n_dims), np.random.randn(drawsize, n_dims), @@ -376,6 +365,8 @@ def test_populate( r_flow = 1.0 + min_log_q = None + if r is None: r_out = r_flow if min_radius is not None: @@ -401,9 +392,10 @@ def test_populate( proposal._plot_pool = True proposal.populated_count = 1 proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.truncate_log_q = False - proposal.forward_pass = MagicMock(return_value=(worst_z, worst_q)) - proposal.radius = MagicMock(return_value=(r_flow, worst_q)) + proposal.forward_pass = MagicMock(return_value=(worst_z, np.nan)) + proposal.radius = MagicMock(return_value=r_flow) proposal.get_alt_distribution = MagicMock(return_value=None) proposal.prep_latent_prior = MagicMock() proposal.draw_latent_prior = MagicMock(side_effect=z) @@ -439,7 +431,7 @@ def test_populate( rescale=True, compute_radius=True, ) - proposal.radius.assert_called_once_with(worst_z, worst_q) + proposal.radius.assert_called_once_with(worst_z) else: assert proposal.r is r @@ -451,9 +443,9 @@ def test_populate( proposal.draw_latent_prior.assert_has_calls(draw_calls) rejection_calls = [ - call(z[0], worst_q), - call(z[1], worst_q), - call(z[2], worst_q), + call(z[0], min_log_q=min_log_q), + call(z[1], min_log_q=min_log_q), + call(z[2], min_log_q=min_log_q), ] proposal.rejection_sampling.assert_has_calls(rejection_calls) From 46a66e57610a18601e53d8d4ae088406cbd6b8d9 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 15:33:24 +0000 Subject: [PATCH 08/12] fix: do not return tuple for a single input --- nessai/proposal/flowproposal.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 50363067..078fd0f1 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1164,7 +1164,10 @@ def radius(self, z, *arrays): """ r = np.sqrt(np.sum(z**2.0, axis=-1)) i = np.nanargmax(r) - return (r[i],) + (a[i] for a in arrays) + if arrays: + return (r[i],) + tuple(a[i] for a in arrays) + else: + return r[i] def log_prior(self, x): """ @@ -1372,7 +1375,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): logger.debug("Using previous live points to compute radius") worst_z = self.forward_pass( worst_point, rescale=True, compute_radius=True - ) + )[0] r = self.radius(worst_z) if self.max_radius and r > self.max_radius: r = self.max_radius From 12f67fb4178218d9f63c519a16e8240dd63971ab Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 15:55:00 +0000 Subject: [PATCH 09/12] fix: fix bug with compute radius with all --- nessai/proposal/flowproposal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 078fd0f1..7e544e24 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1373,6 +1373,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): logger.debug(f"Populating with worst point: {worst_point}") if self.compute_radius_with_all: logger.debug("Using previous live points to compute radius") + worst_point = self.training_data worst_z = self.forward_pass( worst_point, rescale=True, compute_radius=True )[0] From 88d2b5d85a670b91279f58bec097c507aac523ae Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 16:14:16 +0000 Subject: [PATCH 10/12] test: add tests for truncation changes --- .../test_flowproposal_flow.py | 14 ++- .../test_flowproposal_population.py | 96 +++++++++++++++++++ 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py b/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py index d3943bb7..22c43c1e 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_flow.py @@ -75,10 +75,14 @@ def test_forward_pass(proposal, model, n): @pytest.mark.parametrize("log_p", [np.ones(2), np.array([-1, np.inf])]) -def test_backward_pass(proposal, model, log_p): +@pytest.mark.parametrize("discard_nans", [False, True]) +def test_backward_pass(proposal, model, log_p, discard_nans): """Test the forward pass method""" n = 2 - acc = int(np.isfinite(log_p).sum()) + if discard_nans: + acc = int(np.isfinite(log_p).sum()) + else: + acc = len(log_p) x = np.random.randn(n, model.dims) z = np.random.randn(n, model.dims) proposal.inverse_rescale = MagicMock( @@ -92,7 +96,11 @@ def test_backward_pass(proposal, model, log_p): proposal.flow = MagicMock() proposal.flow.sample_and_log_prob = MagicMock(return_value=[x, log_p]) - x_out, log_p = FlowProposal.backward_pass(proposal, z) + x_out, log_p = FlowProposal.backward_pass( + proposal, + z, + discard_nans=discard_nans, + ) assert len(x_out) == acc proposal.inverse_rescale.assert_called_once() diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index 9ba8d7f3..77ae388c 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -325,6 +325,15 @@ def draw(dims, N=None, r=None, fuzz=None): assert proposal._draw_func(N=10).shape == (10, 2) +def test_prep_latent_prior_flow(proposal): + proposal.latent_prior = "flow" + proposal.flow = MagicMock() + proposal.flow.sample_latent_distribution = MagicMock() + FlowProposal.prep_latent_prior(proposal) + proposal._draw_func(10) + proposal.flow.sample_latent_distribution.assert_called_once_with(10) + + def test_draw_latent_prior(proposal): proposal._draw_func = MagicMock(return_value=[1, 2]) out = FlowProposal.draw_latent_prior(proposal, 2) @@ -479,3 +488,90 @@ def test_populate_not_initialised(proposal): with pytest.raises(RuntimeError) as excinfo: FlowProposal.populate(proposal, 1.0) assert "Proposal has not been initialised. " in str(excinfo.value) + + +def test_populate_truncate_log_q(proposal): + n_dims = 2 + nlive = 8 + poolsize = 10 + drawsize = 5 + names = ["x", "y"] + r_flow = 2.0 + worst_point = np.array( + [[1, 2, 3]], dtype=[("x", "f8"), ("y", "f8"), ("logL", "f8")] + ) + z = [ + np.random.randn(drawsize, n_dims), + np.random.randn(drawsize, n_dims), + np.random.randn(drawsize, n_dims), + ] + x = [ + numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), + numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), + numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), + ] + log_l = np.random.rand(poolsize) + + proposal.initialised = True + proposal.dims = n_dims + proposal.poolsize = poolsize + proposal.drawsize = drawsize + proposal.fuzz = 1.0 + proposal.indices = None + proposal.acceptance = [0.7] + proposal.keep_samples = False + proposal.fixed_radius = 2.0 + proposal.compute_radius_with_all = False + proposal.check_acceptance = False + proposal._plot_pool = False + proposal.populated_count = 1 + proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.truncate_log_q = True + proposal.training_data = numpy_array_to_live_points( + np.random.randn(nlive, n_dims), + names=names, + ) + + log_q_live = np.log(np.random.rand(nlive)) + min_log_q = log_q_live.min() + + proposal.forward_pass = MagicMock( + return_value=(nlive * [None], log_q_live) + ) + proposal.radius = MagicMock(return_value=r_flow) + proposal.get_alt_distribution = MagicMock(return_value=None) + proposal.prep_latent_prior = MagicMock() + proposal.draw_latent_prior = MagicMock(side_effect=z) + proposal.rejection_sampling = MagicMock( + side_effect=[(a[:-1], b[:-1]) for a, b in zip(z, x)] + ) + proposal.compute_acceptance = MagicMock(return_value=0.8) + proposal.model = MagicMock() + proposal.model.batch_evaluate_log_likelihood = MagicMock( + return_value=log_l + ) + + proposal.convert_to_samples = MagicMock( + side_effect=lambda *args, **kwargs: args[0] + ) + + x_empty = np.empty(poolsize, dtype=proposal.population_dtype) + with patch( + "nessai.proposal.flowproposal.empty_structured_array", + return_value=x_empty, + ) as mock_empty: + FlowProposal.populate(proposal, worst_point, N=10, plot=False) + + mock_empty.assert_called_once_with( + poolsize, + dtype=proposal.population_dtype, + ) + + proposal.forward_pass.assert_called_once_with(proposal.training_data) + + rejection_calls = [ + call(z[0], min_log_q=min_log_q), + call(z[1], min_log_q=min_log_q), + call(z[2], min_log_q=min_log_q), + ] + proposal.rejection_sampling.assert_has_calls(rejection_calls) From b4bbeb00dde249c992b72f3734bc3f64b3ff65b3 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 16:40:52 +0000 Subject: [PATCH 11/12] test: tests for latent_prior='flow' --- tests/test_flows/test_base_flow.py | 1 + .../test_flowproposal/test_flowproposal_configuration.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_flows/test_base_flow.py b/tests/test_flows/test_base_flow.py index 7b056230..601f4424 100644 --- a/tests/test_flows/test_base_flow.py +++ b/tests/test_flows/test_base_flow.py @@ -43,6 +43,7 @@ def test_base_flow_abstract_methods(): "base_distribution_log_prob", "forward_and_log_prob", "sample_and_log_prob", + "sample_latent_distribution", ], ) def test_base_flow_methods(method, flow): diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_configuration.py b/tests/test_proposal/test_flowproposal/test_flowproposal_configuration.py index ed8bb0a6..13caf889 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_configuration.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_configuration.py @@ -92,6 +92,7 @@ def test_configure_plotting(proposal, plot, plot_pool, plot_train): ("uniform", "draw_uniform"), ("uniform_nsphere", "draw_nsphere"), ("uniform_nball", "draw_nsphere"), + ("flow", None), ], ) def test_configure_latent_prior(proposal, latent_prior, prior_func): @@ -99,7 +100,10 @@ def test_configure_latent_prior(proposal, latent_prior, prior_func): proposal.latent_prior = latent_prior proposal.flow_config = {"model_config": {}} FlowProposal.configure_latent_prior(proposal) - assert proposal._draw_latent_prior == getattr(utils, prior_func) + if prior_func: + assert proposal._draw_latent_prior == getattr(utils, prior_func) + else: + assert proposal._draw_latent_prior is None def test_configure_latent_prior_unknown(proposal): From a1dcbb28dfe1a7218b4932c5144116e826b97883 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 Nov 2023 16:44:41 +0000 Subject: [PATCH 12/12] test: add sampling integration test for truncate_log_q --- tests/test_sampling/test_standard_sampling.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 0da1c3db..1002659c 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -459,6 +459,24 @@ def test_constant_volume_mode(integration_model, tmpdir): fs.run(plot=False) +@pytest.mark.slow_integration_test +def test_truncate_log_q(integration_model, tmpdir): + """Test sampling with truncate_log_q""" + output = str(tmpdir.mkdir("test")) + fs = FlowSampler( + integration_model, + output=output, + nlive=500, + plot=False, + proposal_plots=False, + constant_volume_mode=False, + latent_prior="flow", + truncate_log_q=True, + ) + fs.run(plot=False) + assert fs.ns.finalised + + @pytest.mark.slow_integration_test def test_prior_sampling(integration_model, tmpdir): """Test prior sampling"""