Skip to content

Commit

Permalink
make manual generators require tensors and support higher-dimensional…
Browse files Browse the repository at this point in the history
… tensors (#410)

Summary:

This removes the numpy dependency from the manual generators and changes the logic slightly so that we can support higher dimensional tensors of points, allowing us to support pairwise experiments

Differential Revision: D64607853
  • Loading branch information
crasanders authored and facebook-github-bot committed Oct 21, 2024
1 parent a74d426 commit 3ad4413
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
51 changes: 30 additions & 21 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict, Optional, Union
from typing import Dict, Optional

import numpy as np
import torch
from torch.quasirandom import SobolEngine

Expand All @@ -25,27 +24,29 @@ class ManualGenerator(AEPsychGenerator):

def __init__(
self,
lb: Union[np.ndarray, torch.Tensor],
ub: Union[np.ndarray, torch.Tensor],
points: Union[np.ndarray, torch.Tensor],
lb: torch.Tensor,
ub: torch.Tensor,
points: torch.Tensor,
dim: Optional[int] = None,
shuffle: bool = True,
seed: Optional[int] = None,
):
"""Iniatialize ManualGenerator.
Args:
lb (Union[np.ndarray, torch.Tensor]): Lower bounds of each parameter.
ub (Union[np.ndarray, torch.Tensor]): Upper bounds of each parameter.
points (Union[np.ndarray, torch.Tensor]): The points that will be generated.
lb torch.Tensor: Lower bounds of each parameter.
ub torch.Tensor: Upper bounds of each parameter.
points torch.Tensor: The points that will be generated.
dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub.
shuffle (bool): Whether or not to shuffle the order of the points. True by default.
"""
self.seed = seed
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
self.points = points
if shuffle:
np.random.seed(self.seed)
np.random.shuffle(points)
self.points = torch.tensor(points)
if seed is not None:
torch.manual_seed(seed)
self.points = points[torch.randperm(len(points))]

self.max_asks = len(self.points)
self._idx = 0

Expand Down Expand Up @@ -81,7 +82,7 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:
lb = config.gettensor(name, "lb")
ub = config.gettensor(name, "ub")
dim = config.getint(name, "dim", fallback=None)
points = config.getarray(name, "points")
points = config.gettensor(name, "points")
shuffle = config.getboolean(name, "shuffle", fallback=True)
seed = config.getint(name, "seed", fallback=None)

Expand All @@ -106,10 +107,10 @@ class SampleAroundPointsGenerator(ManualGenerator):

def __init__(
self,
lb: Union[np.ndarray, torch.Tensor],
ub: Union[np.ndarray, torch.Tensor],
window: Union[np.ndarray, torch.Tensor],
points: Union[np.ndarray, torch.Tensor],
lb: torch.Tensor,
ub: torch.Tensor,
window: torch.Tensor,
points: torch.Tensor,
samples_per_point: int,
dim: Optional[int] = None,
shuffle: bool = True,
Expand All @@ -127,16 +128,24 @@ def __init__(
seed (int, optional): Random seed.
"""
lb, ub, dim = _process_bounds(lb, ub, dim)
points = torch.Tensor(points)
self.engine = SobolEngine(dimension=dim, scramble=True, seed=seed)
generated = []
if len(points.shape) > 2:
# We need to determine how many stimuli there are per trial to maintain the proper tensor shape
n_draws = points.shape[1]
else:
n_draws = 1
for point in points:
p_lb = torch.max(point - window, lb)
p_ub = torch.min(point + window, ub)
grid = self.engine.draw(samples_per_point)
grid = p_lb + (p_ub - p_lb) * grid
generated.append(grid)
generated = torch.Tensor(np.vstack(generated)) # type: ignore
for _ in range(samples_per_point):
grid = self.engine.draw(n_draws)
grid = p_lb + (p_ub - p_lb) * grid
generated.append(grid)
if len(points.shape) > 2:
generated = torch.stack(generated) # type: ignore
else:
generated = torch.vstack(generated)

super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore

Expand Down
33 changes: 33 additions & 0 deletions tests/generators/test_manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ def test_sample_around_points_generator(self):

self.assertTrue(gen.finished)

def test_sample_around_points_generator_high_dim(self):
points = [[[0.5, 0], [0.5, 1], [0, 0.5]], [[0.25, 0], [0.25, 1], [0, 0.25]]]
window = [0.1, 2]
samples_per_point = 2
config_str = f"""
[common]
lb = [0, 0]
ub = [1, 1]
parnames = [par1, par2]
[SampleAroundPointsGenerator]
points = {points}
window = {window}
samples_per_point = {samples_per_point}
seed = 123
"""
config = Config()
config.update(config_str=config_str)
gen = SampleAroundPointsGenerator.from_config(config)
npt.assert_equal(gen.lb.numpy(), np.array([0, 0]))
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
self.assertEqual(gen.max_asks, len(points * samples_per_point))
self.assertEqual(gen.seed, 123)
self.assertFalse(gen.finished)

points = gen.gen(gen.max_asks)
for i in range(len(window)):
npt.assert_array_less(points[:, i], points[:, i] + window[i])
npt.assert_array_less(np.zeros(points[..., i].shape), points[..., i])
npt.assert_array_less(points[..., i], np.ones(points[..., i].shape))

self.assertTrue(gen.finished)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3ad4413

Please sign in to comment.