diff --git a/tests/helpers.py b/tests/helpers.py index 46c6ef93d..f82a8631f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,13 +1,13 @@ from itertools import product import random -from typing import Any +from typing import Any, List import torch test_dims_rng = random.Random(42) -def get_test_dims(min: int, max: int, *, n: int) -> list[int]: +def get_test_dims(min: int, max: int, *, n: int) -> List[int]: return [test_dims_rng.randint(min, max) for _ in range(n)]