diff --git a/ax/utils/sensitivity/derivative_measures.py b/ax/utils/sensitivity/derivative_measures.py index c3ebe6ccba4..16a9b5a8e01 100644 --- a/ax/utils/sensitivity/derivative_measures.py +++ b/ax/utils/sensitivity/derivative_measures.py @@ -168,7 +168,7 @@ def gradient_measure(self) -> torch.Tensor: else Tensor: (values) x dim """ - return self.aggregation(torch.tensor) + return self.aggregation(torch.as_tensor) def gradient_absolute_measure(self) -> torch.Tensor: r"""Computes the gradient absolute measure: