From 806f7f2363507271dda3160196f4acf26815da8e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:40:13 -0400 Subject: [PATCH 1/6] Use readable name in GraphState --- src/tdastro/base_models.py | 30 ++++++++++---------- src/tdastro/sources/sncomso_models.py | 2 +- src/tdastro/util_nodes/jax_random.py | 2 +- tests/tdastro/sources/test_sncosmo_models.py | 3 +- tests/tdastro/sources/test_spline_source.py | 2 +- tests/tdastro/sources/test_static_source.py | 4 +-- tests/tdastro/test_base_models.py | 6 ++-- 7 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 8acb02e3..881490b9 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -94,6 +94,7 @@ def __init__(self, parameter_name, source_type=0, fixed=False, required=False, n self.required = required self.value = None self.dependency = None + self.node_name = node_name self.set_name(parameter_name, node_name) def set_name(self, parameter_name="", node_name=""): @@ -184,7 +185,7 @@ class ParameterizedNode: The full string used to identify a node. This is a combination of the nodes position in the graph (if known), node_label (if provided), and class information. node_hash : `int` - A hashed version of ``node_string`` used for fast lookups. + A precomputed hashed version of ``node_string``. setters : `dict` A dictionary mapping the parameters' names to information about the setters (ParameterSource). The model parameters are stored in the order in which they @@ -210,7 +211,6 @@ def __init__(self, node_label=None, **kwargs): self.node_label = node_label self.node_pos = None self.node_string = None - self.node_hash = None def __str__(self): """Return the string representation of the node.""" @@ -225,7 +225,7 @@ def _update_node_string(self, extra_tag=None): if self.node_label is not None: self.node_string = f"{pos_string}{self.node_label}" else: - self.node_string = f"{pos_string}{self.__class__.__module__}.{self.__class__.__qualname__}" + self.node_string = f"{pos_string}{self.__class__.__qualname__}" # Allow for the appending of an extra tag. if extra_tag is not None: @@ -290,7 +290,7 @@ def get_param(self, graph_state, name): """ if graph_state is None: raise ValueError(f"Unable to look ip parameter={name}. No graph_state given.") - return graph_state[self.node_hash][name] + return graph_state[self.node_string][name] def get_local_params(self, graph_state): """Get a dictionary of all parameters local to this node. @@ -317,7 +317,7 @@ def get_local_params(self, graph_state): """ if graph_state is None: raise ValueError("No graph_state given.") - return graph_state[self.node_hash] + return graph_state[self.node_string] def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -520,7 +520,7 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None if given_args is not None and setter.full_name in given_args: if setter.fixed: raise ValueError(f"Trying to override fixed parameter {setter.full_name}") - graph_state.set(self.node_hash, name, given_args[setter.full_name]) + graph_state.set(self.node_string, name, given_args[setter.full_name]) else: # Check if we need to sample this parameter's dependency node. if setter.dependency is not None and setter.dependency != self: @@ -528,18 +528,18 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None # Set the result from the correct source. if setter.source_type == ParameterSource.CONSTANT: - graph_state.set(self.node_hash, name, setter.value) + graph_state.set(self.node_string, name, setter.value) elif setter.source_type == ParameterSource.MODEL_PARAMETER: graph_state.set( - self.node_hash, + self.node_string, name, - graph_state[setter.dependency.node_hash][setter.value], + graph_state[setter.dependency.node_string][setter.value], ) elif setter.source_type == ParameterSource.FUNCTION_NODE: graph_state.set( - self.node_hash, + self.node_string, name, - graph_state[setter.dependency.node_hash][setter.value], + graph_state[setter.dependency.node_string][setter.value], ) elif setter.source_type == ParameterSource.COMPUTE_OUTPUT: # Computed parameters are set only after all the other (input) parameters. @@ -664,7 +664,7 @@ def build_pytree(self, graph_state, seen=None): all_values.update(setter_info.dependency.build_pytree(graph_state, seen)) elif setter_info.source_type == ParameterSource.CONSTANT and not setter_info.fixed: # Only the non-fixed, constants go into the PyTree. - all_values[setter_info.full_name] = graph_state[self.node_hash][name] + all_values[setter_info.full_name] = graph_state[self.node_string][name] return all_values @@ -796,7 +796,7 @@ def _build_inputs(self, graph_state, given_args=None, **kwargs): elif key in kwargs: args[key] = kwargs[key] else: - args[key] = graph_state[self.node_hash][key] + args[key] = graph_state[self.node_string][key] return args def _save_results(self, results, graph_state): @@ -811,7 +811,7 @@ def _save_results(self, results, graph_state): in place as it is sampled. """ if len(self.outputs) == 1: - graph_state.set(self.node_hash, self.outputs[0], results) + graph_state.set(self.node_string, self.outputs[0], results) else: if len(results) != len(self.outputs): raise ValueError( @@ -819,7 +819,7 @@ def _save_results(self, results, graph_state): f"Expected {len(self.outputs)}, but got {results}." ) for i in range(len(self.outputs)): - graph_state.set(self.node_hash, self.outputs[i], results[i]) + graph_state.set(self.node_string, self.outputs[i], results[i]) def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): """Execute the wrapped function. diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index bf60a57e..90d685e4 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -58,7 +58,7 @@ def parameter_values(self): def _update_sncosmo_model_parameters(self, graph_state): """Update the parameters for the wrapped sncosmo model.""" - local_params = graph_state.get_node_state(self.node_hash, 0) + local_params = graph_state.get_node_state(self.node_string, 0) sn_params = {} for name in self.source_param_names: sn_params[name] = local_params[name] diff --git a/src/tdastro/util_nodes/jax_random.py b/src/tdastro/util_nodes/jax_random.py index 4e7be0f9..70e39ebc 100644 --- a/src/tdastro/util_nodes/jax_random.py +++ b/src/tdastro/util_nodes/jax_random.py @@ -131,7 +131,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): else: use_shape = [graph_state.num_samples] results = self.func(current_key, shape=use_shape, **args) - graph_state.set(self.node_hash, self.outputs[0], results) + graph_state.set(self.node_string, self.outputs[0], results) return results def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index 95875e8a..06b357d8 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -9,7 +9,7 @@ def test_sncomso_models_hsiao() -> None: state = model.sample_parameters() assert model.get_param(state, "amplitude") == 2.0e10 assert model.get_param(state, "t0") == 0.0 - assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel" + assert str(model) == "0:SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameter_values, [2.0e10]) @@ -29,7 +29,6 @@ def test_sncomso_models_hsiao_t0() -> None: state = model.sample_parameters() assert model.get_param(state, "amplitude") == 2.0e10 assert model.get_param(state, "t0") == 55000.0 - assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameter_values, [2.0e10]) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py index 1a8a0ecc..d42f74f7 100644 --- a/tests/tdastro/sources/test_spline_source.py +++ b/tests/tdastro/sources/test_spline_source.py @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None: wavelengths = np.linspace(100.0, 500.0, 25) fluxes = np.full((len(times), len(wavelengths)), 1.0) model = SplineModel(times, wavelengths, fluxes) - assert str(model) == "tdastro.sources.spline_model.SplineModel" + assert str(model) == "SplineModel" test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index c1947850..050763df 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -52,10 +52,10 @@ def test_static_source_host() -> None: assert model.get_param(state, "ra") == 1.0 assert model.get_param(state, "dec") == 2.0 assert model.get_param(state, "distance") == 3.0 - assert str(model) == "0:tdastro.sources.static_source.StaticSource" + assert str(model) == "0:StaticSource" # Test that we have given a different name to the host. - assert str(host) == "1:tdastro.sources.static_source.StaticSource" + assert str(host) == "1:StaticSource" def test_static_source_resample() -> None: diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 5fe6397d..14a3ccb3 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -104,7 +104,7 @@ def test_parameterized_node(): """Test that we can sample and create a PairModel object.""" # Simple addition model1 = PairModel(value1=0.5, value2=0.5) - assert str(model1) == "test_base_models.PairModel" + assert str(model1) == "PairModel" state = model1.sample_parameters() assert model1.get_param(state, "value1") == 0.5 @@ -238,7 +238,7 @@ def test_parameterized_node_build_pytree(): def test_single_variable_node(): """Test that we can create and query a SingleVariableNode.""" node = SingleVariableNode("A", 10.0) - assert str(node) == "tdastro.base_models.SingleVariableNode" + assert str(node) == "SingleVariableNode" state = node.sample_parameters() assert node.get_param(state, "A") == 10 @@ -253,7 +253,7 @@ def test_function_node_basic(): assert my_func.compute(state, value2=3.0) == 4.0 assert my_func.compute(state, value2=3.0, unused_param=5.0) == 4.0 assert my_func.compute(state, value2=3.0, value1=1.0) == 4.0 - assert str(my_func) == "0:tdastro.base_models.FunctionNode:_test_func" + assert str(my_func) == "0:FunctionNode:_test_func" def test_function_node_chain(): From e686e826b5316448238e57f288ff9f5aed67787d Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:53:13 -0400 Subject: [PATCH 2/6] Convert pytree to use same format as GraphState --- src/tdastro/base_models.py | 31 +++++++++++++++++-------------- tests/tdastro/test_base_models.py | 24 ++++++++++-------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 881490b9..ac796eb8 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -627,7 +627,7 @@ def get_all_node_info(self, field, seen_nodes=None): result.extend(dep.get_all_node_info(field, seen_nodes)) return result - def build_pytree(self, graph_state, seen=None): + def build_pytree(self, graph_state, partial=None): """Build a JAX PyTree representation of the variables in this graph. Parameters @@ -635,14 +635,17 @@ def build_pytree(self, graph_state, seen=None): graph_state : `dict` A dictionary of dictionaries mapping node->hash, variable_name to value. This data structure is modified in place to represent the current state. - seen : `set` - A set of objects that have already been processed. + partial : `dict` + The partial results so far. This is modified in place by the function. + A dictionary mapping node name to a dictionary mapping each variable's name + to its value. Default : ``None`` + Returns ------- values : `dict` - The dictionary mapping the combination of the object identifier and - model parameter name to its value. + A dictionary mapping node name to a dictionary mapping each variable's name + to its value. """ # Check if the node might have incomplete information. if self.node_pos is None: @@ -652,20 +655,20 @@ def build_pytree(self, graph_state, seen=None): ) # Skip nodes that we have already seen. - if seen is None: - seen = set() - if self in seen: - return {} - seen.add(self) + if partial is None: + partial = {} + if self.node_string in partial: + return partial - all_values = {} + # Add new values to the pytree, recursively exploring dependencies. + partial[self.node_string] = {} for name, setter_info in self.setters.items(): if setter_info.dependency is not None: - all_values.update(setter_info.dependency.build_pytree(graph_state, seen)) + partial = setter_info.dependency.build_pytree(graph_state, partial) elif setter_info.source_type == ParameterSource.CONSTANT and not setter_info.fixed: # Only the non-fixed, constants go into the PyTree. - all_values[setter_info.full_name] = graph_state[self.node_string][name] - return all_values + partial[self.node_string][name] = graph_state[self.node_string][name] + return partial class SingleVariableNode(ParameterizedNode): diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 14a3ccb3..57032288 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -219,20 +219,19 @@ def test_parameterized_node_build_pytree(): model1 = PairModel(value1=0.5, value2=1.5, node_label="A") model2 = PairModel(value1=model1.value1, value2=3.0, node_label="B") graph_state = model2.sample_parameters() + pytree = model2.build_pytree(graph_state) - - assert len(pytree) == 3 - assert pytree["1:A.value1"] == 0.5 - assert pytree["1:A.value2"] == 1.5 - assert pytree["0:B.value2"] == 3.0 + assert pytree["1:A"]["value1"] == 0.5 + assert pytree["1:A"]["value2"] == 1.5 + assert pytree["0:B"]["value2"] == 3.0 # Manually set value2 to fixed and check that it no longer appears in the pytree. model1.setters["value2"].fixed = True pytree = model2.build_pytree(graph_state) - assert len(pytree) == 2 - assert pytree["1:A.value1"] == 0.5 - assert pytree["0:B.value2"] == 3.0 + assert pytree["1:A"]["value1"] == 0.5 + assert pytree["0:B"]["value2"] == 3.0 + assert "value2" not in pytree["1:A"] def test_single_variable_node(): @@ -347,14 +346,11 @@ def _test_func2(value1, value2): graph_state = sum_node.sample_parameters() pytree = sum_node.build_pytree(graph_state) - assert len(pytree) == 3 print(pytree) - gr_func = jax.value_and_grad(sum_node.resample_and_compute) values, gradients = gr_func(pytree) - assert len(gradients) == 3 print(gradients) assert values == 9.0 - assert gradients["0:sum:_test_func.value1"] == 1.0 - assert gradients["1:div:_test_func2.value1"] == 2.0 - assert gradients["1:div:_test_func2.value2"] == -16.0 + assert gradients["0:sum:_test_func"]["value1"] == 1.0 + assert gradients["1:div:_test_func2"]["value1"] == 2.0 + assert gradients["1:div:_test_func2"]["value2"] == -16.0 From 9e7deac5dca67f8203fd04be6978f293712cfddf Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:06:49 -0400 Subject: [PATCH 3/6] Simplify given_args logic by merging it into GraphState --- src/tdastro/base_models.py | 91 +++++++++------------ src/tdastro/graph_state.py | 59 +++++++++++++- src/tdastro/sources/physical_model.py | 15 ++-- src/tdastro/sources/sncomso_models.py | 7 +- src/tdastro/util_nodes/jax_random.py | 11 +-- src/tdastro/util_nodes/np_random.py | 9 +-- src/tdastro/util_nodes/scipy_random.py | 9 +-- tests/tdastro/test_base_models.py | 2 +- tests/tdastro/test_graph_state.py | 106 +++++++++++++++++++++++++ 9 files changed, 220 insertions(+), 89 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index ac796eb8..4b91a772 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -463,7 +463,7 @@ def getter(): getter.__name__ = name setattr(self, name, getter) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Placeholder for a general compute function. Parameters @@ -471,9 +471,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -482,7 +479,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): """ return None - def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None): + def _sample_helper(self, graph_state, seen_nodes, rng_info=None): """Internal recursive function to sample the model's underlying parameters if they are provided by a function or ParameterizedNode. All sampled parameters for all nodes are stored in the graph_state dictionary, which is @@ -496,9 +493,6 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None seen_nodes : `dict` A dictionary mapping nodes seen during this sampling run to their ID. Used to avoid sampling nodes multiple times and to validity check the graph. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -516,41 +510,35 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None # so this will iterate through model parameters in the order they were inserted. any_compute = False for name, setter in self.setters.items(): - # If we are given the argument use that and do not worry about the dependencies. - if given_args is not None and setter.full_name in given_args: - if setter.fixed: - raise ValueError(f"Trying to override fixed parameter {setter.full_name}") - graph_state.set(self.node_string, name, given_args[setter.full_name]) + # Check if we need to sample this parameter's dependency node. + if setter.dependency is not None and setter.dependency != self: + setter.dependency._sample_helper(graph_state, seen_nodes, rng_info) + + # Set the result from the correct source. + if setter.source_type == ParameterSource.CONSTANT: + graph_state.set(self.node_string, name, setter.value) + elif setter.source_type == ParameterSource.MODEL_PARAMETER: + graph_state.set( + self.node_string, + name, + graph_state[setter.dependency.node_string][setter.value], + ) + elif setter.source_type == ParameterSource.FUNCTION_NODE: + graph_state.set( + self.node_string, + name, + graph_state[setter.dependency.node_string][setter.value], + ) + elif setter.source_type == ParameterSource.COMPUTE_OUTPUT: + # Computed parameters are set only after all the other (input) parameters. + any_compute = True else: - # Check if we need to sample this parameter's dependency node. - if setter.dependency is not None and setter.dependency != self: - setter.dependency._sample_helper(graph_state, seen_nodes, given_args, rng_info) - - # Set the result from the correct source. - if setter.source_type == ParameterSource.CONSTANT: - graph_state.set(self.node_string, name, setter.value) - elif setter.source_type == ParameterSource.MODEL_PARAMETER: - graph_state.set( - self.node_string, - name, - graph_state[setter.dependency.node_string][setter.value], - ) - elif setter.source_type == ParameterSource.FUNCTION_NODE: - graph_state.set( - self.node_string, - name, - graph_state[setter.dependency.node_string][setter.value], - ) - elif setter.source_type == ParameterSource.COMPUTE_OUTPUT: - # Computed parameters are set only after all the other (input) parameters. - any_compute = True - else: - raise ValueError(f"Invalid ParameterSource type {setter.source_type}") + raise ValueError(f"Invalid ParameterSource type {setter.source_type}") # If this is a function node and the parameters depend on the result of its own computation # call the compute function to fill them in. if any_compute: - self.compute(graph_state, given_args, rng_info) + self.compute(graph_state, rng_info) def sample_parameters(self, given_args=None, num_samples=1, rng_info=None): """Sample the model's underlying parameters if they are provided by a function @@ -584,10 +572,14 @@ def sample_parameters(self, given_args=None, num_samples=1, rng_info=None): nodes = set() self.set_graph_positions(seen_nodes=nodes) + # Create space for the results and set all the given_args as fixed parameters. + results = GraphState(num_samples) + if given_args is not None: + results.update(given_args, all_fixed=True) + # Resample the nodes. All information is stored in the returned results dictionary. seen_nodes = {} - results = GraphState(num_samples) - self._sample_helper(results, seen_nodes, given_args, rng_info) + self._sample_helper(results, seen_nodes, rng_info) return results def get_all_node_info(self, field, seen_nodes=None): @@ -772,7 +764,7 @@ def _update_node_string(self, extra_tag=None): else: super()._update_node_string() - def _build_inputs(self, graph_state, given_args=None, **kwargs): + def _build_inputs(self, graph_state, **kwargs): """Build the input arguments for the function. Parameters @@ -780,9 +772,6 @@ def _build_inputs(self, graph_state, given_args=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. **kwargs : `dict`, optional Additional function arguments. @@ -793,10 +782,7 @@ def _build_inputs(self, graph_state, given_args=None, **kwargs): """ args = {} for key in self.arg_names: - # Override with the given arg or kwarg in that order. - if given_args is not None and self.setters[key].full_name in given_args: - args[key] = given_args[self.setters[key].full_name] - elif key in kwargs: + if key in kwargs: args[key] = kwargs[key] else: args[key] = graph_state[self.node_string][key] @@ -824,7 +810,7 @@ def _save_results(self, results, graph_state): for i in range(len(self.outputs)): graph_state.set(self.node_string, self.outputs[i], results[i]) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -835,9 +821,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -862,7 +845,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): # Build a dictionary of arguments for the function, call the function, and save # the results in the graph state. - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) results = self.func(**args) self._save_results(results, graph_state) return results @@ -880,4 +863,4 @@ def resample_and_compute(self, given_args=None, rng_info=None): the JAX keys or the numpy rngs. """ graph_state = self.sample_parameters(given_args, 1, rng_info) - return self.compute(graph_state, given_args, rng_info) + return self.compute(graph_state, rng_info) diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 59ec1e09..1e16cf44 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -38,6 +38,11 @@ class GraphState: value or array of values. num_samples : `int` A count of the number of samples stored in the GraphState. + num_parameters : `int` + The total number of parameters stored in a single sample within GraphState. + fixed_vars : `dict` + A dictionary mapping the node name to a set of the variable names that + are fixed in this GraphState instance. """ def __init__(self, num_samples=1): @@ -46,6 +51,7 @@ def __init__(self, num_samples=1): self.num_samples = num_samples self.num_parameters = 0 self.states = {} + self.fixed_vars = {} def __len__(self): return self.num_parameters @@ -82,7 +88,7 @@ def get_node_state(self, node_name, sample_num=0): values[var_name] = val[sample_num] return values - def set(self, node_name, var_name, value, force_copy=False): + def set(self, node_name, var_name, value, force_copy=False, fixed=False): """Set a (new) parameter's value(s) in the GraphState from a given constant value or an array of length num_samples (to set all the values at once). @@ -98,13 +104,21 @@ def set(self, node_name, var_name, value, force_copy=False): Make a copy of data in an array. If set to ``False`` this will link to the array, saving memory and computation time. Default: ``False`` + fixed : `bool` + Treat this parameter as fixed and do not change it during subsequent calls to set. + Default: ``False`` """ # Update the meta data. if node_name not in self.states: self.states[node_name] = {} + self.fixed_vars[node_name] = set() if var_name not in self.states[node_name]: self.num_parameters += 1 + # Check if this parameter is fixed. If so, skip the set. + if var_name in self.fixed_vars[node_name]: + return + # Set the actual values. if self.num_samples == 1: # If this GraphState holds only a single sample, set it from the given value. @@ -121,6 +135,49 @@ def set(self, node_name, var_name, value, force_copy=False): # If the GraphState holds N samples and we got a single value, make an array of it. self.states[node_name][var_name] = np.full((self.num_samples), value) + # Mark the variable as fixed if needed. + if fixed: + self.fixed_vars[node_name].add(var_name) + + def update(self, inputs, force_copy=False, all_fixed=False): + """Set multiple parameters' value in the GraphState from a GraphState or a + dictionary of the same form. + + Note + ---- + The number of samples in input must either match the number of samples in the + current object or be 1. + + Parameters + ---------- + inputs : `GraphState` or `dict` + Values to copy. + force_copy : `bool` + Make a copy of data in an array. If set to ``False`` this will link + to the array, saving memory and computation time. + Default: ``False`` + all_fixed : `bool` + Treat all the parameters in inputs as fixed. + Default: ``False`` + + Raises + ------ + ValueError if the input an invalid number of samples. + """ + if isinstance(inputs, GraphState): + if self.num_samples != inputs.num_samples and inputs.num_samples != 1: + raise ValueError("GraphSates must have the same number of samples.") + new_states = inputs.states + else: + new_states = inputs + + # Set the values one by one. The set function takes care of expanding + # any values that are constants (e.g. float or int) to match the correct + # number of samples. + for node_name, node_vars in new_states.items(): + for var_name, value in node_vars.items(): + self.set(node_name, var_name, value, force_copy=force_copy, fixed=all_fixed) + def extract_single_sample(self, sample_num): """Create a new GraphState with a single sample state. diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 0520d43c..e77eaf86 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -222,22 +222,19 @@ def sample_parameters(self, given_args=None, num_samples=1, rng_info=None, **kwa if self.node_pos is None: self.set_graph_positions() - args_to_use = {} - if given_args is not None: - args_to_use.update(given_args) - if kwargs is not None: - args_to_use.update(kwargs) - # We use the same seen_nodes for all sampling calls so each node # is sampled at most one time regardless of link structure. graph_state = GraphState(num_samples) + if given_args is not None: + graph_state.update(given_args, all_fixed=True) + seen_nodes = {} if self.background is not None: - self.background._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) - self._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) + self.background._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) + self._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) for effect in self.effects: - effect._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) + effect._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) return graph_state diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 90d685e4..b10638ad 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -97,7 +97,7 @@ def set(self, **kwargs): self.source_param_names.append(key) self.source.set(**kwargs) - def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1, rng_info=None): + def _sample_helper(self, graph_state, seen_nodes, num_samples=1, rng_info=None): """Internal recursive function to sample the model's underlying parameters if they are provided by a function or ParameterizedNode. @@ -112,9 +112,6 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1 seen_nodes : `dict` A dictionary mapping nodes seen during this sampling run to their ID. Used to avoid sampling nodes multiple times and to validity check the graph. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. num_samples : `int` A count of the number of samples to compute. Default: 1 @@ -126,7 +123,7 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1 ------ Raise a ``ValueError`` the sampling encounters a problem with the order of dependencies. """ - super()._sample_helper(graph_state, seen_nodes, given_args, rng_info) + super()._sample_helper(graph_state, seen_nodes, rng_info) self._update_sncosmo_model_parameters(graph_state) def _evaluate(self, times, wavelengths, graph_state=None, **kwargs): diff --git a/src/tdastro/util_nodes/jax_random.py b/src/tdastro/util_nodes/jax_random.py index 70e39ebc..ceae5e94 100644 --- a/src/tdastro/util_nodes/jax_random.py +++ b/src/tdastro/util_nodes/jax_random.py @@ -89,7 +89,7 @@ class JaxRandomFunc(FunctionNode): def __init__(self, func, **kwargs): super().__init__(func, **kwargs) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped JAX sampling function. Parameters @@ -97,9 +97,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -125,7 +122,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): rng_info[self.node_hash] = next_key # Generate the results. - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) if graph_state.num_samples == 1: results = float(self.func(current_key, **args)) else: @@ -152,7 +149,7 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) class JaxRandomNormal(FunctionNode): @@ -196,4 +193,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Any additional keyword arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/src/tdastro/util_nodes/np_random.py b/src/tdastro/util_nodes/np_random.py index 6e2ea2c2..3fc35ad3 100644 --- a/src/tdastro/util_nodes/np_random.py +++ b/src/tdastro/util_nodes/np_random.py @@ -126,7 +126,7 @@ def set_seed(self, new_seed): self._rng = np.random.default_rng(seed=new_seed) self.func = getattr(self._rng, self.func_name) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -137,9 +137,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -156,7 +153,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): ------ ``ValueError`` is ``func`` attribute is ``None``. """ - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) num_samples = None if graph_state.num_samples == 1 else graph_state.num_samples # If a random number generator is given use that. Otherwise use the default one. @@ -186,4 +183,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/src/tdastro/util_nodes/scipy_random.py b/src/tdastro/util_nodes/scipy_random.py index 8571c2a0..6bb16fd0 100644 --- a/src/tdastro/util_nodes/scipy_random.py +++ b/src/tdastro/util_nodes/scipy_random.py @@ -97,7 +97,7 @@ def _create_and_sample(self, args, rng): sample = NumericalInversePolynomial(dist).rvs(1, rng)[0] return sample - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -108,9 +108,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -132,7 +129,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): else: # This is a class so we will need to create a new distribution object # for each sample (with a single instance of the input parameters). - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) if graph_state.num_samples == 1: dist = self._dist(**args) @@ -164,4 +161,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 57032288..ea0d5f55 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -219,7 +219,7 @@ def test_parameterized_node_build_pytree(): model1 = PairModel(value1=0.5, value2=1.5, node_label="A") model2 = PairModel(value1=model1.value1, value2=3.0, node_label="B") graph_state = model2.sample_parameters() - + pytree = model2.build_pytree(graph_state) assert pytree["1:A"]["value1"] == 0.5 assert pytree["1:A"]["value2"] == 1.5 diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index d179b636..6b565ec2 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -119,6 +119,112 @@ def test_create_multi_sample_graph_state_reference(): assert np.allclose(state2["b"]["v1"], [2.0, 2.5, 3.0, 3.5, 4.0]) +def test_graph_state_fixed(): + """Test that we respected the 'fixed' flag for GraphState.""" + state = GraphState() + assert len(state) == 0 + state.set("a", "v1", 1.0) + state.set("a", "v2", 2.0) + state.set("b", "v1", 3.0, fixed=True) + assert len(state) == 3 + assert state["a"]["v1"] == 1.0 + assert state["a"]["v2"] == 2.0 + assert state["b"]["v1"] == 3.0 + + # Try changing each of the states. Only two should actually change. + state.set("a", "v1", 4.0) + state.set("a", "v2", 5.0) + state.set("b", "v1", 6.0) + assert state["a"]["v1"] == 4.0 + assert state["a"]["v2"] == 5.0 + assert state["b"]["v1"] == 3.0 + + +def test_graph_state_update(): + """Test that we can update a single sample GraphState.""" + state = GraphState() + state.set("a", "v1", 1.0) + state.set("a", "v2", 2.0) + state.set("b", "v1", 3.0) + + state2 = GraphState() + state2.set("a", "v1", 4.0) + state2.set("a", "v3", 5.0) + state2.set("c", "v1", 6.0) + state2.set("c", "v2", 7.0) + + assert len(state) == 3 + assert len(state2) == 4 + + # We set 3 new parameters and overwrite one. + state.update(state2) + assert len(state) == 6 + assert state["a"]["v1"] == 4.0 + assert state["a"]["v2"] == 2.0 + assert state["a"]["v3"] == 5.0 + assert state["b"]["v1"] == 3.0 + assert state["c"]["v1"] == 6.0 + assert state["c"]["v2"] == 7.0 + + # We set 3 new parameters and overwrite one. + state3 = {"a": {"v2": 8.0, "v4": 9.0}, "d": {"v1": 10.0}} + state.update(state3) + assert len(state) == 8 + assert state["a"]["v1"] == 4.0 + assert state["a"]["v2"] == 8.0 + assert state["a"]["v3"] == 5.0 + assert state["a"]["v4"] == 9.0 + assert state["b"]["v1"] == 3.0 + assert state["c"]["v1"] == 6.0 + assert state["c"]["v2"] == 7.0 + assert state["d"]["v1"] == 10.0 + + # Test we cannot update with mismatched number of samples. + state4 = GraphState(num_samples=2) + state4.set("e", "v1", 1.0) + with pytest.raises(ValueError): + state.update(state4) + + +def test_graph_state_update_multi(): + """Test that we can update a single sample GraphState.""" + state = GraphState(num_samples=3) + state.set("a", "v1", [1.0, 2.0, 3.0]) + state.set("a", "v2", [3.0, 4.0, 5.0]) + state.set("b", "v1", [6.0, 7.0, 8.0]) + + state2 = GraphState(num_samples=3) + state2.set("a", "v1", [9.0, 10.0, 11.0]) + state2.set("c", "v1", [12.0, 13.0, 14.0]) + + assert len(state) == 3 + assert len(state2) == 2 + + # We set one new parameter and overwrite one. + state.update(state2) + assert len(state) == 4 + assert np.allclose(state["a"]["v1"], [9.0, 10.0, 11.0]) + assert np.allclose(state["a"]["v2"], [3.0, 4.0, 5.0]) + assert np.allclose(state["b"]["v1"], [6.0, 7.0, 8.0]) + assert np.allclose(state["c"]["v1"], [12.0, 13.0, 14.0]) + + # If we add a parameter with sample_size = 1, we correctly expand it out. + state3 = {"a": {"v2": 15.0}, "d": {"v1": 16.0}} + state.update(state3) + assert len(state) == 5 + assert np.allclose(state["a"]["v1"], [9.0, 10.0, 11.0]) + assert np.allclose(state["a"]["v2"], [15.0, 15.0, 15.0]) + assert np.allclose(state["b"]["v1"], [6.0, 7.0, 8.0]) + assert np.allclose(state["c"]["v1"], [12.0, 13.0, 14.0]) + assert np.allclose(state["d"]["v1"], [16.0, 16.0, 16.0]) + + # Test we cannot update with mismatched number of samples. + state4 = GraphState(num_samples=2) + state4.set("e", "v1", 1.0) + with pytest.raises(ValueError): + state.update(state4) + + def test_transpose_dict_of_list(): """Test the transpose_dict_of_list helper function""" input_dict = { From 8e88f2f9136702895f3a0c68d1c0490d3f3d7df0 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:15:12 -0400 Subject: [PATCH 4/6] Remove the now unused full_name --- src/tdastro/base_models.py | 34 ++++++++----------------------- tests/tdastro/test_base_models.py | 7 +------ 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 4b91a772..22b47758 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -63,6 +63,8 @@ class ParameterSource: ---------- parameter_name : `str` The name of the parameter within the node (short name). + node_name : `str` + The name of the parent node. source_type : `int` The type of source as defined by the class variables. Default = 0 @@ -77,8 +79,6 @@ class ParameterSource: required : `bool` The attribute must exist and be non-None. Default = ``False`` - full_name : `str` - The full name of the parameter including the node information. """ # Class variables for the source enum. @@ -89,32 +89,13 @@ class ParameterSource: COMPUTE_OUTPUT = 4 def __init__(self, parameter_name, source_type=0, fixed=False, required=False, node_name=""): + self.parameter_name = parameter_name + self.node_name = node_name self.source_type = source_type self.fixed = fixed self.required = required self.value = None self.dependency = None - self.node_name = node_name - self.set_name(parameter_name, node_name) - - def set_name(self, parameter_name="", node_name=""): - """Set the name of the parameter field. - - Parameter - --------- - parameter_name : `str` - The name of the parameter within the node (short name). - node_name : `str` - The node string for the node containing this parameter. - """ - if len(parameter_name) == 0: - raise ValueError(f"Invalid parameter name: {parameter_name}") - - self.parameter_name = parameter_name - if len(node_name) > 0: - self.full_name = f"{node_name}.{parameter_name}" - else: - self.full_name = f"{parameter_name}" def set_as_constant(self, value): """Set the parameter as a constant value. @@ -211,6 +192,7 @@ def __init__(self, node_label=None, **kwargs): self.node_label = node_label self.node_pos = None self.node_string = None + self.node_hash = None def __str__(self): """Return the string representation of the node.""" @@ -235,9 +217,9 @@ def _update_node_string(self, extra_tag=None): hashed_object_name = md5(self.node_string.encode()).hexdigest() self.node_hash = int(hashed_object_name, base=16) - # Update the full_name of all node's parameter setters. - for name, setter_info in self.setters.items(): - setter_info.set_name(name, self.node_string) + # Update the node_name of all node's parameter setters. + for _, setter_info in self.setters.items(): + setter_info.node_name = self.node_string def set_graph_positions(self, seen_nodes=None): """Force an update of the graph structure (numbering of each node). diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index ea0d5f55..f763efef 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -61,7 +61,7 @@ def test_parameter_source(): """Test the ParameterSource creation and setter functions.""" source = ParameterSource("test") assert source.parameter_name == "test" - assert source.full_name == "test" + assert source.node_name == "" assert source.source_type == ParameterSource.UNDEFINED assert source.dependency is None assert source.value is None @@ -70,17 +70,12 @@ def test_parameter_source(): source.set_as_constant(10.0) assert source.parameter_name == "test" - assert source.full_name == "test" assert source.source_type == ParameterSource.CONSTANT assert source.dependency is None assert source.value == 10.0 assert not source.fixed assert not source.required - source.set_name("my_var", "my_node") - assert source.parameter_name == "my_var" - assert source.full_name == "my_node.my_var" - with pytest.raises(ValueError): source.set_as_constant(_test_func) From 0fb425ddc0c9747addd17f5d73b64a7354dfd29e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:35:05 -0400 Subject: [PATCH 5/6] Add a debugging string --- src/tdastro/graph_state.py | 8 ++++++++ tests/tdastro/test_graph_state.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 1e16cf44..16c8243b 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -56,6 +56,14 @@ def __init__(self, num_samples=1): def __len__(self): return self.num_parameters + def __str__(self): + str_lines = [] + for node_name, node_vars in self.states.items(): + str_lines.append(f"{node_name}:") + for var_name, value in node_vars.items(): + str_lines.append(f" {var_name}: {value}") + return "\n".join(str_lines) + def __getitem__(self, key): """Access the dictionary of parameter values for a node name.""" return self.states[key] diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 6b565ec2..ec122abb 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -22,6 +22,10 @@ def test_create_single_sample_graph_state(): with pytest.raises(KeyError): _ = state["c"]["v1"] + # We can create a human readable string representation of the GraphState. + debug_str = str(state) + assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0" + # Check that we can get all the values for a specific node. a_vals = state.get_node_state("a") assert len(a_vals) == 2 From c4a5f5de1772fedd2b44bae5b152b5bb8623da24 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sat, 31 Aug 2024 09:52:04 -0400 Subject: [PATCH 6/6] Update src/tdastro/base_models.py Co-authored-by: Olivia R. Lynn --- src/tdastro/base_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 22b47758..5291d145 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -271,7 +271,7 @@ def get_param(self, graph_state, name): ``ValueError`` if graph_state is None. """ if graph_state is None: - raise ValueError(f"Unable to look ip parameter={name}. No graph_state given.") + raise ValueError(f"Unable to look up parameter={name}. No graph_state given.") return graph_state[self.node_string][name] def get_local_params(self, graph_state):