Skip to content

Commit

Permalink
modify the generator object to process np.array argument (and convert…
Browse files Browse the repository at this point in the history
… from string to list) (#355)

Summary:
Pull Request resolved: #355

We modified the generator object to process np.array argument: In the config dictionary, we specify the target argument of any child class of the `LookaheadAcquisitionFunction` class as a numpy array. The config processed the numpy array as a string using self.from_dict.

We have modified (1) the `_str_to_list` function of the Config class in config.py to accept a broader range of string representations of lists, including those with "\n" escape characters and white spaces. The function will properly convert these strings to list. (2) the `_get_acqf_options` function of the `AEPsychGenerator` class in base.py. The function uses regex to detect list argument in a string for any acquisition arguments and calls the `_str_to_list` function to convert it to a list.

Reviewed By: crasanders

Differential Revision: D58163715

fbshipit-source-id: e0654a011c8168f88321fd022ba08704b3c91b01
  • Loading branch information
wenx-guo authored and facebook-github-bot committed Jun 6, 2024
1 parent 53e7a6d commit 6a5ddb8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.
import abc
import ast
import re
import configparser
import json
import warnings
Expand Down Expand Up @@ -163,7 +164,9 @@ def update(
del self["experiment"]

def _str_to_list(self, v: str, element_type: _T = float) -> List[_T]:
if v[0] == "[" and v[-1] == "]":
v = re.sub(r"\n ", ",", v)
v = re.sub(r"(?<!,)\s+", ",", v)
if re.search(r"^\[.*\]$", v, flags=re.DOTALL):
if v == "[]": # empty list
return []
else:
Expand Down
6 changes: 6 additions & 0 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import abc
from inspect import signature
from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional
import re

import torch
from aepsych.config import Config
Expand Down Expand Up @@ -73,9 +74,14 @@ def _get_acqf_options(cls, acqf: AcquisitionFunction, config: Config):
for k in acqf_args_expected:
# if this thing is configured
if k in full_section.keys():
v = config.get(acqf_name, k)
# if it's an object make it an object
if full_section[k] in Config.registered_names.keys():
extra_acqf_args[k] = config.getobj(acqf_name, k)
elif re.search(
r"^\[.*\]$", v, flags=re.DOTALL
): # use regex to check if the value is a list
extra_acqf_args[k] = config._str_to_list(v) # type: ignore
else:
# otherwise try a float
try:
Expand Down

0 comments on commit 6a5ddb8

Please sign in to comment.