Skip to content

Commit bcda1c7

Browse files
committed
Fix edge weight normalization
1 parent c4154c8 commit bcda1c7

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

Diff for: torch_cluster/rw.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,17 @@ def random_walk(
6262
rowptr, col, start, walk_length, p, q,
6363
)
6464
else:
65-
# Normalize edge weights by node degrees
66-
edge_weight = edge_weight / deg[row]
65+
# Normalize edge weights
66+
from torch_sparse import SparseTensor
67+
68+
adj = SparseTensor(
69+
row=row,
70+
col=col,
71+
value=edge_weight,
72+
sparse_sizes=(num_nodes, num_nodes),
73+
)
74+
75+
edge_weight = edge_weight / adj.sum(dim=1).repeat_interleave(deg)
6776

6877
node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted(
6978
rowptr, col, edge_weight, start, walk_length, p, q,

0 commit comments

Comments
 (0)