From 7357a34433c70a5d210f518e91b764095854463c Mon Sep 17 00:00:00 2001 From: Poovaiah Palangappa <98763718+pmpalang@users.noreply.github.com> Date: Wed, 15 Nov 2023 06:23:06 -0800 Subject: [PATCH] Edge-based temporal sampling (#8372) This PR is to enable the edge-based temporal sampling for NeighborSampler. This PR covers both homogeneous and heterogeneous cases. The associated PYG-LIB PR is https://github.com/pyg-team/pyg-lib/pull/280 . Thanks, Poovaiah --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + test/loader/test_neighbor_loader.py | 59 ++++++++++++++++ torch_geometric/datasets/movie_lens.py | 3 + .../loader/link_neighbor_loader.py | 2 +- torch_geometric/loader/neighbor_loader.py | 2 +- torch_geometric/sampler/neighbor_sampler.py | 67 +++++++++++++++++-- torch_geometric/sampler/utils.py | 24 +++++-- torch_geometric/typing.py | 1 + 8 files changed, 143 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a95cb3e6ee7..6e02c9ad29ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372)) - Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363)) - Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357)) - Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345)) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 11bb42c6c125..b02b708254bd 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -22,6 +22,7 @@ withPackage, ) from torch_geometric.typing import ( + WITH_EDGE_TIME_NEIGHBOR_SAMPLE, WITH_PYG_LIB, WITH_TORCH_SPARSE, WITH_WEIGHTED_NEIGHBOR_SAMPLE, @@ -786,6 +787,64 @@ def test_weighted_hetero_neighbor_loader(): assert global_edge_index.tolist() == [[3, 4], [2, 3]] +@pytest.mark.skipif( + not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_edge_level_temporal_homo_neighbor_loader(): + edge_index = torch.tensor([ + [0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3], + ]) + edge_time = torch.arange(edge_index.size(1)) + + data = Data(edge_index=edge_index, edge_time=edge_time, num_nodes=5) + + loader = NeighborLoader( + data, + num_neighbors=[-1, -1], + input_time=torch.tensor([4, 4, 4, 4, 4]), + time_attr='edge_time', + batch_size=1, + ) + + for batch in loader: + assert batch.edge_time.numel() == batch.num_edges + if batch.edge_time.numel() > 0: + assert batch.edge_time.max() <= 4 + + +@pytest.mark.skipif( + not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_edge_level_temporal_hetero_neighbor_loader(): + edge_index = torch.tensor([ + [0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3], + ]) + edge_time = torch.arange(edge_index.size(1)) + + data = HeteroData() + data['A'].num_nodes = 5 + data['A', 'A'].edge_index = edge_index + data['A', 'A'].edge_time = edge_time + + loader = NeighborLoader( + data, + num_neighbors=[-1, -1], + input_nodes='A', + input_time=torch.tensor([4, 4, 4, 4, 4]), + time_attr='edge_time', + batch_size=1, + ) + + for batch in loader: + assert batch['A', 'A'].edge_time.numel() == batch['A', 'A'].num_edges + if batch['A', 'A'].edge_time.numel() > 0: + assert batch['A', 'A'].edge_time.max() <= 4 + + @withCUDA @onlyNeighborSampler @withPackage('torch_frame') diff --git a/torch_geometric/datasets/movie_lens.py b/torch_geometric/datasets/movie_lens.py index bc93e7045eca..ca32f3523671 100644 --- a/torch_geometric/datasets/movie_lens.py +++ b/torch_geometric/datasets/movie_lens.py @@ -90,8 +90,11 @@ def process(self): edge_index = torch.tensor([src, dst]) rating = torch.from_numpy(df['rating'].values).to(torch.long) + time = torch.from_numpy(df['timestamp'].values).to(torch.long) + data['user', 'rates', 'movie'].edge_index = edge_index data['user', 'rates', 'movie'].edge_label = rating + data['user', 'rates', 'movie'].time = time if self.pre_transform is not None: data = self.pre_transform(data) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 9871088eef23..9fd3962b8ca9 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -160,7 +160,7 @@ class LinkNeighborLoader(LinkLoader): Deprecated in favor of the :obj:`neg_sampling` argument. (default: :obj:`None`) time_attr (str, optional): The name of the attribute that denotes - timestamps for the nodes in the graph. + timestamps for either the nodes or edges in the graph. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 9cbf0bcbaebe..341f2f5a23b6 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -160,7 +160,7 @@ class NeighborLoader(NodeLoader): fulfill temporal constraints. (default: :obj:`"uniform"`) time_attr (str, optional): The name of the attribute that denotes - timestamps for the nodes in the graph. + timestamps for either the nodes or edges in the graph. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 099479ca9992..81dfcf421ff4 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -69,13 +69,26 @@ def __init__( self.num_nodes = data.num_nodes self.node_time: Optional[Tensor] = None + self.edge_time: Optional[Tensor] = None + if time_attr is not None: - self.node_time = data[time_attr] + if data.is_node_attr(time_attr): + self.node_time = data[time_attr] + elif data.is_edge_attr(time_attr): + self.edge_time = data[time_attr] + else: + raise ValueError( + f"The time attribute '{time_attr}' is neither a " + f"node-level or edge-level attribute") # Convert the graph data into CSC format for sampling: self.colptr, self.row, self.perm = to_csc( data, device='cpu', share_memory=share_memory, - is_sorted=is_sorted, src_node_time=self.node_time) + is_sorted=is_sorted, src_node_time=self.node_time, + edge_time=self.edge_time) + + if self.edge_time is not None and self.perm is not None: + self.edge_time = self.edge_time[self.perm] self.edge_weight: Optional[Tensor] = None if weight_attr is not None: @@ -89,8 +102,32 @@ def __init__( self.num_nodes = {k: data[k].num_nodes for k in self.node_types} self.node_time: Optional[Dict[NodeType, Tensor]] = None + self.edge_time: Optional[Dict[EdgeType, Tensor]] = None + if time_attr is not None: - self.node_time = data.collect(time_attr) + is_node_level_time = is_edge_level_time = False + + for store in data.node_stores: + if time_attr in store: + is_node_level_time = True + for store in data.edge_stores: + if time_attr in store: + is_edge_level_time = True + + if is_node_level_time and is_edge_level_time: + raise ValueError( + f"The time attribute '{time_attr}' holds both " + f"node-level and edge-level information") + + if not is_node_level_time and not is_edge_level_time: + raise ValueError( + f"The time attribute '{time_attr}' is neither a " + f"node-level or edge-level attribute") + + if is_node_level_time: + self.node_time = data.collect(time_attr) + else: + self.edge_time = data.collect(time_attr) # Conversion to/from C++ string type: Since C++ cannot take # dictionaries with tuples as key as input, edge type triplets need @@ -101,10 +138,19 @@ def __init__( # Convert the graph data into CSC format for sampling: colptr_dict, row_dict, self.perm = to_hetero_csc( data, device='cpu', share_memory=share_memory, - is_sorted=is_sorted, node_time_dict=self.node_time) + is_sorted=is_sorted, node_time_dict=self.node_time, + edge_time_dict=self.edge_time) + self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) + if self.edge_time is not None: + for edge_type, edge_time in self.edge_time.items(): + if self.perm.get(edge_type, None) is not None: + edge_time = edge_time[self.perm[edge_type]] + self.edge_time[edge_type] = edge_time + self.edge_time = remap_keys(self.edge_time, self.to_rel_type) + self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None if weight_attr is not None: self.edge_weight = data.collect(weight_attr) @@ -180,6 +226,8 @@ def __init__( self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None self.node_time: Optional[Dict[NodeType, Tensor]] = None + self.edge_time: Optional[Dict[NodeType, Tensor]] = None + if time_attr is not None: for attr in time_attrs: # Reset index for full data. attr.index = None @@ -198,6 +246,11 @@ def __init__( self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) + if (self.edge_time is not None + and not torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE): + raise ImportError("Edge-level temporal sampling requires a " + "more recent 'pyg-lib' installation") + if (self.edge_weight is not None and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE): raise ImportError("Weighted neighbor sampling requires " @@ -232,7 +285,7 @@ def is_hetero(self) -> bool: @property def is_temporal(self) -> bool: - return self.node_time is not None + return self.node_time is not None or self.edge_time is not None @property def disjoint(self) -> bool: @@ -302,7 +355,7 @@ def _sample( self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: - args += (None, ) + args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) @@ -384,7 +437,7 @@ def _sample( self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: - args += (None, ) + args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) diff --git a/torch_geometric/sampler/utils.py b/torch_geometric/sampler/utils.py index ced0c72ed21b..402d794e68c3 100644 --- a/torch_geometric/sampler/utils.py +++ b/torch_geometric/sampler/utils.py @@ -5,7 +5,7 @@ from torch_geometric.data import Data, HeteroData from torch_geometric.data.storage import EdgeStorage -from torch_geometric.typing import NodeType, OptTensor +from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import coalesce, index_sort, lexsort from torch_geometric.utils.sparse import index2ptr @@ -16,11 +16,19 @@ def sort_csc( row: Tensor, col: Tensor, src_node_time: OptTensor = None, + edge_time: OptTensor = None, ) -> Tuple[Tensor, Tensor, Tensor]: - if src_node_time is None: + + if src_node_time is None and edge_time is None: col, perm = index_sort(col) return row[perm], col, perm - else: + + elif edge_time is not None: + assert src_node_time is None + perm = lexsort([edge_time, col]) + return row[perm], col[perm], perm + + else: # src_node_time is not None perm = lexsort([src_node_time[row], col]) return row[perm], col[perm], perm @@ -32,6 +40,7 @@ def to_csc( share_memory: bool = False, is_sorted: bool = False, src_node_time: Optional[Tensor] = None, + edge_time: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, OptTensor]: # Convert the graph data into a suitable format for sampling (CSC format). # Returns the `colptr` and `row` indices of the graph, as well as an @@ -61,10 +70,8 @@ def to_csc( elif data.edge_index is not None: row, col = data.edge_index if not is_sorted: - row, col, perm = sort_csc(row, col, src_node_time) - + row, col, perm = sort_csc(row, col, src_node_time, edge_time) colptr = index2ptr(col, data.size(1)) - else: row = torch.empty(0, dtype=torch.long, device=device) colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long, @@ -89,6 +96,7 @@ def to_hetero_csc( share_memory: bool = False, is_sorted: bool = False, node_time_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]: # Convert the heterogeneous graph data into a suitable format for sampling # (CSC format). @@ -98,7 +106,9 @@ def to_hetero_csc( for edge_type, store in data.edge_items(): src_node_time = (node_time_dict or {}).get(edge_type[0], None) - out = to_csc(store, device, share_memory, is_sorted, src_node_time) + edge_time = (edge_time_dict or {}).get(edge_type, None) + out = to_csc(store, device, share_memory, is_sorted, src_node_time, + edge_time) colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out return colptr_dict, row_dict, perm_dict diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index d9e687601178..085ecc613022 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -57,6 +57,7 @@ WITH_SAMPLED_OP = False WITH_INDEX_SORT = False WITH_METIS = False + WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False WITH_WEIGHTED_NEIGHBOR_SAMPLE = False try: