Skip to content

Commit

Permalink
Upload
Browse files Browse the repository at this point in the history
  • Loading branch information
RiemannGraph committed Jun 28, 2023
1 parent eeb227f commit b7b5805
Show file tree
Hide file tree
Showing 23 changed files with 1,911 additions and 0 deletions.
156 changes: 156 additions & 0 deletions backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution, GraphAttentionLayer, SpGraphAttentionLayer
from torch_geometric.nn import GCNConv, SAGEConv
from utils import graph_top_K


class GCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features, n_layers, dropout_node=0.5, dropout_edge=0.25):
super(GCN, self).__init__()
self.conv_layers = nn.ModuleList()
self.conv_layers.append(GraphConvolution(in_features, hidden_features))
for _ in range(n_layers - 2):
self.conv_layers.append(GraphConvolution(hidden_features, hidden_features))
self.conv_layers.append(GraphConvolution(hidden_features, out_features))
self.dropout_node = nn.Dropout(dropout_node)
self.dropout_edge = nn.Dropout(dropout_edge)

def forward(self, x, adj):
adj = self.dropout_edge(adj)
for layer in self.conv_layers[: -1]:
x = layer(x, adj)
x = self.dropout_node(F.relu(x))
x = self.conv_layers[-1](x, adj)
return x


class GAT(nn.Module):
def __init__(self, in_features, hidden_features, out_features, dropout_node=0.5, dropout_edge=0.25, alpha=0.2,
n_heads=4):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout_node
self.dropout_edge = nn.Dropout(dropout_edge)

self.attentions = [
GraphAttentionLayer(in_features, hidden_features, dropout=dropout_node, alpha=alpha, concat=True) for _ in
range(n_heads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)

self.out_att = GraphAttentionLayer(hidden_features * n_heads, out_features, dropout=dropout_node, alpha=alpha,
concat=False)

def forward(self, x, adj):
adj = self.dropout_edge(adj)
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = self.out_att(x, adj)
return x


class SpGAT(nn.Module):
def __init__(self, in_features, hidden_features, out_features, dropout_node=0.5, dropout_edge=0.25, alpha=0.2,
n_heads=4):
"""Sparse version of GAT."""
super(SpGAT, self).__init__()
self.dropout = dropout_node
self.dropout_edge = nn.Dropout(dropout_edge)

self.attentions = [SpGraphAttentionLayer(in_features,
hidden_features,
dropout=dropout_node,
alpha=alpha,
concat=True) for _ in range(n_heads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)

self.out_att = SpGraphAttentionLayer(hidden_features * n_heads,
out_features,
dropout=dropout_node,
alpha=alpha,
concat=False)

def forward(self, x, adj):
adj = self.dropout_edge(adj)
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = self.out_att(x, adj)
return x


class GraphSAGE(nn.Module):
def __init__(self, in_features, hidden_features, out_features, n_layers, dropout_node=0.5, dropout_edge=0.25):
super().__init__()
self.conv_layers = nn.ModuleList()
self.conv_layers.append(SAGEConv(in_features, hidden_features))
for _ in range(n_layers - 2):
self.conv_layers.append(SAGEConv(hidden_features, hidden_features))
self.conv_layers.append(SAGEConv(hidden_features, out_features))
self.dropout_node = nn.Dropout(dropout_node)
self.dropout_edge = nn.Dropout(dropout_edge)

def forward(self, x, adj):
adj = self.dropout_edge(adj)
edge_index = adj.nonzero().t()
for layer in self.conv_layers[: -1]:
x = layer(x, edge_index)
x = self.dropout_node(F.relu(x))
x = self.conv_layers[-1](x, edge_index)
return x

# class GraphEncoder(nn.Module):


# def __init__(self, n_layers, in_features, hidden_features, embed_features, dropout, dropout_edge):
# super(GraphEncoder, self).__init__()
# self.dropout_node = nn.Dropout(dropout)
# self.dropout_adj = nn.Dropout(dropout_edge)

# self.encoder_layers = nn.ModuleList()
# self.encoder_layers.append(GraphConvolution(in_features, hidden_features))
# for _ in range(n_layers - 2):
# self.encoder_layers.append(GraphConvolution(hidden_features, hidden_features))
# self.encoder_layers.append(GraphConvolution(hidden_features, embed_features))

# def forward(self, x, adj):
# adj = self.dropout_adj(adj)
# for layer in self.encoder_layers[:-1]:
# x = self.dropout_node(F.relu(layer(x, adj)))
# x = self.encoder_layers[-1](x, adj)
# return x


class GraphEncoder(nn.Module):
def __init__(self, backbone, n_layers, in_features, hidden_features, embed_features,
dropout, dropout_edge, alpha=0.2, n_heads=4, topk=30):
super(GraphEncoder, self).__init__()
if backbone == 'gcn':
model = GCN(in_features, hidden_features, embed_features, n_layers,
dropout, dropout_edge)
elif backbone == 'sage':
model = GraphSAGE(in_features, hidden_features, embed_features, n_layers,
dropout, dropout_edge)
elif backbone == 'gat':
model = GAT(in_features, hidden_features, embed_features,
dropout, dropout_edge,
alpha, n_heads)
elif backbone == 'spgat':
model = SpGAT(in_features, hidden_features, embed_features,
dropout, dropout_edge,
alpha, n_heads)
else:
raise NotImplementedError

self.backbone = backbone
self.model = model
self.topk = topk

def forward(self, x, adj):
if self.backbone in ['gat', 'spgat', 'sage']:
adj = graph_top_K(adj, self.topk)
return self.model(x, adj)
138 changes: 138 additions & 0 deletions data_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import networkx as nx
from GraphRicciCurvature.OllivierRicci import OllivierRicci
import numpy as np
from torch_geometric.datasets import Planetoid, WikipediaNetwork, Actor
from torch_geometric.utils import to_networkx
from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.datasets import load_wine, load_breast_cancer, load_digits, fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split


def get_mask(idx, length):
"""Create mask.
"""
mask = torch.zeros(length, dtype=torch.bool)
mask[idx] = 1
return mask


def load_graph_data(root: str, data_name: str, split='public', **kwargs):
if data_name in ['Cora', 'Citeseer', 'Pubmed']:
dataset = Planetoid(root=root, name=data_name, split=split)
train_mask, val_mask, test_mask = dataset.data.train_mask, dataset.data.val_mask, dataset.data.test_mask
elif data_name == 'ogbn-arxiv':
dataset = PygNodePropPredDataset(root=root, name=data_name)
mask = dataset.get_idx_split()
train_mask, val_mask, test_mask = mask.values()
elif data_name in ['actor', 'chameleon', 'squirrel']:
if data_name == 'actor':
path = root + f'/{data_name}'
dataset = Actor(root=path)
else:
dataset = WikipediaNetwork(root=root, name=data_name)
num_nodes = dataset.data.x.shape[0]
idx_train = []
for j in range(dataset.num_classes):
idx_train.extend([i for i, x in enumerate(dataset.data.y) if x == j][:20])
idx_val = np.arange(num_nodes - 1500, num_nodes - 1000)
idx_test = np.arange(num_nodes - 1000, num_nodes)
label_len = dataset.data.y.shape[0]
train_mask, val_mask, test_mask = get_mask(idx_train, label_len), get_mask(idx_val, label_len), get_mask(idx_test, label_len)
else:
raise NotImplementedError

print(dataset.data)
G = to_networkx(dataset.data)
features = dataset.data.x
num_features = dataset.num_features
labels = dataset.data.y
adjacency = torch.from_numpy(nx.adjacency_matrix(G).toarray())
num_classes = dataset.num_classes
return features, num_features, labels, adjacency, (train_mask, val_mask, test_mask), num_classes


def load_non_graph_data(root: str, data_name: str, seed=100, **kwargs):
features = None
if data_name == 'wine':
dataset = load_wine()
n_train = 10
n_val = 10
n_es = 10
is_scale = True
elif data_name == 'digits':
dataset = load_digits()
n_train = 50
n_val = 50
n_es = 50
is_scale = False
elif data_name == 'cancer':
dataset = load_breast_cancer()
n_train = 10
n_val = 10
n_es = 10
is_scale = True
elif data_name == '20news10':
n_train = 100
n_val = 100
n_es = 100
is_scale = False
categories = ['alt.atheism',
'comp.sys.ibm.pc.hardware',
'misc.forsale',
'rec.autos',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'talk.politics.guns']
dataset = fetch_20newsgroups(subset='all', categories=categories)
vectorizer = CountVectorizer(stop_words='english', min_df=0.05)
X_counts = vectorizer.fit_transform(dataset.data).toarray()
transformer = TfidfTransformer(smooth_idf=False)
features = transformer.fit_transform(X_counts).todense()
else:
raise NotImplementedError

if data_name != '20news10':
if is_scale:
features = scale(dataset.data)
else:
features = dataset.data
features = torch.from_numpy(features)
y = dataset.target
n, num_features = features.shape
train, test, y_train, y_test = train_test_split(np.arange(n), y, random_state=seed,
train_size=n_train + n_val + n_es,
test_size=n - n_train - n_val - n_es,
stratify=y)
train, es, y_train, y_es = train_test_split(train, y_train, random_state=seed,
train_size=n_train + n_val, test_size=n_es,
stratify=y_train)
train, val, y_train, y_val = train_test_split(train, y_train, random_state=seed,
train_size=n_train, test_size=n_val,
stratify=y_train)

train_mask = torch.zeros(n, dtype=bool)
train_mask[train] = True
val_mask = torch.zeros(n, dtype=bool)
val_mask[val] = True
es_mask = torch.zeros(n, dtype=bool)
es_mask[es] = True
test_mask = torch.zeros(n, dtype=bool)
test_mask[test] = True
labels = torch.from_numpy(y)
num_classes = len(dataset.target_names)
return features, num_features, labels, torch.zeros(n, n), (train_mask, val_mask, test_mask), num_classes


def load_data(args, **kwargs):
if args.is_graph:
data_getter = load_graph_data
else:
data_getter = load_non_graph_data
return data_getter(args.root_path, args.dataset)
Empty file added exp/__init__.py
Empty file.
Loading

0 comments on commit b7b5805

Please sign in to comment.