diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index a5d58c88ce..1eceec73ea 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -313,6 +313,7 @@ def verify_arguments(self, is_train): _ = self.restore_model_layers _ = self.restore_model_path _ = self.restore_optimizer_path + _ = self.restore_embed_path _ = self.save_embed_path _ = self.save_embed_format @@ -1003,6 +1004,15 @@ def restore_optimizer_path(self): return self._restore_optimizer_path return None + @property + def restore_embed_path(self): + """ Path to the saved GNN embeddings. + """ + # pylint: disable=no-member + if hasattr(self, "_restore_embed_path"): + return self._restore_embed_path + return None + ### Save model ### @property def save_embed_path(self): @@ -2209,6 +2219,8 @@ def _add_input_args(parser): help='Restore the model weights saved in the specified directory.') group.add_argument('--restore-optimizer-path', type=str, default=argparse.SUPPRESS, help='Restore the optimizer snapshot saved in the specified directory.') + group.add_argument('--restore-embed-path', type=str, default=argparse.SUPPRESS, + help='Restore the GNN embeddings saved in the specified directory.') return parser def _add_output_args(parser): diff --git a/python/graphstorm/dataloading/__init__.py b/python/graphstorm/dataloading/__init__.py index 1aed57bb12..27b833b3e1 100644 --- a/python/graphstorm/dataloading/__init__.py +++ b/python/graphstorm/dataloading/__init__.py @@ -26,6 +26,7 @@ from .dataloading import GSgnnNodeDataLoader, GSgnnNodeSemiSupDataLoader from .dataloading import GSgnnLinkPredictionTestDataLoader from .dataloading import GSgnnLinkPredictionJointTestDataLoader +from .dataloading import GSgnnLinkPredictionRetrievalDataLoader from .dataloading import (FastGSgnnLinkPredictionDataLoader, FastGSgnnLPLocalJointNegDataLoader, FastGSgnnLPJointNegDataLoader, @@ -41,6 +42,7 @@ from .dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, + BUILTIN_LP_RETRIEVAL_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, BUILTIN_LP_LOCALJOINT_NEG_SAMPLER) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index f367219dde..410d688c12 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -359,6 +359,7 @@ def fanout(self): BUILTIN_LP_UNIFORM_NEG_SAMPLER = 'uniform' BUILTIN_LP_JOINT_NEG_SAMPLER = 'joint' +BUILTIN_LP_RETRIEVAL_NEG_SAMPLER = 'full' BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER = 'inbatch_joint' BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER = 'localuniform' BUILTIN_LP_LOCALJOINT_NEG_SAMPLER = 'localjoint' @@ -1059,6 +1060,33 @@ def _prepare_negative_sampler(self, num_negative_edges): self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER return negative_sampler +class GSgnnLinkPredictionRetrievalDataLoader(GSgnnLinkPredictionTestDataLoader): + """ Link prediction minibatch dataloader for validation and test + with the full train graph as the negative sampler. + This is intended for retrieval setting, where a model should compute negative scores + from all the nodes in the training graph + """ + def _prepare_negative_sampler(self, num_negative_edges): + # set `_negative_sampler` to None + negative_sampler = None + self._neg_sample_type = BUILTIN_LP_RETRIEVAL_NEG_SAMPLER + return negative_sampler + + def _next_data(self, etype): + """ Get postive edges for the next iteration for a specific edge type + """ + g = self._data.g + current_pos = self._current_pos[etype] + end_of_etype = current_pos + self._batch_size >= self._fixed_test_size[etype] + + pos_eids = self._target_idx[etype][current_pos:self._fixed_test_size[etype]] \ + if end_of_etype \ + else self._target_idx[etype][current_pos:current_pos+self._batch_size] + pos_pairs = g.find_edges(pos_eids, etype=etype) + self._current_pos[etype] += self._batch_size + return {etype: pos_pairs}, end_of_etype + + ################ Minibatch DataLoader (Node classification) ####################### class GSgnnNodeDataLoaderBase(): diff --git a/python/graphstorm/inference/lp_infer.py b/python/graphstorm/inference/lp_infer.py index 80ba19f4d2..f146413fbd 100644 --- a/python/graphstorm/inference/lp_infer.py +++ b/python/graphstorm/inference/lp_infer.py @@ -19,7 +19,7 @@ from .graphstorm_infer import GSInferrer from ..model.utils import save_full_node_embeddings as save_gsgnn_embeddings -from ..model.utils import save_relation_embeddings +from ..model.utils import save_relation_embeddings, load_gsgnn_embeddings from ..model.edge_decoder import LinkPredictDistMultDecoder from ..model import do_full_graph_inference, do_mini_batch_inference from ..model.lp_gnn import lp_mini_batch_predict @@ -43,7 +43,9 @@ def infer(self, data, loader, save_embed_path, edge_mask_for_gnn_embeddings='train_mask', use_mini_batch_infer=False, node_id_mapping_file=None, - save_embed_format="pytorch"): + save_embed_format="pytorch", + load_embed_path=None + ): """ Do inference The inference can do two things: @@ -70,20 +72,25 @@ def infer(self, data, loader, save_embed_path, graph partition algorithm. save_embed_format : str Specify the format of saved embeddings. + load_embed_path : str + If provided, load the embedding from disk instead of computing them. """ sys_tracker.check('start inferencing') self._model.eval() - if use_mini_batch_infer: - embs = do_mini_batch_inference(self._model, data, fanout=loader.fanout, - edge_mask=edge_mask_for_gnn_embeddings, - task_tracker=self.task_tracker) + g = data.g + if load_embed_path is None: + if use_mini_batch_infer: + embs = do_mini_batch_inference(self._model, data, fanout=loader.fanout, + edge_mask=edge_mask_for_gnn_embeddings, + task_tracker=self.task_tracker) + else: + embs = do_full_graph_inference(self._model, data, fanout=loader.fanout, + edge_mask=edge_mask_for_gnn_embeddings, + task_tracker=self.task_tracker) + sys_tracker.check('compute embeddings') else: - embs = do_full_graph_inference(self._model, data, fanout=loader.fanout, - edge_mask=edge_mask_for_gnn_embeddings, - task_tracker=self.task_tracker) - sys_tracker.check('compute embeddings') + embs = load_gsgnn_embeddings(load_embed_path, g) device = self.device - g = data.g if save_embed_path is not None: save_gsgnn_embeddings(g, save_embed_path, diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 427d15336c..43b4c5df98 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -694,6 +694,48 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): scores[canonical_etype] = (pos_scores, neg_scores) return scores + def calc_retrieval_scores(self, emb, pos_pairs, device): + """ Compute scores for positive edges among all possible edges for retrieval setting + + Parameters + ---------- + emb: dict of Tensor + Node embeddings. + pos_pairs: dict of tuple + Positive edges stored in a tuple: + tuple(positive source, postive destination). + device: th.device + Device used to compute scores + + Return + ------ + Dict of (Tensor, Tensor) + Return a dictionary of edge type to + (positive scores, negative scores) + """ + assert isinstance(pos_pairs, dict) and len(pos_pairs) == 1, \ + "DotDecoder is only applicable to link prediction task with " \ + "single target training edge type" + canonical_etype = list(pos_pairs.keys())[0] + pos_src, pos_dst = pos_pairs[canonical_etype] + utype, _, vtype = canonical_etype + pos_src_emb = emb[utype][pos_src].to(device) + pos_dst_emb = emb[vtype][pos_dst].to(device) + scores = {} + pos_scores = calc_dot_pos_score(pos_src_emb, pos_dst_emb) + neg_dst_emb = emb[vtype][np.arange(emb[vtype].shape[0])].to(device) + neg_scores = th.mm(pos_src_emb, neg_dst_emb.transpose(0, 1)) # [n_pos, n_embs] + # gloo with cpu will consume less GPU memory + neg_scores = neg_scores.cpu() \ + if is_distributed() and get_backend() == "gloo" \ + else neg_scores + pos_scores = pos_scores.detach() + pos_scores = pos_scores.cpu() \ + if is_distributed() and get_backend() == "gloo" \ + else pos_scores + scores[canonical_etype] = (pos_scores, neg_scores) + return scores + @property def in_dims(self): """ The number of input dimensions. diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 91c2c3317c..899ee108b1 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -157,9 +157,14 @@ def lp_mini_batch_predict(model, emb, loader, device): with th.no_grad(): ranking = {} for pos_neg_tuple, neg_sample_type in loader: - score = \ - decoder.calc_test_scores( - emb, pos_neg_tuple, neg_sample_type, device) + if neg_sample_type == 'full': + score = \ + decoder.calc_retrieval_scores(emb, pos_neg_tuple, device) + else: + score = \ + decoder.calc_test_scores( + emb, pos_neg_tuple, neg_sample_type, device) + for canonical_etype, s in score.items(): # We do not concatenate rankings into a single # ranking tensor to avoid unnecessary data copy. diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index aea73096f4..6396d42669 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -907,6 +907,35 @@ def save_full_node_embeddings(g, save_embed_path, save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format) +def load_gsgnn_embeddings(emb_path, g): + '''Load from `save_full_node_embeddings` to a dict of DistTensor's + ''' + with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f: + emb_info = json.load(f) + embs = {} + for ntype in emb_info["emb_name"]: + dist_emb = None + + ntype_emb_path = os.path.join(emb_path, ntype) + emb_files = os.listdir(ntype_emb_path) + ntype_emb_files = [file for file in emb_files if file.endswith(".pt") and + file.startswith("emb")] + ntype_nid_files = [file for file in emb_files if file.endswith(".pt") and + file.startswith("nids")] + ntype_emb_files = sorted(ntype_emb_files) + ntype_nid_files = sorted(ntype_nid_files) + part_policy = g.get_node_partition_policy(ntype) + for emb_file, nid_file in zip(ntype_emb_files, ntype_nid_files): + # Only work with torch 1.13+ + emb = th.load(os.path.join(ntype_emb_path, emb_file),weights_only=True) + nids = th.load(os.path.join(ntype_emb_path, nid_file),weights_only=True) + if dist_emb is None: + dist_emb = create_dist_tensor((part_policy.get_size(), emb.shape[1]), emb.dtype, + name=ntype, part_policy=part_policy) + dist_emb[nids] = emb + barrier() + embs[ntype] = dist_emb + return embs def save_embeddings(emb_path, embeddings, rank, world_size, device=th.device('cpu'), node_id_mapping_file=None, diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 50a2e97acc..48fcf1ae24 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -24,8 +24,10 @@ from graphstorm.dataloading import GSgnnEdgeInferData from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import GSgnnLinkPredictionRetrievalDataLoader from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER +from graphstorm.dataloading import BUILTIN_LP_RETRIEVAL_NEG_SAMPLER from graphstorm.utils import setup_device, get_lm_ntypes def main(config_args): @@ -62,6 +64,8 @@ def main(config_args): test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_RETRIEVAL_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionRetrievalDataLoader else: raise ValueError('Unknown test negative sampler.' 'Supported test negative samplers include ' @@ -71,13 +75,16 @@ def main(config_args): batch_size=config.eval_batch_size, num_negative_edges=config.num_negative_edges_eval, fanout=config.eval_fanout) + # the line below produce gnn embeddings for all nodes on the graph infer.infer(infer_data, dataloader, save_embed_path=config.save_embed_path, edge_mask_for_gnn_embeddings=None if config.no_validation else \ 'train_mask', # if no validation,any edge can be used in message passing. use_mini_batch_infer=config.use_mini_batch_infer, node_id_mapping_file=config.node_id_mapping_file, - save_embed_format=config.save_embed_format) + save_embed_format=config.save_embed_format, + load_embed_path=config.restore_embed_path + ) def generate_parser(): """ Generate an argument parser diff --git a/tests/end2end-tests/graphstorm-lp/lp_retrieval_eval.sh b/tests/end2end-tests/graphstorm-lp/lp_retrieval_eval.sh new file mode 100644 index 0000000000..d246c5fc84 --- /dev/null +++ b/tests/end2end-tests/graphstorm-lp/lp_retrieval_eval.sh @@ -0,0 +1,76 @@ +DGL_HOME=/root/dgl +GS_HOME=$(pwd) +NUM_TRAINERS=4 +NUM_INFO_TRAINERS=2 +export PYTHONPATH=$GS_HOME/python/ +cd $GS_HOME/training_scripts/gsgnn_lp +echo "127.0.0.1" > ip_list.txt +cd $GS_HOME/inference_scripts/lp_infer +echo "127.0.0.1" > ip_list.txt + +# train a model, save model and embeddings +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True + + +best_epoch_dot=$(grep "successfully save the model to" /tmp/train_log.txt | tail -1 | tr -d '\n' | tail -c 1) +echo "The best model is saved in epoch $best_epoch_dot" + +echo "**************dataset: Movielens, do inference on saved model, decoder: dot" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True + +# inference for retrieval setting +echo "**************dataset: Movielens, do inference on saved model, decoder: dot, retrieval setting:" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True --eval-negative-sampler full --save-embed-path none + +# inferece for retrieval setting: ppi +WORKSPACE=/industry-gml-benchmarks/primekg +cd $WORKSPACE +# 1. generate GNN embeddings +python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \ +--num-samplers 0 \ +--ssh-port 2222 \ +--part-config $WORKSPACE/4p/primekg_graph_tasks/1_ppi/primekg.json \ +--ip-config /data/ip_list_p4_zw.txt \ +--cf 1_ppi/frozen_lm_rgcn_lp.yaml \ +--batch-size 1024 \ +--hidden-size 256 \ +--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6 \ +--save-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6/embs + +# 2. calculate MRR in retrieval setting: +python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \ +--num-samplers 0 \ +--ssh-port 2222 \ +--part-config $WORKSPACE/4p/primekg_graph_tasks/1_ppi/primekg.json \ +--ip-config /data/ip_list_p4_zw.txt \ +--cf 1_ppi/frozen_lm_rgcn_lp.yaml \ +--batch-size 1024 \ +--hidden-size 256 \ +--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6 \ +--restore-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6/embs \ +--eval-negative-sampler full --save-embed-path none + +# 1. generate GNN embeddings +python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \ +--num-samplers 0 \ +--ssh-port 2222 \ +--part-config $WORKSPACE/4p/primekg_graph_tasks/2_protein_function_prediction/primekg.json \ +--ip-config /data/ip_list_p4_zw.txt \ +--cf 2_protein_function_prediction/frozen_lm_rgcn_lp.yaml \ +--batch-size 256 \ +--hidden-size 256 \ +--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49 \ +--save-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49/embs + +# 2. calculate MRR in retrieval setting: +python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \ +--num-samplers 0 \ +--ssh-port 2222 \ +--part-config $WORKSPACE/4p/primekg_graph_tasks/2_protein_function_prediction/primekg.json \ +--ip-config /data/ip_list_p4_zw.txt \ +--cf 2_protein_function_prediction/frozen_lm_rgcn_lp.yaml \ +--batch-size 256 \ +--hidden-size 256 \ +--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49 \ +--restore-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49/embs \ +--eval-negative-sampler full --save-embed-path none diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index 4ffd6bb261..3b1ae36965 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -190,6 +190,12 @@ then echo "Dot product inference does not output edge embedding" exit -1 fi + +echo "**************dataset: Movielens, do inference on saved model, decoder: dot, retrieval setting:" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True --eval-negative-sampler full --save-embed-path none + +error_and_exit $? + rm -fr /data/gsgnn_lp_ml_dot/infer-emb/ echo "**************dataset: Movielens, do inference on saved model, decoder: dot, remap without shared file system" diff --git a/tools/partition_graph_lp.py b/tools/partition_graph_lp.py index d4444e32c6..380a88687d 100644 --- a/tools/partition_graph_lp.py +++ b/tools/partition_graph_lp.py @@ -49,7 +49,7 @@ help='split links for inductive settings: no overlapping nodes across ' + 'splits.') argparser.add_argument('--seed', type=int, default=42, - help='random seed for splitting links') + help='random seed for splitting links') # graph modification arguments argparser.add_argument('--add-reverse-edges', action='store_true', help='turn the graph into an undirected graph.')