Skip to content

Commit

Permalink
[GraphBolt][Doc] Update docs related to seeds. (#7351)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu committed Apr 25, 2024
1 parent 6133cec commit 3afa105
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 41 deletions.
4 changes: 2 additions & 2 deletions docs/source/guide/minibatch-custom-sampler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ can be used on heterogeneous graphs:
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
}
)
Expand Down
42 changes: 21 additions & 21 deletions docs/source/guide/minibatch-edge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ edges(namely, node pairs) in the training set instead of the nodes.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2)
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
Expand Down Expand Up @@ -83,9 +83,9 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2)
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
Expand Down Expand Up @@ -138,9 +138,9 @@ concatenating the incident node features and projecting it with a dense layer.
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, node_pairs, x):
src_x = x[node_pairs[0]]
dst_x = x[node_pairs[1]]
def forward(self, seeds, x):
src_x = x[seeds[:, 0]]
dst_x = x[seeds[:, 1]]
data = torch.cat([src_x, dst_x], 1)
return self.W(data)
Expand All @@ -157,9 +157,9 @@ loader, as well as the input node features as follows:
in_features, hidden_features, out_features)
self.predictor = ScorePredictor(num_classes, out_features)
def forward(self, blocks, x, node_pairs):
def forward(self, blocks, x, seeds):
x = self.gcn(blocks, x)
return self.predictor(node_pairs, x)
return self.predictor(seeds, x)
DGL ensures that that the nodes in the edge subgraph are the same as the
output nodes of the last MFG in the generated list of MFGs.
Expand All @@ -182,7 +182,7 @@ their incident node representations.
for data in dataloader:
blocks = data.blocks
x = data.edge_features("feat")
y_hat = model(data.blocks, x, data.positive_node_pairs)
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
Expand Down Expand Up @@ -226,10 +226,10 @@ over the edge types.
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, node_pairs, x):
def forward(self, seeds, x):
scores = {}
for etype in node_pairs.keys():
src, dst = node_pairs[etype]
for etype in seeds.keys():
src, dst = seeds[etype].T
data = torch.cat([x[etype][src], x[etype][dst]], 1)
scores[etype] = self.W(data)
return scores
Expand All @@ -242,9 +242,9 @@ over the edge types.
in_features, hidden_features, out_features, etypes)
self.pred = ScorePredictor(num_classes, out_features)
def forward(self, node_pairs, blocks, x):
def forward(self, seeds, blocks, x):
x = self.rgcn(blocks, x)
return self.pred(node_pairs, x)
return self.pred(seeds, x)
Data loader definition is almost identical to that of homogeneous graph. The
only difference is that the train_set is now an instance of
Expand All @@ -256,17 +256,17 @@ only difference is that the train_set is now an instance of
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_pairs = torch.arange(0, 1000).reshape(-1, 2)
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 3, (1000,))
node_pairs_labels = {
seeds_labels = {
"user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
(seeds, labels), names=("seeds", "labels")
),
"user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
(seeds, labels), names=("seeds", "labels")
),
}
train_set = gb.ItemSetDict(node_pairs_labels)
train_set = gb.ItemSetDict(seeds_labels)
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
Expand Down Expand Up @@ -316,7 +316,7 @@ dictionaries of node types and predictions here.
for data in dataloader:
blocks = data.blocks
x = data.edge_features(("user:like:item", "feat"))
y_hat = model(data.blocks, x, data.positive_node_pairs)
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guide/minibatch-inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ and combined as well.
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[
data.seed_nodes[0] : data.seed_nodes[-1] + 1
data.seeds[0] : data.seeds[-1] + 1
] = hidden_x.to(device)
feature = y
Expand Down
22 changes: 12 additions & 10 deletions docs/source/guide/minibatch-link.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ proportional to a power of degrees.
self.weights = node_degrees ** 0.75
self.k = k
def _sample_with_etype(node_pairs, etype=None):
src, _ = node_pairs
def _sample_with_etype(self, seeds, etype=None):
src, _ = seeds.T
src = src.repeat_interleave(self.k)
dst = self.weights.multinomial(len(src), replacement=True)
return src, dst
Expand Down Expand Up @@ -95,7 +95,7 @@ Define a GraphSAGE model for minibatch training
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Use `node_pairs_with_labels` to get compact node pairs with corresponding
Use `compacted_seeds` and `labels` to get compact node pairs and corresponding
labels.


Expand All @@ -116,15 +116,16 @@ above.
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Unpack MiniBatch.
compacted_pairs, labels = data.node_pairs_with_labels
compacted_seeds = data.compacted_seeds.T
labels = data.labels
node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks.
blocks = data.blocks
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
logits = model.predictor(
y[compacted_pairs[0]] * y[compacted_pairs[1]]
y[compacted_seeds[0]] * y[compacted_seeds[1]]
).squeeze()
# Compute loss.
Expand Down Expand Up @@ -217,8 +218,8 @@ If you want to give your own negative sampling function, just inherit from the
}
self.k = k
def _sample_with_etype(node_pairs, etype):
src, _ = node_pairs
def _sample_with_etype(self, seeds, etype):
src, _ = seeds.T
src = src.repeat_interleave(self.k)
dst = self.weights[etype].multinomial(len(src), replacement=True)
return src, dst
Expand All @@ -241,7 +242,8 @@ loss on specific edge type.
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Unpack MiniBatch.
compacted_pairs, labels = data.node_pairs_with_labels
compacted_seeds = data.compacted_seeds
labels = data.labels
node_features = {
ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes
Expand All @@ -251,8 +253,8 @@ loss on specific edge type.
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
logits = model.predictor(
y[category][compacted_pairs[category][0]]
* y[category][compacted_pairs[category][1]]
y[category][compacted_pairs[category][:, 0]]
* y[category][compacted_pairs[category][:, 1]]
).squeeze()
# Compute loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,8 @@ such as ``num_classes`` and all these fields will be passed to the

The ``name`` field is used to specify the name of the data. It is mandatory
and used to specify the data fields of ``MiniBatch`` for sampling. It can
be either ``seed_nodes``, ``labels``, ``node_pairs``, ``negative_srcs`` or
``negative_dsts``. If any other name is used, it will be added into the
``MiniBatch`` data fields.
be either ``seeds``, ``labels`` or ``indexes``. If any other name is used,
it will be added into the ``MiniBatch`` data fields.
- ``format``: ``string``

The ``format`` field is used to specify the format of the data. It can be
Expand Down
8 changes: 4 additions & 4 deletions notebooks/graphbolt/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@
},
"outputs": [],
"source": [
"node_pairs = torch.tensor(\n",
"seeds = torch.tensor(\n",
" [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
" [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n",
" [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n",
")\n",
"item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n",
"item_set = gb.ItemSet(seeds, names=\"seeds\")\n",
"print(list(item_set))"
]
},
Expand Down Expand Up @@ -262,7 +262,7 @@
"num_nodes = 10\n",
"nodes = torch.arange(num_nodes)\n",
"labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n",
"item_set = gb.ItemSet((nodes, labels), names=(\"seed_nodes\", \"labels\"))\n",
"item_set = gb.ItemSet((nodes, labels), names=(\"seeds\", \"labels\"))\n",
"\n",
"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
"indices = torch.tensor(\n",
Expand Down Expand Up @@ -311,4 +311,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 3afa105

Please sign in to comment.