diff --git a/aepsych/generators/sobol_generator.py b/aepsych/generators/sobol_generator.py index 73cf47589..8f17df984 100644 --- a/aepsych/generators/sobol_generator.py +++ b/aepsych/generators/sobol_generator.py @@ -55,20 +55,14 @@ def gen( Args: num_points (int, optional): Number of points to query. Returns: - torch.Tensor: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, [num_points x dim] or [num_points x dim x stimuli_per_trial] if stimuli_per_trial != 1. """ grid = self.engine.draw(num_points) grid = self.lb + (self.ub - self.lb) * grid if self.stimuli_per_trial == 1: return grid - return torch.tensor( - np.moveaxis( - grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(), - -1, - -self.stimuli_per_trial, - ) - ) + return grid.reshape(num_points, -1, self.stimuli_per_trial) @classmethod def from_config(cls, config: Config) -> 'SobolGenerator':