Skip to content

Commit

Permalink
Allow parsing of nested numbers from config (#400)
Browse files Browse the repository at this point in the history
Summary:

Nested arrays were not directly parseable from the config. E.g., [12, [2, 4], 2, [2.0, 4]] would raise an exception.

If a nested array is detected, it will check whether the desired type is either a float or an int and if so, it'll evaluate the string as a number then recursively cast.

If the desired type is neither a float nor an int, it will raise a ValueException.

Reviewed By: crasanders

Differential Revision: D64138450
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 18, 2024
1 parent 5b18619 commit 8d5d543
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _str_to_array(self, v: str) -> np.ndarray:
return np.array(v, dtype=float)

def _str_to_tensor(self, v: str) -> torch.Tensor:
return torch.Tensor(self._str_to_list(v)).to(torch.float64)
return torch.Tensor(self._str_to_array(v)).to(torch.float64)

def _str_to_obj(self, v: str, fallback_type: _T = str, warn: bool = True) -> object:

Expand Down
33 changes: 31 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def test_nonmonotonic_optimization_config_file(self):
self.assertTrue(
isinstance(strat.strat_list[1].generator, OptimizeAcqfGenerator)
)
self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement)
self.assertTrue(
strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement
)
self.assertTrue(
set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"}
)
Expand Down Expand Up @@ -411,6 +413,31 @@ def test_warn_about_refit(self):
with self.assertWarns(UserWarning):
Strategy.from_config(config, "init_strat")

def test_nested_tensor(self):
points = [[0.25, 0.75], [0.5, 0.9]]
config_str = f"""
[common]
parnames = [par1, par2]
[par1]
par_type = continuous
lower_bound = 0
upper_bound = 1
[par2]
par_type = continuous
lower_bound = 0
upper_bound = 1
[SampleAroundPointsGenerator]
points = {points}
"""
config = Config()
config.update(config_str=config_str)

config_points = config.gettensor("SampleAroundPointsGenerator", "points")
self.assertTrue(torch.all(config_points == torch.tensor(points)))

def test_pairwise_probit_config(self):
config_str = """
[common]
Expand Down Expand Up @@ -581,7 +608,9 @@ def test_pairwise_opt_config(self):
self.assertTrue(strat.strat_list[0].model is None)

self.assertTrue(isinstance(strat.strat_list[1].model, PairwiseProbitModel))
self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement)
self.assertTrue(
strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement
)
self.assertTrue(
set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"}
)
Expand Down

0 comments on commit 8d5d543

Please sign in to comment.