Skip to content

Commit

Permalink
Update hetero dist relabel (#284)
Browse files Browse the repository at this point in the history
Due to the fact that the implementation of distributed training for
hetero has changed, it is also necessary to change the dist hetero
relabel neighborhood function.

Related pytorch_geometric PR:
[#8503](pyg-team/pytorch_geometric#8503)

Changes made:
- `num_sampled_neighbors_per_node` dictionary currently store
information about the number of sampled neighbors for each layer
separately:

`const c10::Dict<rel_type,
std::vector<int64_t>>&num_sampled_neighbors_per_node_dict` -> `const
c10::Dict<rel_type,
std::vector<std::vector<int64_t>>>&num_sampled_neighbors_per_node_dict`
- The method of mapping nodes has also been changed. This is now done
layer by layer.
- After each layer, the range of src nodes for each edge type for the
next layer is calculated and the offsets for edge types having the same
src node types must be the same.
- The src node range for each edge type in a given layer is defined by a
dictionary `srcs_slice_dict`. Local src nodes (`sampled_rows`) will be
created on its basis and the starting value of the next layer will be
the end value from the previous layer.
  • Loading branch information
kgajdamo authored Dec 4, 2023
1 parent a5fcc87 commit d2370c2
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 68 deletions.
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ merge_outputs(
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_neighbors_per_node(p_size);
std::vector<int64_t> num_sampled_neighbors_per_node(p_size);

const auto scalar_type = node_ids[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
Expand Down Expand Up @@ -106,7 +106,7 @@ merge_outputs(
batch_data[j]);
}

sampled_neighbors_per_node[j] = end_node - begin_node;
num_sampled_neighbors_per_node[j] = end_node - begin_node;
}
});

Expand All @@ -128,7 +128,7 @@ merge_outputs(
});

return std::make_tuple(out_node_id, out_edge_id, out_batch,
sampled_neighbors_per_node);
num_sampled_neighbors_per_node);
}

#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \
Expand Down
111 changes: 77 additions & 34 deletions pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ relabel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand All @@ -117,9 +117,16 @@ relabel(
phmap::flat_hash_map<node_type, scalar_t*> batch_data_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_rows_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_cols_dict;
// `srcs_slice_dict` defines the number of src nodes for each edge type in
// a given layer in the form of a range. Local src nodes (`sampled_rows`)
// will be created on its basis, so for a given edge type the ranges will
// not be repeated, and the starting value of the next layer will be the
// end value from the previous layer.
phmap::flat_hash_map<edge_type, std::pair<size_t, size_t>> srcs_slice_dict;

phmap::flat_hash_map<node_type, Mapper<node_t, scalar_t>> mapper_dict;
phmap::flat_hash_map<node_type, std::pair<size_t, size_t>> slice_dict;
phmap::flat_hash_map<node_type, int64_t> srcs_offset_dict;

const bool parallel = at::get_num_threads() > 1 && edge_types.size() > 1;
std::vector<std::vector<edge_type>> threads_edge_types;
Expand All @@ -129,6 +136,14 @@ relabel(
sampled_rows_dict[k];
sampled_cols_dict[k];

// `num_sampled_neighbors_per_node_dict` is a dictionary where for
// each edge type it contains information about how many neighbors every
// src node has sampled. These values are saved in a separate vector for
// each layer.
size_t num_src_nodes =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[0].size();
srcs_slice_dict[k] = {0, num_src_nodes};

if (parallel) {
// Each thread is assigned edge types that have the same dst node
// type. Thanks to this, each thread will operate on a separate
Expand Down Expand Up @@ -161,6 +176,7 @@ relabel(
{k, sampled_nodes_with_duplicates_dict.at(k).data_ptr<scalar_t>()});
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
srcs_offset_dict[k] = 0;
if constexpr (disjoint) {
batch_data_dict.insert(
{k, batch_dict.value().at(k).data_ptr<scalar_t>()});
Expand All @@ -178,44 +194,71 @@ relabel(
}
}
}
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
for (const auto& k : threads_edge_types[j]) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

const auto num_sampled_neighbors_size =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k)).size();

if (num_sampled_neighbors_size == 0) {
continue;
}

for (auto i = 0; i < num_sampled_neighbors_size; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst);

slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[i];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
size_t num_layers =
num_sampled_neighbors_per_node_dict.at(to_rel_type(edge_types[0]))
.size();
// Iterate over the layers
for (auto ell = 0; ell < num_layers; ++ell) {
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto t = _s; t < _e; t++) {
for (const auto& k : threads_edge_types[t]) {
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

auto [src_begin, src_end] = srcs_slice_dict.at(k);

for (auto i = src_begin; i < src_end; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data =
sampled_nodes_data_dict.at(dst);

// For each edge type `slice_dict` defines the number of
// nodes sampled by a src node `i` in the form of a range.
// The indices in the given range point to global dst nodes
// from `dst_sampled_nodes_data`.
slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(
to_rel_type(k))[ell][i - src_begin];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
slice_dict.at(dst).first = end;
}
slice_dict.at(dst).first = end;
}
}
});

// Get local src nodes ranges for the next layer
if (ell < num_layers - 1) {
for (const auto& k : edge_types) {
// Edges with the same src node types will have the same src node
// offsets.
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
if (srcs_offset_dict[src] < srcs_slice_dict.at(k).second) {
srcs_offset_dict[src] = srcs_slice_dict.at(k).second;
}
});
}
for (const auto& k : edge_types) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
srcs_slice_dict[k] = {
srcs_offset_dict.at(src),
srcs_offset_dict.at(src) + num_sampled_neighbors_per_node_dict
.at(to_rel_type(k))[ell + 1]
.size()};
}
}
}

for (const auto& k : edge_types) {
const auto edges = get_sampled_edges<scalar_t>(
Expand Down Expand Up @@ -254,7 +297,7 @@ hetero_relabel_neighborhood_kernel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand Down
20 changes: 11 additions & 9 deletions pyg_lib/csrc/sampler/dist_relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace sampler {
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch,
bool csc,
Expand All @@ -28,7 +28,8 @@ std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
.findSchemaOrThrow("pyg::relabel_neighborhood", "")
.typed<decltype(relabel_neighborhood)>();
return op.call(seed, sampled_nodes_with_duplicates,
sampled_neighbors_per_node, num_nodes, batch, csc, disjoint);
num_sampled_neighbors_per_node, num_nodes, batch, csc,
disjoint);
}

std::tuple<c10::Dict<rel_type, at::Tensor>, c10::Dict<rel_type, at::Tensor>>
Expand All @@ -37,8 +38,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
bool csc,
Expand All @@ -62,21 +63,22 @@ hetero_relabel_neighborhood(
.typed<decltype(hetero_relabel_neighborhood)>();
return op.call(node_types, edge_types, seed_dict,
sampled_nodes_with_duplicates_dict,
sampled_neighbors_per_node_dict, num_nodes_dict, batch_dict,
csc, disjoint);
num_sampled_neighbors_per_node_dict, num_nodes_dict,
batch_dict, csc, disjoint);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::relabel_neighborhood(Tensor seed, Tensor "
"sampled_nodes_with_duplicates, int[] sampled_neighbors_per_node, int "
"sampled_nodes_with_duplicates, int[] num_sampled_neighbors_per_node, "
"int "
"num_nodes, Tensor? batch = None, bool csc = False, bool disjoint = "
"False) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) "
"sampled_nodes_with_duplicates_dict, Dict(str, int[]) "
"sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"sampled_nodes_with_duplicates_dict, Dict(str, int[][]) "
"num_sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"Dict(str, Tensor)? batch_dict = None, bool csc = False, bool disjoint = "
"False) -> (Dict(str, Tensor), Dict(str, Tensor))"));
}
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/dist_relabel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ PYG_API
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool csc = false,
Expand All @@ -32,8 +32,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict =
c10::nullopt,
Expand Down
15 changes: 9 additions & 6 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ TEST(DistMergeOutputsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
Expand Down Expand Up @@ -82,8 +83,9 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 3};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
3};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
Expand Down Expand Up @@ -124,6 +126,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
auto expected_batch = at::tensor({0, 0, 1, 2, 2, 3, 3}, options);
EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_batch));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}
Loading

0 comments on commit d2370c2

Please sign in to comment.