Skip to content

Commit

Permalink
pass generator every sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 20, 2024
1 parent e987fce commit 46e1f11
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
5 changes: 2 additions & 3 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,12 @@ def test_weighted_sampling(self) -> None:
def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator)
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample1 = bbox
break

sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator)
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample2 = bbox
break
Expand Down
13 changes: 8 additions & 5 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,18 @@ def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(0)
sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler:
generator = torch.manual_seed(2)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler1:
sample1 = bbox
print(sample1)
break

sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler:
generator = torch.manual_seed(2)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler2:
sample2 = bbox
print(sample2)
break
assert sample1 == sample2

Expand Down

0 comments on commit 46e1f11

Please sign in to comment.