From 8c1aac947a863a197d6e51d991b942488ec4b54a Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 30 Oct 2024 15:16:04 -0700 Subject: [PATCH] fix sobol generator multi stimuli reshape Summary: Sobol generators would return the incorrect shape when handling multi stimuli generation. This would not cause problem because of the ask converted inadvertantly avoided the problem. Fixed and clarified the docstring what should happen Differential Revision: D65239074 --- aepsych/generators/sobol_generator.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/aepsych/generators/sobol_generator.py b/aepsych/generators/sobol_generator.py index 73cf47589..8f17df984 100644 --- a/aepsych/generators/sobol_generator.py +++ b/aepsych/generators/sobol_generator.py @@ -55,20 +55,14 @@ def gen( Args: num_points (int, optional): Number of points to query. Returns: - torch.Tensor: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, [num_points x dim] or [num_points x dim x stimuli_per_trial] if stimuli_per_trial != 1. """ grid = self.engine.draw(num_points) grid = self.lb + (self.ub - self.lb) * grid if self.stimuli_per_trial == 1: return grid - return torch.tensor( - np.moveaxis( - grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(), - -1, - -self.stimuli_per_trial, - ) - ) + return grid.reshape(num_points, -1, self.stimuli_per_trial) @classmethod def from_config(cls, config: Config) -> 'SobolGenerator':