From c41bbac072fde6953d4c89abec9c799310b10ee0 Mon Sep 17 00:00:00 2001 From: Kai Krajsek Date: Thu, 20 Jun 2024 20:45:20 +0200 Subject: [PATCH] added optional dataset.ishuffle (#1529) Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- heat/utils/data/datatools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/heat/utils/data/datatools.py b/heat/utils/data/datatools.py index 91b9a98f81..6bc92f4b75 100644 --- a/heat/utils/data/datatools.py +++ b/heat/utils/data/datatools.py @@ -83,7 +83,8 @@ def __init__( f"dataset must be a torch Dataset, heat Dataset, heat PartialH5Dataset, currently: {type(dataset)}" ) self.dataset = dataset - self.ishuffle = self.dataset.ishuffle + if hasattr(self.dataset, "ishuffle"): + self.ishuffle = self.dataset.ishuffle if isinstance(self.dataset, partial_dataset.PartialH5Dataset): drop_last = True @@ -110,7 +111,7 @@ def __iter__(self) -> Iterator: """ if isinstance(self.dataset, partial_dataset.PartialH5Dataset): return partial_dataset.PartialH5DataLoaderIter(self) - if hasattr(self, "_full_dataset_shuffle_iter"): + if hasattr(self, "_full_dataset_shuffle_iter") and hasattr(self.dataset, "ishuffle"): # if it is a normal heat dataset then this is defined self._full_dataset_shuffle_iter() return self.DataLoader.__iter__()