Skip to content

Commit

Permalink
Merge pull request #338 from mj-will/truncate-log-q
Browse files Browse the repository at this point in the history
Rework truncation based on log q
  • Loading branch information
mj-will authored Nov 22, 2023
2 parents d5cd2b7 + a1dcbb2 commit ea59056
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 73 deletions.
17 changes: 17 additions & 0 deletions nessai/flowmodel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,23 @@ def sample(self, n: int = 1) -> np.ndarray:
x = self.model.sample(int(n))
return x.cpu().numpy().astype(np.float64)

def sample_latent_distribution(self, n: int = 1) -> np.ndarray:
"""Sample from the latent distribution.
Parameters
----------
n : int
Number of samples to draw
Returns
-------
numpy.ndarray
Array of samples
"""
with torch.inference_mode():
z = self.model.sample_latent_distribution(n)
return z.cpu().numpy().astype(np.float64)

def sample_and_log_prob(self, N=1, z=None, alt_dist=None):
"""
Generate samples from samples drawn from the base distribution or
Expand Down
10 changes: 10 additions & 0 deletions nessai/flows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def log_prob(self, x, context=None):
"""
raise NotImplementedError()

@abstractmethod
def sample_latent_distribution(self, n, context=None):
"""Sample from the latent distribution."""
raise NotImplementedError

@abstractmethod
def base_distribution_log_prob(self, z, context=None):
"""
Expand Down Expand Up @@ -237,6 +242,11 @@ def log_prob(self, inputs, context=None):
log_prob = self._distribution.log_prob(noise)
return log_prob + logabsdet

def sample_latent_distribution(self, n, context=None):
if context is not None:
raise NotImplementedError
return self._distribution.sample(n)

def base_distribution_log_prob(self, z, context=None):
"""
Computes the log probability of samples in the latent for
Expand Down
109 changes: 67 additions & 42 deletions nessai/proposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
save_live_points,
)
from ..utils.sampling import NDimensionalTruncatedGaussian
from ..utils.structures import get_subset_arrays

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,9 +101,8 @@ class FlowProposal(RejectionProposal):
Similar to ``fuzz`` but instead a scaling factor applied to the radius
this specifies a rescaling for volume of the n-ball used to draw
samples. This is translated to a value for ``fuzz``.
truncate : bool, optional
Truncate proposals using probability compute for worst point.
Not recommended.
truncate_log_q : bool, optional
Truncate proposals using minimum log-probability of the training data.
rescale_parameters : list or bool, optional
If True live points are rescaled to `rescale_bounds` before training.
If an instance of `list` then must contain names of parameters to
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
fixed_radius=False,
drawsize=None,
check_acceptance=False,
truncate=False,
truncate_log_q=False,
rescale_bounds=[-1, 1],
expansion_fraction=4.0,
boundary_inversion=False,
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
self.update_bounds = update_bounds
self.check_acceptance = check_acceptance
self.rescale_bounds = rescale_bounds
self.truncate = truncate
self.truncate_log_q = truncate_log_q
self.boundary_inversion = boundary_inversion
self.inversion_type = inversion_type
self.flow_config = flow_config
Expand Down Expand Up @@ -420,6 +420,8 @@ def configure_latent_prior(self):
from ..utils import draw_nsphere

self._draw_latent_prior = draw_nsphere
elif self.latent_prior == "flow":
self._draw_latent_prior = None
else:
raise RuntimeError(
f"Unknown latent prior: {self.latent_prior}, choose from: "
Expand Down Expand Up @@ -1088,7 +1090,9 @@ def forward_pass(self, x, rescale=True, compute_radius=True):

return z, log_prob + log_J

def backward_pass(self, z, rescale=True):
def backward_pass(
self, z, rescale=True, discard_nans=True, return_z=False
):
"""
A backwards pass from the model (latent -> real)
Expand All @@ -1098,14 +1102,21 @@ def backward_pass(self, z, rescale=True):
Structured array of points in the latent space
rescale : bool, optional (True)
Apply inverse rescaling function
discard_nan: bool
If True, samples with NaNs or Infs in log_q are removed.
return_z : bool
If True, return the array of latent samples, this may differ from
the input since samples can be discarded.
Returns
-------
x : array_like
Samples in the latent space
Samples in the data space
log_prob : array_like
Log probabilities corresponding to each sample (including the
Jacobian)
z : array_like
Samples in the latent space, only returned if :code:`return_z=True`
"""
# Compute the log probability
try:
Expand All @@ -1115,8 +1126,9 @@ def backward_pass(self, z, rescale=True):
except AssertionError:
return np.array([]), np.array([])

valid = np.isfinite(log_prob)
x, log_prob = x[valid], log_prob[valid]
if discard_nans:
valid = np.isfinite(log_prob)
x, log_prob = x[valid], log_prob[valid]
x = numpy_array_to_live_points(
x.astype(config.livepoints.default_float_dtype),
self.rescaled_names,
Expand All @@ -1126,10 +1138,13 @@ def backward_pass(self, z, rescale=True):
x, log_J = self.inverse_rescale(x)
# Include Jacobian for the rescaling
log_prob -= log_J
x, log_prob = self.check_prior_bounds(x, log_prob)
return x, log_prob
x, z, log_prob = self.check_prior_bounds(x, z, log_prob)
if return_z:
return x, log_prob, z
else:
return x, log_prob

def radius(self, z, log_q=None):
def radius(self, z, *arrays):
"""
Calculate the radius of a latent point or set of latent points.
If multiple points are parsed the maximum radius is returned.
Expand All @@ -1138,22 +1153,21 @@ def radius(self, z, log_q=None):
----------
z : :obj:`np.ndarray`
Array of points in the latent space
log_q : :obj:`np.ndarray`, optional (None)
Array of corresponding probabilities. If specified
then probability of the maximum radius is also returned.
*arrays :
Additional arrays to return the corresponding value
Returns
-------
tuple of arrays
Tuple of array with the maximum radius and corresponding log_q
if it was a specified input.
Tuple of array with the maximum radius and corresponding values
from any additional arrays that were passed.
"""
if log_q is not None:
r = np.sqrt(np.sum(z**2.0, axis=-1))
i = np.argmax(r)
return r[i], log_q[i]
r = np.sqrt(np.sum(z**2.0, axis=-1))
i = np.nanargmax(r)
if arrays:
return (r[i],) + tuple(a[i] for a in arrays)
else:
return np.nanmax(np.sqrt(np.sum(z**2.0, axis=-1)))
return r[i]

def log_prior(self, x):
"""
Expand Down Expand Up @@ -1219,7 +1233,7 @@ def compute_weights(self, x, log_q):
log_w -= np.max(log_w)
return log_w

def rejection_sampling(self, z, worst_q=None):
def rejection_sampling(self, z, min_log_q=None):
"""
Perform rejection sampling.
Expand All @@ -1230,9 +1244,9 @@ def rejection_sampling(self, z, worst_q=None):
----------
z : ndarray
Samples from the latent space
worst_q : float, optional
min_log_q : float, optional
Lower bound on the log-probability computed using the flow that
is used to truncate new samples. Not recommended.
is used to truncate new samples.
Returns
-------
Expand All @@ -1241,20 +1255,24 @@ def rejection_sampling(self, z, worst_q=None):
array_like
Array of accepted samples in the X space.
"""
x, log_q = self.backward_pass(z, rescale=not self.use_x_prime_prior)
x, log_q, z = self.backward_pass(
z,
rescale=not self.use_x_prime_prior,
discard_nans=False,
return_z=True,
)

if not x.size:
return np.array([]), x

if self.truncate:
if worst_q is None:
raise ValueError(
"`worst_q` is None but truncation is enabled."
)
cut = log_q >= worst_q
x = x[cut]
z = z[cut]
log_q = log_q[cut]
if min_log_q:
above = log_q >= min_log_q
x = x[above]
z = z[above]
log_q = log_q[above]
else:
valid = np.isfinite(log_q)
x, z, log_q = get_subset_arrays(valid, x, z, log_q)

# rescale given priors used initially, need for priors
log_w = self.compute_weights(x, log_q)
Expand Down Expand Up @@ -1311,6 +1329,8 @@ def prep_latent_prior(self):
fuzz=self.fuzz,
)
self._draw_func = self._populate_dist.sample
elif self.latent_prior == "flow":
self._draw_func = lambda N: self.flow.sample_latent_distribution(N)
else:
self._draw_func = partial(
self._draw_latent_prior,
Expand Down Expand Up @@ -1347,25 +1367,30 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
)
if r is not None:
logger.debug(f"Using user inputs for radius {r}")
worst_q = None
elif self.fixed_radius:
r = self.fixed_radius
worst_q = None
else:
logger.debug(f"Populating with worst point: {worst_point}")
if self.compute_radius_with_all:
logger.debug("Using previous live points to compute radius")
worst_point = self.training_data
worst_z, worst_q = self.forward_pass(
worst_z = self.forward_pass(
worst_point, rescale=True, compute_radius=True
)
r, worst_q = self.radius(worst_z, worst_q)
)[0]
r = self.radius(worst_z)
if self.max_radius and r > self.max_radius:
r = self.max_radius
if self.min_radius and r < self.min_radius:
r = self.min_radius

logger.debug(f"Populating proposal with lantent radius: {r:.5}")
if self.truncate_log_q:
log_q_live_points = self.forward_pass(self.training_data)[1]
min_log_q = log_q_live_points.min()
logger.debug(f"Truncating with log_q={min_log_q:.3f}")
else:
min_log_q = None

logger.debug(f"Populating proposal with latent radius: {r:.5}")
self.r = r

self.alt_dist = self.get_alt_distribution()
Expand All @@ -1390,7 +1415,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
z = self.draw_latent_prior(self.drawsize)
proposed += z.shape[0]

z, x = self.rejection_sampling(z, worst_q)
z, x = self.rejection_sampling(z, min_log_q=min_log_q)

if not x.size:
continue
Expand Down
11 changes: 11 additions & 0 deletions tests/test_flowmodel/test_flowmodel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@ def test_sample(model):
np.testing.assert_array_equal(out, x.numpy())


def test_sample_latent_distribution(model):
"""Assert the correct method is called"""
n = 10
z = torch.randn(n, 2)
model.model = MagicMock()
model.model.sample_latent_distribution = MagicMock(return_value=z)
out = FlowModel.sample_latent_distribution(model, n)
model.model.sample_latent_distribution.assert_called_once_with(n)
np.testing.assert_array_equal(out, z.numpy())


def test_move_to_update_default(model):
"""Ensure the stored device is updated"""
model.device = "cuda"
Expand Down
15 changes: 15 additions & 0 deletions tests/test_flows/test_base_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Test the base flow class
"""
import pytest
import torch
from unittest.mock import MagicMock, create_autospec, patch

from nessai.flows.base import BaseFlow, NFlow
Expand Down Expand Up @@ -42,6 +43,7 @@ def test_base_flow_abstract_methods():
"base_distribution_log_prob",
"forward_and_log_prob",
"sample_and_log_prob",
"sample_latent_distribution",
],
)
def test_base_flow_methods(method, flow):
Expand Down Expand Up @@ -140,3 +142,16 @@ def test_nflow_unfreeze(nflow):
nflow._transform.requires_grad_ = MagicMock()
NFlow.unfreeze_transform(nflow)
nflow._transform.requires_grad_.assert_called_once_with(True)


def test_nflow_sample_latent_distribution(nflow):
n = 10
NFlow.sample_latent_distribution(nflow, n)
nflow._distribution.sample.assert_called_once_with(n)


def test_nflow_sample_latent_distribution_context(nflow):
n = 10
context = torch.randn(10)
with pytest.raises(NotImplementedError):
NFlow.sample_latent_distribution(nflow, n, context=context)
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ def test_configure_plotting(proposal, plot, plot_pool, plot_train):
("uniform", "draw_uniform"),
("uniform_nsphere", "draw_nsphere"),
("uniform_nball", "draw_nsphere"),
("flow", None),
],
)
def test_configure_latent_prior(proposal, latent_prior, prior_func):
"""Test to make sure the correct latent priors are used."""
proposal.latent_prior = latent_prior
proposal.flow_config = {"model_config": {}}
FlowProposal.configure_latent_prior(proposal)
assert proposal._draw_latent_prior == getattr(utils, prior_func)
if prior_func:
assert proposal._draw_latent_prior == getattr(utils, prior_func)
else:
assert proposal._draw_latent_prior is None


def test_configure_latent_prior_unknown(proposal):
Expand Down
Loading

0 comments on commit ea59056

Please sign in to comment.