From b7b580591747623dadf776c53e082695c2b84073 Mon Sep 17 00:00:00 2001 From: RiemanGraph Date: Wed, 28 Jun 2023 22:39:21 +0800 Subject: [PATCH] Upload --- backbone.py | 156 ++++++++++++++++++++ data_factory.py | 138 ++++++++++++++++++ exp/__init__.py | 0 exp/exp.py | 236 +++++++++++++++++++++++++++++++ layers.py | 170 ++++++++++++++++++++++ logger.py | 78 ++++++++++ main.py | 97 +++++++++++++ models.py | 207 +++++++++++++++++++++++++++ scripts/cls/GAT_chameleon_nc.sh | 43 ++++++ scripts/cls/GAT_citeseer_nc.sh | 43 ++++++ scripts/cls/GAT_cora_nc.sh | 43 ++++++ scripts/cls/GAT_squirrel_nc.sh | 43 ++++++ scripts/cls/GCN_chameleon_nc.sh | 43 ++++++ scripts/cls/GCN_citeseer_nc.sh | 43 ++++++ scripts/cls/GCN_cora_nc.sh | 43 ++++++ scripts/cls/GCN_squirrel_nc.sh | 43 ++++++ scripts/cls/SAGE_chameleon_nc.sh | 43 ++++++ scripts/cls/SAGE_citeseer_nc.sh | 43 ++++++ scripts/cls/SAGE_cora_nc.sh | 43 ++++++ scripts/cls/SAGE_squirrel_nc.sh | 43 ++++++ scripts/cluster/citeseer_clu.sh | 43 ++++++ scripts/cluster/cora_clu.sh | 37 +++++ utils.py | 233 ++++++++++++++++++++++++++++++ 23 files changed, 1911 insertions(+) create mode 100644 backbone.py create mode 100644 data_factory.py create mode 100644 exp/__init__.py create mode 100644 exp/exp.py create mode 100644 layers.py create mode 100644 logger.py create mode 100644 main.py create mode 100644 models.py create mode 100644 scripts/cls/GAT_chameleon_nc.sh create mode 100644 scripts/cls/GAT_citeseer_nc.sh create mode 100644 scripts/cls/GAT_cora_nc.sh create mode 100644 scripts/cls/GAT_squirrel_nc.sh create mode 100644 scripts/cls/GCN_chameleon_nc.sh create mode 100644 scripts/cls/GCN_citeseer_nc.sh create mode 100644 scripts/cls/GCN_cora_nc.sh create mode 100644 scripts/cls/GCN_squirrel_nc.sh create mode 100644 scripts/cls/SAGE_chameleon_nc.sh create mode 100644 scripts/cls/SAGE_citeseer_nc.sh create mode 100644 scripts/cls/SAGE_cora_nc.sh create mode 100644 scripts/cls/SAGE_squirrel_nc.sh create mode 100644 scripts/cluster/citeseer_clu.sh create mode 100644 scripts/cluster/cora_clu.sh create mode 100644 utils.py diff --git a/backbone.py b/backbone.py new file mode 100644 index 0000000..ed86e2f --- /dev/null +++ b/backbone.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers import GraphConvolution, GraphAttentionLayer, SpGraphAttentionLayer +from torch_geometric.nn import GCNConv, SAGEConv +from utils import graph_top_K + + +class GCN(nn.Module): + def __init__(self, in_features, hidden_features, out_features, n_layers, dropout_node=0.5, dropout_edge=0.25): + super(GCN, self).__init__() + self.conv_layers = nn.ModuleList() + self.conv_layers.append(GraphConvolution(in_features, hidden_features)) + for _ in range(n_layers - 2): + self.conv_layers.append(GraphConvolution(hidden_features, hidden_features)) + self.conv_layers.append(GraphConvolution(hidden_features, out_features)) + self.dropout_node = nn.Dropout(dropout_node) + self.dropout_edge = nn.Dropout(dropout_edge) + + def forward(self, x, adj): + adj = self.dropout_edge(adj) + for layer in self.conv_layers[: -1]: + x = layer(x, adj) + x = self.dropout_node(F.relu(x)) + x = self.conv_layers[-1](x, adj) + return x + + +class GAT(nn.Module): + def __init__(self, in_features, hidden_features, out_features, dropout_node=0.5, dropout_edge=0.25, alpha=0.2, + n_heads=4): + """Dense version of GAT.""" + super(GAT, self).__init__() + self.dropout = dropout_node + self.dropout_edge = nn.Dropout(dropout_edge) + + self.attentions = [ + GraphAttentionLayer(in_features, hidden_features, dropout=dropout_node, alpha=alpha, concat=True) for _ in + range(n_heads)] + for i, attention in enumerate(self.attentions): + self.add_module('attention_{}'.format(i), attention) + + self.out_att = GraphAttentionLayer(hidden_features * n_heads, out_features, dropout=dropout_node, alpha=alpha, + concat=False) + + def forward(self, x, adj): + adj = self.dropout_edge(adj) + x = F.dropout(x, self.dropout, training=self.training) + x = torch.cat([att(x, adj) for att in self.attentions], dim=1) + x = F.dropout(x, self.dropout, training=self.training) + x = self.out_att(x, adj) + return x + + +class SpGAT(nn.Module): + def __init__(self, in_features, hidden_features, out_features, dropout_node=0.5, dropout_edge=0.25, alpha=0.2, + n_heads=4): + """Sparse version of GAT.""" + super(SpGAT, self).__init__() + self.dropout = dropout_node + self.dropout_edge = nn.Dropout(dropout_edge) + + self.attentions = [SpGraphAttentionLayer(in_features, + hidden_features, + dropout=dropout_node, + alpha=alpha, + concat=True) for _ in range(n_heads)] + for i, attention in enumerate(self.attentions): + self.add_module('attention_{}'.format(i), attention) + + self.out_att = SpGraphAttentionLayer(hidden_features * n_heads, + out_features, + dropout=dropout_node, + alpha=alpha, + concat=False) + + def forward(self, x, adj): + adj = self.dropout_edge(adj) + x = F.dropout(x, self.dropout, training=self.training) + x = torch.cat([att(x, adj) for att in self.attentions], dim=1) + x = F.dropout(x, self.dropout, training=self.training) + x = self.out_att(x, adj) + return x + + +class GraphSAGE(nn.Module): + def __init__(self, in_features, hidden_features, out_features, n_layers, dropout_node=0.5, dropout_edge=0.25): + super().__init__() + self.conv_layers = nn.ModuleList() + self.conv_layers.append(SAGEConv(in_features, hidden_features)) + for _ in range(n_layers - 2): + self.conv_layers.append(SAGEConv(hidden_features, hidden_features)) + self.conv_layers.append(SAGEConv(hidden_features, out_features)) + self.dropout_node = nn.Dropout(dropout_node) + self.dropout_edge = nn.Dropout(dropout_edge) + + def forward(self, x, adj): + adj = self.dropout_edge(adj) + edge_index = adj.nonzero().t() + for layer in self.conv_layers[: -1]: + x = layer(x, edge_index) + x = self.dropout_node(F.relu(x)) + x = self.conv_layers[-1](x, edge_index) + return x + + # class GraphEncoder(nn.Module): + + +# def __init__(self, n_layers, in_features, hidden_features, embed_features, dropout, dropout_edge): +# super(GraphEncoder, self).__init__() +# self.dropout_node = nn.Dropout(dropout) +# self.dropout_adj = nn.Dropout(dropout_edge) + +# self.encoder_layers = nn.ModuleList() +# self.encoder_layers.append(GraphConvolution(in_features, hidden_features)) +# for _ in range(n_layers - 2): +# self.encoder_layers.append(GraphConvolution(hidden_features, hidden_features)) +# self.encoder_layers.append(GraphConvolution(hidden_features, embed_features)) + +# def forward(self, x, adj): +# adj = self.dropout_adj(adj) +# for layer in self.encoder_layers[:-1]: +# x = self.dropout_node(F.relu(layer(x, adj))) +# x = self.encoder_layers[-1](x, adj) +# return x + + +class GraphEncoder(nn.Module): + def __init__(self, backbone, n_layers, in_features, hidden_features, embed_features, + dropout, dropout_edge, alpha=0.2, n_heads=4, topk=30): + super(GraphEncoder, self).__init__() + if backbone == 'gcn': + model = GCN(in_features, hidden_features, embed_features, n_layers, + dropout, dropout_edge) + elif backbone == 'sage': + model = GraphSAGE(in_features, hidden_features, embed_features, n_layers, + dropout, dropout_edge) + elif backbone == 'gat': + model = GAT(in_features, hidden_features, embed_features, + dropout, dropout_edge, + alpha, n_heads) + elif backbone == 'spgat': + model = SpGAT(in_features, hidden_features, embed_features, + dropout, dropout_edge, + alpha, n_heads) + else: + raise NotImplementedError + + self.backbone = backbone + self.model = model + self.topk = topk + + def forward(self, x, adj): + if self.backbone in ['gat', 'spgat', 'sage']: + adj = graph_top_K(adj, self.topk) + return self.model(x, adj) \ No newline at end of file diff --git a/data_factory.py b/data_factory.py new file mode 100644 index 0000000..164fd23 --- /dev/null +++ b/data_factory.py @@ -0,0 +1,138 @@ +import torch +import networkx as nx +from GraphRicciCurvature.OllivierRicci import OllivierRicci +import numpy as np +from torch_geometric.datasets import Planetoid, WikipediaNetwork, Actor +from torch_geometric.utils import to_networkx +from ogb.nodeproppred import PygNodePropPredDataset +from sklearn.datasets import load_wine, load_breast_cancer, load_digits, fetch_20newsgroups +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.feature_extraction.text import TfidfTransformer +from sklearn.preprocessing import scale +from sklearn.model_selection import train_test_split + + +def get_mask(idx, length): + """Create mask. + """ + mask = torch.zeros(length, dtype=torch.bool) + mask[idx] = 1 + return mask + + +def load_graph_data(root: str, data_name: str, split='public', **kwargs): + if data_name in ['Cora', 'Citeseer', 'Pubmed']: + dataset = Planetoid(root=root, name=data_name, split=split) + train_mask, val_mask, test_mask = dataset.data.train_mask, dataset.data.val_mask, dataset.data.test_mask + elif data_name == 'ogbn-arxiv': + dataset = PygNodePropPredDataset(root=root, name=data_name) + mask = dataset.get_idx_split() + train_mask, val_mask, test_mask = mask.values() + elif data_name in ['actor', 'chameleon', 'squirrel']: + if data_name == 'actor': + path = root + f'/{data_name}' + dataset = Actor(root=path) + else: + dataset = WikipediaNetwork(root=root, name=data_name) + num_nodes = dataset.data.x.shape[0] + idx_train = [] + for j in range(dataset.num_classes): + idx_train.extend([i for i, x in enumerate(dataset.data.y) if x == j][:20]) + idx_val = np.arange(num_nodes - 1500, num_nodes - 1000) + idx_test = np.arange(num_nodes - 1000, num_nodes) + label_len = dataset.data.y.shape[0] + train_mask, val_mask, test_mask = get_mask(idx_train, label_len), get_mask(idx_val, label_len), get_mask(idx_test, label_len) + else: + raise NotImplementedError + + print(dataset.data) + G = to_networkx(dataset.data) + features = dataset.data.x + num_features = dataset.num_features + labels = dataset.data.y + adjacency = torch.from_numpy(nx.adjacency_matrix(G).toarray()) + num_classes = dataset.num_classes + return features, num_features, labels, adjacency, (train_mask, val_mask, test_mask), num_classes + + +def load_non_graph_data(root: str, data_name: str, seed=100, **kwargs): + features = None + if data_name == 'wine': + dataset = load_wine() + n_train = 10 + n_val = 10 + n_es = 10 + is_scale = True + elif data_name == 'digits': + dataset = load_digits() + n_train = 50 + n_val = 50 + n_es = 50 + is_scale = False + elif data_name == 'cancer': + dataset = load_breast_cancer() + n_train = 10 + n_val = 10 + n_es = 10 + is_scale = True + elif data_name == '20news10': + n_train = 100 + n_val = 100 + n_es = 100 + is_scale = False + categories = ['alt.atheism', + 'comp.sys.ibm.pc.hardware', + 'misc.forsale', + 'rec.autos', + 'rec.sport.hockey', + 'sci.crypt', + 'sci.electronics', + 'sci.med', + 'sci.space', + 'talk.politics.guns'] + dataset = fetch_20newsgroups(subset='all', categories=categories) + vectorizer = CountVectorizer(stop_words='english', min_df=0.05) + X_counts = vectorizer.fit_transform(dataset.data).toarray() + transformer = TfidfTransformer(smooth_idf=False) + features = transformer.fit_transform(X_counts).todense() + else: + raise NotImplementedError + + if data_name != '20news10': + if is_scale: + features = scale(dataset.data) + else: + features = dataset.data + features = torch.from_numpy(features) + y = dataset.target + n, num_features = features.shape + train, test, y_train, y_test = train_test_split(np.arange(n), y, random_state=seed, + train_size=n_train + n_val + n_es, + test_size=n - n_train - n_val - n_es, + stratify=y) + train, es, y_train, y_es = train_test_split(train, y_train, random_state=seed, + train_size=n_train + n_val, test_size=n_es, + stratify=y_train) + train, val, y_train, y_val = train_test_split(train, y_train, random_state=seed, + train_size=n_train, test_size=n_val, + stratify=y_train) + + train_mask = torch.zeros(n, dtype=bool) + train_mask[train] = True + val_mask = torch.zeros(n, dtype=bool) + val_mask[val] = True + es_mask = torch.zeros(n, dtype=bool) + es_mask[es] = True + test_mask = torch.zeros(n, dtype=bool) + test_mask[test] = True + labels = torch.from_numpy(y) + num_classes = len(dataset.target_names) + return features, num_features, labels, torch.zeros(n, n), (train_mask, val_mask, test_mask), num_classes + + +def load_data(args, **kwargs): + if args.is_graph: + data_getter = load_graph_data + else: + data_getter = load_non_graph_data + return data_getter(args.root_path, args.dataset) \ No newline at end of file diff --git a/exp/__init__.py b/exp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/exp/exp.py b/exp/exp.py new file mode 100644 index 0000000..0988fea --- /dev/null +++ b/exp/exp.py @@ -0,0 +1,236 @@ +import torch +import numpy as np +import torch.nn.functional as F +import torch.nn as nn + +from models import DeepRicci, RiemannianEmbeds +from backbone import GCN, SpGAT, GAT, GraphSAGE +from utils import cal_accuracy, cluster_metrics, cal_F1, graph_top_K +from data_factory import load_data +from sklearn.cluster import KMeans +from logger import create_logger +from geoopt.optim import RiemannianAdam + + +class Exp: + def __init__(self, configs): + self.configs = configs + if self.configs.use_gpu and torch.cuda.is_available(): + self.device = torch.device('cuda:0') + else: + self.device = torch.device('cpu') + + def load_data(self): + features, in_features, labels, adj, masks, n_classes = load_data(self.configs) + return features, in_features, labels, adj, masks, n_classes + + def cal_cls_loss(self, model, mask, adj, features, labels): + out = model(features, adj) + loss = F.cross_entropy(out[mask], labels[mask]) + acc = cal_accuracy(out[mask], labels[mask]) + weighted_f1, macro_f1 = cal_F1(out[mask].detach().cpu(), labels[mask].detach().cpu()) + return loss, acc, weighted_f1, macro_f1 + + def select_backbone_model(self, in_features, n_classes): + if self.configs.backbone == 'gcn': + model = GCN(in_features, self.configs.hidden_features_cls, n_classes, self.configs.n_layers_cls, + self.configs.dropout_node_cls, self.configs.dropout_edge_cls) + elif self.configs.backbone == 'sage': + model = GraphSAGE(in_features, self.configs.hidden_features_cls, n_classes, self.configs.n_layers_cls, + self.configs.dropout_node_cls, self.configs.dropout_edge_cls) + elif self.configs.backbone == 'gat': + model = GAT(in_features, self.configs.hidden_features_cls, n_classes, + self.configs.dropout_node_cls, self.configs.dropout_edge_cls, + alpha=self.configs.alpha_gat, n_heads=self.configs.n_heads_gat) + elif self.configs.backbone == 'spgat': + model = SpGAT(in_features, self.configs.hidden_features_cls, n_classes, + self.configs.dropout_node_cls, self.configs.dropout_edge_cls, + alpha=self.configs.alpha_gat, n_heads=self.configs.n_heads_gat) + else: + raise NotImplementedError + return model + + def evaluate_adj_by_cls(self, adj, features, in_features, labels, n_classes, masks): + """masks = (train, val, test)""" + device = self.device + model = self.select_backbone_model(in_features, n_classes).to(device) + optimizer = torch.optim.Adam(model.parameters(), betas=self.configs.betas, lr=self.configs.lr_cls, + weight_decay=self.configs.w_decay_cls) + + best_acc = 0. + best_weighted_f1, best_macro_f1 = 0., 0. + early_stop_count = 0 + best_model = None + + for epoch in range(1, self.configs.epochs_cls + 1): + model.train() + loss, acc, weighted_f1, macro_f1 = self.cal_cls_loss(model, masks[0], adj, features, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + # print(f"Epoch {epoch}: train_loss={loss.item()}, train_accuracy={acc}") + + if epoch % 10 == 0: + model.eval() + val_loss, acc, weighted_f1, macro_f1 = self.cal_cls_loss(model, masks[1], adj, features, labels) + # print(f"Epoch {epoch}: val_loss={val_loss.item()}, val_accuracy={acc}") + if acc > best_acc: + early_stop_count = 0 + best_acc = acc + best_weighted_f1, best_macro_f1 = weighted_f1, macro_f1 + best_model = model + else: + early_stop_count += 1 + if early_stop_count >= self.configs.patience_cls: + break + best_model.eval() + test_loss, test_acc, test_weighted_f1, test_macro_f1 = self.cal_cls_loss(best_model, masks[2], adj, features, + labels) + return best_acc, test_acc, best_model, test_weighted_f1, test_macro_f1 + + def train(self): + logger = create_logger(self.configs.log_path) + device = self.device + features, in_features, labels, init_adj, masks, n_classes = self.load_data() + init_adj = init_adj + torch.eye(features.shape[0]) + + if self.configs.downstream_task == 'clustering': + self.configs.exp_iters = 1 + + best_vals = [] + best_tests = [] + best_weighted_f1s = [] + best_macro_f1s = [] + for exp_iter in range(self.configs.exp_iters): + logger.info(f"\ntrain iters {exp_iter}") + Riemann_embeds_getter = RiemannianEmbeds(features.shape[0], self.configs.d_riemann, + self.configs.cur_h, self.configs.cur_s).to(device) + model = DeepRicci(backbone=self.configs.backbone, n_layers=self.configs.n_layers, + n_heads=self.configs.n_heads, + in_features=in_features, d_riemann=self.configs.d_riemann, d_hyla=self.configs.d_hyla, + embed_features=self.configs.embed_features, hidden_features=self.configs.hidden_features, + dropout=self.configs.dropout_node, dropout_edge=self.configs.dropout_edge, + init_adj=init_adj, + scale=self.configs.scale, cur_h=self.configs.cur_h, cur_s=self.configs.cur_s, + gamma=self.configs.gamma, top_k=self.configs.topk, + act=self.configs.act_func, wq=self.configs.wq_type, act_adj=self.configs.act_adj, + eps=self.configs.eps, temperature=self.configs.temperature, + backbone_topk=self.configs.backbone_topk, + s=self.configs.s, r=self.configs.r, alpha_gat=self.configs.alpha_gat, + n_heads_gat=self.configs.n_heads_gat, device=device).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=self.configs.lr, weight_decay=self.configs.w_decay) + r_optim = RiemannianAdam(Riemann_embeds_getter.parameters(), lr=self.configs.lr, + weight_decay=self.configs.w_decay, + stabilize=100) + + init_adj = init_adj.to(device) + features = features.to(device) + labels = labels.to(device) + + logger.info("--------------------------Training Start-------------------------") + best_val = 0. + best_val_test = 0. + best_weighted_f1 = 0. + best_macro_f1 = 0. + best_cluster = {'acc': 0, 'nmi': 0, 'f1': 0, 'ari': 0} + best_cluster_result = {} + n_cluster_trials = self.configs.n_cluster_trials + for epoch in range(1, self.configs.epochs + 1): + model.train() + Riemann_embeds_getter.train() + update_adj_ratio = self.configs.update_adj_ratio + if update_adj_ratio > 0.: + update_adj_ratio = 1 - np.sin(((epoch / self.configs.epochs) * np.pi) / 2) * update_adj_ratio + else: + update_adj_ratio = 0. + _, new_feature, adj, loss = model(features, init_adj, update_adj_ratio, Riemann_embeds_getter) + + r_optim.zero_grad() + optimizer.zero_grad() + loss.backward() + r_optim.step() + optimizer.step() + + if (1 - self.configs.tau) and (self.configs.iterations == 0 or epoch % self.configs.iterations == 0): + init_adj = self.configs.tau * init_adj + (1 - self.configs.tau) * adj.detach() + + logger.info(f"Epoch {epoch}: train_loss={loss.item()}") + + if epoch % self.configs.eval_freq == 0: + logger.info("---------------Evaluation Start-----------------") + model.eval() + Riemann_embeds_getter.eval() + if self.configs.downstream_task == 'classification': + edge = adj.detach() + if self.configs.backbone in ['gat', 'spgat', 'sage']: + edge = graph_top_K(edge, self.configs.topk) + val_acc, test_acc, _, test_weighted_f1, test_macro_f1 = self.evaluate_adj_by_cls(edge, + new_feature.detach(), + in_features + 2 * self.configs.d_riemann, + labels, + n_classes, + masks) + logger.info( + f"Epoch {epoch}: val_accuracy={val_acc.item() * 100: .2f}%, test_accuracy={test_acc * 100: .2f}%") + logger.info( + f"\t\t weighted_f1={test_weighted_f1.item() * 100: .2f}%, macro_f1={test_macro_f1 * 100: .2f}%") + logger.info("-------------------------------------------------------------------------") + if val_acc > best_val: + best_val = val_acc + best_val_test = test_acc + best_weighted_f1 = test_weighted_f1 + best_macro_f1 = test_macro_f1 + + elif self.configs.downstream_task == 'clustering': + embedding, _, _, _ = model(features, init_adj, update_adj_ratio, Riemann_embeds_getter) + embedding = embedding.detach().cpu().numpy() + acc, nmi, f1, ari = [], [], [], [] + for step in range(n_cluster_trials): + kmeans = KMeans(n_clusters=n_classes, random_state=step) + predicts = kmeans.fit_predict(embedding) + metrics = cluster_metrics(labels.cpu().numpy(), predicts) + acc_, nmi_, f1_, ari_ = metrics.evaluateFromLabel() + acc.append(acc_) + nmi.append(nmi_) + f1.append(f1_) + ari.append(ari_) + acc, nmi, f1, ari = np.mean(acc), np.mean(nmi), np.mean(f1), np.mean(ari) + if acc > best_cluster['acc']: + best_cluster['acc'] = acc + best_cluster_result['acc'] = [acc, nmi, f1, ari] + if nmi > best_cluster['nmi']: + best_cluster['nmi'] = nmi + best_cluster_result['nmi'] = [acc, nmi, f1, ari] + if f1 > best_cluster['f1']: + best_cluster['f1'] = f1 + best_cluster_result['f1'] = [acc, nmi, f1, ari] + if ari > best_cluster['ari']: + best_cluster['ari'] = ari + best_cluster_result['ari'] = [acc, nmi, f1, ari] + logger.info(f"Epoch {epoch}: ACC: {acc}, NMI: {nmi}, F1: {f1}, ARI: {ari}") + logger.info("-------------------------------------------------------------------------") + else: + raise NotImplementedError + + if self.configs.downstream_task == 'classification': + logger.info( + f"best_val_accuracy={best_val.item() * 100: .2f}%, best_test_accuracy={best_val_test * 100: .2f}%") + logger.info( + f"weighted_f1={test_weighted_f1.item() * 100: .2f}%, macro_f1={test_macro_f1 * 100: .2f}%") + best_vals.append(best_val.item()) + best_tests.append(best_val_test.item()) + best_weighted_f1s.append(best_weighted_f1.item()) + best_macro_f1s.append(best_macro_f1.item()) + + if self.configs.downstream_task == 'clustering': + for k, result in best_cluster_result.items(): + acc, nmi, f1, ari = result + logger.info(f"Best Results according to {k}: ACC: {acc}, NMI: {nmi}, F1: {f1}, ARI: {ari} \n") + if self.configs.downstream_task == 'classification': + logger.info(f"best valid results: {np.max(best_vals)}") + logger.info(f"best test results: {np.max(best_tests)}") + logger.info(f"valid results: {np.mean(best_vals)}~{np.std(best_vals)}") + logger.info(f"test results: {np.mean(best_tests)}~{np.std(best_tests)}") + logger.info(f"test weighted-f1: {np.mean(best_weighted_f1s)}~{np.std(best_weighted_f1s)}") + logger.info(f"test macro-f1: {np.mean(best_macro_f1s)}~{np.std(best_macro_f1s)}") diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..385029f --- /dev/null +++ b/layers.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GraphConvolution(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super(GraphConvolution, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x, adj): + x = self.linear(x) + x = torch.matmul(adj, x) + return x + + +class SparseDropout(nn.Module): + def __init__(self, prob=0.5): + super(SparseDropout, self).__init__() + self.prob = prob + + def forward(self, x: torch.Tensor): + mask = (torch.rand(x._values().shape) + (1 - self.prob)).floor().bool() + idx = x._indices()[:, mask] + value = x._values()[mask] / (1 - self.prob) + return torch.sparse_coo_tensor(idx, value, x.shape) + + +class GraphAttentionLayer(nn.Module): + """ + Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__(self, in_features, out_features, dropout, alpha, concat=True): + super(GraphAttentionLayer, self).__init__() + self.dropout = dropout + self.in_features = in_features + self.out_features = out_features + self.alpha = alpha + self.concat = concat + + self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) + nn.init.xavier_uniform_(self.W.data, gain=1.414) + self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + + self.leakyrelu = nn.LeakyReLU(self.alpha) + + def forward(self, h, adj): + Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features) + e = self._prepare_attentional_mechanism_input(Wh) + + zero_vec = -9e15 * torch.ones_like(e) + attention = torch.where(adj > 0, e, zero_vec) + attention = F.softmax(attention, dim=1) + attention = F.dropout(attention, self.dropout, training=self.training) + h_prime = torch.matmul(attention, Wh) + + if self.concat: + return F.elu(h_prime) + else: + return h_prime + + def _prepare_attentional_mechanism_input(self, Wh): + # Wh.shape (N, out_feature) + # self.a.shape (2 * out_feature, 1) + # Wh1&2.shape (N, 1) + # e.shape (N, N) + Wh1 = torch.matmul(Wh, self.a[:self.out_features, :]) + Wh2 = torch.matmul(Wh, self.a[self.out_features:, :]) + # broadcast add + e = Wh1 + Wh2.T + return self.leakyrelu(e) + + def __repr__(self): + return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' + + +class SpecialSpmmFunction(torch.autograd.Function): + """Special function for only sparse region backpropataion layer.""" + + @staticmethod + def forward(ctx, indices, values, shape, b): + assert indices.requires_grad == False + a = torch.sparse_coo_tensor(indices, values, shape) + ctx.save_for_backward(a, b) + ctx.N = shape[0] + return torch.matmul(a, b) + + @staticmethod + def backward(ctx, grad_output): + a, b = ctx.saved_tensors + grad_values = grad_b = None + if ctx.needs_input_grad[1]: + grad_a_dense = grad_output.matmul(b.t()) + edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] + grad_values = grad_a_dense.view(-1)[edge_idx] + if ctx.needs_input_grad[3]: + grad_b = a.t().matmul(grad_output) + return None, grad_values, None, grad_b + + +class SpecialSpmm(nn.Module): + def forward(self, indices, values, shape, b): + return SpecialSpmmFunction.apply(indices, values, shape, b) + + +class SpGraphAttentionLayer(nn.Module): + """ + Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__(self, in_features, out_features, dropout, alpha, concat=True): + super(SpGraphAttentionLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.alpha = alpha + self.concat = concat + + self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) + nn.init.xavier_normal_(self.W.data, gain=1.414) + + self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) + nn.init.xavier_normal_(self.a.data, gain=1.414) + + self.dropout = nn.Dropout(dropout) + self.leakyrelu = nn.LeakyReLU(self.alpha) + self.special_spmm = SpecialSpmm() + + def forward(self, input, adj): + dv = 'cuda' if input.is_cuda else 'cpu' + + N = input.size()[0] + edge = adj.nonzero().t() + + h = torch.mm(input, self.W) + # h: N x out + assert not torch.isnan(h).any() + + # Self-attention on the nodes - Shared attention mechanism + edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() + # edge: 2*D x E + + edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) + assert not torch.isnan(edge_e).any() + # edge_e: E + + e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv)) + # e_rowsum: N x 1 + + edge_e = self.dropout(edge_e) + # edge_e: E + + h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) + assert not torch.isnan(h_prime).any() + # h_prime: N x out + + h_prime = h_prime.div(e_rowsum) + # h_prime: N x out + assert not torch.isnan(h_prime).any() + + if self.concat: + # if this layer is not last layer, + return F.elu(h_prime) + else: + # if this layer is last layer, + return h_prime + + def __repr__(self): + return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..0b567e2 --- /dev/null +++ b/logger.py @@ -0,0 +1,78 @@ +import logging +import time +from datetime import timedelta + + +class DotDict(dict): + """dict.key access""" + def __getattr__(*args): # nested + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def create_logger(filepath, colored=False, debug=False): + log_formatter = LogFormatter(colored=colored) + + # create file handler and set level + if filepath is not None: + file_handler = logging.FileHandler(filepath, "a") + file_handler.setLevel(logging.INFO) + if debug: + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(log_formatter) + + # create console handler and set level to info + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(log_formatter) + + # create logger and set level + logger = logging.getLogger() + logger.handlers = [] + logger.setLevel(logging.DEBUG) + logger.propagate = False + if filepath is not None: + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + # reset logger elapsed time + def reset_time(): + log_formatter.start_time = time.time() + logger.reset_time = reset_time + + return logger + + +class LogFormatter: + def __init__(self, colored=False): + self.colored = colored + self.start_time = time.time() + + def format(self, record): + BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) + RESET_SEQ = "\033[0m" + COLOR_SEQ = "\033[1;%dm" + + COLORS = { + 'WARNING': GREEN, + 'INFO': WHITE, + 'DEBUG': BLUE, + 'CRITICAL': YELLOW, + 'ERROR': RED + } + elapsed_seconds = round(record.created - self.start_time) + levelname = record.levelname + if self.colored: + levelname = COLOR_SEQ % ( + 30 + COLORS[record.levelname]) + record.levelname + RESET_SEQ + + prefix = "%s - %s - %s" % ( + levelname, + time.strftime('%x %X'), + timedelta(seconds=elapsed_seconds) + ) + message = record.getMessage() + message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) + return "%s - %s" % (prefix, message) if message else '' \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..d2a0a94 --- /dev/null +++ b/main.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import os +import random +import argparse +from exp.exp import Exp +from logger import create_logger + + +seed = 3047 +random.seed(seed) +torch.manual_seed(seed) +np.random.seed(seed) + +parser = argparse.ArgumentParser(description='DeepRicci') + +# Experiment settings +parser.add_argument('--downstream_task', type=str, default='classification', + choices=['classification', 'clustering']) +parser.add_argument('--dataset', type=str, default='Cora', + choices=['Cora', 'Citeseer', 'chameleon', 'squirrel']) +parser.add_argument('--is_graph', type=bool, default=True) +parser.add_argument('--root_path', type=str, default='./datasets') +parser.add_argument('--eval_freq', type=int, default=50) +parser.add_argument('--exp_iters', type=int, default=5) +parser.add_argument('--version', type=str, default="run") +parser.add_argument('--log_path', type=str, default="./results/v2302152230/cls_Cora.log") + +# Deep Ricci Graph Contrastive Learning Module +parser.add_argument('--backbone', type=str, default='spgat', choices=['gcn', 'spgat', 'gat', 'sage']) +parser.add_argument('--backbone_topk', type=int, default=8) +parser.add_argument('--epochs', type=int, default=1000) +parser.add_argument('--hidden_features', type=int, default=512) +parser.add_argument('--embed_features', type=int, default=32, help='dimensions of graph embedding') +parser.add_argument('--n_layers', type=int, default=2) +parser.add_argument('--dropout_node', type=float, default=0.5) +parser.add_argument('--dropout_edge', type=float, default=0.75) +parser.add_argument('--lr', type=float, default=0.005) +parser.add_argument('--w_decay', type=float, default=0.) +parser.add_argument('--update_adj_ratio', type=float, default=0.1) +parser.add_argument('--act_adj', type=str, default='elu', choices=['relu', 'elu', 'gumbel']) +parser.add_argument('--act_func', type=str, default='tanh', choices=['relu', 'tanh']) +parser.add_argument('--wq_type', type=str, default='mlp', choices=['linear', 'mlp']) +parser.add_argument('--n_heads', type=int, default=8, help='number of attention heads') +parser.add_argument('--d_riemann', type=int, default=32, help='dimension of Riemannian embedding') +parser.add_argument('--d_hyla', type=int, default=256, help='dimension of Riemannian embedding') +parser.add_argument('--scale', type=float, default=0.1, help='scale for sampling eigenvalues') +parser.add_argument('--cur_h', type=float, default=-1., help='curvature of hyperbolic') +parser.add_argument('--cur_s', type=float, default=1., help='curvature of sphere') +parser.add_argument('--s', type=float, default=1., help='for Fermi-Dirac decoder') +parser.add_argument('--r', type=float, default=2., help='Fermi-Dirac decoder') +parser.add_argument('--gamma', type=float, default=5., help='coefficient for structural loss') +parser.add_argument('--eps', type=float, default=None, help='threshold') +parser.add_argument('--topk', type=int, default=60, help='select topk numbers') +parser.add_argument('--temperature', type=float, default=0.2, help='temperature of NT-Xent loss') + +# Evaluation NetWork for Classification +parser.add_argument('--alpha_gat', type=float, default=0.2) +parser.add_argument('--n_heads_gat', type=int, default=8, help='number of attention heads of gat') +parser.add_argument('--hidden_features_cls', type=int, default=8) +parser.add_argument('--dropout_node_cls', type=float, default=0.6) +parser.add_argument('--dropout_edge_cls', type=float, default=0.6) +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999)) +parser.add_argument('--n_layers_cls', type=int, default=2) +parser.add_argument('--lr_cls', type=float, default=0.005) +parser.add_argument('--w_decay_cls', type=float, default=5e-4) +parser.add_argument('--epochs_cls', type=int, default=200) +parser.add_argument('--patience_cls', type=int, default=10) +parser.add_argument('--save_path_cls', type=str, default='./checkpoints/cls.pth') + +# Evaluation NetWork for Clustering +parser.add_argument('--n_cluster_trials', type=int, default=5) + +# Structure Bootstrapping +parser.add_argument('--tau', type=float, default=0.9999) +parser.add_argument('--iterations', type=int, default=0) + +# GPU +parser.add_argument('--use_gpu', action='store_false', help='use gpu') +parser.add_argument('--gpu', type=int, default=0, help='gpu') +parser.add_argument('--devices', type=str, default='0,1', help='device ids of multile gpus') + +configs = parser.parse_args() +log_path = f"./results/{configs.version}/{configs.downstream_task}_{configs.backbone}_{configs.dataset}.log" +configs.log_path = log_path +if not os.path.exists(f"./results"): + os.mkdir("./results") +if not os.path.exists(f"./results/{configs.version}"): + os.mkdir(f"./results/{configs.version}") +print(f"Log path: {configs.log_path}") +logger = create_logger(configs.log_path) +logger.info(configs) + +exp = Exp(configs) +exp.train() +torch.cuda.empty_cache() + diff --git a/models.py b/models.py new file mode 100644 index 0000000..31a85e0 --- /dev/null +++ b/models.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from backbone import GraphEncoder +import numpy as np +from utils import normalize, graph_top_K, adjacency2index, graph_threshold +from utils import gumbel_sigmoid +from geoopt.manifolds.stereographic.math import artan_k +from geoopt.manifolds.stereographic import StereographicExact +from geoopt import ManifoldTensor +from geoopt import ManifoldParameter +from geoopt.manifolds.stereographic.math import logmap0 + + +class RiemannianEmbeds(nn.Module): + def __init__(self, num_nodes, d_riemann, cur_h=-1., cur_s=1.): + super(RiemannianEmbeds, self).__init__() + self.hyperbolic = StereographicExact(k=cur_h) + self.sphere = StereographicExact(k=cur_s) + self.embeds_hyperbolic = ManifoldParameter(ManifoldTensor(torch.empty(num_nodes, d_riemann), + manifold=self.hyperbolic)) # N, D_r + self.init_weights(self.embeds_hyperbolic) + + self.embeds_sphere = ManifoldParameter(ManifoldTensor(torch.empty(num_nodes, d_riemann), + manifold=self.sphere)) # N, D_r + self.init_weights(self.embeds_sphere) + + def init_weights(self, w, scale=1e-4): + w.data.uniform_(-scale, scale) + w_norm = w.data.norm(p=2, dim=-1, keepdim=True) + w.data = w.data / w_norm * w.manifold.radius * 0.9 * torch.rand(1) + + def normalize(self, x, manifold): + x_norm = x.norm(p=2, dim=-1, keepdim=True) + x = x / x_norm * 0.9 * torch.rand(1).to(x.device) * manifold.radius + return x + + def forward(self): + embeds_hyperbolic = self.normalize(self.embeds_hyperbolic, self.hyperbolic) + embeds_sphere = self.normalize(self.embeds_sphere, self.sphere) + return embeds_hyperbolic, embeds_sphere + + +class DeepRicci(nn.Module): + def __init__(self, backbone, n_layers, n_heads, in_features, d_riemann, + d_hyla, hidden_features, embed_features, dropout, + dropout_edge, init_adj=None, scale=0.1, cur_h=-1., cur_s=1., + s=1, r=2, gamma=1., top_k=30, eps=None, + alpha=0.5, backbone_topk=30, act='relu', wq='linear', act_adj='gumbel', + temperature=0.1, alpha_gat=0.2, n_heads_gat=8, device=torch.device('cuda')): + super(DeepRicci, self).__init__() + self.n_heads = n_heads + self.act_adj = act_adj + if wq == 'linear': + self.w_q = nn.Linear(3 * d_hyla, embed_features * n_heads) + else: + activation = nn.ReLU() if act == 'relu' else nn.Tanh() + self.w_q = nn.Sequential( + nn.Linear(3 * d_hyla, hidden_features), + activation, + nn.Linear(hidden_features, embed_features * n_heads) + ) + self.scale = scale + self.alpha = alpha + self.s = s + self.r = r + + pre = torch.randn(d_hyla, in_features).to(device) + self.boundary_matrix_eu = pre / torch.norm(pre, dim=-1, keepdim=True) # D_h, in_feats + self.bias_eu = 2 * torch.pi * torch.rand(d_hyla).to(device) # n, D_h + + pre = torch.randn(d_hyla, d_riemann).to(device) + self.boundary_matrix_h = pre / torch.norm(pre, dim=-1, keepdim=True) # D_h, D_r + self.eigen_values_h = torch.randn(d_hyla).to(device) * self.scale # n, D_h + self.bias_h = 2 * torch.pi * torch.rand(d_hyla).to(device) # n, D_h + + pre = torch.randn(d_hyla, d_riemann).to(device) + self.boundary_matrix_s = pre / torch.norm(pre, dim=-1, keepdim=True) # D_h, D_r + self.eigen_values_s = torch.randn(d_hyla).to(device) * self.scale # n, D_h + self.bias_s = 2 * torch.pi * torch.rand(d_hyla).to(device) # n, D_h + + self.encoder = GraphEncoder(backbone, n_layers, in_features, hidden_features, d_hyla, dropout, dropout_edge, + alpha_gat, n_heads_gat, backbone_topk) + self.lipschitz_func = nn.Linear(3 * d_hyla, 1) + self.init_adj = init_adj.to(device) + self.gamma = gamma + self.cur_h = cur_h + self.cur_s = cur_s + self.top_k = top_k + self.eps = eps + self.temperature = temperature + + def forward(self, feature, A, ratio, riemann_embeds_getter): + embeds_hyperbolic, embeds_sphere = riemann_embeds_getter() + + feature_euclidean = np.sqrt(2) * torch.cos(torch.matmul(feature, self.boundary_matrix_eu.t()) + self.bias_eu) + feature_hyperbolic = \ + self.cal_laplacian_features(embeds_hyperbolic, self.cur_h, self.boundary_matrix_h, self.eigen_values_h, + self.bias_h) + feature_sphere = \ + self.cal_laplacian_features(embeds_sphere, self.cur_s, self.boundary_matrix_s, self.eigen_values_s, + self.bias_s) + product_features = torch.concat([feature_euclidean, feature_hyperbolic, feature_sphere], dim=-1) + + learned_A = self.learn_adjacency(product_features) + learned_A_normed = (1 - ratio) * normalize(learned_A, "sym") + ratio * normalize(A, "sym") + learned_A_rownormed = (1 - ratio) * normalize(learned_A, "row") + ratio * normalize(A, "row") + + z_E = self.encoder(feature, learned_A_normed) + product_z = torch.concat([z_E, embeds_hyperbolic, embeds_sphere], dim=-1) + + info_loss = self.cal_cl_loss(torch.concat([z_E, feature_hyperbolic], dim=-1), + torch.concat([z_E, feature_sphere], dim=-1)) + + structure_loss = self.cal_structure_loss(product_features, learned_A_rownormed, self.init_adj) + loss = info_loss + self.gamma * structure_loss + new_feature = torch.concat([feature, embeds_hyperbolic, embeds_sphere], dim=-1) + return product_z, new_feature, learned_A_normed, loss + + def dist_to_horocycle(self, z, k, boundary_matrix): + boundary_matrix = boundary_matrix / torch.norm(boundary_matrix, dim=-1, keepdim=True) + div = 1 - torch.matmul(z, boundary_matrix.t()) + dist = (torch.matmul(z, boundary_matrix.t()) - torch.norm(z, dim=-1, keepdim=True) ** 2) / div + P = artan_k(dist, torch.tensor(k)) + return P + + def cal_laplacian_features(self, z, k, boundary_matrix, eigen_values, bias): + horocycle_dist = self.dist_to_horocycle(z, k, boundary_matrix) + n = z.shape[-1] + laplacian_features = torch.exp((n - 1) * horocycle_dist / 2) * \ + torch.cos(eigen_values * horocycle_dist + bias) + return laplacian_features + + def multi_head_attention(self, q, k): + """ + q: [H, N, D] + k: [H, N, D] + """ + N = q.shape[0] + M = k.shape[0] + H = self.n_heads + q = q.reshape(N, -1, H).permute(2, 0, 1) + k = k.reshape(M, -1, H).permute(2, 0, 1) + score = torch.einsum('hnd, hmd->hnm', q, k).mean(0) + score = torch.softmax(score, dim=-1) + return score + + def post_process(self, A, k=None, eps=None): + # assert (k is not None) or (eps is not None) + A = graph_top_K(A, k) if k is not None else graph_threshold(A, eps) + return A + + def learn_adjacency(self, x): + q = self.w_q(x) + learned_A = self.multi_head_attention(q, q) + learned_A = self.post_process(learned_A, self.top_k, self.eps) + learned_A = (learned_A + learned_A.t()) / 2 + if self.act_adj == 'gumbel': + learned_A = torch.sigmoid(learned_A) + learned_A = gumbel_sigmoid(learned_A) + elif self.act_adj == 'relu': + learned_A = torch.relu(learned_A) + elif self.act_adj == 'elu': + learned_A = F.elu(6 * (learned_A - 1)) + 1 + return learned_A + + def compute_wasserstein_dist(self, x, adj, edge_index): + """_summary_ + + Args: + x (_type_): node features + adj (_type_): row normalized adjacency matrix + # """ + src_idx = edge_index[0] + tgt_idx = edge_index[1] + f = self.lipschitz_func(x) + f_x_src = f[src_idx] + f_x_tgt = f[tgt_idx] + L_bar = torch.eye(x.shape[0]).to(x.device) - (1 - 1 / self.alpha) * adj + L_weight = L_bar[src_idx, tgt_idx] + w_dist = self.alpha * L_weight * (f_x_src - f_x_tgt).reshape(-1) + return w_dist.clip(min=0.) + + def compute_ricci_flow(self, x, adjacency, edge_index): + """adjacency: row normalized""" + w_dist_t = self.compute_wasserstein_dist(x, adjacency, edge_index) + return torch.sigmoid((self.r - w_dist_t) / self.s) + + def cal_cl_loss(self, x1, x2): + norm1 = x1.norm(dim=-1) + norm2 = x2.norm(dim=-1) + sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', norm1, norm2) + sim_matrix = torch.exp(sim_matrix / self.temperature) + pos_sim = sim_matrix.diag() + loss_1 = pos_sim / (sim_matrix.sum(dim=-2) - pos_sim) + loss_2 = pos_sim / (sim_matrix.sum(dim=-1) - pos_sim) + + loss_1 = -torch.log(loss_1).mean() + loss_2 = -torch.log(loss_2).mean() + loss = (loss_1 + loss_2) / 2. + return loss + + def cal_structure_loss(self, x, learned_A_rownormed, A): + edge_idx = adjacency2index(A, weight=False, topk=True, k=self.top_k) + pos = self.compute_ricci_flow(x, learned_A_rownormed, edge_idx) + loss = F.nll_loss(torch.log(pos + 1e-5), torch.ones_like(pos).long().to(pos.device)) + return loss \ No newline at end of file diff --git a/scripts/cls/GAT_chameleon_nc.sh b/scripts/cls/GAT_chameleon_nc.sh new file mode 100644 index 0000000..29de3f9 --- /dev/null +++ b/scripts/cls/GAT_chameleon_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset chameleon \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func relu \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.01 \ +--w_decay 0.0 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone spgat \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 8 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/GAT_citeseer_nc.sh b/scripts/cls/GAT_citeseer_nc.sh new file mode 100644 index 0000000..f7d4230 --- /dev/null +++ b/scripts/cls/GAT_citeseer_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Citeseer \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 1000 \ +--hidden_features 512 \ +--embed_features 16 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.001 \ +--w_decay 0.0005 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 128 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 1. \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone spgat \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 8 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.005 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/GAT_cora_nc.sh b/scripts/cls/GAT_cora_nc.sh new file mode 100644 index 0000000..2a29076 --- /dev/null +++ b/scripts/cls/GAT_cora_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Cora \ +--eval_freq 50 \ +--exp_iters 5 \ +--epochs 1000 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.75 \ +--lr 0.005 \ +--w_decay 0. \ +--n_heads 8 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 5. \ +--eps 0.01 \ +--topk 60 \ +--temperature 0.2 \ +--backbone spgat \ +--backbone_topk 8 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 8 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.5 \ +--n_layers_cls 2 \ +--lr_cls 0.005 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 diff --git a/scripts/cls/GAT_squirrel_nc.sh b/scripts/cls/GAT_squirrel_nc.sh new file mode 100644 index 0000000..c4d5706 --- /dev/null +++ b/scripts/cls/GAT_squirrel_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset squirrel \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func tanh \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.5 \ +--lr 0.01 \ +--w_decay 0. \ +--n_heads 4 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone spgat \ +--backbone_topk 30 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 8 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.5 \ +--n_layers_cls 2 \ +--lr_cls 0.005 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/GCN_chameleon_nc.sh b/scripts/cls/GCN_chameleon_nc.sh new file mode 100644 index 0000000..f89ad64 --- /dev/null +++ b/scripts/cls/GCN_chameleon_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset chameleon \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func relu \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.01 \ +--w_decay 0.0 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone gcn \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/GCN_citeseer_nc.sh b/scripts/cls/GCN_citeseer_nc.sh new file mode 100644 index 0000000..ceb8569 --- /dev/null +++ b/scripts/cls/GCN_citeseer_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Citeseer \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 1000 \ +--hidden_features 512 \ +--embed_features 16 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.001 \ +--w_decay 0.0005 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 128 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 1. \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone gcn \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.05 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/GCN_cora_nc.sh b/scripts/cls/GCN_cora_nc.sh new file mode 100644 index 0000000..e4219a6 --- /dev/null +++ b/scripts/cls/GCN_cora_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Cora \ +--eval_freq 50 \ +--exp_iters 5 \ +--epochs 2000 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.75 \ +--dropout_edge 0.75 \ +--lr 0.01 \ +--w_decay 0. \ +--n_heads 8 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 5. \ +--eps 0.01 \ +--topk 60 \ +--temperature 0.2 \ +--backbone gcn \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.8 \ +--dropout_edge_cls 0.8 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 15 \ +--tau 0.9999 \ +--iterations 0 diff --git a/scripts/cls/GCN_squirrel_nc.sh b/scripts/cls/GCN_squirrel_nc.sh new file mode 100644 index 0000000..72b9c95 --- /dev/null +++ b/scripts/cls/GCN_squirrel_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset squirrel \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func tanh \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.01 \ +--w_decay 5e-4 \ +--n_heads 4 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone gcn \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/SAGE_chameleon_nc.sh b/scripts/cls/SAGE_chameleon_nc.sh new file mode 100644 index 0000000..492d5b3 --- /dev/null +++ b/scripts/cls/SAGE_chameleon_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset chameleon \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func relu \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.01 \ +--w_decay 0.0 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone sage \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/SAGE_citeseer_nc.sh b/scripts/cls/SAGE_citeseer_nc.sh new file mode 100644 index 0000000..7d4e437 --- /dev/null +++ b/scripts/cls/SAGE_citeseer_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Citeseer \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 1000 \ +--hidden_features 512 \ +--embed_features 16 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.001 \ +--w_decay 0.0005 \ +--n_heads 4 \ +--d_riemann 16 \ +--d_hyla 128 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 1. \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone sage \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.2 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.05 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cls/SAGE_cora_nc.sh b/scripts/cls/SAGE_cora_nc.sh new file mode 100644 index 0000000..52b0739 --- /dev/null +++ b/scripts/cls/SAGE_cora_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset Cora \ +--eval_freq 50 \ +--exp_iters 5 \ +--epochs 2000 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.75 \ +--dropout_edge 0.75 \ +--lr 0.01 \ +--w_decay 0. \ +--n_heads 8 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 5. \ +--eps 0.01 \ +--topk 60 \ +--temperature 0.2 \ +--backbone sage \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.8 \ +--dropout_edge_cls 0.8 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 15 \ +--tau 0.9999 \ +--iterations 0 diff --git a/scripts/cls/SAGE_squirrel_nc.sh b/scripts/cls/SAGE_squirrel_nc.sh new file mode 100644 index 0000000..10d932c --- /dev/null +++ b/scripts/cls/SAGE_squirrel_nc.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task classification \ +--dataset squirrel \ +--eval_freq 20 \ +--exp_iters 5 \ +--epochs 100 \ +--hidden_features 512 \ +--embed_features 32 \ +--update_adj_ratio 0.1 \ +--act_adj elu \ +--act_func tanh \ +--wq_type mlp \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.25 \ +--lr 0.01 \ +--w_decay 5e-4 \ +--n_heads 4 \ +--d_riemann 32 \ +--d_hyla 256 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1. \ +--r 2. \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone sage \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.2 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 5e-5 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cluster/citeseer_clu.sh b/scripts/cluster/citeseer_clu.sh new file mode 100644 index 0000000..0a41016 --- /dev/null +++ b/scripts/cluster/citeseer_clu.sh @@ -0,0 +1,43 @@ +python3 main.py \ +--downstream_task clustering \ +--dataset Citeseer \ +--eval_freq 100 \ +--exp_iters 10 \ +--epochs 1000 \ +--hidden_features 512 \ +--embed_features 16 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.5 \ +--lr 0.001 \ +--w_decay 0.0005 \ +--n_heads 4 \ +--d_riemann 64 \ +--d_hyla 128 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--s 1.0 \ +--r 2.0 \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--backbone gcn \ +--backbone_topk 10 \ +--n_heads_gat 8 \ +--alpha_gat 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.75 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.999 \ +--iterations 0 \ No newline at end of file diff --git a/scripts/cluster/cora_clu.sh b/scripts/cluster/cora_clu.sh new file mode 100644 index 0000000..3d2efa9 --- /dev/null +++ b/scripts/cluster/cora_clu.sh @@ -0,0 +1,37 @@ +python3 main.py \ +--downstream_task clustering \ +--dataset Cora \ +--eval_freq 100 \ +--exp_iters 10 \ +--epochs 2000 \ +--hidden_features 512 \ +--embed_features 4 \ +--update_adj_ratio 0.1 \ +--act_adj gumbel \ +--act_func relu \ +--wq_type linear \ +--n_layers 2 \ +--dropout_node 0.5 \ +--dropout_edge 0.5 \ +--lr 0.0001 \ +--w_decay 0.0 \ +--n_heads 8 \ +--d_riemann 32 \ +--d_hyla 128 \ +--scale 0.1 \ +--cur_h -1.0 \ +--cur_s 1.0 \ +--gamma 0.05 \ +--eps 0.01 \ +--topk 30 \ +--temperature 0.2 \ +--hidden_features_cls 32 \ +--dropout_node_cls 0.5 \ +--dropout_edge_cls 0.75 \ +--n_layers_cls 2 \ +--lr_cls 0.001 \ +--w_decay_cls 0.0005 \ +--epochs_cls 200 \ +--patience_cls 10 \ +--tau 0.9999 \ +--iterations 0 \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..c349923 --- /dev/null +++ b/utils.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.neighbors import kneighbors_graph +import numpy as np +from sklearn import metrics +from munkres import Munkres +import networkx as nx +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import random + + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data) + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) + + +def k_nearest_neighbors(x, k_neighbours, metric): + adj = kneighbors_graph(x, k_neighbours, metric=metric) + adj = adj.toarray().astype(np.float32) + adj += np.eye(adj.shape[0]) + return adj + + +def graph_top_K(dense_adj, k): + assert k < dense_adj.shape[-1] + _, indices = dense_adj.topk(k=k, dim=-1) + mask = torch.zeros(dense_adj.shape).bool().to(dense_adj.device) + mask[torch.arange(dense_adj.shape[0])[:, None], indices] = True + sparse_adj = torch.masked_fill(dense_adj, ~mask, value=0.) + return sparse_adj + + +def graph_threshold(dense_adj, eps): + sparse_adj = torch.masked_fill(dense_adj, (dense_adj < eps), value=0.) + return sparse_adj + + +def cal_accuracy(preds, trues): + preds = torch.argmax(preds, dim=-1) + correct = (preds == trues).sum() + return correct / len(trues) + + +def cal_F1(preds, trues): + preds = torch.argmax(preds, dim=-1) + weighted_f1 = metrics.f1_score(trues, preds, average='weighted') + macro_f1 = metrics.f1_score(trues, preds, average='macro') + return weighted_f1, macro_f1 + + +def normalize(adj, mode, sparse=False): + if sparse: + adj = adj.coalesce() + if mode == 'sym': + degree_matrix = 1. / (torch.sqrt(torch.sparse.sum(adj, -1))) + value = degree_matrix[adj.indices()[0]] * degree_matrix[adj.indices()[1]] + elif mode == 'row': + degree_matrix = 1. / (torch.sparse.sum(adj, -1)) + value = degree_matrix[adj.indices()[0]] + else: + raise NotImplementedError + return torch.sparse_coo_tensor(adj.indices(), value * adj.values(), adj.shape) + else: + if mode == 'sym': + degree_matrix = 1. / (torch.sqrt(adj.sum(-1)) + 1e-10) + return degree_matrix[:, None] * adj * degree_matrix[None, :] + elif mode == 'row': + degree_matrix = 1. / (adj.sum(-1) + 1e-10) + else: + raise NotImplementedError + return degree_matrix[:, None] * adj + + +def sample_gumbel(shape, eps=1e-20): + U = torch.rand(shape) + U = U.cuda() + return -torch.log(-torch.log(U + eps) + eps) + + +def gumbel_softmax_sample(logits, temperature=1): + y = logits + sample_gumbel(logits.size()) + return torch.nn.functional.softmax(y / temperature, dim=-1) + + +def gumbel_softmax(logits, temperature=1, hard=False): + """ + ST-gumple-softmax + input: [*, n_class] + return: flatten --> [*, n_class] an one-hot vector + """ + y = gumbel_softmax_sample(logits, temperature) + + if not hard: + return y + + shape = y.size() + _, ind = y.max(dim=-1) + y_hard = torch.zeros_like(y).view(-1, shape[-1]) + y_hard.scatter_(1, ind.view(-1, 1), 1) + y_hard = y_hard.view(*shape) + # Set gradients w.r.t. y_hard gradients w.r.t. y + y_hard = (y_hard - y).detach() + y + return y_hard + + +def adjacency2index(adjacency, weight=False, topk=False, k=10): + """_summary_ + + Args: + adjacency (torch.tensor): [N, N] matrix + return: + edge_index: [2, E] + edge_weight: optional + """ + if topk and k: + adj = graph_top_K(adjacency, k) + else: + adj = adjacency + edge_index = torch.nonzero(adj).t().contiguous() + if weight: + weight = adjacency[edge_index[0], edge_index[1]].reshape(-1) + return edge_index, weight + + else: + return edge_index + + +def plot_graph(edges, labels, n_class): + """_summary_ + + Args: + edges (ndarray): [E, 2] + nodes (ndarray): [N, ] + labels (ndarray): [N, ] + """ + color_list = [] + colors = list(mcolors.CSS4_COLORS.keys()) + for i in range(n_class): + color_code = random.choice(colors) + color_list.append(color_code) + print(f"Choose color list: {color_list}") + G = nx.Graph() + nodes = [(i, {'label': labels[i], 'color': color_list[labels[i]]}) for i in range(len(labels))] + G.add_nodes_from(nodes) + G.add_weighted_edges_from(edges) + + nodes_color = nx.get_node_attributes(G, 'color').values() + edge_weights = nx.get_edge_attributes(G, 'weight') + pos = nx.spring_layout(G) + nx.draw(G, with_labels=True, node_color=nodes_color, pos=pos) + nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights) + + +def curvature_arctanh(x, k): + if k > 0: + return (1 / np.sqrt(k)) * torch.arctanh(np.sqrt(k) * x) + if k == 0: + return x + if k < 0: + return (1 / np.sqrt(-k)) * torch.arctanh(np.sqrt(-k) * x) + + +def gumbel_sigmoid(probs, t=0.1, hard=False): + eps = torch.rand_like(probs).to(probs.device) + eps = eps.clip(0.01, 0.99) + probs = probs.clip(0.001, 0.999) + logits1 = probs.log() - (1 - probs).log() + logits2 = eps.log() - (1 - probs).log() + samples = torch.sigmoid((logits1 + logits2) / t) + if hard: + sampel_hard = (samples > 0.5).float() + samples_hard = (sampel_hard - samples).detach() + samples + return samples_hard + return samples + + +class cluster_metrics: + def __init__(self, trues, predicts): + self.trues = trues + self.predicts = predicts + + def clusterAcc(self): + l1 = list(set(self.trues)) + l2 = list(set(self.predicts)) + num1 = len(l1) + num2 = len(l2) + if num1 != num2: + raise Exception("number of classes not equal") + + """compute the cost of allocating c1 in L1 to c2 in L2""" + cost = np.zeros((num1, num2), dtype=int) + for i, c1 in enumerate(l1): + maps = np.where(self.trues == c1)[0] + for j, c2 in enumerate(l2): + maps_d = [i1 for i1 in maps if self.predicts[i1] == c2] + cost[i, j] = len(maps_d) + + mks = Munkres() + index = mks.compute(-cost) + new_predicts = np.zeros(len(self.predicts)) + for i, c in enumerate(l1): + c2 = l2[index[i][1]] + allocate_index = np.where(self.predicts == c2)[0] + new_predicts[allocate_index] = c + + acc = metrics.accuracy_score(self.trues, new_predicts) + f1_macro = metrics.f1_score(self.trues, new_predicts, average='macro') + precision_macro = metrics.precision_score(self.trues, new_predicts, average='macro') + recall_macro = metrics.recall_score(self.trues, new_predicts, average='macro') + f1_micro = metrics.f1_score(self.trues, new_predicts, average='micro') + precision_micro = metrics.precision_score(self.trues, new_predicts, average='micro') + recall_micro = metrics.recall_score(self.trues, new_predicts, average='micro') + return acc, f1_macro, precision_macro, recall_macro, f1_micro, precision_micro, recall_micro + + def evaluateFromLabel(self): + nmi = metrics.normalized_mutual_info_score(self.trues, self.predicts) + adjscore = metrics.adjusted_rand_score(self.trues, self.predicts) + acc, f1_macro, precision_macro, recall_macro, f1_micro, precision_micro, recall_micro = self.clusterAcc() + return acc, nmi, f1_macro, adjscore + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ \ No newline at end of file