Skip to content

Commit

Permalink
Edge-based temporal sampling (#8372)
Browse files Browse the repository at this point in the history
This PR is to enable the edge-based temporal sampling for
NeighborSampler. This PR covers both homogeneous and heterogeneous
cases. The associated PYG-LIB PR is
pyg-team/pyg-lib#280 .

Thanks,
Poovaiah

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2023
1 parent 987d767 commit 7357a34
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372))
- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))
Expand Down
59 changes: 59 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
withPackage,
)
from torch_geometric.typing import (
WITH_EDGE_TIME_NEIGHBOR_SAMPLE,
WITH_PYG_LIB,
WITH_TORCH_SPARSE,
WITH_WEIGHTED_NEIGHBOR_SAMPLE,
Expand Down Expand Up @@ -786,6 +787,64 @@ def test_weighted_hetero_neighbor_loader():
assert global_edge_index.tolist() == [[3, 4], [2, 3]]


@pytest.mark.skipif(
not WITH_EDGE_TIME_NEIGHBOR_SAMPLE,
reason="'pyg-lib' does not support weighted neighbor sampling",
)
def test_edge_level_temporal_homo_neighbor_loader():
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_time = torch.arange(edge_index.size(1))

data = Data(edge_index=edge_index, edge_time=edge_time, num_nodes=5)

loader = NeighborLoader(
data,
num_neighbors=[-1, -1],
input_time=torch.tensor([4, 4, 4, 4, 4]),
time_attr='edge_time',
batch_size=1,
)

for batch in loader:
assert batch.edge_time.numel() == batch.num_edges
if batch.edge_time.numel() > 0:
assert batch.edge_time.max() <= 4


@pytest.mark.skipif(
not WITH_EDGE_TIME_NEIGHBOR_SAMPLE,
reason="'pyg-lib' does not support weighted neighbor sampling",
)
def test_edge_level_temporal_hetero_neighbor_loader():
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_time = torch.arange(edge_index.size(1))

data = HeteroData()
data['A'].num_nodes = 5
data['A', 'A'].edge_index = edge_index
data['A', 'A'].edge_time = edge_time

loader = NeighborLoader(
data,
num_neighbors=[-1, -1],
input_nodes='A',
input_time=torch.tensor([4, 4, 4, 4, 4]),
time_attr='edge_time',
batch_size=1,
)

for batch in loader:
assert batch['A', 'A'].edge_time.numel() == batch['A', 'A'].num_edges
if batch['A', 'A'].edge_time.numel() > 0:
assert batch['A', 'A'].edge_time.max() <= 4


@withCUDA
@onlyNeighborSampler
@withPackage('torch_frame')
Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/datasets/movie_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ def process(self):
edge_index = torch.tensor([src, dst])

rating = torch.from_numpy(df['rating'].values).to(torch.long)
time = torch.from_numpy(df['timestamp'].values).to(torch.long)

data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = rating
data['user', 'rates', 'movie'].time = time

if self.pre_transform is not None:
data = self.pre_transform(data)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class LinkNeighborLoader(LinkLoader):
Deprecated in favor of the :obj:`neg_sampling` argument.
(default: :obj:`None`)
time_attr (str, optional): The name of the attribute that denotes
timestamps for the nodes in the graph.
timestamps for either the nodes or edges in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the center node.
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class NeighborLoader(NodeLoader):
fulfill temporal constraints.
(default: :obj:`"uniform"`)
time_attr (str, optional): The name of the attribute that denotes
timestamps for the nodes in the graph.
timestamps for either the nodes or edges in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the center node.
Expand Down
67 changes: 60 additions & 7 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,26 @@ def __init__(
self.num_nodes = data.num_nodes

self.node_time: Optional[Tensor] = None
self.edge_time: Optional[Tensor] = None

if time_attr is not None:
self.node_time = data[time_attr]
if data.is_node_attr(time_attr):
self.node_time = data[time_attr]
elif data.is_edge_attr(time_attr):
self.edge_time = data[time_attr]
else:
raise ValueError(
f"The time attribute '{time_attr}' is neither a "
f"node-level or edge-level attribute")

# Convert the graph data into CSC format for sampling:
self.colptr, self.row, self.perm = to_csc(
data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted, src_node_time=self.node_time)
is_sorted=is_sorted, src_node_time=self.node_time,
edge_time=self.edge_time)

if self.edge_time is not None and self.perm is not None:
self.edge_time = self.edge_time[self.perm]

self.edge_weight: Optional[Tensor] = None
if weight_attr is not None:
Expand All @@ -89,8 +102,32 @@ def __init__(
self.num_nodes = {k: data[k].num_nodes for k in self.node_types}

self.node_time: Optional[Dict[NodeType, Tensor]] = None
self.edge_time: Optional[Dict[EdgeType, Tensor]] = None

if time_attr is not None:
self.node_time = data.collect(time_attr)
is_node_level_time = is_edge_level_time = False

for store in data.node_stores:
if time_attr in store:
is_node_level_time = True
for store in data.edge_stores:
if time_attr in store:
is_edge_level_time = True

if is_node_level_time and is_edge_level_time:
raise ValueError(
f"The time attribute '{time_attr}' holds both "
f"node-level and edge-level information")

if not is_node_level_time and not is_edge_level_time:
raise ValueError(
f"The time attribute '{time_attr}' is neither a "
f"node-level or edge-level attribute")

if is_node_level_time:
self.node_time = data.collect(time_attr)
else:
self.edge_time = data.collect(time_attr)

# Conversion to/from C++ string type: Since C++ cannot take
# dictionaries with tuples as key as input, edge type triplets need
Expand All @@ -101,10 +138,19 @@ def __init__(
# Convert the graph data into CSC format for sampling:
colptr_dict, row_dict, self.perm = to_hetero_csc(
data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted, node_time_dict=self.node_time)
is_sorted=is_sorted, node_time_dict=self.node_time,
edge_time_dict=self.edge_time)

self.row_dict = remap_keys(row_dict, self.to_rel_type)
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

if self.edge_time is not None:
for edge_type, edge_time in self.edge_time.items():
if self.perm.get(edge_type, None) is not None:
edge_time = edge_time[self.perm[edge_type]]
self.edge_time[edge_type] = edge_time
self.edge_time = remap_keys(self.edge_time, self.to_rel_type)

self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
if weight_attr is not None:
self.edge_weight = data.collect(weight_attr)
Expand Down Expand Up @@ -180,6 +226,8 @@ def __init__(
self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None

self.node_time: Optional[Dict[NodeType, Tensor]] = None
self.edge_time: Optional[Dict[NodeType, Tensor]] = None

if time_attr is not None:
for attr in time_attrs: # Reset index for full data.
attr.index = None
Expand All @@ -198,6 +246,11 @@ def __init__(
self.row_dict = remap_keys(row_dict, self.to_rel_type)
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

if (self.edge_time is not None
and not torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE):
raise ImportError("Edge-level temporal sampling requires a "
"more recent 'pyg-lib' installation")

if (self.edge_weight is not None
and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE):
raise ImportError("Weighted neighbor sampling requires "
Expand Down Expand Up @@ -232,7 +285,7 @@ def is_hetero(self) -> bool:

@property
def is_temporal(self) -> bool:
return self.node_time is not None
return self.node_time is not None or self.edge_time is not None

@property
def disjoint(self) -> bool:
Expand Down Expand Up @@ -302,7 +355,7 @@ def _sample(
self.node_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (None, )
args += (self.edge_time, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
Expand Down Expand Up @@ -384,7 +437,7 @@ def _sample(
self.node_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (None, )
args += (self.edge_time, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
Expand Down
24 changes: 17 additions & 7 deletions torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import NodeType, OptTensor
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import coalesce, index_sort, lexsort
from torch_geometric.utils.sparse import index2ptr

Expand All @@ -16,11 +16,19 @@ def sort_csc(
row: Tensor,
col: Tensor,
src_node_time: OptTensor = None,
edge_time: OptTensor = None,
) -> Tuple[Tensor, Tensor, Tensor]:
if src_node_time is None:

if src_node_time is None and edge_time is None:
col, perm = index_sort(col)
return row[perm], col, perm
else:

elif edge_time is not None:
assert src_node_time is None
perm = lexsort([edge_time, col])
return row[perm], col[perm], perm

else: # src_node_time is not None
perm = lexsort([src_node_time[row], col])
return row[perm], col[perm], perm

Expand All @@ -32,6 +40,7 @@ def to_csc(
share_memory: bool = False,
is_sorted: bool = False,
src_node_time: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, OptTensor]:
# Convert the graph data into a suitable format for sampling (CSC format).
# Returns the `colptr` and `row` indices of the graph, as well as an
Expand Down Expand Up @@ -61,10 +70,8 @@ def to_csc(
elif data.edge_index is not None:
row, col = data.edge_index
if not is_sorted:
row, col, perm = sort_csc(row, col, src_node_time)

row, col, perm = sort_csc(row, col, src_node_time, edge_time)
colptr = index2ptr(col, data.size(1))

else:
row = torch.empty(0, dtype=torch.long, device=device)
colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
Expand All @@ -89,6 +96,7 @@ def to_hetero_csc(
share_memory: bool = False,
is_sorted: bool = False,
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
# Convert the heterogeneous graph data into a suitable format for sampling
# (CSC format).
Expand All @@ -98,7 +106,9 @@ def to_hetero_csc(

for edge_type, store in data.edge_items():
src_node_time = (node_time_dict or {}).get(edge_type[0], None)
out = to_csc(store, device, share_memory, is_sorted, src_node_time)
edge_time = (edge_time_dict or {}).get(edge_type, None)
out = to_csc(store, device, share_memory, is_sorted, src_node_time,
edge_time)
colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out

return colptr_dict, row_dict, perm_dict
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
WITH_SAMPLED_OP = False
WITH_INDEX_SORT = False
WITH_METIS = False
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False

try:
Expand Down

0 comments on commit 7357a34

Please sign in to comment.