-
Notifications
You must be signed in to change notification settings - Fork 5
/
E_high.py
24 lines (17 loc) · 878 Bytes
/
E_high.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch_geometric.utils as pyg_utils
def compute_E_high(adj_matrix, feat_matrix):
adj_tensor = torch.tensor(adj_matrix, dtype=torch.float32)
feat_tensor = feat_matrix.clone().detach().to(dtype=torch.float32)
deg_tensor = torch.sum(adj_tensor, dim=1)
deg_matrix = torch.diag(deg_tensor)
laplacian_tensor = deg_matrix - adj_tensor
numerator = torch.matmul(torch.matmul(feat_tensor.T, laplacian_tensor), feat_tensor)
denominator = torch.matmul(feat_tensor.T, feat_tensor)
S_high = torch.sum(numerator) / torch.sum(denominator)
return S_high.item()
def compute_G_ano(adj_matrix, feat_matrix):
a_high = compute_E_high(adj_matrix, feat_matrix)
deg_matrix = torch.diag(torch.sum(torch.tensor(adj_matrix, dtype=torch.float32), dim=1))
s_high = compute_E_high(adj_matrix, deg_matrix)
return a_high, s_high