Skip to content

Commit

Permalink
[Data] add graph_store & feature_store (#215)
Browse files Browse the repository at this point in the history
* add feature_store & graph_store

* update __init__.py

* add example of feature_store & graph_store

* update login information
  • Loading branch information
xy-Ji authored Jul 12, 2024
1 parent 10dfccc commit e7c2901
Show file tree
Hide file tree
Showing 6 changed files with 1,063 additions and 0 deletions.
25 changes: 25 additions & 0 deletions examples/database/graph_store/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from gammagl.utils import mask_to_index
from gammagl.datasets import Reddit
import tensorlayerx as tlx
import numpy as np
from gdbi.ggl import Neo4jFeatureStore, Neo4jGraphStore

dataset = Reddit('')
graph = dataset[0]

train_idx = tlx.convert_to_numpy(mask_to_index(graph.train_mask))
val_idx = tlx.convert_to_numpy(mask_to_index(graph.val_mask))
test_idx = tlx.convert_to_numpy(mask_to_index(graph.test_mask))

np.savez('idx.npz', train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

uri = 'bolt://localhost:7687'
user_name= 'neo4j'
password= 'neo4j'

feature_store = Neo4jFeatureStore(uri=uri, user_name=user_name, password=password)
graph_store = Neo4jGraphStore(uri=uri, user_name=user_name, password=password)

feature_store['reddit_node', 'x', tlx.arange(start=0, limit=graph.x.shape[0])] = graph.x
feature_store['reddit_node', 'y', tlx.arange(start=0, limit=graph.x.shape[0])] = graph.y
graph_store['reddit_edge', 'coo'] = graph.edge_index
134 changes: 134 additions & 0 deletions examples/database/graph_store/reddit_sage_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from gammagl.utils import mask_to_index
from tensorlayerx.model import WithLoss, TrainOneStep
from tqdm import tqdm
import tensorlayerx as tlx
import argparse
from gammagl.loader.neighbor_sampler import NeighborSampler
from gammagl.models import GraphSAGE_Sample_Model
import numpy as np
from gdbi.ggl import Neo4jFeatureStore, Neo4jGraphStore

class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)

def forward(self, data, y):
logits = self.backbone_network(data['x'], data['subgs'])
loss = self._loss_fn(logits, tlx.gather(data['y'], data['dst_node']))
return loss


def calculate_acc(logits, y, metrics):
"""
Args:
logits: node logits
y: node labels
metrics: tensorlayerx.metrics
Returns:
rst
"""

metrics.update(logits, y)
rst = metrics.result()
metrics.reset()
return rst


def main(args):
uri = 'bolt://localhost:7687'
user_name= 'neo4j'
password= 'neo4j'

feature_store = Neo4jFeatureStore(uri=uri, user_name=user_name, password=password)
graph_store = Neo4jGraphStore(uri=uri, user_name=user_name, password=password)

edge_index = graph_store['reddit_edge', 'coo']

with np.load('idx.npz') as data:
train_idx = tlx.convert_to_tensor(data['train_idx'])
val_idx = tlx.convert_to_tensor(data['val_idx'])
test_idx = tlx.convert_to_tensor(data['test_idx'])
train_loader = NeighborSampler(edge_index=edge_index,
node_idx=train_idx,
sample_lists=[25, 10], batch_size=2048, shuffle=True, num_workers=0)

val_loader = NeighborSampler(edge_index=edge_index,
node_idx=val_idx,
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)
test_loader = NeighborSampler(edge_index=edge_index,
node_idx=test_idx,
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)


net = GraphSAGE_Sample_Model(in_feat=602,
hid_feat=args.hidden_dim,
out_feat=41,
drop_rate=args.drop_rate,
num_layers=args.num_layers)
optimizer = tlx.optimizers.Adam(args.lr)
metrics = tlx.metrics.Accuracy()
train_weights = net.trainable_weights

loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)

y = feature_store['reddit_node', 'y', tlx.arange(start=0, limit=232965)]
y = tlx.reshape(tlx.cast(y, dtype=tlx.int64), (-1,))

for epoch in range(args.n_epoch):
pbar = tqdm(total=int(len(train_loader.dataset)))
pbar.set_description(f'Epoch {epoch:02d}')
for dst_node, n_id, adjs in train_loader:
net.set_train()
# input : sampled subgraphs, sampled node's feat
data = {"x": feature_store['reddit_node', 'x', n_id],
"y": y,
"dst_node": dst_node,
"subgs": adjs}
# label is not used
train_loss = train_one_step(data, tlx.convert_to_tensor([0]))
pbar.update(len(dst_node))
print("Epoch [{:0>3d}] ".format(epoch + 1) + " train loss: {:.4f}".format(train_loss.item()))

logits = net.inference(x, val_loader, data['x'])
if tlx.BACKEND == 'torch':
val_idx = val_idx.to(data['x'].device)
val_logits = tlx.gather(logits, val_idx)
val_y = tlx.gather(data['y'], val_idx)
val_acc = calculate_acc(val_logits, val_y, metrics)

logits = net.inference(x, test_loader, data['x'])
test_logits = tlx.gather(logits, test_idx)
test_y = tlx.gather(data['y'], test_idx)
test_acc = calculate_acc(test_logits, test_y, metrics)

print("val acc: {:.4f} || test acc{:.4f}".format(val_acc, test_acc))


if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.0005, help="learnin rate")
parser.add_argument("--n_epoch", type=int, default=50, help="number of epoch")
parser.add_argument("--hidden_dim", type=int, default=256, help="dimention of hidden layers")
parser.add_argument("--drop_rate", type=float, default=0.8, help="drop_rate")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--l2_coef", type=float, default=0., help="l2 loss coeficient")
parser.add_argument('--dataset', type=str, default='reddit', help='dataset')
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
# parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--gpu", type=int, default=-1)

args = parser.parse_args()
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")

main(args)
6 changes: 6 additions & 0 deletions gammagl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .in_memory_dataset import InMemoryDataset
from .extract import extract_zip, extract_tar
from .utils import global_config_init
from .feature_store import FeatureStore, TensorAttr
from .graph_store import GraphStore, EdgeAttr

__all__ = [
'BaseGraph',
Expand All @@ -19,6 +21,10 @@
'extract_zip',
'extract_tar',
'global_config_init',
'FeatureStore',
'TensorAttr',
'GraphStore',
'EdgeAttr',
]

classes = __all__
Loading

0 comments on commit e7c2901

Please sign in to comment.