-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model & Dataset] facebook & sp2gcl (#201)
* Create facebook.py * sp2_gcl * sp2gcl * sp2-gcl * sp2_gcl_1 * sp2_gcl * sp2_gcl_new * sp2_gcl_new * sp2gcl_new * change the code of sp2gcl * update * update * update test file * update * update * update --------- Co-authored-by: Guangyu Zhou <[email protected]> Co-authored-by: gyzhou2000 <[email protected]>
- Loading branch information
1 parent
27a563a
commit e276485
Showing
11 changed files
with
460 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
import tensorlayerx as tlx | ||
import tensorlayerx.nn as nn | ||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, data, label): | ||
logits = self.backbone_network(data['x']) | ||
train_logits = tlx.gather(logits, data['train_idx']) | ||
train_y = tlx.gather(data['y'], data['train_idx']) | ||
loss = self._loss_fn(train_logits, train_y) | ||
return loss | ||
|
||
class LogReg(nn.Module): | ||
def __init__(self, hid_dim, out_dim, name=None): | ||
super(LogReg, self).__init__(name=name) | ||
self.linear = nn.Linear(in_features=hid_dim, out_features=out_dim, W_init=tlx.initializers.xavier_uniform(), b_init=tlx.initializers.zeros()) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
|
||
def node_evaluation(emb, y, train_idx, test_idx, lr=1e-2, weight_decay=1e-4): | ||
nclass = y.max().item() + 1 | ||
logreg = LogReg(hid_dim=emb.shape[1], out_dim=nclass, name="logreg") | ||
opt = tlx.optimizers.Adam(lr=lr, weight_decay=weight_decay) | ||
train_weights = logreg.trainable_weights | ||
loss = tlx.losses.softmax_cross_entropy_with_logits | ||
loss_func = SemiSpvzLoss(logreg, loss) | ||
train_one_step = TrainOneStep(loss_func, opt, train_weights) | ||
|
||
data = { | ||
'x': emb, | ||
'y': y, | ||
'train_idx':train_idx, | ||
'test_idx':test_idx | ||
} | ||
|
||
best_test_acc = 0 | ||
for epoch in range(100): | ||
logreg.set_train() | ||
loss = train_one_step(data=data, label=y) | ||
logreg.set_eval() | ||
test_logits = logreg(emb[test_idx]) | ||
test_preds = tlx.argmax(test_logits, axis=1) | ||
test_acc = tlx.reduce_sum(test_preds == y[test_idx]).float() / test_idx.size(0) | ||
if test_acc > best_test_acc: | ||
best_test_acc = test_acc | ||
logreg.save_weights("./"+logreg.name+".npz", format='npz_dict') | ||
|
||
# print("==Epoch [{:0>3d}] ".format(epoch+1)\ | ||
# + " train loss: {:.4f}".format(loss.item())\ | ||
# + " acc: {:.4f}".format(test_acc.item())) | ||
|
||
logreg.load_weights("./"+logreg.name+".npz", format='npz_dict') | ||
logreg.set_eval() | ||
test_logits = logreg(emb[test_idx]) | ||
test_preds = tlx.argmax(test_logits, axis=1) | ||
test_acc = tlx.reduce_sum(test_preds == y[test_idx]).float() / test_idx.size(0) | ||
|
||
return test_acc.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Graph Contrastive Learning with Stable and Scalable | ||
|
||
- Paper link: [https://proceedings.neurips.cc/paper_files/paper/2023/file/8e9a6582caa59fda0302349702965171-Paper-Conference.pdf](https://proceedings.neurips.cc/paper_files/paper/2023/file/8e9a6582caa59fda0302349702965171-Paper-Conference.pdf) | ||
- Author's code repo: [https://github.com/bdy9527/Sp2GCL](https://github.com/bdy9527/Sp2GCL). | ||
|
||
# Dataset Statics | ||
|
||
| Dataset | # Nodes | # Edges | # Classes | | ||
|----------|---------|----------|-----------| | ||
| PubMed | 19,717 | 88,648 | 3 | | ||
| Wiki-CS | 11,701 | 216,123 | 10 | | ||
| Facebook | 22,470 | 342,004 | 4 | | ||
|
||
|
||
|
||
Results | ||
------- | ||
|
||
```bash | ||
TL_BACKEND="torch" python sp2gcl_trainer.py --dataset pubmed --hidden_dim 4096 --spe_dim 30 --output_dim 32 --lr 0.001 --period 50 --l2_coef 5e-4 --n_epoch 5 | ||
TL_BACKEND="torch" python sp2gcl_trainer.py --dataset wikics --hidden_dim 2048 --spe_dim 100 --output_dim 32 --lr 0.001 --period 10 --l2_coef 5e-1 --n_epoch 3 | ||
TL_BACKEND="torch" python sp2gcl_trainer.py --dataset facebook --hidden_dim 1500 --spe_dim 100 --output_dim 32 --lr 0.001 --period 10 --l2_coef 5e-4 --n_epoch 5 | ||
``` | ||
|
||
|
||
# Dataset Statics | ||
|
||
| Dataset | Paper Code | Out(th) | | ||
| -------- | ---------- | ---------- | | ||
| PubMed | 82.3±0.3 | 78.66±0.76 | | ||
| Wiki-CS | 79.42±0.19 | 78.64±0.20 | | ||
| Facebook | 90.43±0.13 | 87.53±0.34 | | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
# os.environ['TL_BACKEND'] = 'torch' | ||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
import argparse | ||
import tensorlayerx as tlx | ||
import scipy.sparse.linalg | ||
from evaluation_test import node_evaluation | ||
from tensorlayerx.model import WithLoss, TrainOneStep | ||
from gammagl.models import SpaSpeNode | ||
from gammagl.datasets import FacebookPagePage, WikiCS, Planetoid | ||
from gammagl.utils import get_laplacian, to_scipy_sparse_matrix, get_train_val_test_split, mask_to_index | ||
import numpy as np | ||
|
||
|
||
def compute_laplacian(data, args): | ||
edge_index = data.edge_index | ||
num_nodes = data.num_nodes | ||
|
||
index, attr = get_laplacian(edge_index, num_nodes, normalization="sym") | ||
L = to_scipy_sparse_matrix(index, attr) | ||
e, u = scipy.sparse.linalg.eigsh(L, k=args.spe_dim, which='SM', tol=1e-3) | ||
|
||
data.e = tlx.convert_to_tensor(e, dtype=tlx.float32) | ||
data.u = tlx.convert_to_tensor(u, dtype=tlx.float32) | ||
|
||
return data | ||
|
||
|
||
class ContrastiveLoss(WithLoss): | ||
def __init__(self, model, loss_fn): | ||
super(ContrastiveLoss, self).__init__(backbone=model, loss_fn=loss_fn) | ||
|
||
def forward(self, data, label): | ||
h1, h2 = self.backbone_network(data['x'], data['edge_index'], data['e'], data['u']) | ||
|
||
logits = tlx.matmul(h1, tlx.transpose(h2, perm=(1, 0))) | ||
|
||
exp_logits = tlx.exp(logits) | ||
diag = tlx.convert_to_tensor(np.diagonal(tlx.convert_to_numpy(exp_logits))) | ||
|
||
sum_rows_a = tlx.reduce_sum(exp_logits, axis=1) | ||
sum_rows_b = tlx.reduce_sum(exp_logits, axis=0) | ||
|
||
log_prob1 = tlx.log(diag / (sum_rows_a - diag)) | ||
log_prob2 = tlx.log(diag / (sum_rows_b - diag)) | ||
loss = -0.5 * (log_prob1 + log_prob2).mean() | ||
|
||
return loss | ||
|
||
def main(args): | ||
if args.dataset in ['pubmed', 'wikics', 'facebook']: | ||
if args.dataset == 'facebook': | ||
dataset = FacebookPagePage() | ||
elif args.dataset == 'wikics': | ||
dataset = WikiCS() | ||
elif args.dataset == 'pubmed': | ||
dataset = Planetoid(name=args.dataset) | ||
|
||
else: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
|
||
data = dataset[0] | ||
data = compute_laplacian(data, args) | ||
e = tlx.convert_to_tensor(data.e[:args.spe_dim], dtype=tlx.float32) | ||
u = tlx.convert_to_tensor(data.u[:, :args.spe_dim], dtype=tlx.float32) | ||
if 'train_mask' in data.keys: | ||
train_mask = data.train_mask | ||
test_mask = data.test_mask | ||
val_mask = data.val_mask | ||
else: | ||
train_mask, val_mask, test_mask = get_train_val_test_split(data, 0.1, 0.1) | ||
train_idx = mask_to_index(train_mask) | ||
test_idx = mask_to_index(test_mask) | ||
val_idx = mask_to_index(val_mask) | ||
|
||
model = SpaSpeNode(input_dim=data.x.shape[1], | ||
spe_dim=args.spe_dim, | ||
hidden_dim=args.hidden_dim, | ||
output_dim=args.output_dim, | ||
period=args.period, | ||
name="sp2gcl") | ||
|
||
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) | ||
train_weights = model.trainable_weights | ||
loss_func = ContrastiveLoss(model, tlx.losses.softmax_cross_entropy_with_logits) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
||
data_all = { | ||
'x': data.x, | ||
'edge_index': data.edge_index, | ||
'train_idx': train_idx, | ||
'valid_idx': val_idx, | ||
'test_idx': test_idx, | ||
'e': data.e, | ||
'u': data.u | ||
} | ||
|
||
best_val_acc = 0 | ||
for epoch in range(args.n_epoch): | ||
model.set_train() | ||
train_loss = train_one_step(data=data_all, label=data.y) | ||
model.set_eval() | ||
spa_emb = tlx.detach(model.spa_encoder(data.x, data.edge_index)) | ||
spe_emb = tlx.detach(model.spe_encoder(e, u)) | ||
val_acc = node_evaluation(tlx.concat((spa_emb, spe_emb), axis=-1), data.y, train_idx, val_idx) | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
model.save_weights(args.best_model_path+model.name+".npz", format='npz_dict') | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
+ " train loss: {:.4f}".format(train_loss.item())\ | ||
+ " val acc: {:.4f}".format(val_acc)) | ||
|
||
model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict') | ||
model.set_eval() | ||
spa_emb = tlx.detach(model.spa_encoder(data.x, data.edge_index)) | ||
spe_emb = tlx.detach(model.spe_encoder(e, u)) | ||
test_acc = node_evaluation(tlx.concat((spa_emb, spe_emb), axis=-1), data.y, train_idx, test_idx) | ||
print("Test acc: {:.4f}".format(test_acc)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', default='wikics', help='dataset') | ||
parser.add_argument('--dataset_path', type=str, default=r'', help="path to save dataset") | ||
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") | ||
parser.add_argument('--spe_dim', type=int, default=100) | ||
parser.add_argument('--period', type=int, default=256) | ||
parser.add_argument('--hidden_dim', type=int, default=512, help="dimention of hidden layers") | ||
parser.add_argument('--output_dim', type=int, default=512) | ||
parser.add_argument('--lr', type=float, default=0.01, help="learnin rate") | ||
parser.add_argument('--l2_coef', type=float, default=5e-4, help="l2 loss coeficient") | ||
parser.add_argument('--n_epoch', type=int, default=3, help="number of epoch") | ||
parser.add_argument("--gpu", type=int, default=0) | ||
args = parser.parse_args() | ||
if args.gpu >=0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Callable, Optional | ||
import os | ||
import numpy as np | ||
import tensorlayerx as tlx | ||
|
||
from gammagl.data import Graph, InMemoryDataset, download_url | ||
|
||
class FacebookPagePage(InMemoryDataset): | ||
r"""The Facebook Page-Page network dataset introduced in the | ||
`"Multi-scale Attributed Node Embedding" | ||
<https://arxiv.org/abs/1909.13021>`_ paper. | ||
Nodes represent verified pages on Facebook and edges are mutual likes. | ||
It contains 22,470 nodes, 342,004 edges, 128 node features and 4 classes. | ||
Args: | ||
root (str): Root directory where the dataset should be saved. | ||
transform (callable, optional): A function/transform that takes in an | ||
:obj:`gammagl.data.Graph` object and returns a transformed | ||
version. The data object will be transformed before every access. | ||
(default: :obj:`None`) | ||
pre_transform (callable, optional): A function/transform that takes in | ||
an :obj:`gammagl.data.Graph` object and returns a | ||
transformed version. The data object will be transformed before | ||
being saved to disk. (default: :obj:`None`) | ||
force_reload (bool, optional): Whether to re-process the dataset. | ||
(default: :obj:`False`) | ||
""" | ||
|
||
url = 'https://graphmining.ai/datasets/ptg/facebook.npz' | ||
|
||
def __init__( | ||
self, | ||
root: str = None, | ||
transform: Optional[Callable] = None, | ||
pre_transform: Optional[Callable] = None, | ||
force_reload: bool = False, | ||
) -> None: | ||
super().__init__(root, transform, pre_transform, force_reload=force_reload) | ||
self.data, self.slices=self.load_data(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self) -> str: | ||
return 'facebook.npz' | ||
|
||
@property | ||
def processed_file_names(self) -> str: | ||
return tlx.BACKEND + '_data.pt' | ||
|
||
def download(self) -> None: | ||
download_url(self.url, self.raw_dir) | ||
|
||
def process(self) -> None: | ||
data = np.load(self.raw_paths[0], 'r', allow_pickle=True) | ||
x = tlx.convert_to_tensor(data['features'], dtype=tlx.float32) | ||
y = tlx.convert_to_tensor(data['target'], dtype=tlx.int64) | ||
edge_index = tlx.convert_to_tensor(data['edges'], dtype=tlx.int64) | ||
if tlx.BACKEND == 'mindspore': | ||
edge_index = edge_index.T | ||
else: | ||
edge_index = tlx.ops.transpose(edge_index) | ||
|
||
data = Graph(x=x, edge_index=edge_index, y=y) | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
self.save_data(self.collate([data]), self.processed_paths[0]) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.