Skip to content

Commit

Permalink
Allow parsing of nested numbers from config (facebookresearch#400)
Browse files Browse the repository at this point in the history
Summary:

Nested arrays were not directly parseable from the config. E.g., [12, [2, 4], 2, [2.0, 4]] would raise an exception.

If a nested array is detected, it will check whether the desired type is either a float or an int and if so, it'll evaluate the string as a number then recursively cast.

If the desired type is neither a float nor an int, it will raise a ValueException.

Differential Revision: D64138450
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 10, 2024
1 parent d52d713 commit 06dfaa7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
25 changes: 24 additions & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,34 @@ def update(
self["common"][i] = self["experiment"][i]
del self["experiment"]

def _str_to_list(self, v: str, element_type: _T = float) -> List[_T]:
def _str_to_list(
self, v: str, element_type: _T = float
) -> List[_T] | List[List[_T]]:
v = re.sub(r"\n ", ",", v)
v = re.sub(r"(?<!,)\s+", ",", v)
v = re.sub(r",]", "]", v)
if re.search(r"^\[.*\]$", v, flags=re.DOTALL):
if len(re.findall(r"\[", v)) > 1:
if element_type in [float, int]: # Easy to handle nested numbers

def _nested_cast(v, element_type):
result = []
for item in v:
if isinstance(item, list):
# Recursively convert nested lists
result.append(_nested_cast(item, element_type))
else:
# Convert individual values
result.append(element_type(item))
return result

v = ast.literal_eval(v)
return _nested_cast(v, element_type)

else:
raise ValueError(
"Nested list detected with an element_type that is hard to evaluate"
)
if v == "[]": # empty list
return []
else:
Expand Down
33 changes: 31 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def test_nonmonotonic_optimization_config_file(self):
self.assertTrue(
isinstance(strat.strat_list[1].generator, OptimizeAcqfGenerator)
)
self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement)
self.assertTrue(
strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement
)
self.assertTrue(
set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"}
)
Expand Down Expand Up @@ -411,6 +413,31 @@ def test_warn_about_refit(self):
with self.assertWarns(UserWarning):
Strategy.from_config(config, "init_strat")

def test_nested_tensor(self):
points = [[0.25, 0.75], [0.5, 0.9]]
config_str = f"""
[common]
parnames = [par1, par2]
[par1]
par_type = continuous
lower_bound = 0
upper_bound = 1
[par2]
par_type = continuous
lower_bound = 0
upper_bound = 1
[SampleAroundPointsGenerator]
points = {points}
"""
config = Config()
config.update(config_str=config_str)

config_points = config.gettensor("SampleAroundPointsGenerator", "points")
self.assertTrue(torch.all(config_points == torch.tensor(points)))

def test_pairwise_probit_config(self):
config_str = """
[common]
Expand Down Expand Up @@ -581,7 +608,9 @@ def test_pairwise_opt_config(self):
self.assertTrue(strat.strat_list[0].model is None)

self.assertTrue(isinstance(strat.strat_list[1].model, PairwiseProbitModel))
self.assertTrue(strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement)
self.assertTrue(
strat.strat_list[1].generator.acqf is qLogNoisyExpectedImprovement
)
self.assertTrue(
set(strat.strat_list[1].generator.acqf_kwargs.keys()) == {"objective"}
)
Expand Down

0 comments on commit 06dfaa7

Please sign in to comment.