From 06dfaa7736961dfba209462173acefb2f013f80c Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Thu, 10 Oct 2024 09:06:10 -0700 Subject: [PATCH] Allow parsing of nested numbers from config (#400) Summary: Nested arrays were not directly parseable from the config. E.g., [12, [2, 4], 2, [2.0, 4]] would raise an exception. If a nested array is detected, it will check whether the desired type is either a float or an int and if so, it'll evaluate the string as a number then recursively cast. If the desired type is neither a float nor an int, it will raise a ValueException. Differential Revision: D64138450 --- aepsych/config.py | 25 ++++++++++++++++++++++++- tests/test_config.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/aepsych/config.py b/aepsych/config.py index a7a5b69bb..9fd04f44d 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -182,11 +182,34 @@ def update( self["common"][i] = self["experiment"][i] del self["experiment"] - def _str_to_list(self, v: str, element_type: _T = float) -> List[_T]: + def _str_to_list( + self, v: str, element_type: _T = float + ) -> List[_T] | List[List[_T]]: v = re.sub(r"\n ", ",", v) v = re.sub(r"(? 1: + if element_type in [float, int]: # Easy to handle nested numbers + + def _nested_cast(v, element_type): + result = [] + for item in v: + if isinstance(item, list): + # Recursively convert nested lists + result.append(_nested_cast(item, element_type)) + else: + # Convert individual values + result.append(element_type(item)) + return result + + v = ast.literal_eval(v) + return _nested_cast(v, element_type) + + else: + raise ValueError( + "Nested list detected with an element_type that is hard to evaluate" + ) if v == "[]": # empty list return [] else: diff --git a/tests/test_config.py b/tests/test_config.py index f2df367d8..fe6234e08 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -186,7 +186,9 @@ def test_nonmonotonic_optimization_config_file(self): self.assertTrue( isinstance(strat.strat_list[1].generator, OptimizeAcqfGenerator) ) - self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement) + self.assertTrue( + strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement + ) self.assertTrue( set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"} ) @@ -411,6 +413,31 @@ def test_warn_about_refit(self): with self.assertWarns(UserWarning): Strategy.from_config(config, "init_strat") + def test_nested_tensor(self): + points = [[0.25, 0.75], [0.5, 0.9]] + config_str = f""" + [common] + parnames = [par1, par2] + + [par1] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [par2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [SampleAroundPointsGenerator] + points = {points} + """ + config = Config() + config.update(config_str=config_str) + + config_points = config.gettensor("SampleAroundPointsGenerator", "points") + self.assertTrue(torch.all(config_points == torch.tensor(points))) + def test_pairwise_probit_config(self): config_str = """ [common] @@ -581,7 +608,9 @@ def test_pairwise_opt_config(self): self.assertTrue(strat.strat_list[0].model is None) self.assertTrue(isinstance(strat.strat_list[1].model, PairwiseProbitModel)) - self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement) + self.assertTrue( + strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement + ) self.assertTrue( set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"} )