From b9bad61cac26ab458cc41f8ba9176d7e683fccc6 Mon Sep 17 00:00:00 2001 From: GentleZhu Date: Fri, 15 Mar 2024 04:51:57 +0000 Subject: [PATCH 1/4] knn retrieval evaluation --- examples/knn_retriever/build_index.py | 141 ++++++++++++++++++ examples/peft_llm_gnn/AR_Video_Games.json | 2 +- .../peft_llm_gnn/lp_config_Video_Games.yaml | 2 +- examples/peft_llm_gnn/main_lp.py | 7 +- python/graphstorm/model/utils.py | 27 ++++ 5 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 examples/knn_retriever/build_index.py diff --git a/examples/knn_retriever/build_index.py b/examples/knn_retriever/build_index.py new file mode 100644 index 0000000000..2a81f64a11 --- /dev/null +++ b/examples/knn_retriever/build_index.py @@ -0,0 +1,141 @@ +import torch as th +import time +import graphstorm as gs +from graphstorm.utils import is_distributed +import faiss +import dgl +import numpy as np +from collections import defaultdict +from graphstorm.config import get_argument_parser +from graphstorm.config import GSConfig +from graphstorm.dataloading import GSgnnNodeDataLoader +from graphstorm.dataloading import GSgnnNodeTrainData +from graphstorm.utils import setup_device +from graphstorm.model.utils import load_gsgnn_embeddings + +def calculate_recall(pred, ground_truth): + # Convert list_data to a set if it's not already a set + if not isinstance(pred, set): + pred = set(pred) + + overlap = len(pred & ground_truth) + #if overlap > 0: + # return 1 + #else: + # return 0 + return overlap / len(ground_truth) + +def main(config_args): + """ main function + """ + config = GSConfig(config_args) + embs = load_gsgnn_embeddings(config.save_embed_path) + if False: + index_dimension = embs[config.target_ntype].size(1) + # Number of clusters (higher values lead to better recall but slower search) + #nlist = 750 + #quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization + #index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT) + #index.train(embs[config.target_ntype]) + index = faiss.IndexFlatIP(index_dimension) + index.add(embs[config.target_ntype]) + else: + scores = embs[config.target_ntype] @ embs[config.target_ntype].T + #scores.fill_diagonal_(-10) + + #print(scores.abs().mean()) + + gs.initialize(ip_config=config.ip_config, backend=config.backend) + device = setup_device(config.local_rank) + #index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(embedding_size)) + # Define the training dataset + train_data = GSgnnNodeTrainData( + config.graph_name, + config.part_config, + train_ntypes=config.target_ntype, + eval_ntypes=config.eval_target_ntype, + label_field=None, + node_feat_field=None, + ) + #for i in range(embs[config.target_ntype].shape[0]): + # print(embs[config.target_ntype][i,:].sum(), train_data.g.ndata['bert_h'][i].sum()) + # breakpoint() + # embs[config.target_ntype][i,:] = train_data.g.ndata['bert_h'][i] + + #print( train_data.g.ndata['bert_h'][0,:], embs[config.target_ntype][0,:]) + #print(train_data.g.ndata['bert_h']) + + # TODO: devise a dataloader that can exclude targets and add train_mask like LP Loader + test_dataloader = GSgnnNodeDataLoader( + train_data, + train_data.test_idxs, + fanout=[-1], + batch_size=config.eval_batch_size, + device=device, + train_task=False, + ) + dataloader_iter = iter(test_dataloader) + len_dataloader = max_num_batch = len(test_dataloader) + tensor = th.tensor([len_dataloader], device=device) + if is_distributed(): + th.distributed.all_reduce(tensor, op=th.distributed.ReduceOp.MAX) + max_num_batch = tensor[0] + recall = [] + max_ = [] + for iter_l in range(max_num_batch): + ground_truth = defaultdict(set) + input_nodes, seeds, blocks = next(dataloader_iter) + #block_graph = dgl.block_to_graph(blocks[0]) + src_id = blocks[0].srcdata[dgl.NID].tolist() + dst_id = blocks[0].dstdata[dgl.NID].tolist() + #print(blocks[0].edges(form='uv', etype='also_buy')) + #breakpoint() + # print(dgl.NID) + if 'also_buy' in blocks[0].etypes: + #src, dst = block_graph.edges(form='uv', etype='also_buy') + src, dst = blocks[0].edges(form='uv', etype='also_buy') + for s,d in zip(src.tolist(),dst.tolist()): + ground_truth[dst_id[d]].add(src_id[s]) + #ground_truth[src_id[s]].add(dst_id[d]) + if 'also_buy-rev' in blocks[0].etypes: + #src, dst = block_graph.edges(form='uv', etype='also_buy-rev') + src, dst = blocks[0].edges(form='uv', etype='also_buy-rev') + for s,d in zip(src.tolist(),dst.tolist()): + ground_truth[dst_id[d]].add(src_id[s]) + #ground_truth[src_id[s]].add(dst_id[d]) + query_idx = list(ground_truth.keys()) + #print(ground_truth) + #breakpoint() + #ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1) + #knn_result = lll.tolist() + + for idx,query in enumerate(query_idx): + #if len(ground_truth[query]) > 10: + rank_list = scores[query,:].argsort(descending=True).tolist() + #for ii in rank_list[:10]: + #print(ii, query, train_data.g.ndata['bert_h'][query] @train_data.g.ndata['bert_h'][ii].T) + # print(ii, query, scores[query, ii]) + #print(ground_truth[query]) + #breakpoint() + #recall.append(calculate_recall(lll[idx, 1:], ground_truth[query])) + recall.append(calculate_recall(rank_list[:100], ground_truth[query])) + #print(ground_truth) + max_.append(query) + #print(recall) + if gs.get_rank() == 0: + #print(query_idx, lll) + print(max_num_batch, len(recall), np.mean(recall)) + print(len(max_), len(set(max_))) + breakpoint() + +def generate_parser(): + """Generate an argument parser""" + parser = get_argument_parser() + return parser + +if __name__ == "__main__": + arg_parser = generate_parser() + + args = arg_parser.parse_args() + print(args) + main(args) \ No newline at end of file diff --git a/examples/peft_llm_gnn/AR_Video_Games.json b/examples/peft_llm_gnn/AR_Video_Games.json index e2dc5e4959..0a000a7913 100644 --- a/examples/peft_llm_gnn/AR_Video_Games.json +++ b/examples/peft_llm_gnn/AR_Video_Games.json @@ -51,7 +51,7 @@ "transform": {"name": "bert_hf", "bert_model": "bert-base-uncased", "infer_batch_size": 128, - "max_seq_length": 32} + "max_seq_length": 128} } ], "labels": [ diff --git a/examples/peft_llm_gnn/lp_config_Video_Games.yaml b/examples/peft_llm_gnn/lp_config_Video_Games.yaml index c14f3fd2c1..fe54aebba3 100644 --- a/examples/peft_llm_gnn/lp_config_Video_Games.yaml +++ b/examples/peft_llm_gnn/lp_config_Video_Games.yaml @@ -28,7 +28,7 @@ gsf: hyperparam: dropout: 0. lr: 0.0001 - num_epochs: 3 + num_epochs: 4 batch_size: 16 eval_batch_size: 16 wd_l2norm: 0.00001 diff --git a/examples/peft_llm_gnn/main_lp.py b/examples/peft_llm_gnn/main_lp.py index 7272fe9b48..649507e862 100644 --- a/examples/peft_llm_gnn/main_lp.py +++ b/examples/peft_llm_gnn/main_lp.py @@ -102,7 +102,7 @@ def main(config_args): save_model_frequency=config.save_model_frequency, use_mini_batch_infer=True ) - + # Load the best checkpoint best_model_path = trainer.get_best_model_path() model.restore_model(best_model_path) @@ -123,7 +123,10 @@ def main(config_args): # Run inference on the inference dataset and save the GNN embeddings in the specified path. infer.infer(train_data, test_dataloader, save_embed_path=config.save_embed_path, edge_mask_for_gnn_embeddings='train_mask', - use_mini_batch_infer=True, infer_batch_size=config.eval_batch_size) + use_mini_batch_infer=True, + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format, + infer_batch_size=config.eval_batch_size) def generate_parser(): """Generate an argument parser""" diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 0969c95f5d..5315a21e8a 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -27,6 +27,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import dgl +import pandas as pd from ..config import GRAPHSTORM_LP_EMB_L2_NORMALIZATION from ..gconstruct.file_io import stream_dist_tensors_to_hdf5 @@ -1065,6 +1066,32 @@ def save_full_node_embeddings(g, save_embed_path, save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format) +def load_gsgnn_embeddings(emb_path): + '''Load from `save_full_node_embeddings` to a dict of DistTensor's + ''' + with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f: + emb_info = json.load(f) + embs = {} + for ntype in emb_info["emb_name"]: + path = os.path.join(emb_path, ntype) + ntype_emb_files = os.listdir(path) + nid_files = [fname for fname in ntype_emb_files \ + if fname.startswith("embed_nids-") and fname.endswith("pt")] + emb_files = [fname for fname in ntype_emb_files \ + if fname.startswith("embed-") and fname.endswith("pt")] + num_parts = len(emb_files) + embeddings_list = [] + nid_list = [] + for i in range(num_parts): + embeddings_list.append(th.load(os.path.join(path, emb_files[i]))) + nid_list.append(th.load(os.path.join(path, nid_files[i]))) + # Convert the list of embeddings to a PyTorch tensor + embeddings_tensor = th.cat(embeddings_list, dim=0) + nids_tensor = th.cat(nid_list, dim=0) + result_tensor = th.zeros_like(embeddings_tensor) + result_tensor[nids_tensor] = embeddings_tensor + embs[ntype] = result_tensor + return embs def save_embeddings(emb_path, embeddings, rank, world_size, device=th.device('cpu'), node_id_mapping_file=None, From 597b9a8b6fbea5a07ff85b7d5dc6314837a87f70 Mon Sep 17 00:00:00 2001 From: GentleZhu Date: Wed, 20 Mar 2024 21:24:35 +0000 Subject: [PATCH 2/4] fix bugs --- examples/knn_retriever/build_index.py | 41 ++++++----------- examples/knn_retriever/embedding_config.yaml | 46 ++++++++++++++++++++ examples/knn_retriever/run_knn.sh | 18 ++++++++ 3 files changed, 78 insertions(+), 27 deletions(-) create mode 100644 examples/knn_retriever/embedding_config.yaml create mode 100644 examples/knn_retriever/run_knn.sh diff --git a/examples/knn_retriever/build_index.py b/examples/knn_retriever/build_index.py index 2a81f64a11..b34f55fa29 100644 --- a/examples/knn_retriever/build_index.py +++ b/examples/knn_retriever/build_index.py @@ -30,18 +30,15 @@ def main(config_args): """ config = GSConfig(config_args) embs = load_gsgnn_embeddings(config.save_embed_path) - if False: - index_dimension = embs[config.target_ntype].size(1) - # Number of clusters (higher values lead to better recall but slower search) - #nlist = 750 - #quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization - #index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT) - #index.train(embs[config.target_ntype]) - index = faiss.IndexFlatIP(index_dimension) - index.add(embs[config.target_ntype]) - else: - scores = embs[config.target_ntype] @ embs[config.target_ntype].T - #scores.fill_diagonal_(-10) + + index_dimension = embs[config.target_ntype].size(1) + # Number of clusters (higher values lead to better recall but slower search) + #nlist = 750 + #quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization + #index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT) + #index.train(embs[config.target_ntype]) + index = faiss.IndexFlatIP(index_dimension) + index.add(embs[config.target_ntype]) #print(scores.abs().mean()) @@ -68,7 +65,7 @@ def main(config_args): # TODO: devise a dataloader that can exclude targets and add train_mask like LP Loader test_dataloader = GSgnnNodeDataLoader( train_data, - train_data.test_idxs, + train_data.train_idxs, fanout=[-1], batch_size=config.eval_batch_size, device=device, @@ -106,27 +103,17 @@ def main(config_args): query_idx = list(ground_truth.keys()) #print(ground_truth) #breakpoint() - #ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1) + ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1) #knn_result = lll.tolist() for idx,query in enumerate(query_idx): - #if len(ground_truth[query]) > 10: - rank_list = scores[query,:].argsort(descending=True).tolist() - #for ii in rank_list[:10]: - #print(ii, query, train_data.g.ndata['bert_h'][query] @train_data.g.ndata['bert_h'][ii].T) - # print(ii, query, scores[query, ii]) - #print(ground_truth[query]) - #breakpoint() - #recall.append(calculate_recall(lll[idx, 1:], ground_truth[query])) - recall.append(calculate_recall(rank_list[:100], ground_truth[query])) - #print(ground_truth) + recall.append(calculate_recall(lll[idx, 1:], ground_truth[query])) max_.append(query) #print(recall) if gs.get_rank() == 0: #print(query_idx, lll) - print(max_num_batch, len(recall), np.mean(recall)) - print(len(max_), len(set(max_))) - breakpoint() + #print(max_num_batch, len(recall), np.mean(recall)) + print(f'recall@100: {np.mean(recall)}') def generate_parser(): """Generate an argument parser""" diff --git a/examples/knn_retriever/embedding_config.yaml b/examples/knn_retriever/embedding_config.yaml new file mode 100644 index 0000000000..34ec9557fc --- /dev/null +++ b/examples/knn_retriever/embedding_config.yaml @@ -0,0 +1,46 @@ +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + gnn: + model_encoder_type: mlp + fanout: "5,5" + node_feat_name: + - item:bert_h + num_layers: 2 + hidden_size: 768 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: /shared_data/graphstorm/examples/peft_llm_gnn/results/lp/Video_Games + hyperparam: + dropout: 0. + lr: 0.001 + num_epochs: 1 + batch_size: 512 + eval_batch_size: 512 + wd_l2norm: 0.00001 + no_validation: false + rgcn: + num_bases: -1 + use_self_loop: true + lp_decoder_type: dot_product + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + link_prediction: + num_negative_edges: 1 + num_negative_edges_eval: 100 + contrastive_loss_temperature: 0.1 + lp_loss_func: contrastive + lp_embed_normalizer: l2_norm + train_negative_sampler: inbatch_joint + target_ntype: item + eval_etype: + - "item,also_buy,item" + train_etype: + - "item,also_buy,item" + exclude_training_targets: true + reverse_edge_types_map: ["item,also_buy,also_buy-rev,item"] \ No newline at end of file diff --git a/examples/knn_retriever/run_knn.sh b/examples/knn_retriever/run_knn.sh new file mode 100644 index 0000000000..28a309d0e6 --- /dev/null +++ b/examples/knn_retriever/run_knn.sh @@ -0,0 +1,18 @@ +WORKSPACE=/shared_data/graphstorm/examples/knn_retriever/ +DATASPACE=/shared_data/graphstorm/examples/peft_llm_gnn/ +dataset=amazon_review +domain=$1 + +python -m graphstorm.run.launch \ + --workspace "$WORKSPACE" \ + --part-config "$DATASPACE"/datasets/amazon_review_"$domain"/amazon_review.json \ + --ip-config "$DATASPACE"/ip_list.txt \ + --num-trainers 1 \ + --num-servers 1 \ + --num-samplers 0 \ + --ssh-port 22 \ + --do-nid-remap False \ + build_index.py \ + --cf "$WORKSPACE"/embedding_config.yaml \ + --save-model-path "$DATASPACE"/model/lp/"$domain"/ \ + --save-embed-path "$DATASPACE"/results/lp/"$domain"/ \ No newline at end of file From 942e8c69bf4962dedb02bdd352c44e5e9f73b50f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 23 Apr 2024 00:57:52 +0000 Subject: [PATCH 3/4] for merge --- examples/peft_llm_gnn/preprocess_amazon_review.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/peft_llm_gnn/preprocess_amazon_review.py b/examples/peft_llm_gnn/preprocess_amazon_review.py index 2ce3ac2fab..bfb89164d3 100644 --- a/examples/peft_llm_gnn/preprocess_amazon_review.py +++ b/examples/peft_llm_gnn/preprocess_amazon_review.py @@ -44,8 +44,8 @@ def encode_parquet(sub_g, edge_dict, idx2asin, asin_data, field_name): pt_lvl3.append(math.nan) df = pd.DataFrame({'item': item, 'text': item_text, 'pt_lvl3': np.array(pt_lvl3)}) table = pa.Table.from_pandas(df) - pq.write_table(table, f'data/amazon_review/{field_name}/item.parquet') os.makedirs(f'data/amazon_review/{field_name}/', exist_ok=True) + pq.write_table(table, f'data/amazon_review/{field_name}/item.parquet') for etype in edge_dict: u,v = edge_dict[etype] edge_mask = u < v @@ -146,4 +146,4 @@ def construct_graph(directory_path, ood_fields = ['Video_Games, Automotive']): if __name__ == '__main__': directory_path = 'raw_data/' - construct_graph(directory_path, ['Video_Games']) \ No newline at end of file + construct_graph(directory_path, ['Video_Games']) From 29b60ebe99c1208b91d231bc586fc2a43496f444 Mon Sep 17 00:00:00 2001 From: GentleZhu Date: Tue, 23 Apr 2024 21:02:59 +0000 Subject: [PATCH 4/4] fix bug --- python/graphstorm/model/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 5315a21e8a..0dfcc9eb92 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -1075,10 +1075,10 @@ def load_gsgnn_embeddings(emb_path): for ntype in emb_info["emb_name"]: path = os.path.join(emb_path, ntype) ntype_emb_files = os.listdir(path) - nid_files = [fname for fname in ntype_emb_files \ - if fname.startswith("embed_nids-") and fname.endswith("pt")] - emb_files = [fname for fname in ntype_emb_files \ - if fname.startswith("embed-") and fname.endswith("pt")] + nid_files = sorted([fname for fname in ntype_emb_files \ + if fname.startswith("embed_nids-") and fname.endswith("pt")]) + emb_files = sorted([fname for fname in ntype_emb_files \ + if fname.startswith("embed-") and fname.endswith("pt")]) num_parts = len(emb_files) embeddings_list = [] nid_list = []