diff --git a/examples/sampling/graphbolt/link_prediction.py b/examples/sampling/graphbolt/link_prediction.py index 25a4f3b0f5e3..a48fbc25f115 100644 --- a/examples/sampling/graphbolt/link_prediction.py +++ b/examples/sampling/graphbolt/link_prediction.py @@ -328,8 +328,6 @@ def train(args, model, graph, features, train_set): total_loss += loss.item() if step + 1 == args.early_stop: - # Early stopping requires a new dataloader to reset its state. - dataloader = create_dataloader(args, graph, features, train_set) break end_epoch_time = time.time() diff --git a/notebooks/stochastic_training/link_prediction.ipynb b/notebooks/stochastic_training/link_prediction.ipynb index c94ac357a15f..72902cf3eaeb 100644 --- a/notebooks/stochastic_training/link_prediction.ipynb +++ b/notebooks/stochastic_training/link_prediction.ipynb @@ -129,14 +129,13 @@ "outputs": [], "source": [ "from functools import partial\n", - "def create_train_dataloader():\n", - " datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n", - " datapipe = datapipe.copy_to(device)\n", - " datapipe = datapipe.sample_uniform_negative(graph, 5)\n", - " datapipe = datapipe.sample_neighbor(graph, [5, 5])\n", - " datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n", - " datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", - " return gb.DataLoader(datapipe)" + "datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\n", + "datapipe = datapipe.copy_to(device)\n", + "datapipe = datapipe.sample_uniform_negative(graph, 5)\n", + "datapipe = datapipe.sample_neighbor(graph, [5, 5])\n", + "datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\n", + "datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", + "train_dataloader = gb.DataLoader(datapipe)" ] }, { @@ -157,7 +156,7 @@ }, "outputs": [], "source": [ - "data = next(iter(create_train_dataloader()))\n", + "data = next(iter(train_dataloader))\n", "print(f\"MiniBatch: {data}\")" ] }, @@ -253,7 +252,7 @@ "for epoch in range(3):\n", " model.train()\n", " total_loss = 0\n", - " for step, data in tqdm(enumerate(create_train_dataloader())):\n", + " for step, data in tqdm(enumerate(train_dataloader)):\n", " # Get node pairs with labels for loss calculation.\n", " compacted_seeds = data.compacted_seeds.T\n", " labels = data.labels\n", diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 13f493756d48..66c5d2e8e94d 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -313,6 +313,20 @@ def __iter__(self): while len(self.buffer) > 0: yield self.buffer.popleft() + def __getstate__(self): + state = (self.datapipe, self.buffer.maxlen) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + self.datapipe, buffer_size = state + self.buffer = deque(maxlen=buffer_size) + + def reset(self): + """Resets the state of the datapipe.""" + self.buffer.clear() + @functional_datapipe("wait") class Waiter(IterDataPipe): diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index bc7d16c25d3e..bf8b79d40c95 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -111,3 +111,9 @@ def test_gpu_sampling_DataLoader( ) assert len(bufferers) == bufferer_awaiter_cnt assert len(list(dataloader)) == N // B + + for i, _ in enumerate(dataloader): + if i >= 1: + break + + assert len(list(dataloader)) == N // B