Skip to content

Commit

Permalink
fix sobol generator multi stimuli reshape (#422)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #422

Sobol generators would return the incorrect shape when handling multi stimuli generation. This would not cause problem because of the ask converted inadvertantly avoided the problem.

Fixed and clarified the docstring what should happen.

Sort of a bandaid fix, tensor shapes may need to be unified more carefully when n-choice models are implemented.

Differential Revision: D65239074
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 30, 2024
1 parent a8dc757 commit c813464
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 21 deletions.
Binary file added 1572972c92344d14bbc755b79df3bf99.db
Binary file not shown.
12 changes: 3 additions & 9 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,17 @@ def gen(
Args:
num_points (int, optional): Number of points to query.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim] or [num_points x dim x stimuli_per_trial] if stimuli_per_trial != 1.
"""
grid = self.engine.draw(num_points)
grid = self.lb + (self.ub - self.lb) * grid
if self.stimuli_per_trial == 1:
return grid

return torch.tensor(
np.moveaxis(
grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(),
-1,
-self.stimuli_per_trial,
)
)
return grid.reshape(num_points, self.stimuli_per_trial, -1).swapaxes(-1, -2)

@classmethod
def from_config(cls, config: Config) -> 'SobolGenerator':
def from_config(cls, config: Config) -> "SobolGenerator":
classname = cls.__name__

lb = config.gettensor(classname, "lb")
Expand Down
35 changes: 23 additions & 12 deletions tests/models/test_pairwise_probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,9 @@ def test_2d_pairwise_probit_pure_exploration(self):
next_pair, [bernoulli.rvs(f_pairwise(new_novel_det, next_pair))]
)

xy = torch.stack(torch.meshgrid(
torch.linspace(-1, 1, 30),
torch.linspace(-1, 1, 30)
), dim=-1).view(-1, 2)
xy = torch.stack(
torch.meshgrid(torch.linspace(-1, 1, 30), torch.linspace(-1, 1, 30)), dim=-1
).view(-1, 2)

zhat, _ = strat.predict(xy)

Expand Down Expand Up @@ -508,9 +507,9 @@ def test_1d_pairwise_server(self):

for _i in range(n_init + n_opt):
next_config = ask(server)

next_x = torch.tensor(next_config["x"], dtype=torch.float64)

next_y = bernoulli.rvs(f_pairwise(f_1d, next_x, noise_scale=0.1))
tell(server, config=next_config, outcome=next_y)

Expand Down Expand Up @@ -561,7 +560,9 @@ def test_2d_pairwise_server(self):
)
for _i in range(n_init + n_opt):
next_config = ask(server)
next_pair = torch.stack((torch.tensor(next_config["x"]), torch.tensor(next_config["y"])), dim=0)
next_pair = torch.stack(
(torch.tensor(next_config["x"]), torch.tensor(next_config["y"])), dim=0
)
next_y = bernoulli.rvs(f_pairwise(f_2d, next_pair, noise_scale=0.1))
tell(server, config=next_config, outcome=next_y)

Expand Down Expand Up @@ -614,7 +615,7 @@ def test_serialization_1d(self):
for _i in range(n_init + n_opt):
next_config = ask(server)
next_x = torch.tensor(next_config["x"], dtype=torch.float64)

next_y = bernoulli.rvs(f_pairwise(f_1d, next_x))
tell(server, config=next_config, outcome=next_y)

Expand Down Expand Up @@ -677,7 +678,9 @@ def test_serialization_2d(self):

for _i in range(n_init + n_opt):
next_config = ask(server)
next_pair = torch.stack((torch.tensor(next_config["x"]), torch.tensor(next_config["y"])), dim=0)
next_pair = torch.stack(
(torch.tensor(next_config["x"]), torch.tensor(next_config["y"])), dim=0
)
next_y = bernoulli.rvs(f_pairwise(f_2d, next_pair))
tell(server, config=next_config, outcome=next_y)

Expand Down Expand Up @@ -772,8 +775,8 @@ def test_config_to_tensor(self):

config_str = """
[common]
lb = [-1, -1, -1]
ub = [1, 1, 1]
lb = [-1, 1e-6, 10]
ub = [-1e-6, 1, 100]
stimuli_per_trial=2
outcome_types=[binary]
parnames = [x, y, z]
Expand Down Expand Up @@ -804,7 +807,15 @@ def test_config_to_tensor(self):

conf = ask(server)

self.assertTrue(server._config_to_tensor(conf).shape == (3, 2))
tensor = server._config_to_tensor(conf)
self.assertTrue(tensor.shape == (3, 2))

# Check if reshapes were correct
self.assertTrue(torch.all(tensor[0, :] <= -1e-6))
self.assertTrue(
torch.all(torch.logical_and(tensor[1, :] >= 1e-6, tensor[1, :] <= 1))
)
self.assertTrue(torch.all(tensor[2, :] >= 10))


if __name__ == "__main__":
Expand Down

0 comments on commit c813464

Please sign in to comment.