From 63c9d11ef06a5266e12787cfcdfad03d0baeb714 Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 8 Aug 2024 17:40:58 +0200 Subject: [PATCH] first draft of MultipleTrainEpochsSampler --- src/schnetpack/data/sampler.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/schnetpack/data/sampler.py b/src/schnetpack/data/sampler.py index 0e353ef88..1fc1fca47 100644 --- a/src/schnetpack/data/sampler.py +++ b/src/schnetpack/data/sampler.py @@ -1,7 +1,7 @@ from typing import Iterator, List, Callable import numpy as np -from torch.utils.data import Sampler, WeightedRandomSampler +from torch.utils.data import Sampler, WeightedRandomSampler, RandomSampler from schnetpack import properties from schnetpack.data import BaseAtomsData @@ -11,6 +11,7 @@ "StratifiedSampler", "NumberOfAtomsCriterion", "PropertyCriterion", + "MultipleTrainEpochsSampler", ] @@ -95,3 +96,19 @@ def calculate_weights(self, partition_criterion): weights = bin_weights[bin_indices] return weights + + +class MultipleTrainEpochsSampler(RandomSampler): + def __init__( + self, + data_source, + num_samples=None, + n_train_epochs=1, + generator=None, + ): + super().__init__( + data_source=data_source, + replacement=True, + num_samples=len(data_source) * n_train_epochs, + generator=generator, + )