diff --git a/aepsych/config.py b/aepsych/config.py index 80fa692c9..3215dc394 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -199,7 +199,7 @@ def _str_to_array(self, v: str) -> np.ndarray: return np.array(v, dtype=float) def _str_to_tensor(self, v: str) -> torch.Tensor: - return torch.Tensor(self._str_to_list(v)).to(torch.float64) + return torch.Tensor(self._str_to_array(v)).to(torch.float64) def _str_to_obj(self, v: str, fallback_type: _T = str, warn: bool = True) -> object: diff --git a/tests/test_config.py b/tests/test_config.py index f2df367d8..fe6234e08 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -186,7 +186,9 @@ def test_nonmonotonic_optimization_config_file(self): self.assertTrue( isinstance(strat.strat_list[1].generator, OptimizeAcqfGenerator) ) - self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement) + self.assertTrue( + strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement + ) self.assertTrue( set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"} ) @@ -411,6 +413,31 @@ def test_warn_about_refit(self): with self.assertWarns(UserWarning): Strategy.from_config(config, "init_strat") + def test_nested_tensor(self): + points = [[0.25, 0.75], [0.5, 0.9]] + config_str = f""" + [common] + parnames = [par1, par2] + + [par1] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [par2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [SampleAroundPointsGenerator] + points = {points} + """ + config = Config() + config.update(config_str=config_str) + + config_points = config.gettensor("SampleAroundPointsGenerator", "points") + self.assertTrue(torch.all(config_points == torch.tensor(points))) + def test_pairwise_probit_config(self): config_str = """ [common] @@ -581,7 +608,9 @@ def test_pairwise_opt_config(self): self.assertTrue(strat.strat_list[0].model is None) self.assertTrue(isinstance(strat.strat_list[1].model, PairwiseProbitModel)) - self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement) + self.assertTrue( + strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement + ) self.assertTrue( set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"} )