Skip to content

Commit

Permalink
Merge pull request #89 from lincc-frameworks/numerical_inverse_sampli…
Browse files Browse the repository at this point in the history
…ng_node

Adapt HostmassX1Func to use random node infrastructure
  • Loading branch information
jeremykubica authored Aug 29, 2024
2 parents aa81e3f + 3ca35fb commit 4a87261
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 27 deletions.
32 changes: 6 additions & 26 deletions src/tdastro/astro_utils/snia_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from astropy.cosmology import FlatLambdaCDM
from scipy.stats import norm
from scipy.stats.sampling import NumericalInversePolynomial

from tdastro.base_models import FunctionNode
from tdastro.util_nodes.scipy_random import NumericalInversePolynomialFunc


class HostmassX1Distr:
Expand Down Expand Up @@ -66,26 +66,6 @@ def pdf(self, x1):
return self._p(x1, hostmass=self.hostmass) * norm.pdf(x1, loc=0, scale=1)


def _hostmass_x1func(hostmass):
"""Sample x1 as a function of hostmass.
Parameters
----------
hostmass : `float`
The hostmass value.
Returns
-------
x1 : `float`
The x1 parameter in the SALT3 model
"""

dist = HostmassX1Distr(hostmass)
x1 = NumericalInversePolynomial(dist).rvs(1)[0]

return x1


def _x0_from_distmod(distmod, x1, c, alpha, beta, m_abs):
"""Calculate the SALT3 x0 parameter given distance modulus based on Tripp relation.
distmod = -2.5*log10(x0) + alpha * x1 - beta * c - m_abs
Expand Down Expand Up @@ -140,23 +120,23 @@ def _distmod_from_redshift(redshift, H0=73.0, Omega_m=0.3):
return distmod


class HostmassX1Func(FunctionNode):
"""A wrapper class for the _hostmass_x1func() function.
class HostmassX1Func(NumericalInversePolynomialFunc):
"""A class for sampling from the HostmassX1Distr.
Parameters
----------
hostmass : function or constant
The function or constant providing the hostmass value.
skewness : constant
Skewness parameter that defines the skewed normal distribution.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, hostmass, **kwargs):
# Call the super class's constructor with the needed information.
# We use the HostmassX1Distr class so a new instance will be created
# each sample.
super().__init__(
func=_hostmass_x1func,
dist=HostmassX1Distr,
hostmass=hostmass,
**kwargs,
)
Expand Down
32 changes: 32 additions & 0 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,35 @@ def extract_single_sample(self, sample_num):
else:
new_state.states[node_name][var_name] = value[sample_num]
return new_state


def transpose_dict_of_list(input_dict, num_elem):
"""Transpose a dictionary of iterables to a list of dictionaries.
Parameters
----------
input_dict : `dict`
A dictionary of iterables, each of which is length num_elem.
num_elem : `int`
The length of the iterables.
Returns
-------
output_list : `list`
A length num_elem list of dictionaries, each with the same keys mapping
to a single value.
Raises
------
``ValueError`` if any of the iterables have different lengths.
"""
if num_elem < 1:
raise ValueError(f"Trying to transpose a dictionary with {num_elem} elements")

output_list = [{} for _ in range(num_elem)]
for key, values in input_dict.items():
if len(values) != num_elem:
raise ValueError(f"Entry {key} has length {len(values)}. Expected {num_elem}.")
for i in range(num_elem):
output_list[i][key] = values[i]
return output_list
167 changes: 167 additions & 0 deletions src/tdastro/util_nodes/scipy_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Wrapper classes for some of scipy's sampling functions."""

from os import urandom

import numpy as np
from scipy.stats.sampling import NumericalInversePolynomial

from tdastro.base_models import FunctionNode
from tdastro.graph_state import transpose_dict_of_list


class NumericalInversePolynomialFunc(FunctionNode):
"""A class for sampling from scipy's NumericalInversePolynomial
given an distribution object or a class from which to create
such an object.
Note
----
If a class is provided, then the sampling function will create a new
object (with the sampled parameters) for each sampling.
Attributes
----------
_dist : object or class
An object or class with either a pdf() or logpdf() method that defines
the distribution from which to sample.
_inv_poly: `scipy.stats.sampling.NumericalInversePolynomial`
The scipy object to use for sampling. Set to ``None`` if _dist is a class.
_vect_sample : `numpy.vectorize`
The vectorized function to create a distribution from a class and sample it.
Set to ``None`` if _dist is an object.
_rng : `numpy.random._generator.Generator`
This object's random number generator.
Parameters
----------
dist : object or class
An object or class with either a pdf() or logpdf() method that defines
the distribution from which to sample.
seed : `int`, optional
The seed to use.
"""

def __init__(self, dist=None, seed=None, **kwargs):
# Check that the distribution object/class has a pdf or logpdf function.
if not hasattr(dist, "pdf") and not hasattr(dist, "logpdf"):
raise ValueError("Distribution must have either pdf() or logpdf().")
self._dist = dist

# Classes show up as type="type"
if isinstance(dist, type):
self._inv_poly = None
self._vect_sample = np.vectorize(self._create_and_sample)
else:
self._inv_poly = NumericalInversePolynomial(self._dist)
self._vect_sample = None

# Get a default random number generator for this object, using the
# given seed if one is provided.
if seed is None:
seed = int.from_bytes(urandom(4), "big")
self._rng = np.random.default_rng(seed=seed)

# Set the function and add all the kwargs as parameters.
super().__init__(self._rvs, **kwargs)

def _rvs(self):
"""A place holder function to use for object naming."""
pass

def set_seed(self, new_seed):
"""Update the random number generator's seed to a given value.
Parameters
----------
new_seed : `int`
The given seed
"""
self._rng = np.random.default_rng(seed=new_seed)

def _create_and_sample(self, args, rng):
"""Create the distribution function and sample it.
Parameters
----------
args : `dict`
A dictionary mapping argument name to individual values.
rng : `numpy.random._generator.Generator`
The random number generator to use.
Returns
-------
sample : `float`
The result of sampling the function.
"""
dist = self._dist(**args)
sample = NumericalInversePolynomial(dist).rvs(1, rng)[0]
return sample

def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
"""Execute the wrapped function.
The input arguments are taken from the current graph_state and the outputs
are written to graph_state.
Parameters
----------
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
rng_info : `dict`, optional
A dictionary of random number generator information for each node, such as
the JAX keys or the numpy rngs.
**kwargs : `dict`, optional
Additional function arguments.
Returns
-------
results : any
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
rng = rng_info if rng_info is not None else self._rng

if self._inv_poly is not None:
# Batch sample all the results.
num_samples = None if graph_state.num_samples == 1 else graph_state.num_samples
results = self._inv_poly.rvs(num_samples, rng)
else:
# This is a class so we will need to create a new distribution object
# for each sample (with a single instance of the input parameters).
args = self._build_inputs(graph_state, given_args, **kwargs)

if graph_state.num_samples == 1:
dist = self._dist(**args)
results = NumericalInversePolynomial(dist).rvs(1, rng)[0]
else:
# Transpose the dict of arrays to a list of dicts.
arg_list = transpose_dict_of_list(args, graph_state.num_samples)
results = self._vect_sample(arg_list, rng)

# Save and return the results.
self._save_results(results, graph_state)
return results

def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
"""A helper function for testing that regenerates the output.
Parameters
----------
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : `int`
A count of the number of samples to compute.
Default: 1
rng_info : `dict`, optional
A dictionary of random number generator information for each node, such as
the JAX keys or the numpy rngs.
**kwargs : `dict`, optional
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, given_args, rng_info, **kwargs)
45 changes: 45 additions & 0 deletions tests/tdastro/astro_utils/test_snia_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
from tdastro.astro_utils.snia_utils import HostmassX1Func
from tdastro.util_nodes.np_random import NumpyRandomFunc


def test_sample_hostmass_x1c():
"""Test that we can sample correctly from HostmassX1Func."""
num_samples = 5

hm_node1 = HostmassX1Func(
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=100),
seed=101,
)
states1 = hm_node1.sample_parameters(num_samples=num_samples)
values1 = hm_node1.get_param(states1, "function_node_result")
assert len(values1) == num_samples
assert len(np.unique(values1)) == num_samples

# If we create a new node with the same hostmas and the same seeds, we get the
# same results and the same hostmasses.
hm_node2 = HostmassX1Func(
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=100),
seed=101,
)
states2 = hm_node2.sample_parameters(num_samples=num_samples)
values2 = hm_node2.get_param(states2, "function_node_result")
assert np.allclose(values1, values2)
assert np.allclose(
hm_node1.get_param(states1, "hostmass"),
hm_node2.get_param(states2, "hostmass"),
)

# If we use a different seed for the hostmass function only, we get
# different results (i.e. the hostmass parameter is being resampled).
hm_node3 = HostmassX1Func(
hostmass=NumpyRandomFunc("uniform", low=7, high=12, seed=102),
seed=101,
)
states3 = hm_node3.sample_parameters(num_samples=num_samples)
values3 = hm_node3.get_param(states3, "function_node_result")
assert not np.allclose(values1, values3)
assert not np.allclose(
hm_node1.get_param(states1, "hostmass"),
hm_node3.get_param(states3, "hostmass"),
)
26 changes: 25 additions & 1 deletion tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from tdastro.graph_state import GraphState
from tdastro.graph_state import GraphState, transpose_dict_of_list


def test_create_single_sample_graph_state():
Expand Down Expand Up @@ -117,3 +117,27 @@ def test_create_multi_sample_graph_state_reference():
state2["a"]["v2"][2] = 5.0
assert np.allclose(state2["a"]["v2"], [2.0, 2.5, 5.0, 3.5, 4.0])
assert np.allclose(state2["b"]["v1"], [2.0, 2.5, 3.0, 3.5, 4.0])


def test_transpose_dict_of_list():
"""Test the transpose_dict_of_list helper function"""
input_dict = {
"a": [1, 2, 3],
"b": [4, 5, 6],
"c": [7, 8, 9],
}
expected = [
{"a": 1, "b": 4, "c": 7},
{"a": 2, "b": 5, "c": 8},
{"a": 3, "b": 6, "c": 9},
]
output_list = transpose_dict_of_list(input_dict, 3)
assert len(output_list) == 3
for i in range(3):
assert expected[i] == output_list[i]

# We fail if num_elem does not match the list lengths.
with pytest.raises(ValueError):
_ = transpose_dict_of_list(input_dict, 0)
with pytest.raises(ValueError):
_ = transpose_dict_of_list(input_dict, 4)
Loading

0 comments on commit 4a87261

Please sign in to comment.