From d281e4240c606f540c66a4d68e845cdffea2654a Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:00:12 +0800 Subject: [PATCH] [DistGB] add testcases for DistDGL local sampling on multiple partitions (#7451) --- .../distributed/test_distributed_sampling.py | 173 ++++++++++++++++++ tests/distributed/test_mp_dataloader.py | 40 ++-- 2 files changed, 195 insertions(+), 18 deletions(-) diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 2a88d5516226..5aabb4a9defc 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -22,6 +22,9 @@ sample_etype_neighbors, sample_neighbors, ) + +from dgl.distributed.graph_partition_book import _etype_tuple_to_str + from scipy import sparse as spsp from utils import generate_ip_config, reset_envs @@ -1685,6 +1688,176 @@ def test_standalone_etype_sampling(): check_standalone_etype_sampling(Path(tmpdirname)) +@pytest.mark.parametrize("num_parts", [1, 4]) +@pytest.mark.parametrize("use_graphbolt", [False]) +@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"]) +def test_local_sampling_homograph(num_parts, use_graphbolt, prob_or_mask): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as test_dir: + g = CitationGraphDataset("cora")[0] + prob = torch.rand(g.num_edges()) + mask = prob > 0.2 + prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0 + g.edata["prob"] = prob + g.edata["mask"] = mask + graph_name = "test_local_sampling" + + _, orig_eids = partition_graph( + g, + graph_name, + num_parts, + test_dir, + num_hops=1, + part_method="metis", + return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, + ) + + part_config = os.path.join(test_dir, f"{graph_name}.json") + for part_id in range(num_parts): + local_g, _, edge_feats, gpb, _, _, _ = load_partition( + part_config, + part_id, + load_feats=True, + use_graphbolt=use_graphbolt, + ) + inner_global_nids = gpb.partid2nids(part_id) + inner_global_eids = gpb.partid2eids(part_id) + inner_node_data = ( + local_g.node_attributes["inner_node"] + if use_graphbolt + else local_g.ndata["inner_node"] + ) + inner_edge_data = ( + local_g.edge_attributes["inner_edge"] + if use_graphbolt + else local_g.edata["inner_edge"] + ) + assert len(inner_global_nids) == inner_node_data.sum() + assert len(inner_global_eids) == inner_edge_data.sum() + + c_etype = gpb.canonical_etypes[0] + _prob = [] + prob = edge_feats[_etype_tuple_to_str(c_etype) + "/" + prob_or_mask] + assert len(prob) == len(inner_global_eids) + assert len(prob) <= inner_edge_data.shape[0] + _prob.append(prob) + + sampled_g = dgl.distributed.graph_services._sample_neighbors( + use_graphbolt, + local_g, + gpb, + inner_global_nids, + 5, + prob=_prob, + ) + sampled_homo_eids = sampled_g.global_eids + sampled_orig_eids = orig_eids[sampled_homo_eids] + assert torch.all(g.edata[prob_or_mask][sampled_orig_eids] > 0) + + +@pytest.mark.parametrize("num_parts", [1, 4]) +@pytest.mark.parametrize("use_graphbolt", [False]) +@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"]) +def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as test_dir: + g = create_random_hetero() + for c_etype in g.canonical_etypes: + prob = torch.rand(g.num_edges(c_etype)) + mask = prob > 0.2 + prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0 + g.edges[c_etype].data["prob"] = prob + g.edges[c_etype].data["mask"] = mask + graph_name = "test_local_sampling" + + _, orig_eids = partition_graph( + g, + graph_name, + num_parts, + test_dir, + num_hops=1, + part_method="metis", + return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=True, + store_inner_node=True, + store_inner_edge=True, + ) + + part_config = os.path.join(test_dir, f"{graph_name}.json") + for part_id in range(num_parts): + local_g, _, edge_feats, gpb, _, _, _ = load_partition( + part_config, + part_id, + load_feats=True, + use_graphbolt=use_graphbolt, + ) + inner_global_nids = [ + gpb.map_to_homo_nid(gpb.partid2nids(part_id, ntype), ntype) + for ntype in gpb.ntypes + ] + inner_global_nids = torch.cat(inner_global_nids) + inner_global_eids = { + c_etype: gpb.partid2eids(part_id, c_etype) + for c_etype in gpb.canonical_etypes + } + inner_node_data = ( + local_g.node_attributes["inner_node"] + if use_graphbolt + else local_g.ndata["inner_node"] + ) + inner_edge_data = ( + local_g.edge_attributes["inner_edge"] + if use_graphbolt + else local_g.edata["inner_edge"] + ) + assert len(inner_global_nids) == inner_node_data.sum() + num_inner_global_eids = sum( + [len(eids) for eids in inner_global_eids.values()] + ) + assert num_inner_global_eids == inner_edge_data.sum() + + _prob = [] + for i, c_etype in enumerate(gpb.canonical_etypes): + prob = edge_feats[ + _etype_tuple_to_str(c_etype) + "/" + prob_or_mask + ] + assert len(prob) == len(inner_global_eids[c_etype]) + assert ( + len(prob) + == gpb.local_etype_offset[i + 1] - gpb.local_etype_offset[i] + ) + assert len(prob) <= inner_edge_data.shape[0] + _prob.append(prob) + + sampled_g = dgl.distributed.graph_services._sample_etype_neighbors( + use_graphbolt, + local_g, + gpb, + inner_global_nids, + torch.full((len(g.canonical_etypes),), 5, dtype=torch.int64), + prob=_prob, + etype_offset=gpb.local_etype_offset, + ) + sampled_homo_eids = sampled_g.global_eids + sampled_etype_ids, sampled_per_etype_eids = gpb.map_to_per_etype( + sampled_homo_eids + ) + for etype_id, c_etype in enumerate(gpb.canonical_etypes): + indices = torch.nonzero(sampled_etype_ids == etype_id).squeeze() + sampled_eids = sampled_per_etype_eids[indices] + sampled_orig_eids = orig_eids[c_etype][sampled_eids] + assert torch.all( + g.edges[c_etype].data[prob_or_mask][sampled_orig_eids] > 0 + ) + + if __name__ == "__main__": import tempfile diff --git a/tests/distributed/test_mp_dataloader.py b/tests/distributed/test_mp_dataloader.py index 4cf867ecd217..0aa2e37ca90d 100644 --- a/tests/distributed/test_mp_dataloader.py +++ b/tests/distributed/test_mp_dataloader.py @@ -21,6 +21,27 @@ from utils import generate_ip_config, reset_envs +def _unique_rand_graph(num_nodes=1000, num_edges=10 * 1000): + edges_set = set() + while len(edges_set) < num_edges: + src = np.random.randint(0, num_nodes - 1) + dst = np.random.randint(0, num_nodes - 1) + if ( + src != dst + and (src, dst) not in edges_set + and (dst, src) not in edges_set + ): + edges_set.add((src, dst)) + src_list, dst_list = zip(*edges_set) + + src = th.tensor(src_list, dtype=th.long) + dst = th.tensor(dst_list, dtype=th.long) + g = dgl.graph((th.cat([src, dst]), th.cat([dst, src]))) + E = len(src) + reverse_eids = th.cat([th.arange(E, 2 * E), th.arange(0, E)]) + return g, reverse_eids + + class NeighborSampler(object): def __init__( self, @@ -889,24 +910,7 @@ def test_edge_dataloader_homograph( num_server = 1 dataloader_type = "edge" reset_envs() - g = CitationGraphDataset("cora")[0] - src, dst = g.edges() - # Remove reverse edges. - visited = th.zeros_like(src, dtype=th.bool) - remove_mask = th.zeros_like(src, dtype=th.bool) - for i, (src_id, dst_id) in enumerate(zip(src, dst)): - if visited[i]: - continue - if g.has_edges_between(dst_id, src_id): - eid = g.edge_ids(dst_id, src_id) - visited[eid] = True - remove_mask[i] = True - visited[i] = True - src = src[~remove_mask] - dst = dst[~remove_mask] - g = dgl.graph((th.cat([src, dst]), th.cat([dst, src]))) - E = len(src) - reverse_eids = th.cat([th.arange(E, 2 * E), th.arange(0, E)]) + g, reverse_eids = _unique_rand_graph() check_dataloader( g, num_server,