Skip to content

Commit

Permalink
[GraphBolt][CUDA] Make dataloader pickleable. (#7391)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored May 10, 2024
1 parent 7a2e873 commit 6f2ccbf
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
2 changes: 0 additions & 2 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ def train(args, model, graph, features, train_set):

total_loss += loss.item()
if step + 1 == args.early_stop:
# Early stopping requires a new dataloader to reset its state.
dataloader = create_dataloader(args, graph, features, train_set)
break

end_epoch_time = time.time()
Expand Down
19 changes: 9 additions & 10 deletions notebooks/stochastic_training/link_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,13 @@
"outputs": [],
"source": [
"from functools import partial\n",
"def create_train_dataloader():\n",
" datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n",
" datapipe = datapipe.copy_to(device)\n",
" datapipe = datapipe.sample_uniform_negative(graph, 5)\n",
" datapipe = datapipe.sample_neighbor(graph, [5, 5])\n",
" datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n",
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
" return gb.DataLoader(datapipe)"
"datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n",
"datapipe = datapipe.copy_to(device)\n",
"datapipe = datapipe.sample_uniform_negative(graph, 5)\n",
"datapipe = datapipe.sample_neighbor(graph, [5, 5])\n",
"datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n",
"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
"train_dataloader = gb.DataLoader(datapipe)"
]
},
{
Expand All @@ -157,7 +156,7 @@
},
"outputs": [],
"source": [
"data = next(iter(create_train_dataloader()))\n",
"data = next(iter(train_dataloader))\n",
"print(f\"MiniBatch: {data}\")"
]
},
Expand Down Expand Up @@ -253,7 +252,7 @@
"for epoch in range(3):\n",
" model.train()\n",
" total_loss = 0\n",
" for step, data in tqdm(enumerate(create_train_dataloader())):\n",
" for step, data in tqdm(enumerate(train_dataloader)):\n",
" # Get node pairs with labels for loss calculation.\n",
" compacted_seeds = data.compacted_seeds.T\n",
" labels = data.labels\n",
Expand Down
14 changes: 14 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,20 @@ def __iter__(self):
while len(self.buffer) > 0:
yield self.buffer.popleft()

def __getstate__(self):
state = (self.datapipe, self.buffer.maxlen)
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(state)
return state

def __setstate__(self, state):
self.datapipe, buffer_size = state
self.buffer = deque(maxlen=buffer_size)

def reset(self):
"""Resets the state of the datapipe."""
self.buffer.clear()


@functional_datapipe("wait")
class Waiter(IterDataPipe):
Expand Down
6 changes: 6 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,9 @@ def test_gpu_sampling_DataLoader(
)
assert len(bufferers) == bufferer_awaiter_cnt
assert len(list(dataloader)) == N // B

for i, _ in enumerate(dataloader):
if i >= 1:
break

assert len(list(dataloader)) == N // B

0 comments on commit 6f2ccbf

Please sign in to comment.