Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edge-based temporal sampling #280

Merged
merged 15 commits into from
Nov 14, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280))
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
- Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275))
Expand Down
1 change: 1 addition & 0 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_hetero_neighbor(dataset, **kwargs):
seed_dict,
num_neighbors_dict,
node_time_dict,
edge_time_dict=None,
seed_time_dict=None,
edge_weight_dict=edge_weight_dict,
csc=True,
Expand Down
3 changes: 2 additions & 1 deletion benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_neighbor(dataset, **kwargs):
col,
seed,
num_neighbors,
time=node_time,
node_time=node_time,
edge_time=None,
seed_time=None,
edge_weight=edge_weight,
replace=args.replace,
Expand Down
232 changes: 168 additions & 64 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -38,7 +39,8 @@ hetero_neighbor_sample_kernel(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
Expand All @@ -53,7 +55,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand Down
47 changes: 26 additions & 21 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -38,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time,
seed_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -56,7 +57,8 @@ hetero_neighbor_sample(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
Expand Down Expand Up @@ -89,17 +91,18 @@ hetero_neighbor_sample(
.findSchemaOrThrow("pyg::hetero_neighbor_sample", "")
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
num_neighbors_dict, node_time_dict, edge_time_dict,
seed_time_dict, edge_weight_dict, csc, replace, directed,
disjoint, temporal_strategy, return_edge_id);
}

std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -118,34 +121,36 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy);
return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time,
seed_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) -> "
"num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, "
"Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Tensor, Tensor, Tensor, Tensor?, int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict "
"= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"Dict(str, Tensor)? node_time_dict = None, Dict(str, Tensor)? "
"edge_time_dict = None, Dict(str, Tensor)? seed_time_dict = None, "
"Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform') -> (Tensor, Tensor, int[])"));
"num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, "
"Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform') -> (Tensor, Tensor, int[])"));
}

} // namespace sampler
Expand Down
10 changes: 7 additions & 3 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& node_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
Expand All @@ -48,7 +49,9 @@ hetero_neighbor_sample(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict =
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict =
c10::nullopt,
Expand All @@ -72,7 +75,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& node_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
Expand Down
59 changes: 31 additions & 28 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def neighbor_sample(
col: Tensor,
seed: Tensor,
num_neighbors: List[int],
time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
Expand All @@ -39,16 +40,27 @@ def neighbor_sample(
num_neighbors (List[int]): The number of neighbors to sample for each
node in each iteration. If an entry is set to :obj:`-1`, all
neighbors will be included.
time (torch.Tensor, optional): Timestamps for the nodes in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the seed node.
node_time (torch.Tensor, optional): Timestamps for the nodes in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
nodes have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
edge_time (torch.Tensor, optional): Timestamps for the edges in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
edges have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
seed_time (torch.Tensor, optional): Optional values to override the
timestamp for seed nodes. If not set, will use timestamps in
:obj:`time` as default for seed nodes. (default: :obj:`None`)
:obj:`node_time` as default for seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge-weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
Expand All @@ -75,18 +87,19 @@ def neighbor_sample(
Lastly, returns information about the sampled amount of nodes and edges
per hop.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, seed_time, edge_weight, csc,
replace, directed, disjoint,
temporal_strategy, return_edge_id)
return torch.ops.pyg.neighbor_sample( #
rowptr, col, seed, num_neighbors, node_time, edge_time, seed_time,
edge_weight, csc, replace, directed, disjoint, temporal_strategy,
return_edge_id)


def hetero_neighbor_sample(
rowptr_dict: Dict[EdgeType, Tensor],
col_dict: Dict[EdgeType, Tensor],
seed_dict: Dict[NodeType, Tensor],
num_neighbors_dict: Dict[EdgeType, List[int]],
time_dict: Optional[Dict[NodeType, Tensor]] = None,
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
seed_time_dict: Optional[Dict[NodeType, Tensor]] = None,
edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None,
csc: bool = False,
Expand Down Expand Up @@ -123,29 +136,19 @@ def hetero_neighbor_sample(
TO_REL_TYPE[k]: v
for k, v in num_neighbors_dict.items()
}
if edge_time_dict is not None:
edge_time_dict = {TO_REL_TYPE[k]: v for k, v in edge_time_dict.items()}
if edge_weight_dict is not None:
edge_weight_dict = {
TO_REL_TYPE[k]: v
for k, v in edge_weight_dict.items()
}

out = torch.ops.pyg.hetero_neighbor_sample(
node_types,
edge_types,
rowptr_dict,
col_dict,
seed_dict,
num_neighbors_dict,
time_dict,
seed_time_dict,
edge_weight_dict,
csc,
replace,
directed,
disjoint,
temporal_strategy,
return_edge_id,
)
out = torch.ops.pyg.hetero_neighbor_sample( #
node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, node_time_dict, edge_time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint, temporal_strategy,
return_edge_id)

(row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict,
num_edges_per_hop_dict) = out
Expand Down
9 changes: 6 additions & 3 deletions test/csrc/sampler/test_dist_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand All @@ -85,7 +86,8 @@ TEST(DistDisjointNeighborTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down Expand Up @@ -121,7 +123,8 @@ TEST(DistTemporalNeighborTest, BasicAssertions) {
/*col=*/col,
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/time,
/*node_time=*/time,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down
9 changes: 6 additions & 3 deletions test/csrc/sampler/test_dist_relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/seed,
/*num_neighbors=*/{2},
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down Expand Up @@ -182,7 +183,8 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) {
/*col_dict=*/col_dict,
/*seed_dict=*/seed_dict,
/*num_neighbors_dict=*/num_neighbors_dict,
/*time_dict=*/c10::nullopt,
/*node_time_dict=*/c10::nullopt,
/*edge_time_dict=*/c10::nullopt,
/*seed_time_dict=*/c10::nullopt,
/*edge_weight_dict=*/c10::nullopt,
/*csc=*/true);
Expand Down Expand Up @@ -246,7 +248,8 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) {
/*col_dict=*/col_dict,
/*seed_dict=*/seed_dict,
/*num_neighbors_dict=*/num_neighbors_dict,
/*time_dict=*/c10::nullopt,
/*node_time_dict=*/c10::nullopt,
/*edge_time_dict=*/c10::nullopt,
/*seed_time_dict=*/c10::nullopt,
/*edge_weight_dict=*/c10::nullopt,
/*csc=*/false,
Expand Down
Loading
Loading