diff --git a/cogdl/data/data.py b/cogdl/data/data.py
index 1614a9a0..d2d71475 100644
--- a/cogdl/data/data.py
+++ b/cogdl/data/data.py
@@ -480,16 +480,17 @@ def __init__(self, x=None, y=None, **kwargs):
self.x = x
self.y = y
self.grb_adj = None
+ num_nodes = x.shape[0] if x is not None else None
for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
+ num_nodes = item
elif key == "grb_adj":
self.grb_adj = item
elif not is_read_adj_key(key):
self[key] = item
- num_nodes = x.shape[0] if x is not None else None
if "edge_index_train" in kwargs:
self._adj_train = Adjacency(num_nodes=num_nodes)
for key, item in kwargs.items():
@@ -534,14 +535,17 @@ def add_remaining_self_loops(self):
self._adj_full.add_remaining_self_loops()
if self._adj_train is not None:
self._adj_train.add_remaining_self_loops()
+ return self
def padding_self_loops(self):
self._adj.padding_self_loops()
+ return self
def remove_self_loops(self):
self._adj_full.remove_self_loops()
if self._adj_train is not None:
self._adj_train.remove_self_loops()
+ return self
def row_norm(self):
self._adj.row_norm()
@@ -790,7 +794,7 @@ def sample_adj(self, batch, size=-1, replace=True):
if not torch.is_tensor(batch):
batch = torch.tensor(batch, dtype=torch.long)
(row_ptr, col_indices, nodes, edges) = sample_adj_c(
- self._adj.row_indptr, self.col_indices, batch, size, replace
+ self.row_indptr, self.col_indices, batch, size, replace
)
else:
if torch.is_tensor(batch):
@@ -891,13 +895,18 @@ def subgraph(self, node_idx, keep_order=False):
val = self.edge_weight.numpy()
N = self.num_nodes
self[key] = sp.csr_matrix((val, (row, col)), shape=(N, N))
- sub_adj = self[key][node_idx, :][:, node_idx]
+ sub_adj = self[key][node_idx, :][:, node_idx].tocoo()
sub_g = Graph()
- sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long()
- sub_g.col_indices = torch.from_numpy(sub_adj.indices).long()
+ # sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long()
+ # sub_g.col_indices = torch.from_numpy(sub_adj.indices).long()
+ row = torch.from_numpy(sub_adj.row).long()
+ col = torch.from_numpy(sub_adj.col).long()
+ sub_g.edge_index = (row, col)
sub_g.edge_weight = torch.from_numpy(sub_adj.data)
+ sub_g.num_nodes = len(node_idx)
for key in self.__keys__():
sub_g[key] = self[key][node_idx]
+ sub_g._adj._to_csr()
return sub_g.to(self._adj.device)
def edge_subgraph(self, edge_idx, require_idx=True):
diff --git a/examples/graphmae/imgs/compare.png b/examples/graphmae/imgs/compare.png
new file mode 100644
index 00000000..10ac5d4a
Binary files /dev/null and b/examples/graphmae/imgs/compare.png differ
diff --git a/examples/graphmae/imgs/fig.png b/examples/graphmae/imgs/fig.png
new file mode 100644
index 00000000..33877e6f
Binary files /dev/null and b/examples/graphmae/imgs/fig.png differ
diff --git a/examples/graphmae2/README.md b/examples/graphmae2/README.md
new file mode 100644
index 00000000..af672372
--- /dev/null
+++ b/examples/graphmae2/README.md
@@ -0,0 +1,87 @@
+
GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
+Graph Learner
+
+[**CogDL**](https://github.com/THUDM/cogdl) Implementation for WWW'23 paper: [GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
+Graph Learner](https://arxiv.org/abs/2304.04779).
+
+
+[GraphMAE] The predecessor of this work: [GraphMAE: Self-Supervised Masked Graph Autoencoders](https://arxiv.org/abs/2205.10803) can be found [here](https://github.com/THUDM/cogdl/tree/master/examples/graphmae).
+
+Dependencies
+
+* Python >= 3.7
+* [Pytorch](https://pytorch.org/) >= 1.9.0
+* [cogdl](https://github.com/THUDM/cogdl) >= 0.5.3
+* pyyaml == 5.4.1
+
+
+Quick Start
+
+For quick start, you could run the scripts:
+
+**Node classification**
+
+```bash
+sh run_minibatch.sh # for mini batch node classification
+# example: sh run_minibatch.sh ogbn-arxiv 0
+sh run_fullbatch.sh # for full batch node classification
+# example: sh run_fullbatch.sh cora 0
+
+# Or you could run the code manually:
+# for mini batch node classification
+python main_large.py --dataset ogbn-arxiv --encoder gat --decoder gat --seed 0 --device 0
+# for full batch node classification
+python main_full_batch.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0
+```
+
+Supported datasets:
+
+* mini batch node classification: `ogbn-arxiv`, `ogbn-products`, `mag-scholar-f`, `ogbn-papers100M`
+* full batch node classification: `cora`, `citeseer`, `pubmed`
+
+Run the scripts provided or add `--use_cfg` in command to reproduce the reported results.
+
+**For Large scale graphs**
+Before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. By default, the program will load dataset from `./data` and save the generated local clusters to `./lc_ego_graphs`. To generate a local cluster, you should first install [localclustering](https://github.com/kfoynt/LocalGraphClustering) and then run the following command:
+
+```
+python ./datasets/localclustering.py --dataset --data_dir
+```
+And we also provide the pre-generated local clusters which can be downloaded [here](https://cloud.tsinghua.edu.cn/d/64f859f389ca43eda472/) and then put into `lc_ego_graphs` for usage.
+
+
+
+ Datasets
+
+During the code's execution, the OGB and small-scale datasets (Cora, Citeseer, and PubMed) will be downloaded automatically.
+
+ Experimental Results
+
+Experimental results of node classification on large-scale datasets (Accuracy, %):
+
+| | Ogbn-arxiv | Ogbn-products | Mag-Scholar-F | Ogbn-papers100M |
+| ------------------ | ------------ | ------------ | ------------ | -------------- |
+| MLP | 55.50±0.23 | 61.06±0.08 | 39.11±0.21 | 47.24±0.31 |
+| SGC | 66.92±0.08 | 74.87±0.25 | 54.68±0.23 | 63.29±0.19 |
+| Random-Init | 68.14±0.02 | 74.04±0.06 | 56.57±0.03 | 61.55±0.12 |
+| CCA-SSG | 68.57±0.02 | 75.27±0.05 | 51.55±0.03 | 55.67±0.15 |
+| GRACE | 69.34±0.01 | 79.47±0.59 | 57.39±0.02 | 61.21±0.12 |
+| BGRL | 70.51±0.03 | 78.59±0.02 | 57.57±0.01 | 62.18±0.15 |
+| GGD | - | 75.70±0.40 | - | 63.50±0.50 |
+| GraphMAE | 71.03±0.02 | 78.89±0.01 | 58.75±0.03 | 62.54±0.09 |
+| **GraphMAE2** | **71.89±0.03** | **81.59±0.02** | **59.24±0.01** | **64.89±0.04** |
+
+
+
+ Citing
+
+If you find this work is helpful to your research, please consider citing our paper:
+
+```
+@inproceedings{hou2023graphmae2,
+ title={GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner},
+ author={Zhenyu Hou, Yufei He, Yukuo Cen, Xiao Liu, Yuxiao Dong, Evgeny Kharlamov, Jie Tang},
+ booktitle={Proceedings of the ACM Web Conference 2023 (WWW’23)},
+ year={2023}
+}
+```
diff --git a/examples/graphmae2/asserts/overview.png b/examples/graphmae2/asserts/overview.png
new file mode 100644
index 00000000..c1ed51eb
Binary files /dev/null and b/examples/graphmae2/asserts/overview.png differ
diff --git a/examples/graphmae2/configs/citeseer.yaml b/examples/graphmae2/configs/citeseer.yaml
new file mode 100644
index 00000000..013838d2
--- /dev/null
+++ b/examples/graphmae2/configs/citeseer.yaml
@@ -0,0 +1,27 @@
+lr: 0.0005 # 0.0005
+lr_f: 0.025
+num_hidden: 1024
+num_heads: 4
+num_out_heads: 1
+num_layers: 2
+weight_decay: 1e-4
+weight_decay_f: 1e-2
+max_epoch: 500
+max_epoch_f: 500
+mask_rate: 0.5
+num_layers: 2
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.1
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.0
+optimizer: adam
+replace_rate: 0.0
+alpha_l: 1
+scheduler: True
+remask_method: fixed
+momentum: 1
+lam: 0.1
\ No newline at end of file
diff --git a/examples/graphmae2/configs/cora.yaml b/examples/graphmae2/configs/cora.yaml
new file mode 100644
index 00000000..c0d9f2cb
--- /dev/null
+++ b/examples/graphmae2/configs/cora.yaml
@@ -0,0 +1,27 @@
+lr: 0.001
+lr_f: 0.025
+num_hidden: 1024
+num_heads: 8
+num_out_heads: 1
+num_layers: 2
+weight_decay: 2e-4
+weight_decay_f: 1e-4
+max_epoch: 2000
+max_epoch_f: 300
+mask_rate: 0.5
+num_layers: 2
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.1
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.0
+optimizer: adam
+replace_rate: 0.1
+alpha_l: 4
+scheduler: True
+remask_method: fixed
+momentum: 0
+lam: 0.1
\ No newline at end of file
diff --git a/examples/graphmae2/configs/mag-scholar-f.yaml b/examples/graphmae2/configs/mag-scholar-f.yaml
new file mode 100644
index 00000000..46aead7f
--- /dev/null
+++ b/examples/graphmae2/configs/mag-scholar-f.yaml
@@ -0,0 +1,30 @@
+lr: 0.001
+lr_f: 0.001
+num_hidden: 1024
+num_heads: 8
+num_out_heads: 1
+num_layers: 4
+weight_decay: 0.04
+weight_decay_f: 0
+max_epoch: 10
+max_epoch_f: 1000
+batch_size: 512
+batch_size_f: 256
+mask_rate: 0.5
+num_layers: 4
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.2
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.5
+optimizer: adamw
+alpha_l: 2
+scheduler: True
+remask_method: random
+momentum: 0.996
+lam: 0.1
+delayed_ema_epoch: 0
+num_remasking: 3
\ No newline at end of file
diff --git a/examples/graphmae2/configs/ogbn-arxiv.yaml b/examples/graphmae2/configs/ogbn-arxiv.yaml
new file mode 100644
index 00000000..a7ca3ccf
--- /dev/null
+++ b/examples/graphmae2/configs/ogbn-arxiv.yaml
@@ -0,0 +1,30 @@
+lr: 0.0025
+lr_f: 0.005
+num_hidden: 1024
+num_heads: 8
+num_out_heads: 1
+num_layers: 4
+weight_decay: 0.06
+weight_decay_f: 1e-4
+max_epoch: 60
+max_epoch_f: 1000
+batch_size: 512
+batch_size_f: 256
+mask_rate: 0.5
+num_layers: 4
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.1
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.5
+optimizer: adamw
+alpha_l: 6
+scheduler: True
+remask_method: random
+momentum: 0.996
+lam: 10.0
+delayed_ema_epoch: 40
+num_remasking: 3
\ No newline at end of file
diff --git a/examples/graphmae2/configs/ogbn-papers100M.yaml b/examples/graphmae2/configs/ogbn-papers100M.yaml
new file mode 100644
index 00000000..4dba4f98
--- /dev/null
+++ b/examples/graphmae2/configs/ogbn-papers100M.yaml
@@ -0,0 +1,30 @@
+lr: 0.001
+lr_f: 0.001
+num_hidden: 1024
+num_heads: 4
+num_out_heads: 1
+num_layers: 4
+weight_decay: 0.05
+weight_decay_f: 0
+max_epoch: 10
+max_epoch_f: 1000
+batch_size: 512
+batch_size_f: 256
+mask_rate: 0.5
+num_layers: 4
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.2
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.5
+optimizer: adamw
+alpha_l: 2
+scheduler: True
+remask_method: random
+momentum: 0.996
+lam: 10.0
+delayed_ema_epoch: 0
+num_remasking: 3
\ No newline at end of file
diff --git a/examples/graphmae2/configs/ogbn-products.yaml b/examples/graphmae2/configs/ogbn-products.yaml
new file mode 100644
index 00000000..d822d49b
--- /dev/null
+++ b/examples/graphmae2/configs/ogbn-products.yaml
@@ -0,0 +1,30 @@
+lr: 0.002
+lr_f: 0.001
+num_hidden: 1024
+num_heads: 4
+num_out_heads: 1
+num_layers: 4
+weight_decay: 0.04
+weight_decay_f: 0
+max_epoch: 20
+max_epoch_f: 1000
+batch_size: 512
+batch_size_f: 256
+mask_rate: 0.5
+num_layers: 4
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.2
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.5
+optimizer: adamw
+alpha_l: 3
+scheduler: True
+remask_method: random
+momentum: 0.996
+lam: 5.0
+delayed_ema_epoch: 0
+num_remasking: 3
\ No newline at end of file
diff --git a/examples/graphmae2/configs/pubmed.yaml b/examples/graphmae2/configs/pubmed.yaml
new file mode 100644
index 00000000..d80fb21d
--- /dev/null
+++ b/examples/graphmae2/configs/pubmed.yaml
@@ -0,0 +1,27 @@
+lr: 0.005
+lr_f: 0.025
+num_hidden: 512
+num_heads: 2
+num_out_heads: 1
+num_layers: 2
+weight_decay: 1e-5
+weight_decay_f: 5e-4
+max_epoch: 2000
+max_epoch_f: 500
+mask_rate: 0.9
+num_layers: 2
+encoder: gat
+decoder: gat
+activation: prelu
+attn_drop: 0.1
+linear_prob: True
+in_drop: 0.2
+loss_fn: sce
+drop_edge_rate: 0.0
+optimizer: adam
+replace_rate: 0.0
+alpha_l: 4
+scheduler: True
+remask_method: fixed
+momentum: 0.995
+lam: 1
diff --git a/examples/graphmae2/datasets/__init__.py b/examples/graphmae2/datasets/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/graphmae2/datasets/data_proc.py b/examples/graphmae2/datasets/data_proc.py
new file mode 100644
index 00000000..4710907a
--- /dev/null
+++ b/examples/graphmae2/datasets/data_proc.py
@@ -0,0 +1,50 @@
+import logging
+from typing import Counter
+from xml.sax.handler import feature_string_interning
+import numpy as np
+from collections import namedtuple
+import scipy.sparse as sp
+from scipy.special import iv
+from sklearn import preprocessing
+from sklearn.utils.extmath import randomized_svd
+
+import torch
+import torch.nn.functional as F
+
+from cogdl.utils import to_undirected
+from cogdl.datasets import build_dataset_from_path
+from sklearn.preprocessing import StandardScaler
+
+def load_small_dataset(data_dir, dataset_name):
+ dataset = build_dataset_from_path(data_dir, dataset=dataset_name)
+ if dataset_name == "ogbn-arxiv":
+ graph = dataset[0]
+ feat = graph.x
+ feat = scale_feats(feat)
+ graph.x = feat
+ else:
+ graph = dataset[0]
+ graph.add_remaining_self_loops()
+
+ num_features = graph.x.shape[1]
+ num_classes = dataset.num_classes
+ return graph, (num_features, num_classes)
+
+def preprocess(graph):
+ feat = graph.x
+ edge_index = graph.edge_index
+ edge_index = to_undirected(edge_index, num_nodes=graph.num_nodes)
+ graph.edge_index = edge_index
+ graph.x = feat
+
+ graph.add_remaining_self_loops()
+ return graph
+
+
+def scale_feats(x):
+ logging.info("### scaling features ###")
+ scaler = StandardScaler()
+ feats = x.numpy()
+ scaler.fit(feats)
+ feats = torch.from_numpy(scaler.transform(feats)).float()
+ return feats
diff --git a/examples/graphmae2/datasets/lc_sampler.py b/examples/graphmae2/datasets/lc_sampler.py
new file mode 100644
index 00000000..ff2e680e
--- /dev/null
+++ b/examples/graphmae2/datasets/lc_sampler.py
@@ -0,0 +1,156 @@
+import os
+
+import numpy as np
+import torch
+
+import cogdl
+from cogdl.datasets import build_dataset_from_path
+from cogdl.data import Graph
+
+from .data_proc import preprocess, scale_feats
+from utils import mask_edge
+
+import logging
+import torch.multiprocessing
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+
+logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
+
+def load_dataset(data_dir, dataset_name):
+ dataset = build_dataset_from_path(data_dir, dataset=dataset_name)
+ if dataset_name.startswith("ogbn"):
+ graph = dataset[0]
+ logging.info("--- to undirected graph ---")
+ graph = preprocess(graph)
+ feat = graph.x
+ feat = scale_feats(feat)
+ graph.x = feat
+ graph.add_remaining_self_loops()
+
+ # num_features = graph.x.shape[1]
+ # num_classes = dataset.num_classes
+ # return graph, (num_features, num_classes)
+ # feats, graph, labels, split_idx
+ train_idx = graph.train_mask.nonzero().squeeze(1)
+ val_idx = graph.val_mask.nonzero().squeeze(1)
+ test_idx = graph.test_mask.nonzero().squeeze(1)
+ split_idx = {"train": train_idx, "valid": val_idx, "test": test_idx}
+ return graph.x, graph, graph.y, split_idx
+
+
+class LinearProbingDataLoader(DataLoader):
+ def __init__(self, idx, feats, labels=None, **kwargs):
+ self.labels = labels
+ self.feats = feats
+
+ kwargs["collate_fn"] = self.__collate_fn__
+ super().__init__(dataset=idx, **kwargs)
+
+ def __collate_fn__(self, batch_idx):
+ feats = self.feats[batch_idx]
+ label = self.labels[batch_idx]
+
+ return feats, label
+
+class OnlineLCLoader(DataLoader):
+ def __init__(self, root_nodes, graph, feats, labels=None, drop_edge_rate=0, **kwargs):
+ self.graph = graph
+ self.labels = labels
+ self._drop_edge_rate = drop_edge_rate
+ self.ego_graph_nodes = root_nodes
+ self.feats = feats
+
+ dataset = np.arange(len(root_nodes))
+ kwargs["collate_fn"] = self.__collate_fn__
+ super().__init__(dataset, **kwargs)
+
+ def drop_edge(self, g):
+ if self._drop_edge_rate <= 0:
+ return g, g
+ n_node = g.num_nodes
+ g = g.remove_self_loops()
+ mask_index1 = mask_edge(g, self._drop_edge_rate)
+ mask_index2 = mask_edge(g, self._drop_edge_rate)
+ src = g.edge_index[0]
+ dst = g.edge_index[1]
+ nsrc1 = src[mask_index1]
+ ndst1 = dst[mask_index1]
+ nsrc2 = src[mask_index2]
+ ndst2 = dst[mask_index2]
+ g1 = Graph(edge_index=(nsrc1, ndst1), num_nodes=n_node)
+ g1.add_remaining_self_loops()
+ g2 = Graph(edge_index=(nsrc2, ndst2), num_nodes=n_node)
+ g2.add_remaining_self_loops()
+
+ return g1, g2
+
+ def __collate_fn__(self, batch_idx):
+ ego_nodes = [self.ego_graph_nodes[i] for i in batch_idx]
+ subgs = [self.graph.subgraph(ego_nodes[i]) for i in range(len(ego_nodes))]
+
+ sg = cogdl.data.batch_graphs(subgs)
+
+ nodes = torch.from_numpy(np.concatenate(ego_nodes)).long()
+ num_nodes = [x.shape[0] for x in ego_nodes]
+ cum_num_nodes = np.cumsum([0] + num_nodes)[:-1]
+
+ if self._drop_edge_rate > 0:
+ drop_g1, drop_g2 = self.drop_edge(sg)
+
+ sg = sg.add_remaining_self_loops()
+ sg.x = self.feats[nodes]
+ targets = torch.from_numpy(cum_num_nodes)
+
+ if self.labels != None:
+ label = self.labels[batch_idx]
+ else:
+ label = None
+
+ if self._drop_edge_rate > 0:
+ return sg, targets, label, nodes, drop_g1, drop_g2
+ else:
+ return sg, targets, label, nodes
+
+
+def setup_training_data(dataset_name, data_dir, ego_graphs_file_path):
+ feats, graph, labels, split_idx = load_dataset(data_dir, dataset_name)
+
+ train_lbls = labels[split_idx["train"]]
+ val_lbls = labels[split_idx["valid"]]
+ test_lbls = labels[split_idx["test"]]
+
+ labels = torch.cat([train_lbls, val_lbls, test_lbls])
+
+ if not os.path.exists(ego_graphs_file_path):
+ raise FileNotFoundError(f"{ego_graphs_file_path} doesn't exist")
+ else:
+ nodes = torch.load(ego_graphs_file_path)
+
+ return feats, graph, labels, split_idx, nodes
+
+
+def setup_training_dataloder(loader_type, training_nodes, graph, feats, batch_size, drop_edge_rate=0, pretrain_clustergcn=False, cluster_iter_data=None):
+ num_workers = 8
+
+ print(" -------- drop edge rate: {} --------".format(drop_edge_rate))
+ dataloader = OnlineLCLoader(training_nodes, graph, feats=feats, drop_edge_rate=drop_edge_rate, batch_size=batch_size, shuffle=True, drop_last=False, persistent_workers=True, num_workers=num_workers)
+ return dataloader
+
+
+def setup_eval_dataloder(loader_type, graph, feats, ego_graph_nodes=None, batch_size=128, shuffle=False):
+ num_workers = 8
+ if loader_type == "lc":
+ assert ego_graph_nodes is not None
+
+ dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, persistent_workers=True, num_workers=num_workers)
+ return dataloader
+
+
+def setup_finetune_dataloder(loader_type, graph, feats, ego_graph_nodes, labels, batch_size, shuffle=False):
+ num_workers = 8
+ dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, labels=labels, feats=feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, num_workers=num_workers, persistent_workers=True)
+
+ return dataloader
diff --git a/examples/graphmae2/datasets/localclustering.py b/examples/graphmae2/datasets/localclustering.py
new file mode 100644
index 00000000..8b648386
--- /dev/null
+++ b/examples/graphmae2/datasets/localclustering.py
@@ -0,0 +1,157 @@
+import argparse
+from collections import namedtuple
+import multiprocessing
+import os
+
+import numpy as np
+from localgraphclustering import *
+from scipy.sparse import csr_matrix
+from ogb.nodeproppred import DglNodePropPredDataset
+import torch
+import logging
+
+import dgl
+from dgl.data import load_data
+
+
+def my_sweep_cut(g, node):
+ vol_sum = 0.0
+ in_edge = 0.0
+ conds = np.zeros_like(node, dtype=np.float32)
+ for i in range(len(node)):
+ idx = node[i]
+ vol_sum += g.d[idx]
+ denominator = min(vol_sum, g.vol_G - vol_sum)
+ if denominator == 0.0:
+ denominator = 1.0
+ in_edge += 2*sum([g.adjacency_matrix[idx,prev] for prev in node[:i+1]])
+ cut = vol_sum - in_edge
+ conds[i] = cut/denominator
+ return conds
+
+
+def calc_local_clustering(args):
+ i, log_steps, num_iter, ego_size, method = args
+ if i % log_steps == 0:
+ print(i)
+ node, ppr = approximate_PageRank(graphlocal, [i], iterations=num_iter, method=method, normalize=False)
+ d_inv = graphlocal.dn[node]
+ d_inv[d_inv > 1.0] = 1.0
+ ppr_d_inv = ppr * d_inv
+ output = list(zip(node, ppr_d_inv))[:ego_size]
+ node, ppr_d_inv = zip(*sorted(output, key=lambda x: x[1], reverse=True))
+ assert node[0] == i
+ node = np.array(node, dtype=np.int32)
+ conds = my_sweep_cut(graphlocal, node)
+ return node, conds
+
+
+def step1_local_clustering(data, name, idx_split, ego_size=128, num_iter=1000, log_steps=10000, num_workers=16, method='acl', save_dir=None):
+ if save_dir is None:
+ save_path = f"{name}-lc-ego-graphs-{ego_size}.pt"
+ else:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+
+ save_path = os.path.join(save_dir, f"{name}-lc-ego-graphs-{ego_size}.pt")
+
+ N = data.num_nodes()
+ edge_index = data.edges()
+ edge_index = (edge_index[0].numpy(), edge_index[1].numpy())
+ adj = csr_matrix((np.ones(edge_index[0].shape[0]), edge_index), shape=(N, N))
+
+ global graphlocal
+ graphlocal = GraphLocal.from_sparse_adjacency(adj)
+ print('graphlocal generated')
+
+ train_idx = idx_split["train"].cpu().numpy()
+ valid_idx = idx_split["valid"].cpu().numpy()
+ test_idx = idx_split["test"].cpu().numpy()
+
+ with multiprocessing.Pool(num_workers) as pool:
+ ego_graphs_train, conds_train = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in train_idx], chunksize=512))
+
+ with multiprocessing.Pool(num_workers) as pool:
+ ego_graphs_valid, conds_valid = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in valid_idx], chunksize=512))
+
+ with multiprocessing.Pool(num_workers) as pool:
+ ego_graphs_test, conds_test = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in test_idx], chunksize=512))
+
+ ego_graphs = []
+ conds = []
+ ego_graphs.extend(ego_graphs_train)
+ ego_graphs.extend(ego_graphs_valid)
+ ego_graphs.extend(ego_graphs_test)
+ conds.extend(conds_train)
+ conds.extend(conds_valid)
+ conds.extend(conds_test)
+
+ ego_graphs = [ego_graphs_train, ego_graphs_valid, ego_graphs_test]
+
+ torch.save(ego_graphs, save_path)
+
+
+def preprocess(graph):
+ # make bidirected
+ if "feat" in graph.ndata:
+ feat = graph.ndata["feat"]
+ else:
+ feat = None
+ # src, dst = graph.all_edges()
+ # graph.add_edges(dst, src)
+ graph = dgl.to_bidirected(graph)
+ if feat is not None:
+ graph.ndata["feat"] = feat
+
+ # add self-loop
+ graph = graph.remove_self_loop().add_self_loop()
+ # graph.create_formats_()
+ return graph
+
+
+def load_dataset(data_dir, dataset_name):
+ if dataset_name.startswith("ogbn"):
+ dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset"))
+ graph, label = dataset[0]
+
+ if "year" in graph.ndata:
+ del graph.ndata["year"]
+ if not graph.is_multigraph:
+ graph = preprocess(graph)
+ # graph = graph.remove_self_loop().add_self_loop()
+
+ split_idx = dataset.get_idx_split()
+ label = label.view(-1)
+
+ elif dataset_name == "mag_scholar_f":
+ edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy"))
+ print(len(edge_index[0]))
+ graph = dgl.DGLGraph((edge_index[0], edge_index[1]))
+ print(graph)
+ num_nodes = graph.num_nodes()
+ assert num_nodes == 12403930
+ split_idx = torch.load(os.path.join(data_dir, dataset_name, "split_idx_f.pt"))
+ else:
+ raise NotImplementedError
+
+ return graph, split_idx
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='LCGNN (Preprocessing)')
+ parser.add_argument('--dataset', type=str, default='flickr')
+ parser.add_argument("--data_dir", type=str, default="data")
+ parser.add_argument("--save_dir", type=str, default="lc_ego_graphs")
+ parser.add_argument('--ego_size', type=int, default=256)
+ parser.add_argument('--num_iter', type=int, default=1000)
+ parser.add_argument('--log_steps', type=int, default=10000)
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--method', type=str, default='acl')
+ parser.add_argument('--num_workers', type=int, default=16)
+ args = parser.parse_args()
+ print(args)
+
+ np.random.seed(args.seed)
+
+ graph, split_idx = load_dataset(args.data_dir, args.dataset)
+ step1_local_clustering(graph, args.dataset, split_idx, args.ego_size, args.num_iter, args.log_steps, args.num_workers, args.method, args.save_dir)
diff --git a/examples/graphmae2/main_full_batch.py b/examples/graphmae2/main_full_batch.py
new file mode 100644
index 00000000..4407f5f6
--- /dev/null
+++ b/examples/graphmae2/main_full_batch.py
@@ -0,0 +1,134 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+
+from utils import (
+ build_args,
+ create_optimizer,
+ set_random_seed,
+ TBLogger,
+ get_current_lr,
+ load_best_configs,
+)
+from datasets.data_proc import load_small_dataset
+from models.finetune import linear_probing_full_batch
+from models import build_model
+
+
+logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
+
+
+def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
+ logging.info("start training..")
+ graph = graph.to(device)
+ x = feat.to(device)
+
+ target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long)
+ epoch_iter = tqdm(range(max_epoch))
+
+ for epoch in epoch_iter:
+ model.train()
+
+ loss = model(graph, x, targets=target_nodes)
+
+ loss_dict = {"loss": loss.item()}
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if scheduler is not None:
+ scheduler.step()
+
+ epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
+ if logger is not None:
+ loss_dict["lr"] = get_current_lr(optimizer)
+ logger.note(loss_dict, step=epoch)
+
+ if (epoch + 1) % 200 == 0:
+ linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)
+
+ return model
+
+
+def main(args):
+ device = args.device if args.device >= 0 else "cpu"
+ seeds = args.seeds
+ dataset_name = args.dataset
+ max_epoch = args.max_epoch
+ max_epoch_f = args.max_epoch_f
+ num_hidden = args.num_hidden
+ num_layers = args.num_layers
+ encoder_type = args.encoder
+ decoder_type = args.decoder
+ replace_rate = args.replace_rate
+
+ optim_type = args.optimizer
+ loss_fn = args.loss_fn
+
+ lr = args.lr
+ weight_decay = args.weight_decay
+ lr_f = args.lr_f
+ weight_decay_f = args.weight_decay_f
+ linear_prob = args.linear_prob
+ load_model = args.load_model
+ logs = args.logging
+ use_scheduler = args.scheduler
+ data_dir = args.data_dir
+
+ graph, (num_features, num_classes) = load_small_dataset(data_dir, dataset_name)
+ args.num_features = num_features
+
+ acc_list = []
+ estp_acc_list = []
+ for i, seed in enumerate(seeds):
+ print(f"####### Run {i} for seed {seed}")
+ set_random_seed(seed)
+
+ if logs:
+ logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
+ else:
+ logger = None
+
+ model = build_model(args)
+ model.to(device)
+ optimizer = create_optimizer(optim_type, model, lr, weight_decay)
+
+ if use_scheduler:
+ logging.info("Use schedular")
+ scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
+ else:
+ scheduler = None
+
+ x = graph.x
+ if not load_model:
+ model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
+ model = model.cpu()
+
+ if load_model:
+ logging.info("Loading Model ... ")
+ model.load_state_dict(torch.load("checkpoint.pt"))
+
+ model = model.to(device)
+ model.eval()
+
+ final_acc, estp_acc = linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob)
+ acc_list.append(final_acc)
+ estp_acc_list.append(estp_acc)
+
+ if logger is not None:
+ logger.finish()
+
+ final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
+ estp_acc, estp_acc_std = np.mean(estp_acc_list), np.std(estp_acc_list)
+ print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")
+ print(f"# early-stopping_acc: {estp_acc:.4f}±{estp_acc_std:.4f}")
+
+
+# Press the green button in the gutter to run the script.
+if __name__ == "__main__":
+ args = build_args()
+ if args.use_cfg:
+ args = load_best_configs(args)
+ print(args)
+ main(args)
diff --git a/examples/graphmae2/main_large.py b/examples/graphmae2/main_large.py
new file mode 100644
index 00000000..21541e06
--- /dev/null
+++ b/examples/graphmae2/main_large.py
@@ -0,0 +1,270 @@
+import logging
+import os
+import numpy as np
+from tqdm import tqdm
+
+import torch
+
+from utils import (
+ WandbLogger,
+ build_args,
+ create_optimizer,
+ set_random_seed,
+ load_best_configs,
+ show_occupied_memory,
+)
+from models import build_model
+from datasets.lc_sampler import (
+ setup_training_dataloder,
+ setup_training_data,
+)
+from models.finetune import linear_probing_minibatch, finetune
+
+import warnings
+
+warnings.filterwarnings("ignore")
+logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
+
+
+def evaluate(
+ model,
+ graph, feats, labels,
+ split_idx,
+ lr_f, weight_decay_f, max_epoch_f,
+ linear_prob=True,
+ device=0,
+ batch_size=256,
+ logger=None, ego_graph_nodes=None,
+ label_rate=1.0,
+ full_graph_forward=False,
+ shuffle=True,
+):
+ logging.info("Using `lc` for evaluation...")
+ num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]]
+ print(num_train,num_val,num_test)
+
+ train_g_idx = np.arange(0, num_train)
+ val_g_idx = np.arange(num_train, num_train+num_val)
+ test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test)
+
+ train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx]
+ val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx]
+ test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx]
+
+ train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx]
+
+ # labels = [train_lbls, val_lbls, test_lbls]
+ assert len(train_ego_graph_nodes) == len(train_lbls)
+ assert len(val_ego_graph_nodes) == len(val_lbls)
+ assert len(test_ego_graph_nodes) == len(test_lbls)
+
+ print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
+ logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
+
+
+ if linear_prob:
+ result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle)
+ else:
+ max_epoch_f = max_epoch_f // 2
+
+ if label_rate < 1.0:
+ rand_idx = np.arange(len(train_ego_graph_nodes))
+ np.random.shuffle(rand_idx)
+ rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))]
+ train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx]
+ train_lbls = train_lbls[rand_idx]
+
+ logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
+
+ # train_lbls = (all_train_lbls, train_lbls)
+ result = finetune(
+ model, graph, feats,
+ [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes],
+ [train_lbls, val_lbls, test_lbls],
+ split_idx=split_idx,
+ lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward,
+ )
+ return result
+
+
+def pretrain(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0):
+ logging.info("start training..")
+
+ model = model.to(device)
+ optimizer = create_optimizer(optimizer, model, lr, weight_decay)
+
+ dataloader = setup_training_dataloder(
+ sampling_method, ego_graph_nodes, graph, feats, batch_size=batch_size, drop_edge_rate=drop_edge_rate)
+
+ logging.info(f"After creating dataloader: Memory: {show_occupied_memory():.2f} MB")
+ if use_scheduler and max_epoch > 0:
+ logging.info("Use scheduler")
+ scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
+ else:
+ scheduler = None
+
+ for epoch in range(max_epoch):
+ epoch_iter = tqdm(dataloader)
+ losses = []
+ # assert (graph.in_degrees() > 0).all(), "after loading"
+
+ for batch_g in epoch_iter:
+ model.train()
+ if drop_edge_rate > 0:
+ batch_g, targets, _, node_idx, drop_g1, drop_g2 = batch_g
+ batch_g = batch_g.to(device)
+ drop_g1 = drop_g1.to(device)
+ drop_g2 = drop_g2.to(device)
+ x = batch_g.x
+ loss = model(batch_g, x, targets, epoch, drop_g1, drop_g2)
+ else:
+ batch_g, targets, _, node_idx = batch_g
+ batch_g = batch_g.to(device)
+ x = batch_g.x
+ loss = model(batch_g, x, targets, epoch)
+
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
+ optimizer.step()
+
+ epoch_iter.set_description(f"train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB")
+ losses.append(loss.item())
+
+ if scheduler is not None:
+ scheduler.step()
+
+ torch.save(model.state_dict(), os.path.join(model_dir, model_name))
+
+ print(f"# Epoch {epoch} | train_loss: {np.mean(losses):.4f}, Memory: {show_occupied_memory():.2f} MB")
+
+ return model
+
+
+if __name__ == "__main__":
+ args = build_args()
+ if args.use_cfg:
+ args = load_best_configs(args)
+
+ if args.device < 0:
+ device = "cpu"
+ else:
+ device = "cuda:{}".format(args.device)
+ seeds = args.seeds
+ dataset_name = args.dataset
+ max_epoch = args.max_epoch
+ max_epoch_f = args.max_epoch_f
+ num_hidden = args.num_hidden
+ num_layers = args.num_layers
+ encoder_type = args.encoder
+ decoder_type = args.decoder
+ encoder = args.encoder
+ decoder = args.decoder
+ num_hidden = args.num_hidden
+ drop_edge_rate = args.drop_edge_rate
+
+ optim_type = args.optimizer
+ loss_fn = args.loss_fn
+
+ lr = args.lr
+ weight_decay = args.weight_decay
+ lr_f = args.lr_f
+ weight_decay_f = args.weight_decay_f
+ linear_prob = args.linear_prob
+ load_model = args.load_model
+ no_pretrain = args.no_pretrain
+ logs = args.logging
+ use_scheduler = args.scheduler
+ batch_size = args.batch_size
+ batch_size_f = args.batch_size_f
+ sampling_method = args.sampling_method
+ ego_graph_file_path = args.ego_graph_file_path
+ data_dir = args.data_dir
+
+ n_procs = torch.cuda.device_count()
+ optimizer_type = args.optimizer
+ label_rate = args.label_rate
+ lam = args.lam
+ full_graph_forward = hasattr(args, "full_graph_forward") and args.full_graph_forward and not linear_prob
+
+ model_dir = "checkpoints"
+ os.makedirs(model_dir, exist_ok=True)
+
+ set_random_seed(0)
+ print(args)
+
+ logging.info(f"Before loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB
+ feats, graph, labels, split_idx, ego_graph_nodes = setup_training_data(dataset_name, data_dir, ego_graph_file_path)
+ if dataset_name == "ogbn-papers100M":
+ pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] + ego_graph_nodes[3]
+ else:
+ pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2]
+ ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] # * merge train/val/test = all
+
+ logging.info(f"After loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB
+
+ args.num_features = feats.shape[1]
+
+ if logs:
+ logger = WandbLogger(log_path=f"{dataset_name}_loss_{loss_fn}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}", project="GraphMAE2", args=args)
+ else:
+ logger = None
+ model_name = f"{encoder}_{decoder}_{num_hidden}_{num_layers}_{dataset_name}_{args.mask_rate}_{num_hidden}_checkpoint.pt"
+
+ model = build_model(args)
+
+ if not args.no_pretrain:
+ # ------------- pretraining starts ----------------
+ if not load_model:
+ logging.info("---- start pretraining ----")
+ model = pretrain(model, feats, graph, pretrain_ego_graph_nodes, max_epoch=max_epoch, device=device, use_scheduler=use_scheduler, lr=lr,
+ weight_decay=weight_decay, batch_size=batch_size, drop_edge_rate=drop_edge_rate,
+ sampling_method=sampling_method, optimizer=optimizer_type)
+
+ model = model.cpu()
+ logging.info(f"saving model to {model_dir}/{model_name}...")
+ torch.save(model.state_dict(), os.path.join(model_dir, model_name))
+ # ------------- pretraining ends ----------------
+
+ if load_model:
+ model.load_state_dict(torch.load(os.path.join(args.checkpoint_path)))
+ logging.info(f"Loading Model from {args.checkpoint_path}...")
+ else:
+ logging.info("--- no pretrain ---")
+
+ model = model.to(device)
+ model.eval()
+
+ logging.info("---- start finetuning / evaluation ----")
+
+ final_accs = []
+ for i,_ in enumerate(seeds):
+ print(f"####### Run seed {seeds[i]}")
+ set_random_seed(seeds[i])
+ eval_model = build_model(args)
+ eval_model.load_state_dict(model.state_dict())
+ eval_model.to(device)
+
+ print(f"features size : {feats.shape[1]}")
+ logging.info("start evaluation...")
+ final_acc = evaluate(
+ eval_model, graph, feats, labels, split_idx,
+ lr_f, weight_decay_f, max_epoch_f,
+ device=device,
+ batch_size=batch_size_f,
+ ego_graph_nodes=ego_graph_nodes,
+ linear_prob=linear_prob,
+ label_rate=label_rate,
+ full_graph_forward=full_graph_forward,
+ shuffle=False if dataset_name == "ogbn-papers100M" else True
+ )
+
+ final_accs.append(float(final_acc))
+
+ print(f"Run {seeds[i]} | TestAcc: {final_acc:.4f}")
+
+ print(f"# final_acc: {np.mean(final_accs):.4f}, std: {np.std(final_accs):.4f}")
+
+ if logger is not None:
+ logger.finish()
diff --git a/examples/graphmae2/models/__init__.py b/examples/graphmae2/models/__init__.py
new file mode 100644
index 00000000..363ff1a9
--- /dev/null
+++ b/examples/graphmae2/models/__init__.py
@@ -0,0 +1,62 @@
+from .edcoder import PreModel
+
+
+def build_model(args):
+ num_heads = args.num_heads
+ num_out_heads = args.num_out_heads
+ num_hidden = args.num_hidden
+ num_layers = args.num_layers
+ residual = args.residual
+ attn_drop = args.attn_drop
+ in_drop = args.in_drop
+ norm = args.norm
+ negative_slope = args.negative_slope
+ encoder_type = args.encoder
+ decoder_type = args.decoder
+ mask_rate = args.mask_rate
+ remask_rate = args.remask_rate
+ mask_method = args.mask_method
+ drop_edge_rate = args.drop_edge_rate
+
+ activation = args.activation
+ loss_fn = args.loss_fn
+ alpha_l = args.alpha_l
+
+ num_features = args.num_features
+ num_dec_layers = args.num_dec_layers
+ num_remasking = args.num_remasking
+ lam = args.lam
+ delayed_ema_epoch = args.delayed_ema_epoch
+ replace_rate = args.replace_rate
+ remask_method = args.remask_method
+ momentum = args.momentum
+
+ model = PreModel(
+ in_dim=num_features,
+ num_hidden=num_hidden,
+ num_layers=num_layers,
+ num_dec_layers=num_dec_layers,
+ num_remasking=num_remasking,
+ nhead=num_heads,
+ nhead_out=num_out_heads,
+ activation=activation,
+ feat_drop=in_drop,
+ attn_drop=attn_drop,
+ negative_slope=negative_slope,
+ residual=residual,
+ encoder_type=encoder_type,
+ decoder_type=decoder_type,
+ mask_rate=mask_rate,
+ remask_rate=remask_rate,
+ mask_method=mask_method,
+ norm=norm,
+ loss_fn=loss_fn,
+ drop_edge_rate=drop_edge_rate,
+ alpha_l=alpha_l,
+ lam=lam,
+ delayed_ema_epoch=delayed_ema_epoch,
+ replace_rate=replace_rate,
+ remask_method=remask_method,
+ momentum=momentum
+ )
+ return model
diff --git a/examples/graphmae2/models/edcoder.py b/examples/graphmae2/models/edcoder.py
new file mode 100644
index 00000000..0b46dda2
--- /dev/null
+++ b/examples/graphmae2/models/edcoder.py
@@ -0,0 +1,364 @@
+from itertools import chain
+
+from typing import Optional
+import torch
+import torch.nn as nn
+from functools import partial
+
+from .gat import GAT
+
+from .loss_func import sce_loss
+
+
+def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm, nhead, nhead_out, attn_drop, negative_slope=0.2, concat_out=True, **kwargs) -> nn.Module:
+ if m_type in ("gat", "tsgat"):
+ mod = GAT(
+ in_dim=in_dim,
+ num_hidden=num_hidden,
+ out_dim=out_dim,
+ num_layers=num_layers,
+ nhead=nhead,
+ nhead_out=nhead_out,
+ concat_out=concat_out,
+ activation=activation,
+ feat_drop=dropout,
+ attn_drop=attn_drop,
+ negative_slope=negative_slope,
+ residual=residual,
+ norm=norm,
+ encoding=(enc_dec == "encoding"),
+ **kwargs,
+ )
+ elif m_type == "mlp":
+ # * just for decoder
+ mod = nn.Sequential(
+ nn.Linear(in_dim, num_hidden * 2),
+ nn.PReLU(),
+ nn.Dropout(0.2),
+ nn.Linear(num_hidden * 2, out_dim)
+ )
+ elif m_type == "linear":
+ mod = nn.Linear(in_dim, out_dim)
+ else:
+ raise NotImplementedError
+
+ return mod
+
+class PreModel(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ num_hidden: int,
+ num_layers: int,
+ num_dec_layers: int,
+ num_remasking: int,
+ nhead: int,
+ nhead_out: int,
+ activation: str,
+ feat_drop: float,
+ attn_drop: float,
+ negative_slope: float,
+ residual: bool,
+ norm: Optional[str],
+ mask_rate: float = 0.3,
+ remask_rate: float = 0.5,
+ remask_method: str = "random",
+ mask_method: str = "random",
+ encoder_type: str = "gat",
+ decoder_type: str = "gat",
+ loss_fn: str = "byol",
+ drop_edge_rate: float = 0.0,
+ alpha_l: float = 2,
+ lam: float = 1.0,
+ delayed_ema_epoch: int = 0,
+ momentum: float = 0.996,
+ replace_rate: float = 0.0,
+ ):
+ super(PreModel, self).__init__()
+ self._mask_rate = mask_rate
+ self._remask_rate = remask_rate
+ self._mask_method = mask_method
+ self._alpha_l = alpha_l
+ self._delayed_ema_epoch = delayed_ema_epoch
+
+ self.num_remasking = num_remasking
+ self._encoder_type = encoder_type
+ self._decoder_type = decoder_type
+ self._drop_edge_rate = drop_edge_rate
+ self._output_hidden_size = num_hidden
+ self._momentum = momentum
+ self._replace_rate = replace_rate
+ self._num_remasking = num_remasking
+ self._remask_method = remask_method
+
+ self._token_rate = 1 - self._replace_rate
+ self._lam = lam
+
+ assert num_hidden % nhead == 0
+ assert num_hidden % nhead_out == 0
+ if encoder_type in ("gat",):
+ enc_num_hidden = num_hidden // nhead
+ enc_nhead = nhead
+ else:
+ enc_num_hidden = num_hidden
+ enc_nhead = 1
+
+ dec_in_dim = num_hidden
+ dec_num_hidden = num_hidden // nhead if decoder_type in ("gat",) else num_hidden
+
+ # build encoder
+ self.encoder = setup_module(
+ m_type=encoder_type,
+ enc_dec="encoding",
+ in_dim=in_dim,
+ num_hidden=enc_num_hidden,
+ out_dim=enc_num_hidden,
+ num_layers=num_layers,
+ nhead=enc_nhead,
+ nhead_out=enc_nhead,
+ concat_out=True,
+ activation=activation,
+ dropout=feat_drop,
+ attn_drop=attn_drop,
+ negative_slope=negative_slope,
+ residual=residual,
+ norm=norm,
+ )
+
+ self.decoder = setup_module(
+ m_type=decoder_type,
+ enc_dec="decoding",
+ in_dim=dec_in_dim,
+ num_hidden=dec_num_hidden,
+ out_dim=in_dim,
+ nhead_out=nhead_out,
+ num_layers=num_dec_layers,
+ nhead=nhead,
+ activation=activation,
+ dropout=feat_drop,
+ attn_drop=attn_drop,
+ negative_slope=negative_slope,
+ residual=residual,
+ norm=norm,
+ concat_out=True,
+ )
+
+ self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
+ self.dec_mask_token = nn.Parameter(torch.zeros(1, num_hidden))
+
+ self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
+ #self.reset_parameters_for_token()
+
+
+ # * setup loss function
+ self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
+
+ self.projector = nn.Sequential(
+ nn.Linear(num_hidden, 256),
+ nn.PReLU(),
+ nn.Linear(256, num_hidden),
+ )
+ self.projector_ema = nn.Sequential(
+ nn.Linear(num_hidden, 256),
+ nn.PReLU(),
+ nn.Linear(256, num_hidden),
+ )
+ self.predictor = nn.Sequential(
+ nn.PReLU(),
+ nn.Linear(num_hidden, num_hidden)
+ )
+
+ self.encoder_ema = setup_module(
+ m_type=encoder_type,
+ enc_dec="encoding",
+ in_dim=in_dim,
+ num_hidden=enc_num_hidden,
+ out_dim=enc_num_hidden,
+ num_layers=num_layers,
+ nhead=enc_nhead,
+ nhead_out=enc_nhead,
+ concat_out=True,
+ activation=activation,
+ dropout=feat_drop,
+ attn_drop=attn_drop,
+ negative_slope=negative_slope,
+ residual=residual,
+ norm=norm,
+ )
+ self.encoder_ema.load_state_dict(self.encoder.state_dict())
+ self.projector_ema.load_state_dict(self.projector.state_dict())
+
+ for p in self.encoder_ema.parameters():
+ p.requires_grad = False
+ p.detach_()
+ for p in self.projector_ema.parameters():
+ p.requires_grad = False
+ p.detach_()
+
+ self.print_num_parameters()
+
+ def print_num_parameters(self):
+ num_encoder_params = [p.numel() for p in self.encoder.parameters() if p.requires_grad]
+ num_decoder_params = [p.numel() for p in self.decoder.parameters() if p.requires_grad]
+ num_params = [p.numel() for p in self.parameters() if p.requires_grad]
+
+ print(f"num_encoder_params: {sum(num_encoder_params)}, num_decoder_params: {sum(num_decoder_params)}, num_params_in_total: {sum(num_params)}")
+
+ def reset_parameters_for_token(self):
+ nn.init.xavier_normal_(self.enc_mask_token)
+ nn.init.xavier_normal_(self.dec_mask_token)
+ nn.init.xavier_normal_(self.encoder_to_decoder.weight, gain=1.414)
+
+ @property
+ def output_hidden_dim(self):
+ return self._output_hidden_size
+
+ def setup_loss_fn(self, loss_fn, alpha_l):
+ if loss_fn == "mse":
+ print(f"=== Use mse_loss ===")
+ criterion = nn.MSELoss()
+ elif loss_fn == "sce":
+ print(f"=== Use sce_loss and alpha_l={alpha_l} ===")
+ criterion = partial(sce_loss, alpha=alpha_l)
+ else:
+ raise NotImplementedError
+ return criterion
+
+ def forward(self, g, x, targets=None, epoch=0, drop_g1=None, drop_g2=None): # ---- attribute reconstruction ----
+ loss = self.mask_attr_prediction(g, x, targets, epoch, drop_g1, drop_g2)
+
+ return loss
+
+ def mask_attr_prediction(self, g, x, targets, epoch, drop_g1=None, drop_g2=None):
+ pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate)
+ use_g = drop_g1 if drop_g1 is not None else g
+
+ enc_rep = self.encoder(use_g, use_x,)
+
+ with torch.no_grad():
+ drop_g2 = drop_g2 if drop_g2 is not None else g
+ latent_target = self.encoder_ema(drop_g2, x,)
+ if targets is not None:
+ latent_target = self.projector_ema(latent_target[targets])
+ else:
+ latent_target = self.projector_ema(latent_target[keep_nodes])
+
+ if targets is not None:
+ latent_pred = self.projector(enc_rep[targets])
+ latent_pred = self.predictor(latent_pred)
+ loss_latent = sce_loss(latent_pred, latent_target, 1)
+ else:
+ latent_pred = self.projector(enc_rep[keep_nodes])
+ latent_pred = self.predictor(latent_pred)
+ loss_latent = sce_loss(latent_pred, latent_target, 1)
+
+ # ---- attribute reconstruction ----
+ origin_rep = self.encoder_to_decoder(enc_rep)
+
+ loss_rec_all = 0
+ if self._remask_method == "random":
+ for i in range(self._num_remasking):
+ rep = origin_rep.clone()
+ rep, remask_nodes, rekeep_nodes = self.random_remask(use_g, rep, self._remask_rate)
+ recon = self.decoder(pre_use_g, rep)
+
+ x_init = x[mask_nodes]
+ x_rec = recon[mask_nodes]
+ loss_rec = self.criterion(x_init, x_rec)
+ loss_rec_all += loss_rec
+ loss_rec = loss_rec_all
+ elif self._remask_method == "fixed":
+ rep = self.fixed_remask(g, origin_rep, mask_nodes)
+ x_rec = self.decoder(pre_use_g, rep)[mask_nodes]
+ x_init = x[mask_nodes]
+ loss_rec = self.criterion(x_init, x_rec)
+ else:
+ raise NotImplementedError
+
+ loss = loss_rec + self._lam * loss_latent
+
+ if epoch >= self._delayed_ema_epoch:
+ self.ema_update()
+ return loss
+
+ def ema_update(self):
+ def update(student, teacher):
+ with torch.no_grad():
+ # m = momentum_schedule[it] # momentum parameter
+ m = self._momentum
+ for param_q, param_k in zip(student.parameters(), teacher.parameters()):
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
+ update(self.encoder, self.encoder_ema)
+ update(self.projector, self.projector_ema)
+
+ def embed(self, g, x):
+ rep = self.encoder(g, x)
+ return rep
+
+ def get_encoder(self):
+ #self.encoder.reset_classifier(out_size)
+ return self.encoder
+
+ def reset_encoder(self, out_size):
+ self.encoder.reset_classifier(out_size)
+
+ @property
+ def enc_params(self):
+ return self.encoder.parameters()
+
+ @property
+ def dec_params(self):
+ return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
+
+ def output_grad(self):
+ grad_dict = {}
+ for n, p in self.named_parameters():
+ if p.grad is not None:
+ grad_dict[n] = p.grad.abs().mean().item()
+ return grad_dict
+
+ def encoding_mask_noise(self, g, x, mask_rate=0.3):
+ num_nodes = g.num_nodes
+ perm = torch.randperm(num_nodes, device=x.device)
+ num_mask_nodes = int(mask_rate * num_nodes)
+
+ # exclude isolated nodes
+ # isolated_nodes = torch.where(g.in_degrees() <= 1)[0]
+ # mask_nodes = perm[: num_mask_nodes]
+ # mask_nodes = torch.index_fill(torch.full((num_nodes,), False, device=device), 0, mask_nodes, True)
+ # mask_nodes[isolated_nodes] = False
+ # keep_nodes = torch.where(~mask_nodes)[0]
+ # mask_nodes = torch.where(mask_nodes)[0]
+ # num_mask_nodes = mask_nodes.shape[0]
+
+ # random masking
+ num_mask_nodes = int(mask_rate * num_nodes)
+ mask_nodes = perm[: num_mask_nodes]
+ keep_nodes = perm[num_mask_nodes: ]
+
+ out_x = x.clone()
+ token_nodes = mask_nodes
+ out_x[mask_nodes] = 0.0
+
+ out_x[token_nodes] += self.enc_mask_token
+ use_g = g.clone()
+
+ return use_g, out_x, (mask_nodes, keep_nodes)
+
+ def random_remask(self,g,rep,remask_rate=0.5):
+
+ num_nodes = g.num_nodes
+ perm = torch.randperm(num_nodes, device=rep.device)
+ num_remask_nodes = int(remask_rate * num_nodes)
+ remask_nodes = perm[: num_remask_nodes]
+ rekeep_nodes = perm[num_remask_nodes: ]
+
+ rep = rep.clone()
+ rep[remask_nodes] = 0
+ rep[remask_nodes] += self.dec_mask_token
+
+ return rep, remask_nodes, rekeep_nodes
+
+ def fixed_remask(self, g, rep, masked_nodes):
+ rep[masked_nodes] = 0
+ return rep
\ No newline at end of file
diff --git a/examples/graphmae2/models/finetune.py b/examples/graphmae2/models/finetune.py
new file mode 100644
index 00000000..a9055691
--- /dev/null
+++ b/examples/graphmae2/models/finetune.py
@@ -0,0 +1,388 @@
+import copy
+import logging
+
+from tqdm import tqdm
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+from datasets.lc_sampler import setup_eval_dataloder, setup_finetune_dataloder, LinearProbingDataLoader
+from utils import accuracy, set_random_seed, show_occupied_memory, get_current_lr
+
+import wandb
+
+
+def linear_probing_minibatch(
+ model, graph,
+ feats, ego_graph_nodes, labels,
+ lr_f, weight_decay_f, max_epoch_f,
+ device, batch_size=-1, shuffle=True):
+ logging.info("-- Linear Probing in downstream tasks ---")
+ train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes = ego_graph_nodes
+ num_train, num_val = len(train_ego_graph_nodes), len(val_ego_graph_nodes)
+ train_lbls, val_lbls, test_lbls = labels
+ # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c","ogbn-arxiv","ogbn-products"]:
+ # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c", "ogbn-arxiv", "ogbn-products"]:
+ eval_loader = setup_eval_dataloder("lc", graph, feats, train_ego_graph_nodes+val_ego_graph_nodes+test_ego_graph_nodes, 512)
+
+ with torch.no_grad():
+ model.eval()
+ embeddings = []
+
+ for batch in tqdm(eval_loader, desc="Infering..."):
+ batch_g, targets, _, node_idx = batch
+ batch_g = batch_g.to(device)
+ x = batch_g.x
+ targets = targets.to(device)
+
+ batch_emb = model.embed(batch_g, x)[targets]
+ embeddings.append(batch_emb.cpu())
+ embeddings = torch.cat(embeddings, dim=0)
+
+ train_emb, val_emb, test_emb = embeddings[:num_train], embeddings[num_train:num_train+num_val], embeddings[num_train+num_val:]
+
+ batch_size = 5120
+ acc = []
+ seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ for i,_ in enumerate(seeds):
+ print(f"####### Run seed {seeds[i]} for LinearProbing...")
+ set_random_seed(seeds[i])
+ print(f"training sample:{len(train_emb)}")
+ test_acc = node_classification_linear_probing(
+ (train_emb, val_emb, test_emb),
+ (train_lbls, val_lbls, test_lbls),
+ lr_f, weight_decay_f, max_epoch_f, device, batch_size=batch_size, shuffle=shuffle)
+ acc.append(test_acc)
+
+ print(f"# final_acc: {np.mean(acc):.4f}, std: {np.std(acc):.4f}")
+
+ return np.mean(acc)
+
+
+
+class LogisticRegression(nn.Module):
+ def __init__(self, num_dim, num_class):
+ super().__init__()
+ self.linear = nn.Linear(num_dim, num_class)
+
+ def forward(self, g, x, *args):
+ logits = self.linear(x)
+ return logits
+
+
+def node_classification_linear_probing(embeddings, labels, lr, weight_decay, max_epoch, device, mute=False, batch_size=-1, shuffle=True):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ train_emb, val_emb, test_emb = embeddings
+ train_label, val_label, test_label = labels
+ train_label = train_label.to(torch.long)
+ val_label = val_label.to(torch.long)
+ test_label = test_label.to(torch.long)
+
+ best_val_acc = 0
+ best_val_epoch = 0
+ best_model = None
+
+ if not mute:
+ epoch_iter = tqdm(range(max_epoch))
+ else:
+ epoch_iter = range(max_epoch)
+
+ encoder = LogisticRegression(train_emb.shape[1], int(train_label.max().item() + 1))
+ encoder = encoder.to(device)
+ optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay)
+
+ if batch_size > 0:
+ train_loader = LinearProbingDataLoader(np.arange(len(train_emb)), train_emb, train_label, batch_size=batch_size, num_workers=4, persistent_workers=True, shuffle=shuffle)
+ # train_loader = DataLoader(np.arange(len(train_emb)), batch_size=batch_size, shuffle=False)
+ val_loader = LinearProbingDataLoader(np.arange(len(val_emb)), val_emb, val_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False)
+ test_loader = LinearProbingDataLoader(np.arange(len(test_emb)), test_emb, test_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False)
+ else:
+ train_loader = [np.arange(len(train_emb))]
+ val_loader = [np.arange(len(val_emb))]
+ test_loader = [np.arange(len(test_emb))]
+
+ def eval_forward(loader, _label):
+ pred_all = []
+ for batch_x, _ in loader:
+ batch_x = batch_x.to(device)
+ pred = encoder(None, batch_x)
+ pred_all.append(pred.cpu())
+ pred = torch.cat(pred_all, dim=0)
+ acc = accuracy(pred, _label)
+ return acc
+
+ for epoch in epoch_iter:
+ encoder.train()
+
+ for batch_x, batch_label in train_loader:
+ batch_x = batch_x.to(device)
+ batch_label = batch_label.to(device)
+ pred = encoder(None, batch_x)
+ loss = criterion(pred, batch_label)
+ optimizer.zero_grad()
+ loss.backward()
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3)
+ optimizer.step()
+
+ with torch.no_grad():
+ encoder.eval()
+ val_acc = eval_forward(val_loader, val_label)
+
+ if val_acc >= best_val_acc:
+ best_val_acc = val_acc
+ best_val_epoch = epoch
+ best_model = copy.deepcopy(encoder)
+
+ if not mute:
+ epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_acc:{val_acc:.4f}")
+
+ best_model.eval()
+ encoder = best_model
+ with torch.no_grad():
+ test_acc = eval_forward(test_loader, test_label)
+ if mute:
+ print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
+ else:
+ print(f"--- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
+
+ return test_acc
+
+
+def finetune(
+ model,
+ graph,
+ feats,
+ ego_graph_nodes,
+ labels,
+ split_idx,
+ lr_f, weight_decay_f, max_epoch_f,
+ use_scheduler, batch_size,
+ device,
+ logger=None,
+ full_graph_forward=False,
+):
+ logging.info("-- Finetuning in downstream tasks ---")
+ train_egs, val_egs, test_egs = ego_graph_nodes
+ print(f"num of egos:{len(train_egs)},{len(val_egs)},{len(test_egs)}")
+
+ print(graph.num_nodes())
+
+ train_nid = split_idx["train"].numpy()
+ val_nid = split_idx["valid"].numpy()
+ test_nid = split_idx["test"].numpy()
+
+ train_lbls, val_lbls, test_lbls = [x.long() for x in labels]
+ print(f"num of labels:{len(train_lbls)},{len(val_lbls)},{len(test_lbls)}")
+
+ num_classes = max(max(train_lbls.max().item(), val_lbls.max().item()), test_lbls.max().item()) + 1
+
+ model = model.get_encoder()
+ model.reset_classifier(int(num_classes))
+ model = model.to(device)
+ criterion = torch.nn.CrossEntropyLoss()
+
+ train_loader = setup_finetune_dataloder("lc", graph, feats, train_egs, train_lbls, batch_size=batch_size, shuffle=True)
+ val_loader = setup_finetune_dataloder("lc", graph, feats, val_egs, val_lbls, batch_size=batch_size, shuffle=False)
+ test_loader = setup_finetune_dataloder("lc", graph, feats, test_egs, test_lbls, batch_size=batch_size, shuffle=False)
+
+ #optimizer = torch.optim.Adam(model.parameters(), lr=lr_f, weight_decay=weight_decay_f)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr_f, weight_decay=weight_decay_f)
+
+ if use_scheduler and max_epoch_f > 0:
+ logging.info("Use schedular")
+ warmup_epochs = int(max_epoch_f * 0.1)
+ # scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch_f) ) * 0.5
+ scheduler = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else ( 1 + np.cos((epoch - warmup_epochs) * np.pi / (max_epoch_f - warmup_epochs))) * 0.5
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
+ else:
+ scheduler = None
+
+ def eval_with_lc(model, loader):
+ pred_counts = []
+ model.eval()
+ epoch_iter = tqdm(loader)
+ with torch.no_grad():
+ for batch in epoch_iter:
+ batch_g, targets, batch_lbls, node_idx = batch
+ batch_g = batch_g.to(device)
+ batch_lbls = batch_lbls.to(device)
+ x = batch_g.x
+
+ prediction = model(batch_g, x)
+ prediction = prediction[targets]
+ pred_counts.append((prediction.argmax(1) == batch_lbls))
+ pred_counts = torch.cat(pred_counts)
+ acc = pred_counts.float().sum() / pred_counts.shape[0]
+ return acc
+
+ def eval_full_prop(model, g, nfeat, val_nid, test_nid, batch_size, device):
+ model.eval()
+
+ with torch.no_grad():
+ pred = model.inference(g, nfeat, batch_size, device)
+ model.train()
+
+ return accuracy(pred[val_nid], val_lbls.cpu()), accuracy(pred[test_nid], test_lbls.cpu())
+
+ best_val_acc = 0
+ best_model = None
+ best_epoch = 0
+ test_acc = 0
+ early_stop_cnt = 0
+
+ for epoch in range(max_epoch_f):
+ if epoch == 0:
+ scheduler.step()
+ continue
+ if early_stop_cnt >= 10:
+ break
+ epoch_iter = tqdm(train_loader)
+ losses = []
+ model.train()
+
+ for batch_g, targets, batch_lbls, node_idx in epoch_iter:
+ batch_g = batch_g.to(device)
+ targets = targets.to(device)
+ batch_lbls = batch_lbls.to(device)
+ x = batch_g.x
+
+ prediction = model(batch_g, x)
+ prediction = prediction[targets]
+ loss = criterion(prediction, batch_lbls)
+
+ optimizer.zero_grad()
+ loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
+ optimizer.step()
+
+ metrics = {"finetune_loss": loss}
+ wandb.log(metrics)
+
+ if logger is not None:
+ logger.log(metrics)
+
+ epoch_iter.set_description(f"Finetuning | train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB")
+ losses.append(loss.item())
+
+ if scheduler is not None:
+ scheduler.step()
+
+ if not full_graph_forward:
+ if epoch > 0:
+ val_acc = eval_with_lc(model, val_loader)
+ _test_acc = 0
+ else:
+ if epoch > 0 and epoch % 1 == 0:
+ val_acc, _test_acc = eval_full_prop(model, graph, feats, val_nid, test_nid, 10000, device)
+ model = model.to(device)
+
+ print('val Acc {:.4f}'.format(val_acc))
+ if val_acc > best_val_acc:
+ best_model = copy.deepcopy(model)
+ best_val_acc = val_acc
+ test_acc = _test_acc
+ best_epoch = epoch
+ early_stop_cnt = 0
+ else:
+ early_stop_cnt += 1
+
+ if not full_graph_forward:
+ print("val Acc {:.4f}, Best Val Acc {:.4f}".format(val_acc, best_val_acc))
+ else:
+ print("Val Acc {:.4f}, Best Val Acc {:.4f} Test Acc {:.4f}".format(val_acc, best_val_acc, test_acc))
+
+ metrics = {"epoch_val_acc": val_acc,
+ "test_acc": test_acc,
+ "epoch": epoch,
+ "lr_f": get_current_lr(optimizer)}
+
+ wandb.log(metrics)
+ if logger is not None:
+ logger.log(metrics)
+ print(f"# Finetuning - Epoch {epoch} | train_loss: {np.mean(losses):.4f}, ValAcc: {val_acc:.4f}, TestAcc: {test_acc:.4f}, Memory: {show_occupied_memory():.2f} MB")
+
+ model = best_model
+ if not full_graph_forward:
+ test_acc = eval_with_lc(test_loader)
+
+ print(f"Finetune | TestAcc: {test_acc:.4f} from Epoch {best_epoch}")
+ return test_acc
+
+
+def linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob=True, mute=False):
+ model.eval()
+ with torch.no_grad():
+ x = model.embed(graph.to(device), x.to(device))
+ in_feat = x.shape[1]
+ encoder = LogisticRegression(in_feat, num_classes)
+
+ num_finetune_params = [p.numel() for p in encoder.parameters() if p.requires_grad]
+ if not mute:
+ print(f"num parameters for finetuning: {sum(num_finetune_params)}")
+
+ encoder.to(device)
+ optimizer_f = torch.optim.Adam(encoder.parameters(), lr=lr_f, weight_decay=weight_decay_f)
+ final_acc, estp_acc = _linear_probing_full_batch(encoder, graph, x, optimizer_f, max_epoch_f, device, mute)
+ return final_acc, estp_acc
+
+
+def _linear_probing_full_batch(model, graph, feat, optimizer, max_epoch, device, mute=False):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ graph = graph.to(device)
+ x = feat.to(device)
+
+ train_mask = graph.train_mask
+ val_mask = graph.val_mask
+ test_mask = graph.test_mask
+ labels = graph.y
+
+ best_val_acc = 0
+ best_val_epoch = 0
+ best_model = None
+
+ if not mute:
+ epoch_iter = tqdm(range(max_epoch))
+ else:
+ epoch_iter = range(max_epoch)
+
+ for epoch in epoch_iter:
+ model.train()
+ out = model(graph, x)
+ loss = criterion(out[train_mask], labels[train_mask])
+ optimizer.zero_grad()
+ loss.backward()
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3)
+ optimizer.step()
+
+ with torch.no_grad():
+ model.eval()
+ pred = model(graph, x)
+ val_acc = accuracy(pred[val_mask], labels[val_mask])
+ val_loss = criterion(pred[val_mask], labels[val_mask])
+ test_acc = accuracy(pred[test_mask], labels[test_mask])
+ test_loss = criterion(pred[test_mask], labels[test_mask])
+
+ if val_acc >= best_val_acc:
+ best_val_acc = val_acc
+ best_val_epoch = epoch
+ best_model = copy.deepcopy(model)
+
+ if not mute:
+ epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_loss:{val_loss.item(): .4f}, val_acc:{val_acc}, test_loss:{test_loss.item(): .4f}, test_acc:{test_acc: .4f}")
+
+ best_model.eval()
+ with torch.no_grad():
+ pred = best_model(graph, x)
+ estp_test_acc = accuracy(pred[test_mask], labels[test_mask])
+ if mute:
+ print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
+ else:
+ print(f"--- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
+
+ return test_acc, estp_test_acc
diff --git a/examples/graphmae2/models/gat.py b/examples/graphmae2/models/gat.py
new file mode 100644
index 00000000..28eb5e29
--- /dev/null
+++ b/examples/graphmae2/models/gat.py
@@ -0,0 +1,198 @@
+import tqdm
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from cogdl.utils import (
+ EdgeSoftmax,
+ MultiHeadSpMM,
+ get_activation,
+ get_norm_layer,
+)
+
+from utils import create_activation, create_norm
+
+
+class GAT(nn.Module):
+ def __init__(self,
+ in_dim,
+ num_hidden,
+ out_dim,
+ num_layers,
+ nhead,
+ nhead_out,
+ activation,
+ feat_drop,
+ attn_drop,
+ negative_slope,
+ residual,
+ norm,
+ concat_out=False,
+ encoding=False
+ ):
+ super(GAT, self).__init__()
+ self.out_dim = out_dim
+ self.num_heads = nhead
+ self.num_layers = num_layers
+ self.gat_layers = nn.ModuleList()
+ self.activation = activation
+ self.concat_out = concat_out
+
+ last_activation = activation if encoding else None
+ last_residual = (encoding and residual)
+ last_norm = norm if encoding else None
+ self.feat_drop = nn.Dropout(feat_drop)
+
+ if num_layers == 1:
+ self.gat_layers.append(
+ GATLayer(
+ in_dim, out_dim, nhead=nhead_out,
+ attn_drop=attn_drop,
+ alpha=negative_slope,
+ residual=last_residual,
+ norm=last_norm,
+ activation=last_activation,
+ )
+ )
+ else:
+ # input projection (no residual)
+ self.gat_layers.append(
+ GATLayer(
+ in_dim, num_hidden, nhead,
+ attn_drop=attn_drop,
+ alpha=negative_slope,
+ residual=residual,
+ activation=activation,
+ norm=norm
+ )
+ )
+ # hidden layers
+ for l in range(1, num_layers - 1):
+ # due to multi-head, the in_dim = num_hidden * num_heads
+ self.gat_layers.append(
+ GATLayer(
+ num_hidden * nhead, num_hidden, nhead=nhead,
+ attn_drop=attn_drop,
+ alpha=negative_slope,
+ residual=residual,
+ activation=activation,
+ norm=norm
+ )
+ )
+ # output projection
+ self.gat_layers.append(
+ GATLayer(
+ num_hidden * nhead, out_dim,
+ nhead=nhead_out,
+ attn_drop=attn_drop,
+ alpha=negative_slope,
+ residual=last_residual,
+ activation=last_activation,
+ norm=last_norm
+ )
+ )
+ self.head = nn.Identity()
+
+ def forward(self, g, inputs, return_hidden=False):
+ h = inputs
+ hidden_list = []
+ for l in range(self.num_layers):
+ h = self.feat_drop(h)
+ h = self.gat_layers[l](g, h)
+ hidden_list.append(h)
+ # h = h.flatten(1)
+ # output projection
+ if return_hidden:
+ return self.head(h), hidden_list
+ else:
+ return self.head(h)
+
+ def reset_classifier(self, num_classes):
+ self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)
+
+
+class GATLayer(nn.Module):
+ """
+ Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
+ """
+
+ def __init__(
+ self, in_feats, out_feats, nhead=1, alpha=0.2, attn_drop=0.5, activation=None, residual=False, norm=None
+ ):
+ super(GATLayer, self).__init__()
+ self.in_features = in_feats
+ self.out_features = out_feats
+ self.alpha = alpha
+ self.nhead = nhead
+
+ self.W = nn.Parameter(torch.FloatTensor(in_feats, out_feats * nhead))
+ self.bias = nn.Parameter(torch.FloatTensor(out_feats * nhead,))
+
+ self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_feats)))
+ self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_feats)))
+
+ self.edge_softmax = EdgeSoftmax()
+ self.mhspmm = MultiHeadSpMM()
+
+ self.dropout = nn.Dropout(attn_drop)
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
+ self.act = None if activation is None else get_activation(activation)
+ self.norm = None if norm is None else get_norm_layer(norm, out_feats * nhead)
+
+ if residual:
+ self.residual = nn.Linear(in_feats, out_feats * nhead)
+ else:
+ self.register_buffer("residual", None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # def reset(tensor):
+ # stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
+ # tensor.data.uniform_(-stdv, stdv)
+
+ # reset(self.a_l)
+ # reset(self.a_r)
+ # reset(self.W)
+
+ gain = nn.init.calculate_gain('relu')
+ nn.init.xavier_normal_(self.W, gain=gain)
+ nn.init.xavier_normal_(self.a_l, gain=gain)
+ nn.init.xavier_normal_(self.a_r, gain=gain)
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0)
+ if isinstance(self.residual, nn.Linear):
+ nn.init.xavier_normal_(self.residual.weight, gain=gain)
+
+ def forward(self, graph, x):
+ h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
+ h[torch.isnan(h)] = 0.0
+
+ row, col = graph.edge_index
+ # Self-attention on the nodes - Shared attention mechanism
+ h_l = (self.a_l * h).sum(dim=-1)
+ h_r = (self.a_r * h).sum(dim=-1)
+
+ # edge_attention: E * H
+ edge_attention = self.leakyrelu(h_l[row] + h_r[col])
+ edge_attention = self.edge_softmax(graph, edge_attention)
+ edge_attention = self.dropout(edge_attention)
+
+ out = self.mhspmm(graph, edge_attention, h)
+
+ if self.bias is not None:
+ out = out + self.bias
+
+ if self.residual:
+ res = self.residual(x)
+ out += res
+ if self.norm is not None:
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+ return out
+
+ def __repr__(self):
+ return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
diff --git a/examples/graphmae2/models/gcn.py b/examples/graphmae2/models/gcn.py
new file mode 100644
index 00000000..23292e72
--- /dev/null
+++ b/examples/graphmae2/models/gcn.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from cogdl.layers import GCNLayer
+
+from utils import create_activation, create_norm
+
+
+class GCN(nn.Module):
+ def __init__(self,
+ in_dim,
+ num_hidden,
+ out_dim,
+ num_layers,
+ dropout,
+ activation,
+ residual,
+ norm,
+ encoding=False
+ ):
+ super(GCN, self).__init__()
+ self.out_dim = out_dim
+ self.num_layers = num_layers
+ self.gcn_layers = nn.ModuleList()
+ self.activation = activation
+ self.dropout = dropout
+
+ last_activation = activation if encoding else None
+ last_residual = encoding and residual
+ last_norm = norm if encoding else None
+
+ if num_layers == 1:
+ self.gcn_layers.append(GCNLayer(
+ in_dim, out_dim, residual=last_residual, norm=last_norm, activation=last_activation))
+ else:
+ self.gcn_layers.append(GCNLayer(
+ in_dim, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
+ for l in range(1, num_layers - 1):
+ # due to multi-head, the in_dim = num_hidden * num_heads
+ self.gcn_layers.append(GCNLayer(
+ num_hidden, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
+ self.gcn_layers.append(GCNLayer(
+ num_hidden, out_dim, residual=last_residual, activation=last_activation, norm=last_norm))
+
+ self.head = nn.Identity()
+
+ def forward(self, g, inputs, return_hidden=False):
+ h = inputs
+ hidden_list = []
+ for l in range(self.num_layers):
+ h = F.dropout(h, p=self.dropout, training=self.training)
+ h = self.gcn_layers[l](g, h)
+ hidden_list.append(h)
+ if return_hidden:
+ return self.head(h), hidden_list
+ else:
+ return self.head(h)
+
+ def reset_classifier(self, num_classes):
+ self.head = nn.Linear(self.out_dim, num_classes)
diff --git a/examples/graphmae2/models/loss_func.py b/examples/graphmae2/models/loss_func.py
new file mode 100644
index 00000000..2105c64f
--- /dev/null
+++ b/examples/graphmae2/models/loss_func.py
@@ -0,0 +1,164 @@
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def auc_pair_loss(x, y, z):
+ x = F.normalize(x, p=2, dim=-1)
+ y = F.normalize(y, p=2, dim=-1)
+ z = F.normalize(z, p=2, dim=-1)
+
+ sim = (x * y).sum(dim=-1)
+ dissim = (x * z).sum(dim=-1)
+ loss = (1 - sim + dissim).mean()
+ # loss = (1 - sim).mean()
+ return loss
+
+
+def sce_loss(x, y, alpha=3):
+ x = F.normalize(x, p=2, dim=-1)
+ y = F.normalize(y, p=2, dim=-1)
+
+ # loss = - (x * y).sum(dim=-1)
+ # loss = (x_h - y_h).norm(dim=1).pow(alpha)
+
+ loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
+
+ loss = loss.mean()
+ return loss
+
+
+class DINOLoss(nn.Module):
+ def __init__(self, out_dim, warmup_teacher_temp, teacher_temp,
+ warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
+ center_momentum=0.9):
+ super().__init__()
+ self.student_temp = student_temp
+ self.center_momentum = center_momentum
+ self.register_buffer("center", torch.zeros(1, out_dim))
+ # we apply a warm up for the teacher temperature because
+ # a too high temperature makes the training instable at the beginning
+ self.teacher_temp_schedule = np.concatenate((
+ np.linspace(warmup_teacher_temp,
+ teacher_temp, warmup_teacher_temp_epochs),
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
+ ))
+
+ def forward(self, student_output, teacher_output, epoch):
+ """
+ Cross-entropy between softmax outputs of the teacher and student networks.
+ """
+ student_out = student_output / self.student_temp
+
+ # teacher centering and sharpening
+ temp = self.teacher_temp_schedule[epoch]
+ teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
+ teacher_out = teacher_out.detach()
+
+ loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1)
+ loss = loss.mean()
+ self.update_center(teacher_output)
+ return loss
+
+ # total_loss = 0
+ # n_loss_terms = 0
+ # for iq, q in enumerate(teacher_out):
+ # for v in range(len(student_out)):
+ # if v == iq:
+ # # we skip cases where student and teacher operate on the same view
+ # continue
+ # loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
+ # total_loss += loss.mean()
+ # n_loss_terms += 1
+ # total_loss /= n_loss_terms
+ # self.update_center(teacher_output)
+ # return total_loss
+
+ @torch.no_grad()
+ def update_center(self, teacher_output):
+ """
+ Update center used for teacher output.
+ """
+ batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
+
+ # ema update
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
+
+
+class MLPHead(nn.Module):
+ def __init__(self, hidden_size, out_dim, num_layers=2, bottleneck_dim=256):
+ super().__init__()
+ self._num_layers = num_layers
+ self.mlp = nn.ModuleList()
+ for i in range(num_layers):
+ if i == num_layers - 1:
+ self.mlp.append(
+ nn.Linear(hidden_size, bottleneck_dim)
+ )
+ else:
+ self.mlp.append(nn.Linear(hidden_size, hidden_size))
+ # self.mlp.append(nn.LayerNorm(hidden_size))
+ self.mlp.append(nn.PReLU())
+
+ self.apply(self._init_weights)
+ # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ # self.last_layer.weight_g.requires_grad = False
+ # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ num_layers = len(self.mlp)
+ for i, layer in enumerate(self.mlp):
+ x = layer(x)
+
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ x = self.last_layer(x)
+ return x
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
\ No newline at end of file
diff --git a/examples/graphmae2/run_fullbatch.sh b/examples/graphmae2/run_fullbatch.sh
new file mode 100644
index 00000000..44009bea
--- /dev/null
+++ b/examples/graphmae2/run_fullbatch.sh
@@ -0,0 +1,37 @@
+dataset=$1
+device=$2
+
+[ -z "${dataset}" ] && dataset="cora"
+[ -z "${device}" ] && device=0
+
+CUDA_VISIBLE_DEVICES=$device \
+ python main_full_batch.py \
+ --device 0 \
+ --dataset $dataset \
+ --mask_method "random" \
+ --remask_method "fixed" \
+ --mask_rate 0.5 \
+ --in_drop 0.2 \
+ --attn_drop 0.1 \
+ --num_layers 2 \
+ --num_dec_layers 1 \
+ --num_hidden 256 \
+ --num_heads 4 \
+ --num_out_heads 1 \
+ --encoder "gat" \
+ --decoder "gat" \
+ --max_epoch 1000 \
+ --max_epoch_f 300 \
+ --lr 0.001 \
+ --weight_decay 0.04 \
+ --lr_f 0.005 \
+ --weight_decay_f 1e-4 \
+ --activation "prelu" \
+ --loss_fn "sce" \
+ --alpha_l 3 \
+ --scheduler \
+ --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 \
+ --lam 0.5 \
+ --linear_prob \
+ --data_dir "./dataset" \
+ --use_cfg
diff --git a/examples/graphmae2/run_minibatch.sh b/examples/graphmae2/run_minibatch.sh
new file mode 100644
index 00000000..b211b5bc
--- /dev/null
+++ b/examples/graphmae2/run_minibatch.sh
@@ -0,0 +1,50 @@
+dataset=$1
+device=$2
+
+[ -z "${dataset}" ] && dataset="ogbn-arxiv"
+[ -z "${device}" ] && device=0
+
+CUDA_VISIBLE_DEVICES=$device \
+ python main_large.py \
+ --device 0 \
+ --dataset $dataset \
+ --mask_type "mask" \
+ --mask_rate 0.5 \
+ --remask_rate 0.5 \
+ --num_remasking 3 \
+ --in_drop 0.2 \
+ --attn_drop 0.2 \
+ --num_layers 4 \
+ --num_dec_layers 1 \
+ --num_hidden 1024 \
+ --num_heads 4 \
+ --num_out_heads 1 \
+ --encoder "gat" \
+ --decoder "gat" \
+ --max_epoch 60 \
+ --max_epoch_f 1000 \
+ --lr 0.002 \
+ --weight_decay 0.04 \
+ --lr_f 0.005 \
+ --weight_decay_f 1e-4 \
+ --activation "prelu" \
+ --optimizer "adamw" \
+ --drop_edge_rate 0.5 \
+ --loss_fn "sce" \
+ --alpha_l 4 \
+ --mask_method "random" \
+ --scheduler \
+ --batch_size 512 \
+ --batch_size_f 256 \
+ --seeds 0 \
+ --residual \
+ --norm "layernorm" \
+ --sampling_method "lc" \
+ --label_rate 1.0 \
+ --lam 1.0 \
+ --momentum 0.996 \
+ --linear_prob \
+ --use_cfg \
+ --ego_graph_file_path "./lc_ego_graphs/${dataset}-lc-ego-graphs-256.pt" \
+ --data_dir "./dataset" \
+ # --logging
diff --git a/examples/graphmae2/utils.py b/examples/graphmae2/utils.py
new file mode 100644
index 00000000..487a602e
--- /dev/null
+++ b/examples/graphmae2/utils.py
@@ -0,0 +1,317 @@
+import os
+import argparse
+import random
+import psutil
+import yaml
+import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+import wandb
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import optim as optim
+
+from sklearn.decomposition import PCA
+from sklearn.manifold import TSNE
+import matplotlib.pyplot as plt
+
+
+logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
+
+
+def accuracy(y_pred, y_true):
+ y_true = y_true.squeeze().long()
+ preds = y_pred.max(1)[1].type_as(y_true)
+ correct = preds.eq(y_true).double()
+ correct = correct.sum().item()
+ return correct / len(y_true)
+
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.determinstic = True
+
+
+def get_current_lr(optimizer):
+ return optimizer.state_dict()["param_groups"][0]["lr"]
+
+
+def build_args():
+ parser = argparse.ArgumentParser(description="GAT")
+ parser.add_argument("--seeds", type=int, nargs="+", default=[0])
+ parser.add_argument("--dataset", type=str, default="cora")
+ parser.add_argument("--device", type=int, default=0)
+ parser.add_argument("--max_epoch", type=int, default=500,
+ help="number of training epochs")
+ parser.add_argument("--warmup_steps", type=int, default=-1)
+
+ parser.add_argument("--num_heads", type=int, default=4,
+ help="number of hidden attention heads")
+ parser.add_argument("--num_out_heads", type=int, default=1,
+ help="number of output attention heads")
+ parser.add_argument("--num_layers", type=int, default=2,
+ help="number of hidden layers")
+ parser.add_argument("--num_dec_layers", type=int, default=1)
+ parser.add_argument("--num_remasking", type=int, default=3)
+ parser.add_argument("--num_hidden", type=int, default=512,
+ help="number of hidden units")
+ parser.add_argument("--residual", action="store_true", default=False,
+ help="use residual connection")
+ parser.add_argument("--in_drop", type=float, default=.2,
+ help="input feature dropout")
+ parser.add_argument("--attn_drop", type=float, default=.1,
+ help="attention dropout")
+ parser.add_argument("--norm", type=str, default=None)
+ parser.add_argument("--lr", type=float, default=0.001,
+ help="learning rate")
+ parser.add_argument("--weight_decay", type=float, default=0,
+ help="weight decay")
+ parser.add_argument("--negative_slope", type=float, default=0.2,
+ help="the negative slope of leaky relu")
+ parser.add_argument("--activation", type=str, default="prelu")
+ parser.add_argument("--mask_rate", type=float, default=0.5)
+ parser.add_argument("--remask_rate", type=float, default=0.5)
+ parser.add_argument("--remask_method", type=str, default="random")
+ parser.add_argument("--mask_type", type=str, default="mask",
+ help="`mask` or `drop`")
+ parser.add_argument("--mask_method", type=str, default="random")
+ parser.add_argument("--drop_edge_rate", type=float, default=0.0)
+ parser.add_argument("--drop_edge_rate_f", type=float, default=0.0)
+
+ parser.add_argument("--encoder", type=str, default="gat")
+ parser.add_argument("--decoder", type=str, default="gat")
+ parser.add_argument("--loss_fn", type=str, default="sce")
+ parser.add_argument("--alpha_l", type=float, default=2)
+ parser.add_argument("--optimizer", type=str, default="adam")
+
+ parser.add_argument("--max_epoch_f", type=int, default=300)
+ parser.add_argument("--lr_f", type=float, default=0.01)
+ parser.add_argument("--weight_decay_f", type=float, default=0.0)
+ parser.add_argument("--linear_prob", action="store_true", default=False)
+
+
+ parser.add_argument("--no_pretrain", action="store_true")
+ parser.add_argument("--load_model", action="store_true")
+ parser.add_argument("--checkpoint_path", type=str, default=None)
+ parser.add_argument("--use_cfg", action="store_true")
+ parser.add_argument("--logging", action="store_true")
+ parser.add_argument("--scheduler", action="store_true", default=False)
+
+ parser.add_argument("--batch_size", type=int, default=256)
+ parser.add_argument("--batch_size_f", type=int, default=128)
+ parser.add_argument("--sampling_method", type=str, default="saint", help="sampling method, `lc` or `saint`")
+
+ parser.add_argument("--label_rate", type=float, default=1.0)
+ parser.add_argument("--ego_graph_file_path", type=str, default=None)
+ parser.add_argument("--data_dir", type=str, default="data")
+
+ parser.add_argument("--lam", type=float, default=1.0)
+ parser.add_argument("--full_graph_forward", action="store_true", default=False)
+ parser.add_argument("--delayed_ema_epoch", type=int, default=0)
+ parser.add_argument("--replace_rate", type=float, default=0.0)
+ parser.add_argument("--momentum", type=float, default=0.996)
+
+ args = parser.parse_args()
+ return args
+
+def create_activation(name):
+ if name == "relu":
+ return nn.ReLU()
+ elif name == "gelu":
+ return nn.GELU()
+ elif name == "prelu":
+ return nn.PReLU()
+ elif name == "selu":
+ return nn.SELU()
+ elif name == "elu":
+ return nn.ELU()
+ elif name == "silu":
+ return nn.SiLU()
+ elif name is None:
+ return nn.Identity()
+ else:
+ raise NotImplementedError(f"{name} is not implemented.")
+
+
+def identity_norm(x):
+ def func(x):
+ return x
+ return func
+
+def create_norm(name):
+ if name == "layernorm":
+ return nn.LayerNorm
+ elif name == "batchnorm":
+ return nn.BatchNorm1d
+ elif name == "identity":
+ return identity_norm
+ else:
+ # print("Identity norm")
+ return None
+
+
+def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None):
+ opt_lower = opt.lower()
+ parameters = model.parameters()
+ opt_args = dict(lr=lr, weight_decay=weight_decay)
+
+ opt_split = opt_lower.split("_")
+ opt_lower = opt_split[-1]
+
+ if opt_lower == "adam":
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == "adamw":
+ optimizer = optim.AdamW(parameters, **opt_args)
+ elif opt_lower == "adadelta":
+ optimizer = optim.Adadelta(parameters, **opt_args)
+ elif opt_lower == "sgd":
+ opt_args["momentum"] = 0.9
+ return optim.SGD(parameters, **opt_args)
+ else:
+ raise NotImplementedError("Invalid optimizer")
+
+ return optimizer
+
+
+def show_occupied_memory():
+ process = psutil.Process(os.getpid())
+ return process.memory_info().rss / 1024**2
+
+
+# -------------------
+def mask_edge(graph, mask_prob):
+ E = graph.num_edges
+
+ mask_rates = torch.FloatTensor(np.ones(E) * mask_prob)
+ masks = torch.bernoulli(1 - mask_rates)
+ mask_idx = masks.nonzero().squeeze(1)
+ return mask_idx.to(graph.device)
+
+
+def drop_edge(graph, drop_rate, return_edges=False):
+ if drop_rate <= 0:
+ return graph
+
+ n_node = graph.num_nodes
+ edge_mask = mask_edge(graph, drop_rate)
+ src = graph.edge_index[0]
+ dst = graph.edge_index[1]
+
+ nsrc = src[edge_mask]
+ ndst = dst[edge_mask]
+
+ ng = Graph(edge_index=(nsrc, ndst), num_nodes=n_node)
+ ng.add_remaining_self_loops()
+
+ dsrc = src[~edge_mask]
+ ddst = dst[~edge_mask]
+
+ if return_edges:
+ return ng, (dsrc, ddst)
+ return ng
+
+def visualize(x, y, method="tsne"):
+ if torch.is_tensor(x):
+ x = x.cpu().numpy()
+
+ if torch.is_tensor(y):
+ y = y.cpu().numpy()
+
+ if method == "tsne":
+ func = TSNE(n_components=2)
+ else:
+ func = PCA(n_components=2)
+ out = func.fit_transform(x)
+ plt.scatter(out[:, 0], out[:, 1], c=y)
+ plt.savefig("vis.png")
+
+
+def load_best_configs(args):
+ dataset_name = args.dataset
+ config_path = os.path.join("configs", f"{dataset_name}.yaml")
+ with open(config_path, "r") as f:
+ configs = yaml.load(f, yaml.FullLoader)
+
+ for k, v in configs.items():
+ if "lr" in k or "weight_decay" in k:
+ v = float(v)
+ setattr(args, k, v)
+ logging.info(f"----- Using best configs from {config_path} -----")
+
+ return args
+
+
+
+def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
+ warmup_schedule = np.array([])
+ warmup_iters = warmup_epochs * niter_per_ep
+ if warmup_epochs > 0:
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+
+ schedule = np.concatenate((warmup_schedule, schedule))
+ assert len(schedule) == epochs * niter_per_ep
+ return schedule
+
+
+
+# ------ logging ------
+
+class TBLogger(object):
+ def __init__(self, log_path="./logging_data", name="run"):
+ super(TBLogger, self).__init__()
+
+ if not os.path.exists(log_path):
+ os.makedirs(log_path, exist_ok=True)
+
+ self.last_step = 0
+ self.log_path = log_path
+ raw_name = os.path.join(log_path, name)
+ name = raw_name
+ for i in range(1000):
+ name = raw_name + str(f"_{i}")
+ if not os.path.exists(name):
+ break
+ self.writer = SummaryWriter(logdir=name)
+
+ def note(self, metrics, step=None):
+ if step is None:
+ step = self.last_step
+ for key, value in metrics.items():
+ self.writer.add_scalar(key, value, step)
+ self.last_step = step
+
+ def finish(self):
+ self.writer.close()
+
+
+class WandbLogger(object):
+ def __init__(self, log_path, project, args):
+ self.log_path = log_path
+ self.project = project
+ self.args = args
+ self.last_step = 0
+ self.project = project
+ self.start()
+
+ def start(self):
+ self.run = wandb.init(config=self.args, project=self.project)
+
+ def log(self, metrics, step=None):
+ if not hasattr(self, "run"):
+ self.start()
+ if step is None:
+ step = self.last_step
+ self.run.log(metrics)
+ self.last_step = step
+
+ def finish(self):
+ self.run.finish()