From b34398ea1276f2f84ed839e219392fbc2bcd77aa Mon Sep 17 00:00:00 2001 From: Maltimore Date: Mon, 25 Nov 2024 17:55:39 +0100 Subject: [PATCH] fix dataset split into train/test/val --- src/schnetpack/data/splitting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/schnetpack/data/splitting.py b/src/schnetpack/data/splitting.py index c931b4783..26215e742 100644 --- a/src/schnetpack/data/splitting.py +++ b/src/schnetpack/data/splitting.py @@ -48,13 +48,13 @@ def random_split(dsize: int, *split_sizes: Union[int, float]) -> List[torch.tens Args: dsize - Size of dataset. - split_sizes - Sizes for each split. One can be set to -1 to assign all - remaining data. Values in [0, 1] can be used to give relative partition + split_sizes - Sizes for each split. One value can be set to None to assign all + remaining data. Values in [0, 1) can be used to give relative partition sizes. """ split_sizes = absolute_split_sizes(dsize, split_sizes) offsets = torch.cumsum(torch.tensor(split_sizes), dim=0) - indices = torch.randperm(sum(split_sizes)).tolist() + indices = torch.randperm(dsize).tolist() partition_sizes_idx = [ indices[offset - length : offset] for offset, length in zip(offsets, split_sizes)