diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 1614a9a0..d2d71475 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -480,16 +480,17 @@ def __init__(self, x=None, y=None, **kwargs): self.x = x self.y = y self.grb_adj = None + num_nodes = x.shape[0] if x is not None else None for key, item in kwargs.items(): if key == "num_nodes": self.__num_nodes__ = item + num_nodes = item elif key == "grb_adj": self.grb_adj = item elif not is_read_adj_key(key): self[key] = item - num_nodes = x.shape[0] if x is not None else None if "edge_index_train" in kwargs: self._adj_train = Adjacency(num_nodes=num_nodes) for key, item in kwargs.items(): @@ -534,14 +535,17 @@ def add_remaining_self_loops(self): self._adj_full.add_remaining_self_loops() if self._adj_train is not None: self._adj_train.add_remaining_self_loops() + return self def padding_self_loops(self): self._adj.padding_self_loops() + return self def remove_self_loops(self): self._adj_full.remove_self_loops() if self._adj_train is not None: self._adj_train.remove_self_loops() + return self def row_norm(self): self._adj.row_norm() @@ -790,7 +794,7 @@ def sample_adj(self, batch, size=-1, replace=True): if not torch.is_tensor(batch): batch = torch.tensor(batch, dtype=torch.long) (row_ptr, col_indices, nodes, edges) = sample_adj_c( - self._adj.row_indptr, self.col_indices, batch, size, replace + self.row_indptr, self.col_indices, batch, size, replace ) else: if torch.is_tensor(batch): @@ -891,13 +895,18 @@ def subgraph(self, node_idx, keep_order=False): val = self.edge_weight.numpy() N = self.num_nodes self[key] = sp.csr_matrix((val, (row, col)), shape=(N, N)) - sub_adj = self[key][node_idx, :][:, node_idx] + sub_adj = self[key][node_idx, :][:, node_idx].tocoo() sub_g = Graph() - sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long() - sub_g.col_indices = torch.from_numpy(sub_adj.indices).long() + # sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long() + # sub_g.col_indices = torch.from_numpy(sub_adj.indices).long() + row = torch.from_numpy(sub_adj.row).long() + col = torch.from_numpy(sub_adj.col).long() + sub_g.edge_index = (row, col) sub_g.edge_weight = torch.from_numpy(sub_adj.data) + sub_g.num_nodes = len(node_idx) for key in self.__keys__(): sub_g[key] = self[key][node_idx] + sub_g._adj._to_csr() return sub_g.to(self._adj.device) def edge_subgraph(self, edge_idx, require_idx=True): diff --git a/examples/graphmae/imgs/compare.png b/examples/graphmae/imgs/compare.png new file mode 100644 index 00000000..10ac5d4a Binary files /dev/null and b/examples/graphmae/imgs/compare.png differ diff --git a/examples/graphmae/imgs/fig.png b/examples/graphmae/imgs/fig.png new file mode 100644 index 00000000..33877e6f Binary files /dev/null and b/examples/graphmae/imgs/fig.png differ diff --git a/examples/graphmae2/README.md b/examples/graphmae2/README.md new file mode 100644 index 00000000..af672372 --- /dev/null +++ b/examples/graphmae2/README.md @@ -0,0 +1,87 @@ +

GraphMAE2: A Decoding-Enhanced Masked Self-Supervised +Graph Learner

+ +[**CogDL**](https://github.com/THUDM/cogdl) Implementation for WWW'23 paper: [GraphMAE2: A Decoding-Enhanced Masked Self-Supervised +Graph Learner](https://arxiv.org/abs/2304.04779). + + +[GraphMAE] The predecessor of this work: [GraphMAE: Self-Supervised Masked Graph Autoencoders](https://arxiv.org/abs/2205.10803) can be found [here](https://github.com/THUDM/cogdl/tree/master/examples/graphmae). + +

Dependencies

+ +* Python >= 3.7 +* [Pytorch](https://pytorch.org/) >= 1.9.0 +* [cogdl](https://github.com/THUDM/cogdl) >= 0.5.3 +* pyyaml == 5.4.1 + + +

Quick Start

+ +For quick start, you could run the scripts: + +**Node classification** + +```bash +sh run_minibatch.sh # for mini batch node classification +# example: sh run_minibatch.sh ogbn-arxiv 0 +sh run_fullbatch.sh # for full batch node classification +# example: sh run_fullbatch.sh cora 0 + +# Or you could run the code manually: +# for mini batch node classification +python main_large.py --dataset ogbn-arxiv --encoder gat --decoder gat --seed 0 --device 0 +# for full batch node classification +python main_full_batch.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0 +``` + +Supported datasets: + +* mini batch node classification: `ogbn-arxiv`, `ogbn-products`, `mag-scholar-f`, `ogbn-papers100M` +* full batch node classification: `cora`, `citeseer`, `pubmed` + +Run the scripts provided or add `--use_cfg` in command to reproduce the reported results. + +**For Large scale graphs** +Before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. By default, the program will load dataset from `./data` and save the generated local clusters to `./lc_ego_graphs`. To generate a local cluster, you should first install [localclustering](https://github.com/kfoynt/LocalGraphClustering) and then run the following command: + +``` +python ./datasets/localclustering.py --dataset --data_dir +``` +And we also provide the pre-generated local clusters which can be downloaded [here](https://cloud.tsinghua.edu.cn/d/64f859f389ca43eda472/) and then put into `lc_ego_graphs` for usage. + + + +

Datasets

+ +During the code's execution, the OGB and small-scale datasets (Cora, Citeseer, and PubMed) will be downloaded automatically. + +

Experimental Results

+ +Experimental results of node classification on large-scale datasets (Accuracy, %): + +| | Ogbn-arxiv | Ogbn-products | Mag-Scholar-F | Ogbn-papers100M | +| ------------------ | ------------ | ------------ | ------------ | -------------- | +| MLP | 55.50±0.23 | 61.06±0.08 | 39.11±0.21 | 47.24±0.31 | +| SGC | 66.92±0.08 | 74.87±0.25 | 54.68±0.23 | 63.29±0.19 | +| Random-Init | 68.14±0.02 | 74.04±0.06 | 56.57±0.03 | 61.55±0.12 | +| CCA-SSG | 68.57±0.02 | 75.27±0.05 | 51.55±0.03 | 55.67±0.15 | +| GRACE | 69.34±0.01 | 79.47±0.59 | 57.39±0.02 | 61.21±0.12 | +| BGRL | 70.51±0.03 | 78.59±0.02 | 57.57±0.01 | 62.18±0.15 | +| GGD | - | 75.70±0.40 | - | 63.50±0.50 | +| GraphMAE | 71.03±0.02 | 78.89±0.01 | 58.75±0.03 | 62.54±0.09 | +| **GraphMAE2** | **71.89±0.03** | **81.59±0.02** | **59.24±0.01** | **64.89±0.04** | + + + +

Citing

+ +If you find this work is helpful to your research, please consider citing our paper: + +``` +@inproceedings{hou2023graphmae2, + title={GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner}, + author={Zhenyu Hou, Yufei He, Yukuo Cen, Xiao Liu, Yuxiao Dong, Evgeny Kharlamov, Jie Tang}, + booktitle={Proceedings of the ACM Web Conference 2023 (WWW’23)}, + year={2023} +} +``` diff --git a/examples/graphmae2/asserts/overview.png b/examples/graphmae2/asserts/overview.png new file mode 100644 index 00000000..c1ed51eb Binary files /dev/null and b/examples/graphmae2/asserts/overview.png differ diff --git a/examples/graphmae2/configs/citeseer.yaml b/examples/graphmae2/configs/citeseer.yaml new file mode 100644 index 00000000..013838d2 --- /dev/null +++ b/examples/graphmae2/configs/citeseer.yaml @@ -0,0 +1,27 @@ +lr: 0.0005 # 0.0005 +lr_f: 0.025 +num_hidden: 1024 +num_heads: 4 +num_out_heads: 1 +num_layers: 2 +weight_decay: 1e-4 +weight_decay_f: 1e-2 +max_epoch: 500 +max_epoch_f: 500 +mask_rate: 0.5 +num_layers: 2 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.1 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.0 +optimizer: adam +replace_rate: 0.0 +alpha_l: 1 +scheduler: True +remask_method: fixed +momentum: 1 +lam: 0.1 \ No newline at end of file diff --git a/examples/graphmae2/configs/cora.yaml b/examples/graphmae2/configs/cora.yaml new file mode 100644 index 00000000..c0d9f2cb --- /dev/null +++ b/examples/graphmae2/configs/cora.yaml @@ -0,0 +1,27 @@ +lr: 0.001 +lr_f: 0.025 +num_hidden: 1024 +num_heads: 8 +num_out_heads: 1 +num_layers: 2 +weight_decay: 2e-4 +weight_decay_f: 1e-4 +max_epoch: 2000 +max_epoch_f: 300 +mask_rate: 0.5 +num_layers: 2 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.1 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.0 +optimizer: adam +replace_rate: 0.1 +alpha_l: 4 +scheduler: True +remask_method: fixed +momentum: 0 +lam: 0.1 \ No newline at end of file diff --git a/examples/graphmae2/configs/mag-scholar-f.yaml b/examples/graphmae2/configs/mag-scholar-f.yaml new file mode 100644 index 00000000..46aead7f --- /dev/null +++ b/examples/graphmae2/configs/mag-scholar-f.yaml @@ -0,0 +1,30 @@ +lr: 0.001 +lr_f: 0.001 +num_hidden: 1024 +num_heads: 8 +num_out_heads: 1 +num_layers: 4 +weight_decay: 0.04 +weight_decay_f: 0 +max_epoch: 10 +max_epoch_f: 1000 +batch_size: 512 +batch_size_f: 256 +mask_rate: 0.5 +num_layers: 4 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.2 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.5 +optimizer: adamw +alpha_l: 2 +scheduler: True +remask_method: random +momentum: 0.996 +lam: 0.1 +delayed_ema_epoch: 0 +num_remasking: 3 \ No newline at end of file diff --git a/examples/graphmae2/configs/ogbn-arxiv.yaml b/examples/graphmae2/configs/ogbn-arxiv.yaml new file mode 100644 index 00000000..a7ca3ccf --- /dev/null +++ b/examples/graphmae2/configs/ogbn-arxiv.yaml @@ -0,0 +1,30 @@ +lr: 0.0025 +lr_f: 0.005 +num_hidden: 1024 +num_heads: 8 +num_out_heads: 1 +num_layers: 4 +weight_decay: 0.06 +weight_decay_f: 1e-4 +max_epoch: 60 +max_epoch_f: 1000 +batch_size: 512 +batch_size_f: 256 +mask_rate: 0.5 +num_layers: 4 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.1 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.5 +optimizer: adamw +alpha_l: 6 +scheduler: True +remask_method: random +momentum: 0.996 +lam: 10.0 +delayed_ema_epoch: 40 +num_remasking: 3 \ No newline at end of file diff --git a/examples/graphmae2/configs/ogbn-papers100M.yaml b/examples/graphmae2/configs/ogbn-papers100M.yaml new file mode 100644 index 00000000..4dba4f98 --- /dev/null +++ b/examples/graphmae2/configs/ogbn-papers100M.yaml @@ -0,0 +1,30 @@ +lr: 0.001 +lr_f: 0.001 +num_hidden: 1024 +num_heads: 4 +num_out_heads: 1 +num_layers: 4 +weight_decay: 0.05 +weight_decay_f: 0 +max_epoch: 10 +max_epoch_f: 1000 +batch_size: 512 +batch_size_f: 256 +mask_rate: 0.5 +num_layers: 4 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.2 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.5 +optimizer: adamw +alpha_l: 2 +scheduler: True +remask_method: random +momentum: 0.996 +lam: 10.0 +delayed_ema_epoch: 0 +num_remasking: 3 \ No newline at end of file diff --git a/examples/graphmae2/configs/ogbn-products.yaml b/examples/graphmae2/configs/ogbn-products.yaml new file mode 100644 index 00000000..d822d49b --- /dev/null +++ b/examples/graphmae2/configs/ogbn-products.yaml @@ -0,0 +1,30 @@ +lr: 0.002 +lr_f: 0.001 +num_hidden: 1024 +num_heads: 4 +num_out_heads: 1 +num_layers: 4 +weight_decay: 0.04 +weight_decay_f: 0 +max_epoch: 20 +max_epoch_f: 1000 +batch_size: 512 +batch_size_f: 256 +mask_rate: 0.5 +num_layers: 4 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.2 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.5 +optimizer: adamw +alpha_l: 3 +scheduler: True +remask_method: random +momentum: 0.996 +lam: 5.0 +delayed_ema_epoch: 0 +num_remasking: 3 \ No newline at end of file diff --git a/examples/graphmae2/configs/pubmed.yaml b/examples/graphmae2/configs/pubmed.yaml new file mode 100644 index 00000000..d80fb21d --- /dev/null +++ b/examples/graphmae2/configs/pubmed.yaml @@ -0,0 +1,27 @@ +lr: 0.005 +lr_f: 0.025 +num_hidden: 512 +num_heads: 2 +num_out_heads: 1 +num_layers: 2 +weight_decay: 1e-5 +weight_decay_f: 5e-4 +max_epoch: 2000 +max_epoch_f: 500 +mask_rate: 0.9 +num_layers: 2 +encoder: gat +decoder: gat +activation: prelu +attn_drop: 0.1 +linear_prob: True +in_drop: 0.2 +loss_fn: sce +drop_edge_rate: 0.0 +optimizer: adam +replace_rate: 0.0 +alpha_l: 4 +scheduler: True +remask_method: fixed +momentum: 0.995 +lam: 1 diff --git a/examples/graphmae2/datasets/__init__.py b/examples/graphmae2/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/graphmae2/datasets/data_proc.py b/examples/graphmae2/datasets/data_proc.py new file mode 100644 index 00000000..4710907a --- /dev/null +++ b/examples/graphmae2/datasets/data_proc.py @@ -0,0 +1,50 @@ +import logging +from typing import Counter +from xml.sax.handler import feature_string_interning +import numpy as np +from collections import namedtuple +import scipy.sparse as sp +from scipy.special import iv +from sklearn import preprocessing +from sklearn.utils.extmath import randomized_svd + +import torch +import torch.nn.functional as F + +from cogdl.utils import to_undirected +from cogdl.datasets import build_dataset_from_path +from sklearn.preprocessing import StandardScaler + +def load_small_dataset(data_dir, dataset_name): + dataset = build_dataset_from_path(data_dir, dataset=dataset_name) + if dataset_name == "ogbn-arxiv": + graph = dataset[0] + feat = graph.x + feat = scale_feats(feat) + graph.x = feat + else: + graph = dataset[0] + graph.add_remaining_self_loops() + + num_features = graph.x.shape[1] + num_classes = dataset.num_classes + return graph, (num_features, num_classes) + +def preprocess(graph): + feat = graph.x + edge_index = graph.edge_index + edge_index = to_undirected(edge_index, num_nodes=graph.num_nodes) + graph.edge_index = edge_index + graph.x = feat + + graph.add_remaining_self_loops() + return graph + + +def scale_feats(x): + logging.info("### scaling features ###") + scaler = StandardScaler() + feats = x.numpy() + scaler.fit(feats) + feats = torch.from_numpy(scaler.transform(feats)).float() + return feats diff --git a/examples/graphmae2/datasets/lc_sampler.py b/examples/graphmae2/datasets/lc_sampler.py new file mode 100644 index 00000000..ff2e680e --- /dev/null +++ b/examples/graphmae2/datasets/lc_sampler.py @@ -0,0 +1,156 @@ +import os + +import numpy as np +import torch + +import cogdl +from cogdl.datasets import build_dataset_from_path +from cogdl.data import Graph + +from .data_proc import preprocess, scale_feats +from utils import mask_edge + +import logging +import torch.multiprocessing +from torch.utils.data import DataLoader +from tqdm import tqdm +torch.multiprocessing.set_sharing_strategy('file_system') + + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) + +def load_dataset(data_dir, dataset_name): + dataset = build_dataset_from_path(data_dir, dataset=dataset_name) + if dataset_name.startswith("ogbn"): + graph = dataset[0] + logging.info("--- to undirected graph ---") + graph = preprocess(graph) + feat = graph.x + feat = scale_feats(feat) + graph.x = feat + graph.add_remaining_self_loops() + + # num_features = graph.x.shape[1] + # num_classes = dataset.num_classes + # return graph, (num_features, num_classes) + # feats, graph, labels, split_idx + train_idx = graph.train_mask.nonzero().squeeze(1) + val_idx = graph.val_mask.nonzero().squeeze(1) + test_idx = graph.test_mask.nonzero().squeeze(1) + split_idx = {"train": train_idx, "valid": val_idx, "test": test_idx} + return graph.x, graph, graph.y, split_idx + + +class LinearProbingDataLoader(DataLoader): + def __init__(self, idx, feats, labels=None, **kwargs): + self.labels = labels + self.feats = feats + + kwargs["collate_fn"] = self.__collate_fn__ + super().__init__(dataset=idx, **kwargs) + + def __collate_fn__(self, batch_idx): + feats = self.feats[batch_idx] + label = self.labels[batch_idx] + + return feats, label + +class OnlineLCLoader(DataLoader): + def __init__(self, root_nodes, graph, feats, labels=None, drop_edge_rate=0, **kwargs): + self.graph = graph + self.labels = labels + self._drop_edge_rate = drop_edge_rate + self.ego_graph_nodes = root_nodes + self.feats = feats + + dataset = np.arange(len(root_nodes)) + kwargs["collate_fn"] = self.__collate_fn__ + super().__init__(dataset, **kwargs) + + def drop_edge(self, g): + if self._drop_edge_rate <= 0: + return g, g + n_node = g.num_nodes + g = g.remove_self_loops() + mask_index1 = mask_edge(g, self._drop_edge_rate) + mask_index2 = mask_edge(g, self._drop_edge_rate) + src = g.edge_index[0] + dst = g.edge_index[1] + nsrc1 = src[mask_index1] + ndst1 = dst[mask_index1] + nsrc2 = src[mask_index2] + ndst2 = dst[mask_index2] + g1 = Graph(edge_index=(nsrc1, ndst1), num_nodes=n_node) + g1.add_remaining_self_loops() + g2 = Graph(edge_index=(nsrc2, ndst2), num_nodes=n_node) + g2.add_remaining_self_loops() + + return g1, g2 + + def __collate_fn__(self, batch_idx): + ego_nodes = [self.ego_graph_nodes[i] for i in batch_idx] + subgs = [self.graph.subgraph(ego_nodes[i]) for i in range(len(ego_nodes))] + + sg = cogdl.data.batch_graphs(subgs) + + nodes = torch.from_numpy(np.concatenate(ego_nodes)).long() + num_nodes = [x.shape[0] for x in ego_nodes] + cum_num_nodes = np.cumsum([0] + num_nodes)[:-1] + + if self._drop_edge_rate > 0: + drop_g1, drop_g2 = self.drop_edge(sg) + + sg = sg.add_remaining_self_loops() + sg.x = self.feats[nodes] + targets = torch.from_numpy(cum_num_nodes) + + if self.labels != None: + label = self.labels[batch_idx] + else: + label = None + + if self._drop_edge_rate > 0: + return sg, targets, label, nodes, drop_g1, drop_g2 + else: + return sg, targets, label, nodes + + +def setup_training_data(dataset_name, data_dir, ego_graphs_file_path): + feats, graph, labels, split_idx = load_dataset(data_dir, dataset_name) + + train_lbls = labels[split_idx["train"]] + val_lbls = labels[split_idx["valid"]] + test_lbls = labels[split_idx["test"]] + + labels = torch.cat([train_lbls, val_lbls, test_lbls]) + + if not os.path.exists(ego_graphs_file_path): + raise FileNotFoundError(f"{ego_graphs_file_path} doesn't exist") + else: + nodes = torch.load(ego_graphs_file_path) + + return feats, graph, labels, split_idx, nodes + + +def setup_training_dataloder(loader_type, training_nodes, graph, feats, batch_size, drop_edge_rate=0, pretrain_clustergcn=False, cluster_iter_data=None): + num_workers = 8 + + print(" -------- drop edge rate: {} --------".format(drop_edge_rate)) + dataloader = OnlineLCLoader(training_nodes, graph, feats=feats, drop_edge_rate=drop_edge_rate, batch_size=batch_size, shuffle=True, drop_last=False, persistent_workers=True, num_workers=num_workers) + return dataloader + + +def setup_eval_dataloder(loader_type, graph, feats, ego_graph_nodes=None, batch_size=128, shuffle=False): + num_workers = 8 + if loader_type == "lc": + assert ego_graph_nodes is not None + + dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, persistent_workers=True, num_workers=num_workers) + return dataloader + + +def setup_finetune_dataloder(loader_type, graph, feats, ego_graph_nodes, labels, batch_size, shuffle=False): + num_workers = 8 + dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, labels=labels, feats=feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, num_workers=num_workers, persistent_workers=True) + + return dataloader diff --git a/examples/graphmae2/datasets/localclustering.py b/examples/graphmae2/datasets/localclustering.py new file mode 100644 index 00000000..8b648386 --- /dev/null +++ b/examples/graphmae2/datasets/localclustering.py @@ -0,0 +1,157 @@ +import argparse +from collections import namedtuple +import multiprocessing +import os + +import numpy as np +from localgraphclustering import * +from scipy.sparse import csr_matrix +from ogb.nodeproppred import DglNodePropPredDataset +import torch +import logging + +import dgl +from dgl.data import load_data + + +def my_sweep_cut(g, node): + vol_sum = 0.0 + in_edge = 0.0 + conds = np.zeros_like(node, dtype=np.float32) + for i in range(len(node)): + idx = node[i] + vol_sum += g.d[idx] + denominator = min(vol_sum, g.vol_G - vol_sum) + if denominator == 0.0: + denominator = 1.0 + in_edge += 2*sum([g.adjacency_matrix[idx,prev] for prev in node[:i+1]]) + cut = vol_sum - in_edge + conds[i] = cut/denominator + return conds + + +def calc_local_clustering(args): + i, log_steps, num_iter, ego_size, method = args + if i % log_steps == 0: + print(i) + node, ppr = approximate_PageRank(graphlocal, [i], iterations=num_iter, method=method, normalize=False) + d_inv = graphlocal.dn[node] + d_inv[d_inv > 1.0] = 1.0 + ppr_d_inv = ppr * d_inv + output = list(zip(node, ppr_d_inv))[:ego_size] + node, ppr_d_inv = zip(*sorted(output, key=lambda x: x[1], reverse=True)) + assert node[0] == i + node = np.array(node, dtype=np.int32) + conds = my_sweep_cut(graphlocal, node) + return node, conds + + +def step1_local_clustering(data, name, idx_split, ego_size=128, num_iter=1000, log_steps=10000, num_workers=16, method='acl', save_dir=None): + if save_dir is None: + save_path = f"{name}-lc-ego-graphs-{ego_size}.pt" + else: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + save_path = os.path.join(save_dir, f"{name}-lc-ego-graphs-{ego_size}.pt") + + N = data.num_nodes() + edge_index = data.edges() + edge_index = (edge_index[0].numpy(), edge_index[1].numpy()) + adj = csr_matrix((np.ones(edge_index[0].shape[0]), edge_index), shape=(N, N)) + + global graphlocal + graphlocal = GraphLocal.from_sparse_adjacency(adj) + print('graphlocal generated') + + train_idx = idx_split["train"].cpu().numpy() + valid_idx = idx_split["valid"].cpu().numpy() + test_idx = idx_split["test"].cpu().numpy() + + with multiprocessing.Pool(num_workers) as pool: + ego_graphs_train, conds_train = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in train_idx], chunksize=512)) + + with multiprocessing.Pool(num_workers) as pool: + ego_graphs_valid, conds_valid = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in valid_idx], chunksize=512)) + + with multiprocessing.Pool(num_workers) as pool: + ego_graphs_test, conds_test = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in test_idx], chunksize=512)) + + ego_graphs = [] + conds = [] + ego_graphs.extend(ego_graphs_train) + ego_graphs.extend(ego_graphs_valid) + ego_graphs.extend(ego_graphs_test) + conds.extend(conds_train) + conds.extend(conds_valid) + conds.extend(conds_test) + + ego_graphs = [ego_graphs_train, ego_graphs_valid, ego_graphs_test] + + torch.save(ego_graphs, save_path) + + +def preprocess(graph): + # make bidirected + if "feat" in graph.ndata: + feat = graph.ndata["feat"] + else: + feat = None + # src, dst = graph.all_edges() + # graph.add_edges(dst, src) + graph = dgl.to_bidirected(graph) + if feat is not None: + graph.ndata["feat"] = feat + + # add self-loop + graph = graph.remove_self_loop().add_self_loop() + # graph.create_formats_() + return graph + + +def load_dataset(data_dir, dataset_name): + if dataset_name.startswith("ogbn"): + dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset")) + graph, label = dataset[0] + + if "year" in graph.ndata: + del graph.ndata["year"] + if not graph.is_multigraph: + graph = preprocess(graph) + # graph = graph.remove_self_loop().add_self_loop() + + split_idx = dataset.get_idx_split() + label = label.view(-1) + + elif dataset_name == "mag_scholar_f": + edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy")) + print(len(edge_index[0])) + graph = dgl.DGLGraph((edge_index[0], edge_index[1])) + print(graph) + num_nodes = graph.num_nodes() + assert num_nodes == 12403930 + split_idx = torch.load(os.path.join(data_dir, dataset_name, "split_idx_f.pt")) + else: + raise NotImplementedError + + return graph, split_idx + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='LCGNN (Preprocessing)') + parser.add_argument('--dataset', type=str, default='flickr') + parser.add_argument("--data_dir", type=str, default="data") + parser.add_argument("--save_dir", type=str, default="lc_ego_graphs") + parser.add_argument('--ego_size', type=int, default=256) + parser.add_argument('--num_iter', type=int, default=1000) + parser.add_argument('--log_steps', type=int, default=10000) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--method', type=str, default='acl') + parser.add_argument('--num_workers', type=int, default=16) + args = parser.parse_args() + print(args) + + np.random.seed(args.seed) + + graph, split_idx = load_dataset(args.data_dir, args.dataset) + step1_local_clustering(graph, args.dataset, split_idx, args.ego_size, args.num_iter, args.log_steps, args.num_workers, args.method, args.save_dir) diff --git a/examples/graphmae2/main_full_batch.py b/examples/graphmae2/main_full_batch.py new file mode 100644 index 00000000..4407f5f6 --- /dev/null +++ b/examples/graphmae2/main_full_batch.py @@ -0,0 +1,134 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch + +from utils import ( + build_args, + create_optimizer, + set_random_seed, + TBLogger, + get_current_lr, + load_best_configs, +) +from datasets.data_proc import load_small_dataset +from models.finetune import linear_probing_full_batch +from models import build_model + + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) + + +def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None): + logging.info("start training..") + graph = graph.to(device) + x = feat.to(device) + + target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long) + epoch_iter = tqdm(range(max_epoch)) + + for epoch in epoch_iter: + model.train() + + loss = model(graph, x, targets=target_nodes) + + loss_dict = {"loss": loss.item()} + optimizer.zero_grad() + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + + epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}") + if logger is not None: + loss_dict["lr"] = get_current_lr(optimizer) + logger.note(loss_dict, step=epoch) + + if (epoch + 1) % 200 == 0: + linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True) + + return model + + +def main(args): + device = args.device if args.device >= 0 else "cpu" + seeds = args.seeds + dataset_name = args.dataset + max_epoch = args.max_epoch + max_epoch_f = args.max_epoch_f + num_hidden = args.num_hidden + num_layers = args.num_layers + encoder_type = args.encoder + decoder_type = args.decoder + replace_rate = args.replace_rate + + optim_type = args.optimizer + loss_fn = args.loss_fn + + lr = args.lr + weight_decay = args.weight_decay + lr_f = args.lr_f + weight_decay_f = args.weight_decay_f + linear_prob = args.linear_prob + load_model = args.load_model + logs = args.logging + use_scheduler = args.scheduler + data_dir = args.data_dir + + graph, (num_features, num_classes) = load_small_dataset(data_dir, dataset_name) + args.num_features = num_features + + acc_list = [] + estp_acc_list = [] + for i, seed in enumerate(seeds): + print(f"####### Run {i} for seed {seed}") + set_random_seed(seed) + + if logs: + logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}") + else: + logger = None + + model = build_model(args) + model.to(device) + optimizer = create_optimizer(optim_type, model, lr, weight_decay) + + if use_scheduler: + logging.info("Use schedular") + scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5 + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) + else: + scheduler = None + + x = graph.x + if not load_model: + model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger) + model = model.cpu() + + if load_model: + logging.info("Loading Model ... ") + model.load_state_dict(torch.load("checkpoint.pt")) + + model = model.to(device) + model.eval() + + final_acc, estp_acc = linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob) + acc_list.append(final_acc) + estp_acc_list.append(estp_acc) + + if logger is not None: + logger.finish() + + final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) + estp_acc, estp_acc_std = np.mean(estp_acc_list), np.std(estp_acc_list) + print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}") + print(f"# early-stopping_acc: {estp_acc:.4f}±{estp_acc_std:.4f}") + + +# Press the green button in the gutter to run the script. +if __name__ == "__main__": + args = build_args() + if args.use_cfg: + args = load_best_configs(args) + print(args) + main(args) diff --git a/examples/graphmae2/main_large.py b/examples/graphmae2/main_large.py new file mode 100644 index 00000000..21541e06 --- /dev/null +++ b/examples/graphmae2/main_large.py @@ -0,0 +1,270 @@ +import logging +import os +import numpy as np +from tqdm import tqdm + +import torch + +from utils import ( + WandbLogger, + build_args, + create_optimizer, + set_random_seed, + load_best_configs, + show_occupied_memory, +) +from models import build_model +from datasets.lc_sampler import ( + setup_training_dataloder, + setup_training_data, +) +from models.finetune import linear_probing_minibatch, finetune + +import warnings + +warnings.filterwarnings("ignore") +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) + + +def evaluate( + model, + graph, feats, labels, + split_idx, + lr_f, weight_decay_f, max_epoch_f, + linear_prob=True, + device=0, + batch_size=256, + logger=None, ego_graph_nodes=None, + label_rate=1.0, + full_graph_forward=False, + shuffle=True, +): + logging.info("Using `lc` for evaluation...") + num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]] + print(num_train,num_val,num_test) + + train_g_idx = np.arange(0, num_train) + val_g_idx = np.arange(num_train, num_train+num_val) + test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test) + + train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx] + val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx] + test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx] + + train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx] + + # labels = [train_lbls, val_lbls, test_lbls] + assert len(train_ego_graph_nodes) == len(train_lbls) + assert len(val_ego_graph_nodes) == len(val_lbls) + assert len(test_ego_graph_nodes) == len(test_lbls) + + print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") + logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") + + + if linear_prob: + result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle) + else: + max_epoch_f = max_epoch_f // 2 + + if label_rate < 1.0: + rand_idx = np.arange(len(train_ego_graph_nodes)) + np.random.shuffle(rand_idx) + rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))] + train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx] + train_lbls = train_lbls[rand_idx] + + logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") + + # train_lbls = (all_train_lbls, train_lbls) + result = finetune( + model, graph, feats, + [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], + [train_lbls, val_lbls, test_lbls], + split_idx=split_idx, + lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward, + ) + return result + + +def pretrain(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0): + logging.info("start training..") + + model = model.to(device) + optimizer = create_optimizer(optimizer, model, lr, weight_decay) + + dataloader = setup_training_dataloder( + sampling_method, ego_graph_nodes, graph, feats, batch_size=batch_size, drop_edge_rate=drop_edge_rate) + + logging.info(f"After creating dataloader: Memory: {show_occupied_memory():.2f} MB") + if use_scheduler and max_epoch > 0: + logging.info("Use scheduler") + scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5 + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) + else: + scheduler = None + + for epoch in range(max_epoch): + epoch_iter = tqdm(dataloader) + losses = [] + # assert (graph.in_degrees() > 0).all(), "after loading" + + for batch_g in epoch_iter: + model.train() + if drop_edge_rate > 0: + batch_g, targets, _, node_idx, drop_g1, drop_g2 = batch_g + batch_g = batch_g.to(device) + drop_g1 = drop_g1.to(device) + drop_g2 = drop_g2.to(device) + x = batch_g.x + loss = model(batch_g, x, targets, epoch, drop_g1, drop_g2) + else: + batch_g, targets, _, node_idx = batch_g + batch_g = batch_g.to(device) + x = batch_g.x + loss = model(batch_g, x, targets, epoch) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 3) + optimizer.step() + + epoch_iter.set_description(f"train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB") + losses.append(loss.item()) + + if scheduler is not None: + scheduler.step() + + torch.save(model.state_dict(), os.path.join(model_dir, model_name)) + + print(f"# Epoch {epoch} | train_loss: {np.mean(losses):.4f}, Memory: {show_occupied_memory():.2f} MB") + + return model + + +if __name__ == "__main__": + args = build_args() + if args.use_cfg: + args = load_best_configs(args) + + if args.device < 0: + device = "cpu" + else: + device = "cuda:{}".format(args.device) + seeds = args.seeds + dataset_name = args.dataset + max_epoch = args.max_epoch + max_epoch_f = args.max_epoch_f + num_hidden = args.num_hidden + num_layers = args.num_layers + encoder_type = args.encoder + decoder_type = args.decoder + encoder = args.encoder + decoder = args.decoder + num_hidden = args.num_hidden + drop_edge_rate = args.drop_edge_rate + + optim_type = args.optimizer + loss_fn = args.loss_fn + + lr = args.lr + weight_decay = args.weight_decay + lr_f = args.lr_f + weight_decay_f = args.weight_decay_f + linear_prob = args.linear_prob + load_model = args.load_model + no_pretrain = args.no_pretrain + logs = args.logging + use_scheduler = args.scheduler + batch_size = args.batch_size + batch_size_f = args.batch_size_f + sampling_method = args.sampling_method + ego_graph_file_path = args.ego_graph_file_path + data_dir = args.data_dir + + n_procs = torch.cuda.device_count() + optimizer_type = args.optimizer + label_rate = args.label_rate + lam = args.lam + full_graph_forward = hasattr(args, "full_graph_forward") and args.full_graph_forward and not linear_prob + + model_dir = "checkpoints" + os.makedirs(model_dir, exist_ok=True) + + set_random_seed(0) + print(args) + + logging.info(f"Before loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB + feats, graph, labels, split_idx, ego_graph_nodes = setup_training_data(dataset_name, data_dir, ego_graph_file_path) + if dataset_name == "ogbn-papers100M": + pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] + ego_graph_nodes[3] + else: + pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] + ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] # * merge train/val/test = all + + logging.info(f"After loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB + + args.num_features = feats.shape[1] + + if logs: + logger = WandbLogger(log_path=f"{dataset_name}_loss_{loss_fn}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}", project="GraphMAE2", args=args) + else: + logger = None + model_name = f"{encoder}_{decoder}_{num_hidden}_{num_layers}_{dataset_name}_{args.mask_rate}_{num_hidden}_checkpoint.pt" + + model = build_model(args) + + if not args.no_pretrain: + # ------------- pretraining starts ---------------- + if not load_model: + logging.info("---- start pretraining ----") + model = pretrain(model, feats, graph, pretrain_ego_graph_nodes, max_epoch=max_epoch, device=device, use_scheduler=use_scheduler, lr=lr, + weight_decay=weight_decay, batch_size=batch_size, drop_edge_rate=drop_edge_rate, + sampling_method=sampling_method, optimizer=optimizer_type) + + model = model.cpu() + logging.info(f"saving model to {model_dir}/{model_name}...") + torch.save(model.state_dict(), os.path.join(model_dir, model_name)) + # ------------- pretraining ends ---------------- + + if load_model: + model.load_state_dict(torch.load(os.path.join(args.checkpoint_path))) + logging.info(f"Loading Model from {args.checkpoint_path}...") + else: + logging.info("--- no pretrain ---") + + model = model.to(device) + model.eval() + + logging.info("---- start finetuning / evaluation ----") + + final_accs = [] + for i,_ in enumerate(seeds): + print(f"####### Run seed {seeds[i]}") + set_random_seed(seeds[i]) + eval_model = build_model(args) + eval_model.load_state_dict(model.state_dict()) + eval_model.to(device) + + print(f"features size : {feats.shape[1]}") + logging.info("start evaluation...") + final_acc = evaluate( + eval_model, graph, feats, labels, split_idx, + lr_f, weight_decay_f, max_epoch_f, + device=device, + batch_size=batch_size_f, + ego_graph_nodes=ego_graph_nodes, + linear_prob=linear_prob, + label_rate=label_rate, + full_graph_forward=full_graph_forward, + shuffle=False if dataset_name == "ogbn-papers100M" else True + ) + + final_accs.append(float(final_acc)) + + print(f"Run {seeds[i]} | TestAcc: {final_acc:.4f}") + + print(f"# final_acc: {np.mean(final_accs):.4f}, std: {np.std(final_accs):.4f}") + + if logger is not None: + logger.finish() diff --git a/examples/graphmae2/models/__init__.py b/examples/graphmae2/models/__init__.py new file mode 100644 index 00000000..363ff1a9 --- /dev/null +++ b/examples/graphmae2/models/__init__.py @@ -0,0 +1,62 @@ +from .edcoder import PreModel + + +def build_model(args): + num_heads = args.num_heads + num_out_heads = args.num_out_heads + num_hidden = args.num_hidden + num_layers = args.num_layers + residual = args.residual + attn_drop = args.attn_drop + in_drop = args.in_drop + norm = args.norm + negative_slope = args.negative_slope + encoder_type = args.encoder + decoder_type = args.decoder + mask_rate = args.mask_rate + remask_rate = args.remask_rate + mask_method = args.mask_method + drop_edge_rate = args.drop_edge_rate + + activation = args.activation + loss_fn = args.loss_fn + alpha_l = args.alpha_l + + num_features = args.num_features + num_dec_layers = args.num_dec_layers + num_remasking = args.num_remasking + lam = args.lam + delayed_ema_epoch = args.delayed_ema_epoch + replace_rate = args.replace_rate + remask_method = args.remask_method + momentum = args.momentum + + model = PreModel( + in_dim=num_features, + num_hidden=num_hidden, + num_layers=num_layers, + num_dec_layers=num_dec_layers, + num_remasking=num_remasking, + nhead=num_heads, + nhead_out=num_out_heads, + activation=activation, + feat_drop=in_drop, + attn_drop=attn_drop, + negative_slope=negative_slope, + residual=residual, + encoder_type=encoder_type, + decoder_type=decoder_type, + mask_rate=mask_rate, + remask_rate=remask_rate, + mask_method=mask_method, + norm=norm, + loss_fn=loss_fn, + drop_edge_rate=drop_edge_rate, + alpha_l=alpha_l, + lam=lam, + delayed_ema_epoch=delayed_ema_epoch, + replace_rate=replace_rate, + remask_method=remask_method, + momentum=momentum + ) + return model diff --git a/examples/graphmae2/models/edcoder.py b/examples/graphmae2/models/edcoder.py new file mode 100644 index 00000000..0b46dda2 --- /dev/null +++ b/examples/graphmae2/models/edcoder.py @@ -0,0 +1,364 @@ +from itertools import chain + +from typing import Optional +import torch +import torch.nn as nn +from functools import partial + +from .gat import GAT + +from .loss_func import sce_loss + + +def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm, nhead, nhead_out, attn_drop, negative_slope=0.2, concat_out=True, **kwargs) -> nn.Module: + if m_type in ("gat", "tsgat"): + mod = GAT( + in_dim=in_dim, + num_hidden=num_hidden, + out_dim=out_dim, + num_layers=num_layers, + nhead=nhead, + nhead_out=nhead_out, + concat_out=concat_out, + activation=activation, + feat_drop=dropout, + attn_drop=attn_drop, + negative_slope=negative_slope, + residual=residual, + norm=norm, + encoding=(enc_dec == "encoding"), + **kwargs, + ) + elif m_type == "mlp": + # * just for decoder + mod = nn.Sequential( + nn.Linear(in_dim, num_hidden * 2), + nn.PReLU(), + nn.Dropout(0.2), + nn.Linear(num_hidden * 2, out_dim) + ) + elif m_type == "linear": + mod = nn.Linear(in_dim, out_dim) + else: + raise NotImplementedError + + return mod + +class PreModel(nn.Module): + def __init__( + self, + in_dim: int, + num_hidden: int, + num_layers: int, + num_dec_layers: int, + num_remasking: int, + nhead: int, + nhead_out: int, + activation: str, + feat_drop: float, + attn_drop: float, + negative_slope: float, + residual: bool, + norm: Optional[str], + mask_rate: float = 0.3, + remask_rate: float = 0.5, + remask_method: str = "random", + mask_method: str = "random", + encoder_type: str = "gat", + decoder_type: str = "gat", + loss_fn: str = "byol", + drop_edge_rate: float = 0.0, + alpha_l: float = 2, + lam: float = 1.0, + delayed_ema_epoch: int = 0, + momentum: float = 0.996, + replace_rate: float = 0.0, + ): + super(PreModel, self).__init__() + self._mask_rate = mask_rate + self._remask_rate = remask_rate + self._mask_method = mask_method + self._alpha_l = alpha_l + self._delayed_ema_epoch = delayed_ema_epoch + + self.num_remasking = num_remasking + self._encoder_type = encoder_type + self._decoder_type = decoder_type + self._drop_edge_rate = drop_edge_rate + self._output_hidden_size = num_hidden + self._momentum = momentum + self._replace_rate = replace_rate + self._num_remasking = num_remasking + self._remask_method = remask_method + + self._token_rate = 1 - self._replace_rate + self._lam = lam + + assert num_hidden % nhead == 0 + assert num_hidden % nhead_out == 0 + if encoder_type in ("gat",): + enc_num_hidden = num_hidden // nhead + enc_nhead = nhead + else: + enc_num_hidden = num_hidden + enc_nhead = 1 + + dec_in_dim = num_hidden + dec_num_hidden = num_hidden // nhead if decoder_type in ("gat",) else num_hidden + + # build encoder + self.encoder = setup_module( + m_type=encoder_type, + enc_dec="encoding", + in_dim=in_dim, + num_hidden=enc_num_hidden, + out_dim=enc_num_hidden, + num_layers=num_layers, + nhead=enc_nhead, + nhead_out=enc_nhead, + concat_out=True, + activation=activation, + dropout=feat_drop, + attn_drop=attn_drop, + negative_slope=negative_slope, + residual=residual, + norm=norm, + ) + + self.decoder = setup_module( + m_type=decoder_type, + enc_dec="decoding", + in_dim=dec_in_dim, + num_hidden=dec_num_hidden, + out_dim=in_dim, + nhead_out=nhead_out, + num_layers=num_dec_layers, + nhead=nhead, + activation=activation, + dropout=feat_drop, + attn_drop=attn_drop, + negative_slope=negative_slope, + residual=residual, + norm=norm, + concat_out=True, + ) + + self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) + self.dec_mask_token = nn.Parameter(torch.zeros(1, num_hidden)) + + self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False) + #self.reset_parameters_for_token() + + + # * setup loss function + self.criterion = self.setup_loss_fn(loss_fn, alpha_l) + + self.projector = nn.Sequential( + nn.Linear(num_hidden, 256), + nn.PReLU(), + nn.Linear(256, num_hidden), + ) + self.projector_ema = nn.Sequential( + nn.Linear(num_hidden, 256), + nn.PReLU(), + nn.Linear(256, num_hidden), + ) + self.predictor = nn.Sequential( + nn.PReLU(), + nn.Linear(num_hidden, num_hidden) + ) + + self.encoder_ema = setup_module( + m_type=encoder_type, + enc_dec="encoding", + in_dim=in_dim, + num_hidden=enc_num_hidden, + out_dim=enc_num_hidden, + num_layers=num_layers, + nhead=enc_nhead, + nhead_out=enc_nhead, + concat_out=True, + activation=activation, + dropout=feat_drop, + attn_drop=attn_drop, + negative_slope=negative_slope, + residual=residual, + norm=norm, + ) + self.encoder_ema.load_state_dict(self.encoder.state_dict()) + self.projector_ema.load_state_dict(self.projector.state_dict()) + + for p in self.encoder_ema.parameters(): + p.requires_grad = False + p.detach_() + for p in self.projector_ema.parameters(): + p.requires_grad = False + p.detach_() + + self.print_num_parameters() + + def print_num_parameters(self): + num_encoder_params = [p.numel() for p in self.encoder.parameters() if p.requires_grad] + num_decoder_params = [p.numel() for p in self.decoder.parameters() if p.requires_grad] + num_params = [p.numel() for p in self.parameters() if p.requires_grad] + + print(f"num_encoder_params: {sum(num_encoder_params)}, num_decoder_params: {sum(num_decoder_params)}, num_params_in_total: {sum(num_params)}") + + def reset_parameters_for_token(self): + nn.init.xavier_normal_(self.enc_mask_token) + nn.init.xavier_normal_(self.dec_mask_token) + nn.init.xavier_normal_(self.encoder_to_decoder.weight, gain=1.414) + + @property + def output_hidden_dim(self): + return self._output_hidden_size + + def setup_loss_fn(self, loss_fn, alpha_l): + if loss_fn == "mse": + print(f"=== Use mse_loss ===") + criterion = nn.MSELoss() + elif loss_fn == "sce": + print(f"=== Use sce_loss and alpha_l={alpha_l} ===") + criterion = partial(sce_loss, alpha=alpha_l) + else: + raise NotImplementedError + return criterion + + def forward(self, g, x, targets=None, epoch=0, drop_g1=None, drop_g2=None): # ---- attribute reconstruction ---- + loss = self.mask_attr_prediction(g, x, targets, epoch, drop_g1, drop_g2) + + return loss + + def mask_attr_prediction(self, g, x, targets, epoch, drop_g1=None, drop_g2=None): + pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate) + use_g = drop_g1 if drop_g1 is not None else g + + enc_rep = self.encoder(use_g, use_x,) + + with torch.no_grad(): + drop_g2 = drop_g2 if drop_g2 is not None else g + latent_target = self.encoder_ema(drop_g2, x,) + if targets is not None: + latent_target = self.projector_ema(latent_target[targets]) + else: + latent_target = self.projector_ema(latent_target[keep_nodes]) + + if targets is not None: + latent_pred = self.projector(enc_rep[targets]) + latent_pred = self.predictor(latent_pred) + loss_latent = sce_loss(latent_pred, latent_target, 1) + else: + latent_pred = self.projector(enc_rep[keep_nodes]) + latent_pred = self.predictor(latent_pred) + loss_latent = sce_loss(latent_pred, latent_target, 1) + + # ---- attribute reconstruction ---- + origin_rep = self.encoder_to_decoder(enc_rep) + + loss_rec_all = 0 + if self._remask_method == "random": + for i in range(self._num_remasking): + rep = origin_rep.clone() + rep, remask_nodes, rekeep_nodes = self.random_remask(use_g, rep, self._remask_rate) + recon = self.decoder(pre_use_g, rep) + + x_init = x[mask_nodes] + x_rec = recon[mask_nodes] + loss_rec = self.criterion(x_init, x_rec) + loss_rec_all += loss_rec + loss_rec = loss_rec_all + elif self._remask_method == "fixed": + rep = self.fixed_remask(g, origin_rep, mask_nodes) + x_rec = self.decoder(pre_use_g, rep)[mask_nodes] + x_init = x[mask_nodes] + loss_rec = self.criterion(x_init, x_rec) + else: + raise NotImplementedError + + loss = loss_rec + self._lam * loss_latent + + if epoch >= self._delayed_ema_epoch: + self.ema_update() + return loss + + def ema_update(self): + def update(student, teacher): + with torch.no_grad(): + # m = momentum_schedule[it] # momentum parameter + m = self._momentum + for param_q, param_k in zip(student.parameters(), teacher.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + update(self.encoder, self.encoder_ema) + update(self.projector, self.projector_ema) + + def embed(self, g, x): + rep = self.encoder(g, x) + return rep + + def get_encoder(self): + #self.encoder.reset_classifier(out_size) + return self.encoder + + def reset_encoder(self, out_size): + self.encoder.reset_classifier(out_size) + + @property + def enc_params(self): + return self.encoder.parameters() + + @property + def dec_params(self): + return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) + + def output_grad(self): + grad_dict = {} + for n, p in self.named_parameters(): + if p.grad is not None: + grad_dict[n] = p.grad.abs().mean().item() + return grad_dict + + def encoding_mask_noise(self, g, x, mask_rate=0.3): + num_nodes = g.num_nodes + perm = torch.randperm(num_nodes, device=x.device) + num_mask_nodes = int(mask_rate * num_nodes) + + # exclude isolated nodes + # isolated_nodes = torch.where(g.in_degrees() <= 1)[0] + # mask_nodes = perm[: num_mask_nodes] + # mask_nodes = torch.index_fill(torch.full((num_nodes,), False, device=device), 0, mask_nodes, True) + # mask_nodes[isolated_nodes] = False + # keep_nodes = torch.where(~mask_nodes)[0] + # mask_nodes = torch.where(mask_nodes)[0] + # num_mask_nodes = mask_nodes.shape[0] + + # random masking + num_mask_nodes = int(mask_rate * num_nodes) + mask_nodes = perm[: num_mask_nodes] + keep_nodes = perm[num_mask_nodes: ] + + out_x = x.clone() + token_nodes = mask_nodes + out_x[mask_nodes] = 0.0 + + out_x[token_nodes] += self.enc_mask_token + use_g = g.clone() + + return use_g, out_x, (mask_nodes, keep_nodes) + + def random_remask(self,g,rep,remask_rate=0.5): + + num_nodes = g.num_nodes + perm = torch.randperm(num_nodes, device=rep.device) + num_remask_nodes = int(remask_rate * num_nodes) + remask_nodes = perm[: num_remask_nodes] + rekeep_nodes = perm[num_remask_nodes: ] + + rep = rep.clone() + rep[remask_nodes] = 0 + rep[remask_nodes] += self.dec_mask_token + + return rep, remask_nodes, rekeep_nodes + + def fixed_remask(self, g, rep, masked_nodes): + rep[masked_nodes] = 0 + return rep \ No newline at end of file diff --git a/examples/graphmae2/models/finetune.py b/examples/graphmae2/models/finetune.py new file mode 100644 index 00000000..a9055691 --- /dev/null +++ b/examples/graphmae2/models/finetune.py @@ -0,0 +1,388 @@ +import copy +import logging + +from tqdm import tqdm +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from datasets.lc_sampler import setup_eval_dataloder, setup_finetune_dataloder, LinearProbingDataLoader +from utils import accuracy, set_random_seed, show_occupied_memory, get_current_lr + +import wandb + + +def linear_probing_minibatch( + model, graph, + feats, ego_graph_nodes, labels, + lr_f, weight_decay_f, max_epoch_f, + device, batch_size=-1, shuffle=True): + logging.info("-- Linear Probing in downstream tasks ---") + train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes = ego_graph_nodes + num_train, num_val = len(train_ego_graph_nodes), len(val_ego_graph_nodes) + train_lbls, val_lbls, test_lbls = labels + # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c","ogbn-arxiv","ogbn-products"]: + # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c", "ogbn-arxiv", "ogbn-products"]: + eval_loader = setup_eval_dataloder("lc", graph, feats, train_ego_graph_nodes+val_ego_graph_nodes+test_ego_graph_nodes, 512) + + with torch.no_grad(): + model.eval() + embeddings = [] + + for batch in tqdm(eval_loader, desc="Infering..."): + batch_g, targets, _, node_idx = batch + batch_g = batch_g.to(device) + x = batch_g.x + targets = targets.to(device) + + batch_emb = model.embed(batch_g, x)[targets] + embeddings.append(batch_emb.cpu()) + embeddings = torch.cat(embeddings, dim=0) + + train_emb, val_emb, test_emb = embeddings[:num_train], embeddings[num_train:num_train+num_val], embeddings[num_train+num_val:] + + batch_size = 5120 + acc = [] + seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + for i,_ in enumerate(seeds): + print(f"####### Run seed {seeds[i]} for LinearProbing...") + set_random_seed(seeds[i]) + print(f"training sample:{len(train_emb)}") + test_acc = node_classification_linear_probing( + (train_emb, val_emb, test_emb), + (train_lbls, val_lbls, test_lbls), + lr_f, weight_decay_f, max_epoch_f, device, batch_size=batch_size, shuffle=shuffle) + acc.append(test_acc) + + print(f"# final_acc: {np.mean(acc):.4f}, std: {np.std(acc):.4f}") + + return np.mean(acc) + + + +class LogisticRegression(nn.Module): + def __init__(self, num_dim, num_class): + super().__init__() + self.linear = nn.Linear(num_dim, num_class) + + def forward(self, g, x, *args): + logits = self.linear(x) + return logits + + +def node_classification_linear_probing(embeddings, labels, lr, weight_decay, max_epoch, device, mute=False, batch_size=-1, shuffle=True): + criterion = torch.nn.CrossEntropyLoss() + + train_emb, val_emb, test_emb = embeddings + train_label, val_label, test_label = labels + train_label = train_label.to(torch.long) + val_label = val_label.to(torch.long) + test_label = test_label.to(torch.long) + + best_val_acc = 0 + best_val_epoch = 0 + best_model = None + + if not mute: + epoch_iter = tqdm(range(max_epoch)) + else: + epoch_iter = range(max_epoch) + + encoder = LogisticRegression(train_emb.shape[1], int(train_label.max().item() + 1)) + encoder = encoder.to(device) + optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay) + + if batch_size > 0: + train_loader = LinearProbingDataLoader(np.arange(len(train_emb)), train_emb, train_label, batch_size=batch_size, num_workers=4, persistent_workers=True, shuffle=shuffle) + # train_loader = DataLoader(np.arange(len(train_emb)), batch_size=batch_size, shuffle=False) + val_loader = LinearProbingDataLoader(np.arange(len(val_emb)), val_emb, val_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False) + test_loader = LinearProbingDataLoader(np.arange(len(test_emb)), test_emb, test_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False) + else: + train_loader = [np.arange(len(train_emb))] + val_loader = [np.arange(len(val_emb))] + test_loader = [np.arange(len(test_emb))] + + def eval_forward(loader, _label): + pred_all = [] + for batch_x, _ in loader: + batch_x = batch_x.to(device) + pred = encoder(None, batch_x) + pred_all.append(pred.cpu()) + pred = torch.cat(pred_all, dim=0) + acc = accuracy(pred, _label) + return acc + + for epoch in epoch_iter: + encoder.train() + + for batch_x, batch_label in train_loader: + batch_x = batch_x.to(device) + batch_label = batch_label.to(device) + pred = encoder(None, batch_x) + loss = criterion(pred, batch_label) + optimizer.zero_grad() + loss.backward() + # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3) + optimizer.step() + + with torch.no_grad(): + encoder.eval() + val_acc = eval_forward(val_loader, val_label) + + if val_acc >= best_val_acc: + best_val_acc = val_acc + best_val_epoch = epoch + best_model = copy.deepcopy(encoder) + + if not mute: + epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_acc:{val_acc:.4f}") + + best_model.eval() + encoder = best_model + with torch.no_grad(): + test_acc = eval_forward(test_loader, test_label) + if mute: + print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") + else: + print(f"--- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") + + return test_acc + + +def finetune( + model, + graph, + feats, + ego_graph_nodes, + labels, + split_idx, + lr_f, weight_decay_f, max_epoch_f, + use_scheduler, batch_size, + device, + logger=None, + full_graph_forward=False, +): + logging.info("-- Finetuning in downstream tasks ---") + train_egs, val_egs, test_egs = ego_graph_nodes + print(f"num of egos:{len(train_egs)},{len(val_egs)},{len(test_egs)}") + + print(graph.num_nodes()) + + train_nid = split_idx["train"].numpy() + val_nid = split_idx["valid"].numpy() + test_nid = split_idx["test"].numpy() + + train_lbls, val_lbls, test_lbls = [x.long() for x in labels] + print(f"num of labels:{len(train_lbls)},{len(val_lbls)},{len(test_lbls)}") + + num_classes = max(max(train_lbls.max().item(), val_lbls.max().item()), test_lbls.max().item()) + 1 + + model = model.get_encoder() + model.reset_classifier(int(num_classes)) + model = model.to(device) + criterion = torch.nn.CrossEntropyLoss() + + train_loader = setup_finetune_dataloder("lc", graph, feats, train_egs, train_lbls, batch_size=batch_size, shuffle=True) + val_loader = setup_finetune_dataloder("lc", graph, feats, val_egs, val_lbls, batch_size=batch_size, shuffle=False) + test_loader = setup_finetune_dataloder("lc", graph, feats, test_egs, test_lbls, batch_size=batch_size, shuffle=False) + + #optimizer = torch.optim.Adam(model.parameters(), lr=lr_f, weight_decay=weight_decay_f) + + optimizer = torch.optim.AdamW(model.parameters(), lr=lr_f, weight_decay=weight_decay_f) + + if use_scheduler and max_epoch_f > 0: + logging.info("Use schedular") + warmup_epochs = int(max_epoch_f * 0.1) + # scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch_f) ) * 0.5 + scheduler = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else ( 1 + np.cos((epoch - warmup_epochs) * np.pi / (max_epoch_f - warmup_epochs))) * 0.5 + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) + else: + scheduler = None + + def eval_with_lc(model, loader): + pred_counts = [] + model.eval() + epoch_iter = tqdm(loader) + with torch.no_grad(): + for batch in epoch_iter: + batch_g, targets, batch_lbls, node_idx = batch + batch_g = batch_g.to(device) + batch_lbls = batch_lbls.to(device) + x = batch_g.x + + prediction = model(batch_g, x) + prediction = prediction[targets] + pred_counts.append((prediction.argmax(1) == batch_lbls)) + pred_counts = torch.cat(pred_counts) + acc = pred_counts.float().sum() / pred_counts.shape[0] + return acc + + def eval_full_prop(model, g, nfeat, val_nid, test_nid, batch_size, device): + model.eval() + + with torch.no_grad(): + pred = model.inference(g, nfeat, batch_size, device) + model.train() + + return accuracy(pred[val_nid], val_lbls.cpu()), accuracy(pred[test_nid], test_lbls.cpu()) + + best_val_acc = 0 + best_model = None + best_epoch = 0 + test_acc = 0 + early_stop_cnt = 0 + + for epoch in range(max_epoch_f): + if epoch == 0: + scheduler.step() + continue + if early_stop_cnt >= 10: + break + epoch_iter = tqdm(train_loader) + losses = [] + model.train() + + for batch_g, targets, batch_lbls, node_idx in epoch_iter: + batch_g = batch_g.to(device) + targets = targets.to(device) + batch_lbls = batch_lbls.to(device) + x = batch_g.x + + prediction = model(batch_g, x) + prediction = prediction[targets] + loss = criterion(prediction, batch_lbls) + + optimizer.zero_grad() + loss.backward() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + metrics = {"finetune_loss": loss} + wandb.log(metrics) + + if logger is not None: + logger.log(metrics) + + epoch_iter.set_description(f"Finetuning | train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB") + losses.append(loss.item()) + + if scheduler is not None: + scheduler.step() + + if not full_graph_forward: + if epoch > 0: + val_acc = eval_with_lc(model, val_loader) + _test_acc = 0 + else: + if epoch > 0 and epoch % 1 == 0: + val_acc, _test_acc = eval_full_prop(model, graph, feats, val_nid, test_nid, 10000, device) + model = model.to(device) + + print('val Acc {:.4f}'.format(val_acc)) + if val_acc > best_val_acc: + best_model = copy.deepcopy(model) + best_val_acc = val_acc + test_acc = _test_acc + best_epoch = epoch + early_stop_cnt = 0 + else: + early_stop_cnt += 1 + + if not full_graph_forward: + print("val Acc {:.4f}, Best Val Acc {:.4f}".format(val_acc, best_val_acc)) + else: + print("Val Acc {:.4f}, Best Val Acc {:.4f} Test Acc {:.4f}".format(val_acc, best_val_acc, test_acc)) + + metrics = {"epoch_val_acc": val_acc, + "test_acc": test_acc, + "epoch": epoch, + "lr_f": get_current_lr(optimizer)} + + wandb.log(metrics) + if logger is not None: + logger.log(metrics) + print(f"# Finetuning - Epoch {epoch} | train_loss: {np.mean(losses):.4f}, ValAcc: {val_acc:.4f}, TestAcc: {test_acc:.4f}, Memory: {show_occupied_memory():.2f} MB") + + model = best_model + if not full_graph_forward: + test_acc = eval_with_lc(test_loader) + + print(f"Finetune | TestAcc: {test_acc:.4f} from Epoch {best_epoch}") + return test_acc + + +def linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob=True, mute=False): + model.eval() + with torch.no_grad(): + x = model.embed(graph.to(device), x.to(device)) + in_feat = x.shape[1] + encoder = LogisticRegression(in_feat, num_classes) + + num_finetune_params = [p.numel() for p in encoder.parameters() if p.requires_grad] + if not mute: + print(f"num parameters for finetuning: {sum(num_finetune_params)}") + + encoder.to(device) + optimizer_f = torch.optim.Adam(encoder.parameters(), lr=lr_f, weight_decay=weight_decay_f) + final_acc, estp_acc = _linear_probing_full_batch(encoder, graph, x, optimizer_f, max_epoch_f, device, mute) + return final_acc, estp_acc + + +def _linear_probing_full_batch(model, graph, feat, optimizer, max_epoch, device, mute=False): + criterion = torch.nn.CrossEntropyLoss() + + graph = graph.to(device) + x = feat.to(device) + + train_mask = graph.train_mask + val_mask = graph.val_mask + test_mask = graph.test_mask + labels = graph.y + + best_val_acc = 0 + best_val_epoch = 0 + best_model = None + + if not mute: + epoch_iter = tqdm(range(max_epoch)) + else: + epoch_iter = range(max_epoch) + + for epoch in epoch_iter: + model.train() + out = model(graph, x) + loss = criterion(out[train_mask], labels[train_mask]) + optimizer.zero_grad() + loss.backward() + # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3) + optimizer.step() + + with torch.no_grad(): + model.eval() + pred = model(graph, x) + val_acc = accuracy(pred[val_mask], labels[val_mask]) + val_loss = criterion(pred[val_mask], labels[val_mask]) + test_acc = accuracy(pred[test_mask], labels[test_mask]) + test_loss = criterion(pred[test_mask], labels[test_mask]) + + if val_acc >= best_val_acc: + best_val_acc = val_acc + best_val_epoch = epoch + best_model = copy.deepcopy(model) + + if not mute: + epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_loss:{val_loss.item(): .4f}, val_acc:{val_acc}, test_loss:{test_loss.item(): .4f}, test_acc:{test_acc: .4f}") + + best_model.eval() + with torch.no_grad(): + pred = best_model(graph, x) + estp_test_acc = accuracy(pred[test_mask], labels[test_mask]) + if mute: + print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") + else: + print(f"--- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") + + return test_acc, estp_test_acc diff --git a/examples/graphmae2/models/gat.py b/examples/graphmae2/models/gat.py new file mode 100644 index 00000000..28eb5e29 --- /dev/null +++ b/examples/graphmae2/models/gat.py @@ -0,0 +1,198 @@ +import tqdm + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cogdl.utils import ( + EdgeSoftmax, + MultiHeadSpMM, + get_activation, + get_norm_layer, +) + +from utils import create_activation, create_norm + + +class GAT(nn.Module): + def __init__(self, + in_dim, + num_hidden, + out_dim, + num_layers, + nhead, + nhead_out, + activation, + feat_drop, + attn_drop, + negative_slope, + residual, + norm, + concat_out=False, + encoding=False + ): + super(GAT, self).__init__() + self.out_dim = out_dim + self.num_heads = nhead + self.num_layers = num_layers + self.gat_layers = nn.ModuleList() + self.activation = activation + self.concat_out = concat_out + + last_activation = activation if encoding else None + last_residual = (encoding and residual) + last_norm = norm if encoding else None + self.feat_drop = nn.Dropout(feat_drop) + + if num_layers == 1: + self.gat_layers.append( + GATLayer( + in_dim, out_dim, nhead=nhead_out, + attn_drop=attn_drop, + alpha=negative_slope, + residual=last_residual, + norm=last_norm, + activation=last_activation, + ) + ) + else: + # input projection (no residual) + self.gat_layers.append( + GATLayer( + in_dim, num_hidden, nhead, + attn_drop=attn_drop, + alpha=negative_slope, + residual=residual, + activation=activation, + norm=norm + ) + ) + # hidden layers + for l in range(1, num_layers - 1): + # due to multi-head, the in_dim = num_hidden * num_heads + self.gat_layers.append( + GATLayer( + num_hidden * nhead, num_hidden, nhead=nhead, + attn_drop=attn_drop, + alpha=negative_slope, + residual=residual, + activation=activation, + norm=norm + ) + ) + # output projection + self.gat_layers.append( + GATLayer( + num_hidden * nhead, out_dim, + nhead=nhead_out, + attn_drop=attn_drop, + alpha=negative_slope, + residual=last_residual, + activation=last_activation, + norm=last_norm + ) + ) + self.head = nn.Identity() + + def forward(self, g, inputs, return_hidden=False): + h = inputs + hidden_list = [] + for l in range(self.num_layers): + h = self.feat_drop(h) + h = self.gat_layers[l](g, h) + hidden_list.append(h) + # h = h.flatten(1) + # output projection + if return_hidden: + return self.head(h), hidden_list + else: + return self.head(h) + + def reset_classifier(self, num_classes): + self.head = nn.Linear(self.num_heads * self.out_dim, num_classes) + + +class GATLayer(nn.Module): + """ + Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__( + self, in_feats, out_feats, nhead=1, alpha=0.2, attn_drop=0.5, activation=None, residual=False, norm=None + ): + super(GATLayer, self).__init__() + self.in_features = in_feats + self.out_features = out_feats + self.alpha = alpha + self.nhead = nhead + + self.W = nn.Parameter(torch.FloatTensor(in_feats, out_feats * nhead)) + self.bias = nn.Parameter(torch.FloatTensor(out_feats * nhead,)) + + self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_feats))) + self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_feats))) + + self.edge_softmax = EdgeSoftmax() + self.mhspmm = MultiHeadSpMM() + + self.dropout = nn.Dropout(attn_drop) + self.leakyrelu = nn.LeakyReLU(self.alpha) + self.act = None if activation is None else get_activation(activation) + self.norm = None if norm is None else get_norm_layer(norm, out_feats * nhead) + + if residual: + self.residual = nn.Linear(in_feats, out_feats * nhead) + else: + self.register_buffer("residual", None) + self.reset_parameters() + + def reset_parameters(self): + # def reset(tensor): + # stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) + # tensor.data.uniform_(-stdv, stdv) + + # reset(self.a_l) + # reset(self.a_r) + # reset(self.W) + + gain = nn.init.calculate_gain('relu') + nn.init.xavier_normal_(self.W, gain=gain) + nn.init.xavier_normal_(self.a_l, gain=gain) + nn.init.xavier_normal_(self.a_r, gain=gain) + if self.bias is not None: + nn.init.constant_(self.bias, 0) + if isinstance(self.residual, nn.Linear): + nn.init.xavier_normal_(self.residual.weight, gain=gain) + + def forward(self, graph, x): + h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features) + h[torch.isnan(h)] = 0.0 + + row, col = graph.edge_index + # Self-attention on the nodes - Shared attention mechanism + h_l = (self.a_l * h).sum(dim=-1) + h_r = (self.a_r * h).sum(dim=-1) + + # edge_attention: E * H + edge_attention = self.leakyrelu(h_l[row] + h_r[col]) + edge_attention = self.edge_softmax(graph, edge_attention) + edge_attention = self.dropout(edge_attention) + + out = self.mhspmm(graph, edge_attention, h) + + if self.bias is not None: + out = out + self.bias + + if self.residual: + res = self.residual(x) + out += res + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + def __repr__(self): + return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")" diff --git a/examples/graphmae2/models/gcn.py b/examples/graphmae2/models/gcn.py new file mode 100644 index 00000000..23292e72 --- /dev/null +++ b/examples/graphmae2/models/gcn.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cogdl.layers import GCNLayer + +from utils import create_activation, create_norm + + +class GCN(nn.Module): + def __init__(self, + in_dim, + num_hidden, + out_dim, + num_layers, + dropout, + activation, + residual, + norm, + encoding=False + ): + super(GCN, self).__init__() + self.out_dim = out_dim + self.num_layers = num_layers + self.gcn_layers = nn.ModuleList() + self.activation = activation + self.dropout = dropout + + last_activation = activation if encoding else None + last_residual = encoding and residual + last_norm = norm if encoding else None + + if num_layers == 1: + self.gcn_layers.append(GCNLayer( + in_dim, out_dim, residual=last_residual, norm=last_norm, activation=last_activation)) + else: + self.gcn_layers.append(GCNLayer( + in_dim, num_hidden, residual=residual, norm=norm, activation=create_activation(activation))) + for l in range(1, num_layers - 1): + # due to multi-head, the in_dim = num_hidden * num_heads + self.gcn_layers.append(GCNLayer( + num_hidden, num_hidden, residual=residual, norm=norm, activation=create_activation(activation))) + self.gcn_layers.append(GCNLayer( + num_hidden, out_dim, residual=last_residual, activation=last_activation, norm=last_norm)) + + self.head = nn.Identity() + + def forward(self, g, inputs, return_hidden=False): + h = inputs + hidden_list = [] + for l in range(self.num_layers): + h = F.dropout(h, p=self.dropout, training=self.training) + h = self.gcn_layers[l](g, h) + hidden_list.append(h) + if return_hidden: + return self.head(h), hidden_list + else: + return self.head(h) + + def reset_classifier(self, num_classes): + self.head = nn.Linear(self.out_dim, num_classes) diff --git a/examples/graphmae2/models/loss_func.py b/examples/graphmae2/models/loss_func.py new file mode 100644 index 00000000..2105c64f --- /dev/null +++ b/examples/graphmae2/models/loss_func.py @@ -0,0 +1,164 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def auc_pair_loss(x, y, z): + x = F.normalize(x, p=2, dim=-1) + y = F.normalize(y, p=2, dim=-1) + z = F.normalize(z, p=2, dim=-1) + + sim = (x * y).sum(dim=-1) + dissim = (x * z).sum(dim=-1) + loss = (1 - sim + dissim).mean() + # loss = (1 - sim).mean() + return loss + + +def sce_loss(x, y, alpha=3): + x = F.normalize(x, p=2, dim=-1) + y = F.normalize(y, p=2, dim=-1) + + # loss = - (x * y).sum(dim=-1) + # loss = (x_h - y_h).norm(dim=1).pow(alpha) + + loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) + + loss = loss.mean() + return loss + + +class DINOLoss(nn.Module): + def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, + warmup_teacher_temp_epochs, nepochs, student_temp=0.1, + center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + + def forward(self, student_output, teacher_output, epoch): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + student_out = student_output / self.student_temp + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) + teacher_out = teacher_out.detach() + + loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1) + loss = loss.mean() + self.update_center(teacher_output) + return loss + + # total_loss = 0 + # n_loss_terms = 0 + # for iq, q in enumerate(teacher_out): + # for v in range(len(student_out)): + # if v == iq: + # # we skip cases where student and teacher operate on the same view + # continue + # loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) + # total_loss += loss.mean() + # n_loss_terms += 1 + # total_loss /= n_loss_terms + # self.update_center(teacher_output) + # return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + """ + Update center used for teacher output. + """ + batch_center = torch.mean(teacher_output, dim=0, keepdim=True) + + # ema update + self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) + + +class MLPHead(nn.Module): + def __init__(self, hidden_size, out_dim, num_layers=2, bottleneck_dim=256): + super().__init__() + self._num_layers = num_layers + self.mlp = nn.ModuleList() + for i in range(num_layers): + if i == num_layers - 1: + self.mlp.append( + nn.Linear(hidden_size, bottleneck_dim) + ) + else: + self.mlp.append(nn.Linear(hidden_size, hidden_size)) + # self.mlp.append(nn.LayerNorm(hidden_size)) + self.mlp.append(nn.PReLU()) + + self.apply(self._init_weights) + # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + # self.last_layer.weight_g.requires_grad = False + # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + num_layers = len(self.mlp) + for i, layer in enumerate(self.mlp): + x = layer(x) + + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) \ No newline at end of file diff --git a/examples/graphmae2/run_fullbatch.sh b/examples/graphmae2/run_fullbatch.sh new file mode 100644 index 00000000..44009bea --- /dev/null +++ b/examples/graphmae2/run_fullbatch.sh @@ -0,0 +1,37 @@ +dataset=$1 +device=$2 + +[ -z "${dataset}" ] && dataset="cora" +[ -z "${device}" ] && device=0 + +CUDA_VISIBLE_DEVICES=$device \ + python main_full_batch.py \ + --device 0 \ + --dataset $dataset \ + --mask_method "random" \ + --remask_method "fixed" \ + --mask_rate 0.5 \ + --in_drop 0.2 \ + --attn_drop 0.1 \ + --num_layers 2 \ + --num_dec_layers 1 \ + --num_hidden 256 \ + --num_heads 4 \ + --num_out_heads 1 \ + --encoder "gat" \ + --decoder "gat" \ + --max_epoch 1000 \ + --max_epoch_f 300 \ + --lr 0.001 \ + --weight_decay 0.04 \ + --lr_f 0.005 \ + --weight_decay_f 1e-4 \ + --activation "prelu" \ + --loss_fn "sce" \ + --alpha_l 3 \ + --scheduler \ + --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 \ + --lam 0.5 \ + --linear_prob \ + --data_dir "./dataset" \ + --use_cfg diff --git a/examples/graphmae2/run_minibatch.sh b/examples/graphmae2/run_minibatch.sh new file mode 100644 index 00000000..b211b5bc --- /dev/null +++ b/examples/graphmae2/run_minibatch.sh @@ -0,0 +1,50 @@ +dataset=$1 +device=$2 + +[ -z "${dataset}" ] && dataset="ogbn-arxiv" +[ -z "${device}" ] && device=0 + +CUDA_VISIBLE_DEVICES=$device \ + python main_large.py \ + --device 0 \ + --dataset $dataset \ + --mask_type "mask" \ + --mask_rate 0.5 \ + --remask_rate 0.5 \ + --num_remasking 3 \ + --in_drop 0.2 \ + --attn_drop 0.2 \ + --num_layers 4 \ + --num_dec_layers 1 \ + --num_hidden 1024 \ + --num_heads 4 \ + --num_out_heads 1 \ + --encoder "gat" \ + --decoder "gat" \ + --max_epoch 60 \ + --max_epoch_f 1000 \ + --lr 0.002 \ + --weight_decay 0.04 \ + --lr_f 0.005 \ + --weight_decay_f 1e-4 \ + --activation "prelu" \ + --optimizer "adamw" \ + --drop_edge_rate 0.5 \ + --loss_fn "sce" \ + --alpha_l 4 \ + --mask_method "random" \ + --scheduler \ + --batch_size 512 \ + --batch_size_f 256 \ + --seeds 0 \ + --residual \ + --norm "layernorm" \ + --sampling_method "lc" \ + --label_rate 1.0 \ + --lam 1.0 \ + --momentum 0.996 \ + --linear_prob \ + --use_cfg \ + --ego_graph_file_path "./lc_ego_graphs/${dataset}-lc-ego-graphs-256.pt" \ + --data_dir "./dataset" \ + # --logging diff --git a/examples/graphmae2/utils.py b/examples/graphmae2/utils.py new file mode 100644 index 00000000..487a602e --- /dev/null +++ b/examples/graphmae2/utils.py @@ -0,0 +1,317 @@ +import os +import argparse +import random +import psutil +import yaml +import logging +from functools import partial +from tensorboardX import SummaryWriter +import wandb + +import numpy as np +import torch +import torch.nn as nn +from torch import optim as optim + +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt + + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) + + +def accuracy(y_pred, y_true): + y_true = y_true.squeeze().long() + preds = y_pred.max(1)[1].type_as(y_true) + correct = preds.eq(y_true).double() + correct = correct.sum().item() + return correct / len(y_true) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.determinstic = True + + +def get_current_lr(optimizer): + return optimizer.state_dict()["param_groups"][0]["lr"] + + +def build_args(): + parser = argparse.ArgumentParser(description="GAT") + parser.add_argument("--seeds", type=int, nargs="+", default=[0]) + parser.add_argument("--dataset", type=str, default="cora") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--max_epoch", type=int, default=500, + help="number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=-1) + + parser.add_argument("--num_heads", type=int, default=4, + help="number of hidden attention heads") + parser.add_argument("--num_out_heads", type=int, default=1, + help="number of output attention heads") + parser.add_argument("--num_layers", type=int, default=2, + help="number of hidden layers") + parser.add_argument("--num_dec_layers", type=int, default=1) + parser.add_argument("--num_remasking", type=int, default=3) + parser.add_argument("--num_hidden", type=int, default=512, + help="number of hidden units") + parser.add_argument("--residual", action="store_true", default=False, + help="use residual connection") + parser.add_argument("--in_drop", type=float, default=.2, + help="input feature dropout") + parser.add_argument("--attn_drop", type=float, default=.1, + help="attention dropout") + parser.add_argument("--norm", type=str, default=None) + parser.add_argument("--lr", type=float, default=0.001, + help="learning rate") + parser.add_argument("--weight_decay", type=float, default=0, + help="weight decay") + parser.add_argument("--negative_slope", type=float, default=0.2, + help="the negative slope of leaky relu") + parser.add_argument("--activation", type=str, default="prelu") + parser.add_argument("--mask_rate", type=float, default=0.5) + parser.add_argument("--remask_rate", type=float, default=0.5) + parser.add_argument("--remask_method", type=str, default="random") + parser.add_argument("--mask_type", type=str, default="mask", + help="`mask` or `drop`") + parser.add_argument("--mask_method", type=str, default="random") + parser.add_argument("--drop_edge_rate", type=float, default=0.0) + parser.add_argument("--drop_edge_rate_f", type=float, default=0.0) + + parser.add_argument("--encoder", type=str, default="gat") + parser.add_argument("--decoder", type=str, default="gat") + parser.add_argument("--loss_fn", type=str, default="sce") + parser.add_argument("--alpha_l", type=float, default=2) + parser.add_argument("--optimizer", type=str, default="adam") + + parser.add_argument("--max_epoch_f", type=int, default=300) + parser.add_argument("--lr_f", type=float, default=0.01) + parser.add_argument("--weight_decay_f", type=float, default=0.0) + parser.add_argument("--linear_prob", action="store_true", default=False) + + + parser.add_argument("--no_pretrain", action="store_true") + parser.add_argument("--load_model", action="store_true") + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--logging", action="store_true") + parser.add_argument("--scheduler", action="store_true", default=False) + + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--batch_size_f", type=int, default=128) + parser.add_argument("--sampling_method", type=str, default="saint", help="sampling method, `lc` or `saint`") + + parser.add_argument("--label_rate", type=float, default=1.0) + parser.add_argument("--ego_graph_file_path", type=str, default=None) + parser.add_argument("--data_dir", type=str, default="data") + + parser.add_argument("--lam", type=float, default=1.0) + parser.add_argument("--full_graph_forward", action="store_true", default=False) + parser.add_argument("--delayed_ema_epoch", type=int, default=0) + parser.add_argument("--replace_rate", type=float, default=0.0) + parser.add_argument("--momentum", type=float, default=0.996) + + args = parser.parse_args() + return args + +def create_activation(name): + if name == "relu": + return nn.ReLU() + elif name == "gelu": + return nn.GELU() + elif name == "prelu": + return nn.PReLU() + elif name == "selu": + return nn.SELU() + elif name == "elu": + return nn.ELU() + elif name == "silu": + return nn.SiLU() + elif name is None: + return nn.Identity() + else: + raise NotImplementedError(f"{name} is not implemented.") + + +def identity_norm(x): + def func(x): + return x + return func + +def create_norm(name): + if name == "layernorm": + return nn.LayerNorm + elif name == "batchnorm": + return nn.BatchNorm1d + elif name == "identity": + return identity_norm + else: + # print("Identity norm") + return None + + +def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): + opt_lower = opt.lower() + parameters = model.parameters() + opt_args = dict(lr=lr, weight_decay=weight_decay) + + opt_split = opt_lower.split("_") + opt_lower = opt_split[-1] + + if opt_lower == "adam": + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == "adamw": + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == "adadelta": + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == "sgd": + opt_args["momentum"] = 0.9 + return optim.SGD(parameters, **opt_args) + else: + raise NotImplementedError("Invalid optimizer") + + return optimizer + + +def show_occupied_memory(): + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024**2 + + +# ------------------- +def mask_edge(graph, mask_prob): + E = graph.num_edges + + mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) + masks = torch.bernoulli(1 - mask_rates) + mask_idx = masks.nonzero().squeeze(1) + return mask_idx.to(graph.device) + + +def drop_edge(graph, drop_rate, return_edges=False): + if drop_rate <= 0: + return graph + + n_node = graph.num_nodes + edge_mask = mask_edge(graph, drop_rate) + src = graph.edge_index[0] + dst = graph.edge_index[1] + + nsrc = src[edge_mask] + ndst = dst[edge_mask] + + ng = Graph(edge_index=(nsrc, ndst), num_nodes=n_node) + ng.add_remaining_self_loops() + + dsrc = src[~edge_mask] + ddst = dst[~edge_mask] + + if return_edges: + return ng, (dsrc, ddst) + return ng + +def visualize(x, y, method="tsne"): + if torch.is_tensor(x): + x = x.cpu().numpy() + + if torch.is_tensor(y): + y = y.cpu().numpy() + + if method == "tsne": + func = TSNE(n_components=2) + else: + func = PCA(n_components=2) + out = func.fit_transform(x) + plt.scatter(out[:, 0], out[:, 1], c=y) + plt.savefig("vis.png") + + +def load_best_configs(args): + dataset_name = args.dataset + config_path = os.path.join("configs", f"{dataset_name}.yaml") + with open(config_path, "r") as f: + configs = yaml.load(f, yaml.FullLoader) + + for k, v in configs.items(): + if "lr" in k or "weight_decay" in k: + v = float(v) + setattr(args, k, v) + logging.info(f"----- Using best configs from {config_path} -----") + + return args + + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + + +# ------ logging ------ + +class TBLogger(object): + def __init__(self, log_path="./logging_data", name="run"): + super(TBLogger, self).__init__() + + if not os.path.exists(log_path): + os.makedirs(log_path, exist_ok=True) + + self.last_step = 0 + self.log_path = log_path + raw_name = os.path.join(log_path, name) + name = raw_name + for i in range(1000): + name = raw_name + str(f"_{i}") + if not os.path.exists(name): + break + self.writer = SummaryWriter(logdir=name) + + def note(self, metrics, step=None): + if step is None: + step = self.last_step + for key, value in metrics.items(): + self.writer.add_scalar(key, value, step) + self.last_step = step + + def finish(self): + self.writer.close() + + +class WandbLogger(object): + def __init__(self, log_path, project, args): + self.log_path = log_path + self.project = project + self.args = args + self.last_step = 0 + self.project = project + self.start() + + def start(self): + self.run = wandb.init(config=self.args, project=self.project) + + def log(self, metrics, step=None): + if not hasattr(self, "run"): + self.start() + if step is None: + step = self.last_step + self.run.log(metrics) + self.last_step = step + + def finish(self): + self.run.finish()