diff --git a/aepsych/generators/manual_generator.py b/aepsych/generators/manual_generator.py index a1c0ea778..ed5c96d80 100644 --- a/aepsych/generators/manual_generator.py +++ b/aepsych/generators/manual_generator.py @@ -6,9 +6,8 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Dict, Optional, Union +from typing import Dict, Optional -import numpy as np import torch from torch.quasirandom import SobolEngine @@ -25,27 +24,29 @@ class ManualGenerator(AEPsychGenerator): def __init__( self, - lb: Union[np.ndarray, torch.Tensor], - ub: Union[np.ndarray, torch.Tensor], - points: Union[np.ndarray, torch.Tensor], + lb: torch.Tensor, + ub: torch.Tensor, + points: torch.Tensor, dim: Optional[int] = None, shuffle: bool = True, seed: Optional[int] = None, ): """Iniatialize ManualGenerator. Args: - lb (Union[np.ndarray, torch.Tensor]): Lower bounds of each parameter. - ub (Union[np.ndarray, torch.Tensor]): Upper bounds of each parameter. - points (Union[np.ndarray, torch.Tensor]): The points that will be generated. + lb torch.Tensor: Lower bounds of each parameter. + ub torch.Tensor: Upper bounds of each parameter. + points torch.Tensor: The points that will be generated. dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub. shuffle (bool): Whether or not to shuffle the order of the points. True by default. """ self.seed = seed self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) + self.points = points if shuffle: - np.random.seed(self.seed) - np.random.shuffle(points) - self.points = torch.tensor(points) + if seed is not None: + torch.manual_seed(seed) + self.points = points[torch.randperm(len(points))] + self.max_asks = len(self.points) self._idx = 0 @@ -81,7 +82,7 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: lb = config.gettensor(name, "lb") ub = config.gettensor(name, "ub") dim = config.getint(name, "dim", fallback=None) - points = config.getarray(name, "points") + points = config.gettensor(name, "points") shuffle = config.getboolean(name, "shuffle", fallback=True) seed = config.getint(name, "seed", fallback=None) @@ -106,10 +107,10 @@ class SampleAroundPointsGenerator(ManualGenerator): def __init__( self, - lb: Union[np.ndarray, torch.Tensor], - ub: Union[np.ndarray, torch.Tensor], - window: Union[np.ndarray, torch.Tensor], - points: Union[np.ndarray, torch.Tensor], + lb: torch.Tensor, + ub: torch.Tensor, + window: torch.Tensor, + points: torch.Tensor, samples_per_point: int, dim: Optional[int] = None, shuffle: bool = True, @@ -127,16 +128,24 @@ def __init__( seed (int, optional): Random seed. """ lb, ub, dim = _process_bounds(lb, ub, dim) - points = torch.Tensor(points) self.engine = SobolEngine(dimension=dim, scramble=True, seed=seed) generated = [] + if len(points.shape) > 2: + # We need to determine how many stimuli there are per trial to maintain the proper tensor shape + n_draws = points.shape[1] + else: + n_draws = 1 for point in points: p_lb = torch.max(point - window, lb) p_ub = torch.min(point + window, ub) - grid = self.engine.draw(samples_per_point) - grid = p_lb + (p_ub - p_lb) * grid - generated.append(grid) - generated = torch.Tensor(np.vstack(generated)) # type: ignore + for _ in range(samples_per_point): + grid = self.engine.draw(n_draws) + grid = p_lb + (p_ub - p_lb) * grid + generated.append(grid) + if len(points.shape) > 2: + generated = torch.stack(generated) # type: ignore + else: + generated = torch.vstack(generated) super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore diff --git a/tests/generators/test_manual_generator.py b/tests/generators/test_manual_generator.py index 06ed01143..9499bde79 100644 --- a/tests/generators/test_manual_generator.py +++ b/tests/generators/test_manual_generator.py @@ -99,6 +99,39 @@ def test_sample_around_points_generator(self): self.assertTrue(gen.finished) + def test_sample_around_points_generator_high_dim(self): + points = [[[0.5, 0], [0.5, 1], [0, 0.5]], [[0.25, 0], [0.25, 1], [0, 0.25]]] + window = [0.1, 2] + samples_per_point = 2 + config_str = f""" + [common] + lb = [0, 0] + ub = [1, 1] + parnames = [par1, par2] + + [SampleAroundPointsGenerator] + points = {points} + window = {window} + samples_per_point = {samples_per_point} + seed = 123 + """ + config = Config() + config.update(config_str=config_str) + gen = SampleAroundPointsGenerator.from_config(config) + npt.assert_equal(gen.lb.numpy(), np.array([0, 0])) + npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) + self.assertEqual(gen.max_asks, len(points * samples_per_point)) + self.assertEqual(gen.seed, 123) + self.assertFalse(gen.finished) + + points = gen.gen(gen.max_asks) + for i in range(len(window)): + npt.assert_array_less(points[:, i], points[:, i] + window[i]) + npt.assert_array_less(np.zeros(points[..., i].shape), points[..., i]) + npt.assert_array_less(points[..., i], np.ones(points[..., i].shape)) + + self.assertTrue(gen.finished) + if __name__ == "__main__": unittest.main()