Skip to content

Commit

Permalink
save embedding of unsupervised graphsage (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
Song-xx authored Aug 30, 2023
1 parent 3810908 commit 4fb7ced
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
45 changes: 45 additions & 0 deletions graphlearn/examples/tf/ego_sage/train_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,43 @@ def load_graph(args):
decoder=gl.Decoder(weighted=True), directed=False)
return g

def meta_path_sample(ego, node_type, edge_type, ego_name, nbrs_num, sampler):
""" creates the meta-math sampler of the input ego.
Args:
ego: A query object, the input centric nodes/edges
ego_type: A string, the type of `ego`, node_type such as 'paper' or 'user'.
ego_name: A string, the name of `ego`.
nbrs_num: A list, the number of neighbors for each hop.
sampler: A string, the strategy of neighbor sampling.
"""
meta_path = []
hops = range(len(nbrs_num))
meta_path = ['outV' for i in hops]
alias_list = [ego_name + '_hop_' + str(i + 1) for i in hops]
mata_path_string = ""
for path, nbr_count, alias in zip(meta_path, nbrs_num, alias_list):
mata_path_string += path + '(' + edge_type + ').'
ego = getattr(ego, path)(edge_type).sample(nbr_count).by(sampler).alias(alias)
print("Sampling meta path for {} is {}.".format(node_type, mata_path_string))
return ego

def node_embedding(graph, model, node_type, edge_type, **kwargs):
""" save node embedding.
Args:
node_type: such as 'paper' or 'user'.
edge_type: such as 'node_type' or 'edge_type'.
Return:
iterator, ids, embedding.
"""
tfg.conf.training = False
ego_name = 'save_node_' + node_type
seed = graph.V(node_type).batch(kwargs.get('batch_size', 64)).alias(ego_name)
query_save = meta_path_sample(seed, node_type, edge_type, ego_name, kwargs.get('nbrs_num', [10, 5]), kwargs.get('sampler', 'random_without_replacement')).values()
dataset = tfg.Dataset(query_save, window=kwargs.get('window', 10))
ego_graph = dataset.get_egograph(ego_name)
emb = model.forward(ego_graph)
return dataset.iterator, ego_graph.src.ids, emb

def run(args):
# graph input data
Expand All @@ -82,6 +119,10 @@ def run(args):
# prepare train dataset
train_data = EgoSAGEUnsupervisedDataLoader(g, None, args.sampler, args.neg_sampler, args.batch_size,
node_type='i', edge_type='train', nbrs_num=args.nbrs_num)

## uncomment below line will save the embeddings
# save_iter, save_ids, save_emb = node_embedding(graph=lg, model=model, node_type=node_type, edge_type=edge_type, nbrs_num=args.nbr_num)

src_emb = model.forward(train_data.src_ego)
dst_emb = model.forward(train_data.dst_ego)
neg_dst_emb = model.forward(train_data.neg_dst_ego)
Expand All @@ -92,6 +133,10 @@ def run(args):
# train
trainer = LocalTrainer()
trainer.train(train_data.iterator, loss, optimizer, epochs=args.epochs)
## uncomment below lines will save the embeddings
# save_file = './emb.txt'
# trainer.save_node_embedding_bigdata(save_iter=save_iter, save_ids=save_ids, save_emb=save_emb,
# save_file=save_file, block_max_lines=100000, batch_size=batch_size)

# finish
g.close()
Expand Down
52 changes: 50 additions & 2 deletions graphlearn/examples/tf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import graphlearn as gl
import graphlearn.python.nn.tf as tfg


try:
# https://www.tensorflow.org/guide/migrate
import tensorflow.compat.v1 as tf
Expand All @@ -39,6 +40,7 @@
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()


class TFTrainer(object):
"""Class for local or distributed training and evaluation.
Expand Down Expand Up @@ -209,6 +211,52 @@ def train_and_evaluate(self, train_iterator, test_iterator, loss, test_acc, opti
self.train(train_iterator, loss, optimizer, learning_rate, epochs, hooks, **kwargs)
self.test(test_iterator, test_acc, hooks, **kwargs)

def save_node_embedding_bigdata(self, save_iter, save_ids, save_emb, save_file, block_max_lines, batch_size):
if batch_size >= block_max_lines:
emb_writer = open(save_file, 'w')
emb_writer.write('id:int64\temb:string\n')
self.save_node_embedding(emb_writer, save_iter, save_ids, save_emb, batch_size)
else:
print('Start saving embeddings...')
with self.context():
self.global_step = tf.train.get_or_create_global_step()
if self.sess is None:
self.init_session()
self.sess._tf_sess().run(save_iter.initializer)
total_line = 0
block_id = 0
save_file = save_file[0:-4] if save_file.endswith('.txt') else save_file
current_file = save_file + '_%d.txt' % block_id
with open(current_file, 'a') as f:
f.write('id:int64\temb:string\n')
while True:
try:
outs = self.sess._tf_sess().run([save_ids, save_emb])
id_feat = ['%d\t'%i + ','.join(str(x) for x in arr) + '\n' for i, arr in zip(outs[0], outs[1])]
if total_line + len(id_feat) <= (block_id + 1) * block_max_lines:
with open(current_file, 'a') as f:
f.writelines(id_feat)
if total_line + len(id_feat) == (block_id + 1) * block_max_lines:
current_file = save_file + '_%d.txt' % (block_id + 1)
with open(current_file, 'a') as f:
f.write('id:int64\temb:string\n')
block_id += 1
elif (block_id + 1) * block_max_lines < total_line + len(id_feat):
with open(current_file, 'a') as f:
f.writelines(id_feat[0: (block_id + 1) * block_max_lines - total_line])
current_file = save_file + '_%d.txt' % (block_id + 1)
with open(current_file, 'a') as f:
f.write('id:int64\temb:string\n')
f.writelines(id_feat[(block_id + 1) * block_max_lines - total_line:])
block_id += 1
total_line += len(id_feat)
except tf.errors.OutOfRangeError:
print('Save node embeddings done.')
break
print("#################################################")
print("total lines saved = {} , number blocks = {} ".format(total_line, block_id + 1))
print("#################################################")

def save_node_embedding(self, emb_writer, iterator, ids, emb, batch_size):
print('Start saving embeddings...')
with self.context():
Expand All @@ -222,8 +270,8 @@ def save_node_embedding(self, emb_writer, iterator, ids, emb, batch_size):
t = time.time()
outs = self.sess._tf_sess().run([ids, emb])
# [B,], [B,dim]
feat = [','.join(str(x) for x in arr) for arr in outs[1]]
emb_writer.write(list(zip(outs[0], feat)), indices=[0, 1]) # id,emb
id_feat = ['%d\t'%i + ','.join(str(x) for x in arr) + '\n' for i, arr in zip(outs[0], outs[1])]
emb_writer.writelines(id_feat) # id,emb
local_step += 1
if local_step % self.progress_steps == 0:
print('Saved {} node embeddings, Time(s) {:.4f}'.format(local_step * batch_size, time.time() - t))
Expand Down

0 comments on commit 4fb7ced

Please sign in to comment.