Skip to content

Commit

Permalink
[GraphBolt] Update docstring related to cleaning up seed_nodes and …
Browse files Browse the repository at this point in the history
…`node_pairs`. (#7341)

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu authored Apr 24, 2024
1 parent 2251683 commit ba9c152
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 123 deletions.
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/impl/in_subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class InSubgraphSampler(SubgraphSampler):
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes")
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seeds")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler):
Expand Down
16 changes: 10 additions & 6 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ class NeighborSampler(NeighborSamplerImpl):
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
Expand Down Expand Up @@ -534,8 +534,8 @@ class LayerNeighborSampler(NeighborSamplerImpl):
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=1,)
>>> neg_sampler = gb.UniformNegativeSampler(item_sampler, graph, 2)
>>> fanouts = [torch.LongTensor([5]),
Expand Down Expand Up @@ -566,8 +566,12 @@ class LayerNeighborSampler(NeighborSamplerImpl):
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 5, 2]),
)]
>>> next(iter(subgraph_sampler)).compacted_node_pairs
(tensor([0]), tensor([1]))
>>> next(iter(subgraph_sampler)).compacted_seeds
tensor([[0, 1], [0, 2], [0, 3]])
>>> next(iter(subgraph_sampler)).labels
tensor([1., 0., 0.])
>>> next(iter(subgraph_sampler)).indexes
tensor([0, 0, 0])
"""

def __init__(
Expand Down
4 changes: 0 additions & 4 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]

NAMES_INDICATING_NODE_IDS = [
"seed_nodes",
"node_pairs",
"seeds",
"negative_srcs",
"negative_dsts",
]


Expand Down
18 changes: 9 additions & 9 deletions python/dgl/graphbolt/impl/uniform_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 1, 2, 3, 4])
>>> indices = torch.LongTensor([1, 2, 3, 0])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4,)
>>> neg_sampler = gb.UniformNegativeSampler(
... item_sampler, graph, 2)
>>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
None
tensor([[2, 1],
[2, 1],
[3, 2],
[1, 3]])
... print(minibatch.seeds)
... print(minibatch.labels)
... print(minibatch.indexes)
tensor([[0, 1], [1, 2], [2, 3], [3, 0], [0, 1], [0, 3], [1, 1], [1, 2],
[2, 1], [2, 0], [3, 0], [3, 2]])
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3])
"""

def __init__(
Expand Down
166 changes: 74 additions & 92 deletions python/dgl/graphbolt/item_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def minibatcher_default(batch, names):
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# ("seeds",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
Expand Down Expand Up @@ -313,68 +313,61 @@ class ItemSampler(IterDataPipe):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None,
indexes=None, edge_features=None, compacted_seeds=None,
blocks=None,)
2. Node pairs.
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
... names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None, labels=None,
input_nodes=None, indexes=None, edge_features=None,
compacted_seeds=None, blocks=None,)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
... names=("node_pairs", "labels")
... names=("seeds", "labels")
... )
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=tensor([10, 11, 12, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None,
labels=tensor([10, 11, 12, 13]), input_nodes=None,
indexes=None, edge_features=None, compacted_seeds=None,
blocks=None,)
4. Node pairs, labels and indexes.
>>> seeds = torch.arange(0, 20).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
>>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
>>> item_set = gb.ItemSet((seeds, labels, indexes), names=("seeds",
... "labels", "indexes"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None,
negative_dsts=tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None,
labels=tensor([1, 1, 0, 0]), input_nodes=None,
indexes=tensor([0, 1, 0, 0]), edge_features=None,
compacted_seeds=None, blocks=None,)
5. DGLGraphs.
Expand Down Expand Up @@ -404,85 +397,74 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seeds"),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None, indexes=None,
edge_features=None, compacted_seeds=None, blocks=None,)
8. Heterogeneous node pairs.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... seeds_like, names="seeds"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... seeds_follow, names="seeds"),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None, indexes=None,
edge_features=None, compacted_seeds=None, blocks=None,)
9. Heterogeneous node pairs and labels.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 10)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(10, 20)
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 5)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... "user:like:item": gb.ItemSet((seeds_like, labels_like),
... names=("seeds", "labels")),
... "user:follow:user": gb.ItemSet((seeds_follow, labels_follow),
... names=("seeds", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])},
input_nodes=None, indexes=None, edge_features=None,
compacted_seeds=None, blocks=None,)
10. Heterogeneous node pairs, labels and indexes.
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_like = torch.tensor([0, 1, 0, 0, 1])
>>> seeds_follow = torch.arange(20, 30).reshape(-1, 2)
>>> labels_follow = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_follow = torch.tensor([0, 1, 0, 0, 1])
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... "user:like:item": gb.ItemSet((seeds_like, labels_like,
... indexes_like), names=("seeds", "labels", "indexes")),
... "user:follow:user": gb.ItemSet((seeds_follow,labels_follow,
... indexes_follow), names=("seeds", "labels", "indexes")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])},
input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])},
edge_features=None, compacted_seeds=None, blocks=None,)
"""

def __init__(
Expand Down
21 changes: 10 additions & 11 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class MiniBatch:

labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes / node pairs in the graph.
Labels associated with seeds in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_nodes' or 'node_pairs'.
should be corresponding labels to given 'seeds'.
- If `labels` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_nodes' or 'node_pairs'.
value should be corresponding labels to given 'seeds'.
"""

seeds: Union[
Expand Down Expand Up @@ -61,15 +61,14 @@ class MiniBatch:

indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Indexes associated with seed nodes / node pairs in the graph, which
indicates to which query a seed node / node pair belongs.
Indexes associated with seeds in the graph, which
indicates to which query a seeds belongs.
- If `indexes` is a tensor: It indicates the graph is homogeneous. The
value should be corresponding query to given 'seed_nodes' or
'node_pairs'.
- If `indexes` is a dictionary: It indicates the graph is
heterogeneous. The keys should be node or edge type and the value should
be corresponding query to given 'seed_nodes' or 'node_pairs'. For each
key, indexes are consecutive integers starting from zero.
value should be corresponding query to given 'seeds'.
- If `indexes` is a dictionary: It indicates the graph is heterogeneous.
The keys should be node or edge type and the value should be
corresponding query to given 'seeds'. For each key, indexes are
consecutive integers starting from zero.
"""

sampled_subgraphs: List[SampledSubgraph] = None
Expand Down

0 comments on commit ba9c152

Please sign in to comment.