Skip to content

Commit

Permalink
add finished property to manual generators (facebookresearch#409)
Browse files Browse the repository at this point in the history
Summary:

This adds a finished property to the manual generator classes so that they can keep track of whether they have generated all their points. Once this is hooked into the strategy's finishing logic, it should make writing configs simpler.

Differential Revision: D64600239
  • Loading branch information
crasanders authored and facebook-github-bot committed Oct 18, 2024
1 parent a3d3a33 commit a463412
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 9 additions & 4 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Optional, Union, Dict
from typing import Dict, Optional, Union

import numpy as np
import torch
from torch.quasirandom import SobolEngine

from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine


class ManualGenerator(AEPsychGenerator):
Expand Down Expand Up @@ -95,6 +96,10 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:

return options

@property
def finished(self):
return self._idx >= len(self.points)


class SampleAroundPointsGenerator(ManualGenerator):
"""Generator that samples in a window around reference points in a predefined list."""
Expand Down Expand Up @@ -131,9 +136,9 @@ def __init__(
grid = self.engine.draw(samples_per_point)
grid = p_lb + (p_ub - p_lb) * grid
generated.append(grid)
generated = torch.Tensor(np.vstack(generated)) #type: ignore
generated = torch.Tensor(np.vstack(generated)) # type: ignore

super().__init__(lb, ub, generated, dim, shuffle, seed) #type: ignore
super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore

@classmethod
def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:
Expand Down
6 changes: 6 additions & 0 deletions tests/generators/test_manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import numpy.testing as npt

from aepsych.config import Config
from aepsych.generators import ManualGenerator, SampleAroundPointsGenerator

Expand Down Expand Up @@ -50,6 +51,7 @@ def test_manual_generator(self):
gen = ManualGenerator.from_config(config)
npt.assert_equal(gen.lb.numpy(), np.array([0, 0]))
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
self.assertFalse(gen.finished)

p1 = list(gen.gen()[0])
p2 = list(gen.gen()[0])
Expand All @@ -60,6 +62,7 @@ def test_manual_generator(self):
self.assertEqual(sorted([p1, p2, p3, p4]), points)
self.assertEqual(gen.max_asks, len(points))
self.assertEqual(gen.seed, 123)
self.assertTrue(gen.finished)


class TestSampleAroundPointsGenerator(unittest.TestCase):
Expand All @@ -86,13 +89,16 @@ def test_sample_around_points_generator(self):
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
self.assertEqual(gen.max_asks, len(points * samples_per_point))
self.assertEqual(gen.seed, 123)
self.assertFalse(gen.finished)

points = gen.gen(gen.max_asks)
for i in range(len(window)):
npt.assert_array_less(points[:, i], points[:, i] + window[i])
npt.assert_array_less(np.array([0] * len(points)), points[:, i])
npt.assert_array_less(points[:, i], np.array([1] * len(points)))

self.assertTrue(gen.finished)


if __name__ == "__main__":
unittest.main()

0 comments on commit a463412

Please sign in to comment.