-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from lincc-frameworks/numerical_inverse_sampli…
…ng_node Adapt HostmassX1Func to use random node infrastructure
- Loading branch information
Showing
6 changed files
with
388 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
"""Wrapper classes for some of scipy's sampling functions.""" | ||
|
||
from os import urandom | ||
|
||
import numpy as np | ||
from scipy.stats.sampling import NumericalInversePolynomial | ||
|
||
from tdastro.base_models import FunctionNode | ||
from tdastro.graph_state import transpose_dict_of_list | ||
|
||
|
||
class NumericalInversePolynomialFunc(FunctionNode): | ||
"""A class for sampling from scipy's NumericalInversePolynomial | ||
given an distribution object or a class from which to create | ||
such an object. | ||
Note | ||
---- | ||
If a class is provided, then the sampling function will create a new | ||
object (with the sampled parameters) for each sampling. | ||
Attributes | ||
---------- | ||
_dist : object or class | ||
An object or class with either a pdf() or logpdf() method that defines | ||
the distribution from which to sample. | ||
_inv_poly: `scipy.stats.sampling.NumericalInversePolynomial` | ||
The scipy object to use for sampling. Set to ``None`` if _dist is a class. | ||
_vect_sample : `numpy.vectorize` | ||
The vectorized function to create a distribution from a class and sample it. | ||
Set to ``None`` if _dist is an object. | ||
_rng : `numpy.random._generator.Generator` | ||
This object's random number generator. | ||
Parameters | ||
---------- | ||
dist : object or class | ||
An object or class with either a pdf() or logpdf() method that defines | ||
the distribution from which to sample. | ||
seed : `int`, optional | ||
The seed to use. | ||
""" | ||
|
||
def __init__(self, dist=None, seed=None, **kwargs): | ||
# Check that the distribution object/class has a pdf or logpdf function. | ||
if not hasattr(dist, "pdf") and not hasattr(dist, "logpdf"): | ||
raise ValueError("Distribution must have either pdf() or logpdf().") | ||
self._dist = dist | ||
|
||
# Classes show up as type="type" | ||
if isinstance(dist, type): | ||
self._inv_poly = None | ||
self._vect_sample = np.vectorize(self._create_and_sample) | ||
else: | ||
self._inv_poly = NumericalInversePolynomial(self._dist) | ||
self._vect_sample = None | ||
|
||
# Get a default random number generator for this object, using the | ||
# given seed if one is provided. | ||
if seed is None: | ||
seed = int.from_bytes(urandom(4), "big") | ||
self._rng = np.random.default_rng(seed=seed) | ||
|
||
# Set the function and add all the kwargs as parameters. | ||
super().__init__(self._rvs, **kwargs) | ||
|
||
def _rvs(self): | ||
"""A place holder function to use for object naming.""" | ||
pass | ||
|
||
def set_seed(self, new_seed): | ||
"""Update the random number generator's seed to a given value. | ||
Parameters | ||
---------- | ||
new_seed : `int` | ||
The given seed | ||
""" | ||
self._rng = np.random.default_rng(seed=new_seed) | ||
|
||
def _create_and_sample(self, args, rng): | ||
"""Create the distribution function and sample it. | ||
Parameters | ||
---------- | ||
args : `dict` | ||
A dictionary mapping argument name to individual values. | ||
rng : `numpy.random._generator.Generator` | ||
The random number generator to use. | ||
Returns | ||
------- | ||
sample : `float` | ||
The result of sampling the function. | ||
""" | ||
dist = self._dist(**args) | ||
sample = NumericalInversePolynomial(dist).rvs(1, rng)[0] | ||
return sample | ||
|
||
def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): | ||
"""Execute the wrapped function. | ||
The input arguments are taken from the current graph_state and the outputs | ||
are written to graph_state. | ||
Parameters | ||
---------- | ||
graph_state : `GraphState` | ||
An object mapping graph parameters to their values. This object is modified | ||
in place as it is sampled. | ||
given_args : `dict`, optional | ||
A dictionary representing the given arguments for this sample run. | ||
This can be used as the JAX PyTree for differentiation. | ||
rng_info : `dict`, optional | ||
A dictionary of random number generator information for each node, such as | ||
the JAX keys or the numpy rngs. | ||
**kwargs : `dict`, optional | ||
Additional function arguments. | ||
Returns | ||
------- | ||
results : any | ||
The result of the computation. This return value is provided so that testing | ||
functions can easily access the results. | ||
""" | ||
rng = rng_info if rng_info is not None else self._rng | ||
|
||
if self._inv_poly is not None: | ||
# Batch sample all the results. | ||
num_samples = None if graph_state.num_samples == 1 else graph_state.num_samples | ||
results = self._inv_poly.rvs(num_samples, rng) | ||
else: | ||
# This is a class so we will need to create a new distribution object | ||
# for each sample (with a single instance of the input parameters). | ||
args = self._build_inputs(graph_state, given_args, **kwargs) | ||
|
||
if graph_state.num_samples == 1: | ||
dist = self._dist(**args) | ||
results = NumericalInversePolynomial(dist).rvs(1, rng)[0] | ||
else: | ||
# Transpose the dict of arrays to a list of dicts. | ||
arg_list = transpose_dict_of_list(args, graph_state.num_samples) | ||
results = self._vect_sample(arg_list, rng) | ||
|
||
# Save and return the results. | ||
self._save_results(results, graph_state) | ||
return results | ||
|
||
def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): | ||
"""A helper function for testing that regenerates the output. | ||
Parameters | ||
---------- | ||
given_args : `dict`, optional | ||
A dictionary representing the given arguments for this sample run. | ||
This can be used as the JAX PyTree for differentiation. | ||
num_samples : `int` | ||
A count of the number of samples to compute. | ||
Default: 1 | ||
rng_info : `dict`, optional | ||
A dictionary of random number generator information for each node, such as | ||
the JAX keys or the numpy rngs. | ||
**kwargs : `dict`, optional | ||
Additional function arguments. | ||
""" | ||
state = self.sample_parameters(given_args, num_samples, rng_info) | ||
return self.compute(state, given_args, rng_info, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
from tdastro.astro_utils.snia_utils import HostmassX1Func | ||
from tdastro.util_nodes.np_random import NumpyRandomFunc | ||
|
||
|
||
def test_sample_hostmass_x1c(): | ||
"""Test that we can sample correctly from HostmassX1Func.""" | ||
num_samples = 5 | ||
|
||
hm_node1 = HostmassX1Func( | ||
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=100), | ||
seed=101, | ||
) | ||
states1 = hm_node1.sample_parameters(num_samples=num_samples) | ||
values1 = hm_node1.get_param(states1, "function_node_result") | ||
assert len(values1) == num_samples | ||
assert len(np.unique(values1)) == num_samples | ||
|
||
# If we create a new node with the same hostmas and the same seeds, we get the | ||
# same results and the same hostmasses. | ||
hm_node2 = HostmassX1Func( | ||
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=100), | ||
seed=101, | ||
) | ||
states2 = hm_node2.sample_parameters(num_samples=num_samples) | ||
values2 = hm_node2.get_param(states2, "function_node_result") | ||
assert np.allclose(values1, values2) | ||
assert np.allclose( | ||
hm_node1.get_param(states1, "hostmass"), | ||
hm_node2.get_param(states2, "hostmass"), | ||
) | ||
|
||
# If we use a different seed for the hostmass function only, we get | ||
# different results (i.e. the hostmass parameter is being resampled). | ||
hm_node3 = HostmassX1Func( | ||
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=102), | ||
seed=101, | ||
) | ||
states3 = hm_node3.sample_parameters(num_samples=num_samples) | ||
values3 = hm_node3.get_param(states3, "function_node_result") | ||
assert not np.allclose(values1, values3) | ||
assert not np.allclose( | ||
hm_node1.get_param(states1, "hostmass"), | ||
hm_node3.get_param(states3, "hostmass"), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.