Skip to content

Commit

Permalink
Merge pull request #675 from Maltimore/fix_splitting
Browse files Browse the repository at this point in the history
fix dataset split into train/val/test
  • Loading branch information
jnsLs authored Nov 26, 2024
2 parents 47cfa5e + b34398e commit 37f9282
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/schnetpack/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def random_split(dsize: int, *split_sizes: Union[int, float]) -> List[torch.tens
Args:
dsize - Size of dataset.
split_sizes - Sizes for each split. One can be set to -1 to assign all
remaining data. Values in [0, 1] can be used to give relative partition
split_sizes - Sizes for each split. One value can be set to None to assign all
remaining data. Values in [0, 1) can be used to give relative partition
sizes.
"""
split_sizes = absolute_split_sizes(dsize, split_sizes)
offsets = torch.cumsum(torch.tensor(split_sizes), dim=0)
indices = torch.randperm(sum(split_sizes)).tolist()
indices = torch.randperm(dsize).tolist()
partition_sizes_idx = [
indices[offset - length : offset]
for offset, length in zip(offsets, split_sizes)
Expand Down

0 comments on commit 37f9282

Please sign in to comment.