Skip to content

Commit

Permalink
[DataLoader] Allow batch_size=None for GraphDataLoader (dmlc#4483)
Browse files Browse the repository at this point in the history
* overwrite default_collate_fn

* Update dataloader.py

* Update dataloader.py

* Update dataloader.py

* Update dataloader.py

* Update test_dataloader.py

* revert the test code being reverted in dmlc#4956
  • Loading branch information
BarclayII authored Jan 8, 2023
1 parent e296c46 commit 580c702
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
10 changes: 4 additions & 6 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,17 +1124,15 @@ def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs
else:
dataloader_kwargs[k] = v

if collate_fn is None:
self.collate = GraphCollator(**collator_kwargs).collate
else:
self.collate = collate_fn

self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler

super().__init__(dataset=dataset, collate_fn=self.collate, **dataloader_kwargs)
if collate_fn is None and kwargs.get('batch_size', 1) is not None:
collate_fn = GraphCollator(**collator_kwargs).collate

super().__init__(dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs)

def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
Expand Down
14 changes: 10 additions & 4 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@
import pytest


def test_graph_dataloader():
batch_size = 16
@pytest.mark.parametrize('batch_size', [None, 16])
def test_graph_dataloader(batch_size):
num_batches = 2
minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
num_samples = num_batches * (batch_size if batch_size is not None else 1)
minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
assert isinstance(iter(data_loader), Iterator)
for graph, label in data_loader:
assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size
if batch_size is not None:
assert F.asnumpy(label).shape[0] == batch_size
else:
# If batch size is None, the label element will be a single scalar following
# PyTorch's practice.
assert F.asnumpy(label).ndim == 0

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def test_hgt(idtype, in_size, num_heads):
sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False)
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# mini-batch
train_idx = th.randint(0, 100, (10, ), dtype = idtype)
train_idx = th.randperm(100, dtype = idtype)[:10]
sampler = dgl.dataloading.NeighborSampler([-1])
train_loader = dgl.dataloading.DataLoader(g, train_idx.to(dev), sampler,
batch_size=8, device=dev,
Expand Down

0 comments on commit 580c702

Please sign in to comment.