Skip to content

Commit

Permalink
[Feat] range partition book (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhanghyi authored Nov 1, 2024
1 parent 36ce42b commit bebf64f
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 58 deletions.
1 change: 1 addition & 0 deletions graphlearn_torch/python/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .frequency_partitioner import FrequencyPartitioner
from .partition_book import *
from .random_partitioner import RandomPartitioner
6 changes: 0 additions & 6 deletions graphlearn_torch/python/partition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def __getitem__(self, indices):
def offset(self):
return 0

class GLTPartitionBook(PartitionBook, torch.Tensor):
r""" A partition book of graph nodes or edges.
"""
def __getitem__(self, indices) -> torch.Tensor:
return torch.Tensor.__getitem__(self, indices)

HeteroNodePartitionDict = Dict[NodeType, PartitionBook]
HeteroEdgePartitionDict = Dict[EdgeType, PartitionBook]

Expand Down
72 changes: 72 additions & 0 deletions graphlearn_torch/python/partition/partition_book.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from typing import List, Tuple
from .base import PartitionBook


class RangePartitionBook(PartitionBook):
r"""A class for managing range-based partitions of consecutive IDs.
Suitable when IDs within each partition are consecutive.
Args:
partition_ranges (List[Tuple[int, int]]): A list of tuples representing
the start and end (exclusive) of each partition range.
partition_idx (int): The index of the current partition.
Example:
>>> partition_ranges = [(0, 10), (10, 20), (20, 30)]
>>> range_pb = RangePartitionBook(partition_ranges, partition_idx=1)
>>> indices = torch.tensor([0, 5, 10, 15, 20, 25])
>>> partition_ids = range_pb[indices]
>>> print(partition_ids)
tensor([0, 0, 1, 1, 2, 2])
"""

def __init__(self, partition_ranges: List[Tuple[int, int]], partition_idx: int):
if not all(r[0] < r[1] for r in partition_ranges):
raise ValueError("All partition ranges must have start < end")
if not all(r1[1] == r2[0] for r1, r2 in zip(partition_ranges[:-1], partition_ranges[1:])):
raise ValueError("Partition ranges must be continuous")

self.partition_bounds = torch.tensor(
[end for _, end in partition_ranges], dtype=torch.long)
self.partition_idx = partition_idx
self._id2index = OffsetId2Index(partition_ranges[partition_idx][0])

def __getitem__(self, indices: torch.Tensor) -> torch.Tensor:
return torch.searchsorted(self.partition_bounds, indices, right=True)

@property
def device(self):
return self.partition_bounds.device

@property
def id2index(self):
return self._id2index

def id_filter(self, node_pb: PartitionBook, partition_idx: int):
start = self.partition_bounds[partition_idx-1] if partition_idx > 0 else 0
end = self.partition_bounds[partition_idx]
return torch.arange(start, end)


class OffsetId2Index:
r"""
Convert global IDs to local indices by subtracting a specified offset.
"""

def __init__(self, offset: int):
self.offset = offset

def __getitem__(self, ids: torch.Tensor) -> torch.Tensor:
local_indices = ids - self.offset
return local_indices

def to(self, device):
# device is always same as the input ids
return self


class GLTPartitionBook(PartitionBook, torch.Tensor):
r""" A partition book of graph nodes or edges.
"""

def __getitem__(self, indices) -> torch.Tensor:
return torch.Tensor.__getitem__(self, indices)
104 changes: 74 additions & 30 deletions test/python/dist_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

# options for dataset generation
vnum_per_partition = 20
vnum_total = vnum_per_partition * 2
num_partition = 2
vnum_total = vnum_per_partition * num_partition # 40
degree = 2
enum_total = vnum_total * degree
enum_per_partition = vnum_per_partition * degree # 40
enum_total = enum_per_partition * num_partition # 80

# for hetero dataset
user_ntype = 'user'
Expand All @@ -36,64 +38,106 @@
device_num = 2


def _prepare_dataset(rank: int, weighted: bool = False):
# partition
node_pb = torch.tensor(
[v % 2 for v in range(0, vnum_total)],
dtype=torch.long
)
edge_pb = torch.tensor(
[((e // degree) % 2) for e in range(0, enum_total)],
dtype=torch.long
)
def _prepare_dataset(rank: int,
weighted: bool = False,
is_range_partition: bool = False):
"""
Prepare a synthetic graph dataset with 40 nodes and 80 edges for unit tests.
Graph topology:
- rows: [0, 0, 1, 1, 2, 2, ... 37, 37, 38, 38, 39, 39]
- cols: [1, 2, 2, 3, 3, 4, ... 38, 39, 39, 0, 0, 1]
- eids: [0, 1, 2, 3, 4, 5, ... 74, 75, 76, 77, 78, 79]
Node features:
[[0., 0., ..., 0., 0.],
[1., 1., ..., 1., 1.],
...
[39., 39., ..., 39., 39.]]
Edge features:
[[0., 0., ..., 0., 0.],
[1., 1., ..., 1., 1.],
...
[79., 79., ..., 79., 79.]]
Two partition strategies are available:
1. Range partition:
- Nodes with IDs [0, 19] and edges with IDs [0, 39] are on partition 0
- Nodes with IDs [20, 39] and edges with IDs [40, 79] are on partition 1
2. Hash partition:
- Even-numbered nodes and edges are on partition 0
- Odd-numbered nodes and edges are on partition 1
The graph topology and features are identical under both partition strategies.
"""
if is_range_partition:
node_ranges = [(0, vnum_per_partition), (vnum_per_partition, vnum_total)]
edge_ranges = [(0, enum_total // 2), (enum_total // 2, enum_total)]
node_pb = glt.partition.RangePartitionBook(
node_ranges, rank)
edge_pb = glt.partition.RangePartitionBook(
edge_ranges, rank)
start, end, step = rank * vnum_per_partition, (rank + 1) * vnum_per_partition, 1
else:
node_pb = torch.tensor(
[v % 2 for v in range(0, vnum_total)],
dtype=torch.long
)
edge_pb = torch.tensor(
[((e // degree) % 2) for e in range(0, enum_total)],
dtype=torch.long
)
start, end, step = rank, vnum_total, 2


# graph
nodes, rows, cols, eids = [], [], [], []
for v in range(rank, vnum_total, 2):
for v in range(start, end, step):
nodes.append(v)
rows.extend([v for _ in range(degree)])
cols.extend([((v + i + 1) % vnum_total) for i in range(degree)])
eids.extend([(v * degree + i) for i in range(degree)])

edge_index = torch.tensor([rows, cols], dtype=torch.int64)
edge_ids = torch.tensor(eids, dtype=torch.int64)
edge_weights = (edge_ids % 2).to(torch.float)
csr_topo = glt.data.Topology(edge_index=edge_index, edge_ids=edge_ids)
graph = glt.data.Graph(csr_topo, 'ZERO_COPY', device=0)

weighted_csr_topo = glt.data.Topology(
edge_index=edge_index, edge_ids=edge_ids, edge_weights=edge_weights)
graph = glt.data.Graph(csr_topo, 'ZERO_COPY', device=0)
weighted_graph = glt.data.Graph(weighted_csr_topo, 'CPU')

# feature
device_group_list = [glt.data.DeviceGroup(0, [0]),
glt.data.DeviceGroup(1, [1])]
split_ratio = 0.2

nfeat = rank + torch.zeros(len(nodes), 512, dtype=torch.float32)
nfeat_id2idx = glt.utils.id2idx(nodes)
nfeat = torch.tensor(nodes, dtype=torch.float32).unsqueeze(1).repeat(1, 512)
nfeat_id2idx = node_pb.id2index if is_range_partition else glt.utils.id2idx(nodes)
node_feature = glt.data.Feature(nfeat, nfeat_id2idx, split_ratio,
device_group_list, device=0)

efeat = rank + torch.ones(len(eids), 10, dtype=torch.float32)
efeat_id2idx = glt.utils.id2idx(eids)
efeat = torch.tensor(eids, dtype=torch.float32).unsqueeze(1).repeat(1, 10)
efeat_id2idx = edge_pb.id2index if is_range_partition else glt.utils.id2idx(eids)
edge_feature = glt.data.Feature(efeat, efeat_id2idx, split_ratio,
device_group_list, device=0)

# whole node label
node_label = torch.arange(vnum_total)

# dist dataset
if weighted:
return glt.distributed.DistDataset(
2, rank,
weighted_graph, node_feature, edge_feature, node_label,
node_pb, edge_pb
)
else:
return glt.distributed.DistDataset(
2, rank,
graph, node_feature, edge_feature, node_label,
node_pb, edge_pb
)
ds = glt.distributed.DistDataset(
2, rank,
weighted_graph if weighted else graph,
node_feature, edge_feature, node_label,
node_pb, edge_pb
)

if is_range_partition:
ds.id_filter = node_pb.id_filter
return ds


def _prepare_hetero_dataset(
Expand Down
32 changes: 25 additions & 7 deletions test/python/test_dist_link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from dist_test_utils import *
from dist_test_utils import _prepare_dataset, _prepare_hetero_dataset
from parameterized import parameterized

def _check_sample_result(data, edge_dir='out'):
tc = unittest.TestCase()
Expand Down Expand Up @@ -338,6 +339,8 @@ class DistLinkNeighborLoaderTestCase(unittest.TestCase):
def setUp(self):
self.dataset0 = _prepare_dataset(rank=0)
self.dataset1 = _prepare_dataset(rank=1)
self.range_partition_dataset0 = _prepare_dataset(rank=0, is_range_partition=True)
self.range_partition_dataset1 = _prepare_dataset(rank=1, is_range_partition=True)
self.input_edges0 = torch.stack(
(torch.arange(vnum_per_partition), torch.arange(vnum_per_partition)+1)
).to(dtype=torch.long)
Expand All @@ -357,37 +360,52 @@ def setUp(self):
self.master_port = glt.utils.get_free_port()
self.sampling_master_port = glt.utils.get_free_port()

def test_homo_out_sample_collocated(self):
def _get_homo_datasets(self, is_range_partition):
return (self.range_partition_dataset0, self.range_partition_dataset1) if is_range_partition else (self.dataset0, self.dataset1)

@parameterized.expand([
(True),
(False),
])
def test_homo_out_sample_collocated(self, is_range_partition):
print("\n--- DistLinkNeighborLoader Test (homogeneous, collocated) ---")
dataset0, dataset1 = self._get_homo_datasets(is_range_partition)

mp_context = torch.multiprocessing.get_context('spawn')
w0 = mp_context.Process(
target=run_test_as_worker,
args=(2, 0, self.master_port, self.sampling_master_port,
self.dataset0, self.bin_neg_sampling, self.input_edges0, _check_sample_result, True)
dataset0, self.bin_neg_sampling, self.input_edges0, _check_sample_result, True)
)
w1 = mp_context.Process(
target=run_test_as_worker,
args=(2, 1, self.master_port, self.sampling_master_port,
self.dataset1, self.bin_neg_sampling, self.input_edges1, _check_sample_result, True)
dataset1, self.bin_neg_sampling, self.input_edges1, _check_sample_result, True)
)
w0.start()
w1.start()
w0.join()
w1.join()

def test_homo_out_sample_mp(self):

@parameterized.expand([
(True),
(False),
])
def test_homo_out_sample_mp(self, is_range_partition):
print("\n--- DistLinkNeighborLoader Test (homogeneous, multiprocessing) ---")
dataset0, dataset1 = self._get_homo_datasets(is_range_partition)

mp_context = torch.multiprocessing.get_context('spawn')
w0 = mp_context.Process(
target=run_test_as_worker,
args=(2, 0, self.master_port, self.sampling_master_port,
self.dataset0, self.tri_neg_sampling, self.input_edges0,
dataset0, self.tri_neg_sampling, self.input_edges0,
_check_sample_result, False)
)
w1 = mp_context.Process(
target=run_test_as_worker,
args=(2, 1, self.master_port, self.sampling_master_port,
self.dataset1, self.tri_neg_sampling, self.input_edges1,
dataset1, self.tri_neg_sampling, self.input_edges1,
_check_sample_result, False)
)
w0.start()
Expand Down
Loading

0 comments on commit bebf64f

Please sign in to comment.