diff --git a/aepsych/generators/manual_generator.py b/aepsych/generators/manual_generator.py index f794db2ab..a1c0ea778 100644 --- a/aepsych/generators/manual_generator.py +++ b/aepsych/generators/manual_generator.py @@ -6,15 +6,16 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union import numpy as np import torch +from torch.quasirandom import SobolEngine + from aepsych.config import Config from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin from aepsych.utils import _process_bounds -from torch.quasirandom import SobolEngine class ManualGenerator(AEPsychGenerator): @@ -95,6 +96,10 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: return options + @property + def finished(self): + return self._idx >= len(self.points) + class SampleAroundPointsGenerator(ManualGenerator): """Generator that samples in a window around reference points in a predefined list.""" @@ -131,9 +136,9 @@ def __init__( grid = self.engine.draw(samples_per_point) grid = p_lb + (p_ub - p_lb) * grid generated.append(grid) - generated = torch.Tensor(np.vstack(generated)) #type: ignore + generated = torch.Tensor(np.vstack(generated)) # type: ignore - super().__init__(lb, ub, generated, dim, shuffle, seed) #type: ignore + super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore @classmethod def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: diff --git a/tests/generators/test_manual_generator.py b/tests/generators/test_manual_generator.py index ee467a20d..06ed01143 100644 --- a/tests/generators/test_manual_generator.py +++ b/tests/generators/test_manual_generator.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt + from aepsych.config import Config from aepsych.generators import ManualGenerator, SampleAroundPointsGenerator @@ -50,6 +51,7 @@ def test_manual_generator(self): gen = ManualGenerator.from_config(config) npt.assert_equal(gen.lb.numpy(), np.array([0, 0])) npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) + self.assertFalse(gen.finished) p1 = list(gen.gen()[0]) p2 = list(gen.gen()[0]) @@ -60,6 +62,7 @@ def test_manual_generator(self): self.assertEqual(sorted([p1, p2, p3, p4]), points) self.assertEqual(gen.max_asks, len(points)) self.assertEqual(gen.seed, 123) + self.assertTrue(gen.finished) class TestSampleAroundPointsGenerator(unittest.TestCase): @@ -86,6 +89,7 @@ def test_sample_around_points_generator(self): npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) self.assertEqual(gen.max_asks, len(points * samples_per_point)) self.assertEqual(gen.seed, 123) + self.assertFalse(gen.finished) points = gen.gen(gen.max_asks) for i in range(len(window)): @@ -93,6 +97,8 @@ def test_sample_around_points_generator(self): npt.assert_array_less(np.array([0] * len(points)), points[:, i]) npt.assert_array_less(points[:, i], np.array([1] * len(points))) + self.assertTrue(gen.finished) + if __name__ == "__main__": unittest.main()