Skip to content

Commit

Permalink
small bug fix to make our dataset classes subclass torch dataset (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-yuan authored Oct 23, 2024
1 parent 93833a0 commit 5355672
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from utils.corruptions import corrupt_mapping


class CacheDataset:
class CacheDataset(Dataset):
"""
Caches the entire dataset in memory.
"""
Expand All @@ -27,7 +27,7 @@ def __len__(self):
return len(self.data)


class TransformDataset:
class TransformDataset(Dataset):
"""
Applies a transformation to the dataset.
"""
Expand All @@ -45,7 +45,7 @@ def __len__(self):
return len(self.dset)

# Custom dataset wrapper to apply corruption
class CorruptDataset:
class CorruptDataset(Dataset):
def __init__(self, dset: CacheDataset, corruption_fn_name, severity: int = 1):
print("Initialized CorruptDataset with corruption_fn_name: ", corruption_fn_name)
self.dset = dset # Original dataset
Expand Down

0 comments on commit 5355672

Please sign in to comment.