Skip to content

Commit

Permalink
added optional dataset.ishuffle (#1529)
Browse files Browse the repository at this point in the history
Co-authored-by: Claudia Comito <[email protected]>
  • Loading branch information
krajsek and ClaudiaComito authored Jun 20, 2024
1 parent 25bddd1 commit c41bbac
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions heat/utils/data/datatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
f"dataset must be a torch Dataset, heat Dataset, heat PartialH5Dataset, currently: {type(dataset)}"
)
self.dataset = dataset
self.ishuffle = self.dataset.ishuffle
if hasattr(self.dataset, "ishuffle"):
self.ishuffle = self.dataset.ishuffle
if isinstance(self.dataset, partial_dataset.PartialH5Dataset):
drop_last = True

Expand All @@ -110,7 +111,7 @@ def __iter__(self) -> Iterator:
"""
if isinstance(self.dataset, partial_dataset.PartialH5Dataset):
return partial_dataset.PartialH5DataLoaderIter(self)
if hasattr(self, "_full_dataset_shuffle_iter"):
if hasattr(self, "_full_dataset_shuffle_iter") and hasattr(self.dataset, "ishuffle"):
# if it is a normal heat dataset then this is defined
self._full_dataset_shuffle_iter()
return self.DataLoader.__iter__()
Expand Down

0 comments on commit c41bbac

Please sign in to comment.