diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 8acb02e3..5291d145 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -63,6 +63,8 @@ class ParameterSource: ---------- parameter_name : `str` The name of the parameter within the node (short name). + node_name : `str` + The name of the parent node. source_type : `int` The type of source as defined by the class variables. Default = 0 @@ -77,8 +79,6 @@ class ParameterSource: required : `bool` The attribute must exist and be non-None. Default = ``False`` - full_name : `str` - The full name of the parameter including the node information. """ # Class variables for the source enum. @@ -89,31 +89,13 @@ class ParameterSource: COMPUTE_OUTPUT = 4 def __init__(self, parameter_name, source_type=0, fixed=False, required=False, node_name=""): + self.parameter_name = parameter_name + self.node_name = node_name self.source_type = source_type self.fixed = fixed self.required = required self.value = None self.dependency = None - self.set_name(parameter_name, node_name) - - def set_name(self, parameter_name="", node_name=""): - """Set the name of the parameter field. - - Parameter - --------- - parameter_name : `str` - The name of the parameter within the node (short name). - node_name : `str` - The node string for the node containing this parameter. - """ - if len(parameter_name) == 0: - raise ValueError(f"Invalid parameter name: {parameter_name}") - - self.parameter_name = parameter_name - if len(node_name) > 0: - self.full_name = f"{node_name}.{parameter_name}" - else: - self.full_name = f"{parameter_name}" def set_as_constant(self, value): """Set the parameter as a constant value. @@ -184,7 +166,7 @@ class ParameterizedNode: The full string used to identify a node. This is a combination of the nodes position in the graph (if known), node_label (if provided), and class information. node_hash : `int` - A hashed version of ``node_string`` used for fast lookups. + A precomputed hashed version of ``node_string``. setters : `dict` A dictionary mapping the parameters' names to information about the setters (ParameterSource). The model parameters are stored in the order in which they @@ -225,7 +207,7 @@ def _update_node_string(self, extra_tag=None): if self.node_label is not None: self.node_string = f"{pos_string}{self.node_label}" else: - self.node_string = f"{pos_string}{self.__class__.__module__}.{self.__class__.__qualname__}" + self.node_string = f"{pos_string}{self.__class__.__qualname__}" # Allow for the appending of an extra tag. if extra_tag is not None: @@ -235,9 +217,9 @@ def _update_node_string(self, extra_tag=None): hashed_object_name = md5(self.node_string.encode()).hexdigest() self.node_hash = int(hashed_object_name, base=16) - # Update the full_name of all node's parameter setters. - for name, setter_info in self.setters.items(): - setter_info.set_name(name, self.node_string) + # Update the node_name of all node's parameter setters. + for _, setter_info in self.setters.items(): + setter_info.node_name = self.node_string def set_graph_positions(self, seen_nodes=None): """Force an update of the graph structure (numbering of each node). @@ -289,8 +271,8 @@ def get_param(self, graph_state, name): ``ValueError`` if graph_state is None. """ if graph_state is None: - raise ValueError(f"Unable to look ip parameter={name}. No graph_state given.") - return graph_state[self.node_hash][name] + raise ValueError(f"Unable to look up parameter={name}. No graph_state given.") + return graph_state[self.node_string][name] def get_local_params(self, graph_state): """Get a dictionary of all parameters local to this node. @@ -317,7 +299,7 @@ def get_local_params(self, graph_state): """ if graph_state is None: raise ValueError("No graph_state given.") - return graph_state[self.node_hash] + return graph_state[self.node_string] def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -463,7 +445,7 @@ def getter(): getter.__name__ = name setattr(self, name, getter) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Placeholder for a general compute function. Parameters @@ -471,9 +453,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -482,7 +461,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): """ return None - def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None): + def _sample_helper(self, graph_state, seen_nodes, rng_info=None): """Internal recursive function to sample the model's underlying parameters if they are provided by a function or ParameterizedNode. All sampled parameters for all nodes are stored in the graph_state dictionary, which is @@ -496,9 +475,6 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None seen_nodes : `dict` A dictionary mapping nodes seen during this sampling run to their ID. Used to avoid sampling nodes multiple times and to validity check the graph. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -516,41 +492,35 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, rng_info=None # so this will iterate through model parameters in the order they were inserted. any_compute = False for name, setter in self.setters.items(): - # If we are given the argument use that and do not worry about the dependencies. - if given_args is not None and setter.full_name in given_args: - if setter.fixed: - raise ValueError(f"Trying to override fixed parameter {setter.full_name}") - graph_state.set(self.node_hash, name, given_args[setter.full_name]) + # Check if we need to sample this parameter's dependency node. + if setter.dependency is not None and setter.dependency != self: + setter.dependency._sample_helper(graph_state, seen_nodes, rng_info) + + # Set the result from the correct source. + if setter.source_type == ParameterSource.CONSTANT: + graph_state.set(self.node_string, name, setter.value) + elif setter.source_type == ParameterSource.MODEL_PARAMETER: + graph_state.set( + self.node_string, + name, + graph_state[setter.dependency.node_string][setter.value], + ) + elif setter.source_type == ParameterSource.FUNCTION_NODE: + graph_state.set( + self.node_string, + name, + graph_state[setter.dependency.node_string][setter.value], + ) + elif setter.source_type == ParameterSource.COMPUTE_OUTPUT: + # Computed parameters are set only after all the other (input) parameters. + any_compute = True else: - # Check if we need to sample this parameter's dependency node. - if setter.dependency is not None and setter.dependency != self: - setter.dependency._sample_helper(graph_state, seen_nodes, given_args, rng_info) - - # Set the result from the correct source. - if setter.source_type == ParameterSource.CONSTANT: - graph_state.set(self.node_hash, name, setter.value) - elif setter.source_type == ParameterSource.MODEL_PARAMETER: - graph_state.set( - self.node_hash, - name, - graph_state[setter.dependency.node_hash][setter.value], - ) - elif setter.source_type == ParameterSource.FUNCTION_NODE: - graph_state.set( - self.node_hash, - name, - graph_state[setter.dependency.node_hash][setter.value], - ) - elif setter.source_type == ParameterSource.COMPUTE_OUTPUT: - # Computed parameters are set only after all the other (input) parameters. - any_compute = True - else: - raise ValueError(f"Invalid ParameterSource type {setter.source_type}") + raise ValueError(f"Invalid ParameterSource type {setter.source_type}") # If this is a function node and the parameters depend on the result of its own computation # call the compute function to fill them in. if any_compute: - self.compute(graph_state, given_args, rng_info) + self.compute(graph_state, rng_info) def sample_parameters(self, given_args=None, num_samples=1, rng_info=None): """Sample the model's underlying parameters if they are provided by a function @@ -584,10 +554,14 @@ def sample_parameters(self, given_args=None, num_samples=1, rng_info=None): nodes = set() self.set_graph_positions(seen_nodes=nodes) + # Create space for the results and set all the given_args as fixed parameters. + results = GraphState(num_samples) + if given_args is not None: + results.update(given_args, all_fixed=True) + # Resample the nodes. All information is stored in the returned results dictionary. seen_nodes = {} - results = GraphState(num_samples) - self._sample_helper(results, seen_nodes, given_args, rng_info) + self._sample_helper(results, seen_nodes, rng_info) return results def get_all_node_info(self, field, seen_nodes=None): @@ -627,7 +601,7 @@ def get_all_node_info(self, field, seen_nodes=None): result.extend(dep.get_all_node_info(field, seen_nodes)) return result - def build_pytree(self, graph_state, seen=None): + def build_pytree(self, graph_state, partial=None): """Build a JAX PyTree representation of the variables in this graph. Parameters @@ -635,14 +609,17 @@ def build_pytree(self, graph_state, seen=None): graph_state : `dict` A dictionary of dictionaries mapping node->hash, variable_name to value. This data structure is modified in place to represent the current state. - seen : `set` - A set of objects that have already been processed. + partial : `dict` + The partial results so far. This is modified in place by the function. + A dictionary mapping node name to a dictionary mapping each variable's name + to its value. Default : ``None`` + Returns ------- values : `dict` - The dictionary mapping the combination of the object identifier and - model parameter name to its value. + A dictionary mapping node name to a dictionary mapping each variable's name + to its value. """ # Check if the node might have incomplete information. if self.node_pos is None: @@ -652,20 +629,20 @@ def build_pytree(self, graph_state, seen=None): ) # Skip nodes that we have already seen. - if seen is None: - seen = set() - if self in seen: - return {} - seen.add(self) + if partial is None: + partial = {} + if self.node_string in partial: + return partial - all_values = {} + # Add new values to the pytree, recursively exploring dependencies. + partial[self.node_string] = {} for name, setter_info in self.setters.items(): if setter_info.dependency is not None: - all_values.update(setter_info.dependency.build_pytree(graph_state, seen)) + partial = setter_info.dependency.build_pytree(graph_state, partial) elif setter_info.source_type == ParameterSource.CONSTANT and not setter_info.fixed: # Only the non-fixed, constants go into the PyTree. - all_values[setter_info.full_name] = graph_state[self.node_hash][name] - return all_values + partial[self.node_string][name] = graph_state[self.node_string][name] + return partial class SingleVariableNode(ParameterizedNode): @@ -769,7 +746,7 @@ def _update_node_string(self, extra_tag=None): else: super()._update_node_string() - def _build_inputs(self, graph_state, given_args=None, **kwargs): + def _build_inputs(self, graph_state, **kwargs): """Build the input arguments for the function. Parameters @@ -777,9 +754,6 @@ def _build_inputs(self, graph_state, given_args=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. **kwargs : `dict`, optional Additional function arguments. @@ -790,13 +764,10 @@ def _build_inputs(self, graph_state, given_args=None, **kwargs): """ args = {} for key in self.arg_names: - # Override with the given arg or kwarg in that order. - if given_args is not None and self.setters[key].full_name in given_args: - args[key] = given_args[self.setters[key].full_name] - elif key in kwargs: + if key in kwargs: args[key] = kwargs[key] else: - args[key] = graph_state[self.node_hash][key] + args[key] = graph_state[self.node_string][key] return args def _save_results(self, results, graph_state): @@ -811,7 +782,7 @@ def _save_results(self, results, graph_state): in place as it is sampled. """ if len(self.outputs) == 1: - graph_state.set(self.node_hash, self.outputs[0], results) + graph_state.set(self.node_string, self.outputs[0], results) else: if len(results) != len(self.outputs): raise ValueError( @@ -819,9 +790,9 @@ def _save_results(self, results, graph_state): f"Expected {len(self.outputs)}, but got {results}." ) for i in range(len(self.outputs)): - graph_state.set(self.node_hash, self.outputs[i], results[i]) + graph_state.set(self.node_string, self.outputs[i], results[i]) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -832,9 +803,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -859,7 +827,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): # Build a dictionary of arguments for the function, call the function, and save # the results in the graph state. - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) results = self.func(**args) self._save_results(results, graph_state) return results @@ -877,4 +845,4 @@ def resample_and_compute(self, given_args=None, rng_info=None): the JAX keys or the numpy rngs. """ graph_state = self.sample_parameters(given_args, 1, rng_info) - return self.compute(graph_state, given_args, rng_info) + return self.compute(graph_state, rng_info) diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 066bd064..16c8243b 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -56,6 +56,14 @@ def __init__(self, num_samples=1): def __len__(self): return self.num_parameters + def __str__(self): + str_lines = [] + for node_name, node_vars in self.states.items(): + str_lines.append(f"{node_name}:") + for var_name, value in node_vars.items(): + str_lines.append(f" {var_name}: {value}") + return "\n".join(str_lines) + def __getitem__(self, key): """Access the dictionary of parameter values for a node name.""" return self.states[key] @@ -139,7 +147,7 @@ def set(self, node_name, var_name, value, force_copy=False, fixed=False): if fixed: self.fixed_vars[node_name].add(var_name) - def update(self, inputs, force_copy=False): + def update(self, inputs, force_copy=False, all_fixed=False): """Set multiple parameters' value in the GraphState from a GraphState or a dictionary of the same form. @@ -156,6 +164,9 @@ def update(self, inputs, force_copy=False): Make a copy of data in an array. If set to ``False`` this will link to the array, saving memory and computation time. Default: ``False`` + all_fixed : `bool` + Treat all the parameters in inputs as fixed. + Default: ``False`` Raises ------ @@ -173,7 +184,7 @@ def update(self, inputs, force_copy=False): # number of samples. for node_name, node_vars in new_states.items(): for var_name, value in node_vars.items(): - self.set(node_name, var_name, value, force_copy=force_copy) + self.set(node_name, var_name, value, force_copy=force_copy, fixed=all_fixed) def extract_single_sample(self, sample_num): """Create a new GraphState with a single sample state. diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 0520d43c..e77eaf86 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -222,22 +222,19 @@ def sample_parameters(self, given_args=None, num_samples=1, rng_info=None, **kwa if self.node_pos is None: self.set_graph_positions() - args_to_use = {} - if given_args is not None: - args_to_use.update(given_args) - if kwargs is not None: - args_to_use.update(kwargs) - # We use the same seen_nodes for all sampling calls so each node # is sampled at most one time regardless of link structure. graph_state = GraphState(num_samples) + if given_args is not None: + graph_state.update(given_args, all_fixed=True) + seen_nodes = {} if self.background is not None: - self.background._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) - self._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) + self.background._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) + self._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) for effect in self.effects: - effect._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs) + effect._sample_helper(graph_state, seen_nodes, rng_info, **kwargs) return graph_state diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index bf60a57e..b10638ad 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -58,7 +58,7 @@ def parameter_values(self): def _update_sncosmo_model_parameters(self, graph_state): """Update the parameters for the wrapped sncosmo model.""" - local_params = graph_state.get_node_state(self.node_hash, 0) + local_params = graph_state.get_node_state(self.node_string, 0) sn_params = {} for name in self.source_param_names: sn_params[name] = local_params[name] @@ -97,7 +97,7 @@ def set(self, **kwargs): self.source_param_names.append(key) self.source.set(**kwargs) - def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1, rng_info=None): + def _sample_helper(self, graph_state, seen_nodes, num_samples=1, rng_info=None): """Internal recursive function to sample the model's underlying parameters if they are provided by a function or ParameterizedNode. @@ -112,9 +112,6 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1 seen_nodes : `dict` A dictionary mapping nodes seen during this sampling run to their ID. Used to avoid sampling nodes multiple times and to validity check the graph. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. num_samples : `int` A count of the number of samples to compute. Default: 1 @@ -126,7 +123,7 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1 ------ Raise a ``ValueError`` the sampling encounters a problem with the order of dependencies. """ - super()._sample_helper(graph_state, seen_nodes, given_args, rng_info) + super()._sample_helper(graph_state, seen_nodes, rng_info) self._update_sncosmo_model_parameters(graph_state) def _evaluate(self, times, wavelengths, graph_state=None, **kwargs): diff --git a/src/tdastro/util_nodes/jax_random.py b/src/tdastro/util_nodes/jax_random.py index 4e7be0f9..ceae5e94 100644 --- a/src/tdastro/util_nodes/jax_random.py +++ b/src/tdastro/util_nodes/jax_random.py @@ -89,7 +89,7 @@ class JaxRandomFunc(FunctionNode): def __init__(self, func, **kwargs): super().__init__(func, **kwargs) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped JAX sampling function. Parameters @@ -97,9 +97,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -125,13 +122,13 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): rng_info[self.node_hash] = next_key # Generate the results. - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) if graph_state.num_samples == 1: results = float(self.func(current_key, **args)) else: use_shape = [graph_state.num_samples] results = self.func(current_key, shape=use_shape, **args) - graph_state.set(self.node_hash, self.outputs[0], results) + graph_state.set(self.node_string, self.outputs[0], results) return results def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): @@ -152,7 +149,7 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) class JaxRandomNormal(FunctionNode): @@ -196,4 +193,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Any additional keyword arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/src/tdastro/util_nodes/np_random.py b/src/tdastro/util_nodes/np_random.py index 6e2ea2c2..3fc35ad3 100644 --- a/src/tdastro/util_nodes/np_random.py +++ b/src/tdastro/util_nodes/np_random.py @@ -126,7 +126,7 @@ def set_seed(self, new_seed): self._rng = np.random.default_rng(seed=new_seed) self.func = getattr(self._rng, self.func_name) - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -137,9 +137,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -156,7 +153,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): ------ ``ValueError`` is ``func`` attribute is ``None``. """ - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) num_samples = None if graph_state.num_samples == 1 else graph_state.num_samples # If a random number generator is given use that. Otherwise use the default one. @@ -186,4 +183,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/src/tdastro/util_nodes/scipy_random.py b/src/tdastro/util_nodes/scipy_random.py index 8571c2a0..6bb16fd0 100644 --- a/src/tdastro/util_nodes/scipy_random.py +++ b/src/tdastro/util_nodes/scipy_random.py @@ -97,7 +97,7 @@ def _create_and_sample(self, args, rng): sample = NumericalInversePolynomial(dist).rvs(1, rng)[0] return sample - def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): + def compute(self, graph_state, rng_info=None, **kwargs): """Execute the wrapped function. The input arguments are taken from the current graph_state and the outputs @@ -108,9 +108,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): graph_state : `GraphState` An object mapping graph parameters to their values. This object is modified in place as it is sampled. - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. rng_info : `dict`, optional A dictionary of random number generator information for each node, such as the JAX keys or the numpy rngs. @@ -132,7 +129,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs): else: # This is a class so we will need to create a new distribution object # for each sample (with a single instance of the input parameters). - args = self._build_inputs(graph_state, given_args, **kwargs) + args = self._build_inputs(graph_state, **kwargs) if graph_state.num_samples == 1: dist = self._dist(**args) @@ -164,4 +161,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): Additional function arguments. """ state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, given_args, rng_info, **kwargs) + return self.compute(state, rng_info, **kwargs) diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index 95875e8a..06b357d8 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -9,7 +9,7 @@ def test_sncomso_models_hsiao() -> None: state = model.sample_parameters() assert model.get_param(state, "amplitude") == 2.0e10 assert model.get_param(state, "t0") == 0.0 - assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel" + assert str(model) == "0:SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameter_values, [2.0e10]) @@ -29,7 +29,6 @@ def test_sncomso_models_hsiao_t0() -> None: state = model.sample_parameters() assert model.get_param(state, "amplitude") == 2.0e10 assert model.get_param(state, "t0") == 55000.0 - assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameter_values, [2.0e10]) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py index 1a8a0ecc..d42f74f7 100644 --- a/tests/tdastro/sources/test_spline_source.py +++ b/tests/tdastro/sources/test_spline_source.py @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None: wavelengths = np.linspace(100.0, 500.0, 25) fluxes = np.full((len(times), len(wavelengths)), 1.0) model = SplineModel(times, wavelengths, fluxes) - assert str(model) == "tdastro.sources.spline_model.SplineModel" + assert str(model) == "SplineModel" test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index c1947850..050763df 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -52,10 +52,10 @@ def test_static_source_host() -> None: assert model.get_param(state, "ra") == 1.0 assert model.get_param(state, "dec") == 2.0 assert model.get_param(state, "distance") == 3.0 - assert str(model) == "0:tdastro.sources.static_source.StaticSource" + assert str(model) == "0:StaticSource" # Test that we have given a different name to the host. - assert str(host) == "1:tdastro.sources.static_source.StaticSource" + assert str(host) == "1:StaticSource" def test_static_source_resample() -> None: diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 5fe6397d..f763efef 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -61,7 +61,7 @@ def test_parameter_source(): """Test the ParameterSource creation and setter functions.""" source = ParameterSource("test") assert source.parameter_name == "test" - assert source.full_name == "test" + assert source.node_name == "" assert source.source_type == ParameterSource.UNDEFINED assert source.dependency is None assert source.value is None @@ -70,17 +70,12 @@ def test_parameter_source(): source.set_as_constant(10.0) assert source.parameter_name == "test" - assert source.full_name == "test" assert source.source_type == ParameterSource.CONSTANT assert source.dependency is None assert source.value == 10.0 assert not source.fixed assert not source.required - source.set_name("my_var", "my_node") - assert source.parameter_name == "my_var" - assert source.full_name == "my_node.my_var" - with pytest.raises(ValueError): source.set_as_constant(_test_func) @@ -104,7 +99,7 @@ def test_parameterized_node(): """Test that we can sample and create a PairModel object.""" # Simple addition model1 = PairModel(value1=0.5, value2=0.5) - assert str(model1) == "test_base_models.PairModel" + assert str(model1) == "PairModel" state = model1.sample_parameters() assert model1.get_param(state, "value1") == 0.5 @@ -219,26 +214,25 @@ def test_parameterized_node_build_pytree(): model1 = PairModel(value1=0.5, value2=1.5, node_label="A") model2 = PairModel(value1=model1.value1, value2=3.0, node_label="B") graph_state = model2.sample_parameters() - pytree = model2.build_pytree(graph_state) - assert len(pytree) == 3 - assert pytree["1:A.value1"] == 0.5 - assert pytree["1:A.value2"] == 1.5 - assert pytree["0:B.value2"] == 3.0 + pytree = model2.build_pytree(graph_state) + assert pytree["1:A"]["value1"] == 0.5 + assert pytree["1:A"]["value2"] == 1.5 + assert pytree["0:B"]["value2"] == 3.0 # Manually set value2 to fixed and check that it no longer appears in the pytree. model1.setters["value2"].fixed = True pytree = model2.build_pytree(graph_state) - assert len(pytree) == 2 - assert pytree["1:A.value1"] == 0.5 - assert pytree["0:B.value2"] == 3.0 + assert pytree["1:A"]["value1"] == 0.5 + assert pytree["0:B"]["value2"] == 3.0 + assert "value2" not in pytree["1:A"] def test_single_variable_node(): """Test that we can create and query a SingleVariableNode.""" node = SingleVariableNode("A", 10.0) - assert str(node) == "tdastro.base_models.SingleVariableNode" + assert str(node) == "SingleVariableNode" state = node.sample_parameters() assert node.get_param(state, "A") == 10 @@ -253,7 +247,7 @@ def test_function_node_basic(): assert my_func.compute(state, value2=3.0) == 4.0 assert my_func.compute(state, value2=3.0, unused_param=5.0) == 4.0 assert my_func.compute(state, value2=3.0, value1=1.0) == 4.0 - assert str(my_func) == "0:tdastro.base_models.FunctionNode:_test_func" + assert str(my_func) == "0:FunctionNode:_test_func" def test_function_node_chain(): @@ -347,14 +341,11 @@ def _test_func2(value1, value2): graph_state = sum_node.sample_parameters() pytree = sum_node.build_pytree(graph_state) - assert len(pytree) == 3 print(pytree) - gr_func = jax.value_and_grad(sum_node.resample_and_compute) values, gradients = gr_func(pytree) - assert len(gradients) == 3 print(gradients) assert values == 9.0 - assert gradients["0:sum:_test_func.value1"] == 1.0 - assert gradients["1:div:_test_func2.value1"] == 2.0 - assert gradients["1:div:_test_func2.value2"] == -16.0 + assert gradients["0:sum:_test_func"]["value1"] == 1.0 + assert gradients["1:div:_test_func2"]["value1"] == 2.0 + assert gradients["1:div:_test_func2"]["value2"] == -16.0 diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 6b565ec2..ec122abb 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -22,6 +22,10 @@ def test_create_single_sample_graph_state(): with pytest.raises(KeyError): _ = state["c"]["v1"] + # We can create a human readable string representation of the GraphState. + debug_str = str(state) + assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0" + # Check that we can get all the values for a specific node. a_vals = state.get_node_state("a") assert len(a_vals) == 2