Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented evaluator for link-prediction in retrieval setting #667

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def verify_arguments(self, is_train):
_ = self.restore_model_layers
_ = self.restore_model_path
_ = self.restore_optimizer_path
_ = self.restore_embed_path
_ = self.save_embed_path
_ = self.save_embed_format

Expand Down Expand Up @@ -976,6 +977,15 @@ def restore_optimizer_path(self):
return self._restore_optimizer_path
return None

@property
def restore_embed_path(self):
""" Path to the saved GNN embeddings.
"""
# pylint: disable=no-member
if hasattr(self, "_restore_embed_path"):
return self._restore_embed_path
return None

### Save model ###
@property
def save_embed_path(self):
Expand Down Expand Up @@ -2168,6 +2178,8 @@ def _add_input_args(parser):
help='Restore the model weights saved in the specified directory.')
group.add_argument('--restore-optimizer-path', type=str, default=argparse.SUPPRESS,
help='Restore the optimizer snapshot saved in the specified directory.')
group.add_argument('--restore-embed-path', type=str, default=argparse.SUPPRESS,
help='Restore the GNN embeddings saved in the specified directory.')
return parser

def _add_output_args(parser):
Expand Down
2 changes: 2 additions & 0 deletions python/graphstorm/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .dataloading import GSgnnNodeDataLoader, GSgnnNodeSemiSupDataLoader
from .dataloading import GSgnnLinkPredictionTestDataLoader
from .dataloading import GSgnnLinkPredictionJointTestDataLoader
from .dataloading import GSgnnLinkPredictionRetrievalDataLoader
from .dataloading import (FastGSgnnLinkPredictionDataLoader,
FastGSgnnLPLocalJointNegDataLoader,
FastGSgnnLPJointNegDataLoader,
Expand All @@ -41,6 +42,7 @@

from .dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER,
BUILTIN_LP_JOINT_NEG_SAMPLER,
BUILTIN_LP_RETRIEVAL_NEG_SAMPLER,
BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER,
BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER,
BUILTIN_LP_LOCALJOINT_NEG_SAMPLER)
Expand Down
28 changes: 28 additions & 0 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def fanout(self):

BUILTIN_LP_UNIFORM_NEG_SAMPLER = 'uniform'
BUILTIN_LP_JOINT_NEG_SAMPLER = 'joint'
BUILTIN_LP_RETRIEVAL_NEG_SAMPLER = 'full'
BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER = 'inbatch_joint'
BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER = 'localuniform'
BUILTIN_LP_LOCALJOINT_NEG_SAMPLER = 'localjoint'
Expand Down Expand Up @@ -1059,6 +1060,33 @@ def _prepare_negative_sampler(self, num_negative_edges):
self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER
return negative_sampler

class GSgnnLinkPredictionRetrievalDataLoader(GSgnnLinkPredictionTestDataLoader):
""" Link prediction minibatch dataloader for validation and test
with the full train graph as the negative sampler.
This is intended for retrieval setting, where a model should compute negative scores
from all the nodes in the training graph
"""
def _prepare_negative_sampler(self, num_negative_edges):
# set `_negative_sampler` to None
negative_sampler = None
self._neg_sample_type = BUILTIN_LP_RETRIEVAL_NEG_SAMPLER
return negative_sampler

def _next_data(self, etype):
""" Get postive edges for the next iteration for a specific edge type
"""
g = self._data.g
current_pos = self._current_pos[etype]
end_of_etype = current_pos + self._batch_size >= self._fixed_test_size[etype]

pos_eids = self._target_idx[etype][current_pos:self._fixed_test_size[etype]] \
if end_of_etype \
else self._target_idx[etype][current_pos:current_pos+self._batch_size]
pos_pairs = g.find_edges(pos_eids, etype=etype)
self._current_pos[etype] += self._batch_size
return {etype: pos_pairs}, end_of_etype


################ Minibatch DataLoader (Node classification) #######################

class GSgnnNodeDataLoaderBase():
Expand Down
29 changes: 18 additions & 11 deletions python/graphstorm/inference/lp_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .graphstorm_infer import GSInferrer
from ..model.utils import save_full_node_embeddings as save_gsgnn_embeddings
from ..model.utils import save_relation_embeddings
from ..model.utils import save_relation_embeddings, load_gsgnn_embeddings
from ..model.edge_decoder import LinkPredictDistMultDecoder
from ..model.gnn import do_full_graph_inference, do_mini_batch_inference
from ..model.lp_gnn import lp_mini_batch_predict
Expand All @@ -43,7 +43,9 @@ def infer(self, data, loader, save_embed_path,
edge_mask_for_gnn_embeddings='train_mask',
use_mini_batch_infer=False,
node_id_mapping_file=None,
save_embed_format="pytorch"):
save_embed_format="pytorch",
load_embed_path=None
):
""" Do inference

The inference can do two things:
Expand All @@ -70,20 +72,25 @@ def infer(self, data, loader, save_embed_path,
graph partition algorithm.
save_embed_format : str
Specify the format of saved embeddings.
load_embed_path : str
If provided, load the embedding from disk instead of computing them.
"""
sys_tracker.check('start inferencing')
self._model.eval()
if use_mini_batch_infer:
embs = do_mini_batch_inference(self._model, data, fanout=loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
g = data.g
if load_embed_path is None:
if use_mini_batch_infer:
embs = do_mini_batch_inference(self._model, data, fanout=loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
else:
embs = do_full_graph_inference(self._model, data, fanout=loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
else:
embs = do_full_graph_inference(self._model, data, fanout=loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
embs = load_gsgnn_embeddings(load_embed_path, g)
device = self.device
g = data.g
if save_embed_path is not None:
save_gsgnn_embeddings(g,
save_embed_path,
Expand Down
42 changes: 42 additions & 0 deletions python/graphstorm/model/edge_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,48 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device):
scores[canonical_etype] = (pos_scores, neg_scores)
return scores

def calc_retrieval_scores(self, emb, pos_pairs, device):
""" Compute scores for positive edges among all possible edges for retrieval setting

Parameters
----------
emb: dict of Tensor
Node embeddings.
pos_pairs: dict of tuple
Positive edges stored in a tuple:
tuple(positive source, postive destination).
device: th.device
Device used to compute scores

Return
------
Dict of (Tensor, Tensor)
Return a dictionary of edge type to
(positive scores, negative scores)
"""
assert isinstance(pos_pairs, dict) and len(pos_pairs) == 1, \
"DotDecoder is only applicable to link prediction task with " \
"single target training edge type"
canonical_etype = list(pos_pairs.keys())[0]
pos_src, pos_dst = pos_pairs[canonical_etype]
utype, _, vtype = canonical_etype
pos_src_emb = emb[utype][pos_src].to(device)
pos_dst_emb = emb[vtype][pos_dst].to(device)
scores = {}
pos_scores = calc_dot_pos_score(pos_src_emb, pos_dst_emb)
neg_dst_emb = emb[vtype][np.arange(emb[vtype].shape[0])].to(device)
neg_scores = th.mm(pos_src_emb, neg_dst_emb.transpose(0, 1)) # [n_pos, n_embs]
# gloo with cpu will consume less GPU memory
neg_scores = neg_scores.cpu() \
if is_distributed() and get_backend() == "gloo" \
else neg_scores
pos_scores = pos_scores.detach()
pos_scores = pos_scores.cpu() \
if is_distributed() and get_backend() == "gloo" \
else pos_scores
scores[canonical_etype] = (pos_scores, neg_scores)
return scores

@property
def in_dims(self):
""" The number of input dimensions.
Expand Down
11 changes: 8 additions & 3 deletions python/graphstorm/model/lp_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def lp_mini_batch_predict(model, emb, loader, device):
with th.no_grad():
ranking = {}
for pos_neg_tuple, neg_sample_type in loader:
score = \
decoder.calc_test_scores(
emb, pos_neg_tuple, neg_sample_type, device)
if neg_sample_type == 'full':
score = \
decoder.calc_retrieval_scores(emb, pos_neg_tuple, device)
else:
score = \
decoder.calc_test_scores(
emb, pos_neg_tuple, neg_sample_type, device)

for canonical_etype, s in score.items():
# We do not concatenate rankings into a single
# ranking tensor to avoid unnecessary data copy.
Expand Down
27 changes: 27 additions & 0 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,33 @@ def save_full_node_embeddings(g, save_embed_path,

save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format)

def load_gsgnn_embeddings(emb_path, g):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't have a function to load GNN embeddings saved by GraphStorm? @classicsong

'''Load from `save_full_node_embeddings` to a dict of DistTensor's
'''
with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f:
emb_info = json.load(f)
embs = {}
for ntype in emb_info["emb_name"]:
dist_emb = None

ntype_emb_path = os.path.join(emb_path, ntype)
emb_files = os.listdir(ntype_emb_path)
ntype_emb_files = [file for file in emb_files if file.endswith(".pt") and file.startswith("emb")]
ntype_nid_files = [file for file in emb_files if file.endswith(".pt") and file.startswith("nids")]
ntype_emb_files = sorted(ntype_emb_files)
ntype_nid_files = sorted(ntype_nid_files)
part_policy = g.get_node_partition_policy(ntype)
for emb_file, nid_file in zip(ntype_emb_files, ntype_nid_files):
# Only work with torch 1.13+
emb = th.load(os.path.join(ntype_emb_path, emb_file),weights_only=True)
nids = th.load(os.path.join(ntype_emb_path, nid_file),weights_only=True)
if dist_emb is None:
dist_emb = create_dist_tensor((part_policy.get_size(), emb.shape[1]), emb.dtype,
name=ntype, part_policy=part_policy)
dist_emb[nids] = emb
barrier()
embs[ntype] = dist_emb
return embs

def save_embeddings(emb_path, embeddings, rank, world_size,
device=th.device('cpu'), node_id_mapping_file=None,
Expand Down
9 changes: 8 additions & 1 deletion python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from graphstorm.dataloading import GSgnnEdgeInferData
from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionRetrievalDataLoader
from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER
from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER
from graphstorm.dataloading import BUILTIN_LP_RETRIEVAL_NEG_SAMPLER
from graphstorm.utils import setup_device, get_lm_ntypes

def main(config_args):
Expand Down Expand Up @@ -62,6 +64,8 @@ def main(config_args):
test_dataloader_cls = GSgnnLinkPredictionTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_RETRIEVAL_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionRetrievalDataLoader
else:
raise ValueError('Unknown test negative sampler.'
'Supported test negative samplers include '
Expand All @@ -71,13 +75,16 @@ def main(config_args):
batch_size=config.eval_batch_size,
num_negative_edges=config.num_negative_edges_eval,
fanout=config.eval_fanout)
# the line below produce gnn embeddings for all nodes on the graph
infer.infer(infer_data, dataloader,
save_embed_path=config.save_embed_path,
edge_mask_for_gnn_embeddings=None if config.no_validation else \
'train_mask', # if no validation,any edge can be used in message passing.
use_mini_batch_infer=config.use_mini_batch_infer,
node_id_mapping_file=config.node_id_mapping_file,
save_embed_format=config.save_embed_format)
save_embed_format=config.save_embed_format,
load_embed_path=config.restore_embed_path
)

def generate_parser():
""" Generate an argument parser
Expand Down
76 changes: 76 additions & 0 deletions tests/end2end-tests/graphstorm-lp/lp_retrieval_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
DGL_HOME=/root/dgl
GS_HOME=$(pwd)
NUM_TRAINERS=4
NUM_INFO_TRAINERS=2
export PYTHONPATH=$GS_HOME/python/
cd $GS_HOME/training_scripts/gsgnn_lp
echo "127.0.0.1" > ip_list.txt
cd $GS_HOME/inference_scripts/lp_infer
echo "127.0.0.1" > ip_list.txt

# train a model, save model and embeddings
python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True


best_epoch_dot=$(grep "successfully save the model to" /tmp/train_log.txt | tail -1 | tr -d '\n' | tail -c 1)
echo "The best model is saved in epoch $best_epoch_dot"

echo "**************dataset: Movielens, do inference on saved model, decoder: dot"
python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True

# inference for retrieval setting
echo "**************dataset: Movielens, do inference on saved model, decoder: dot, retrieval setting:"
python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True --eval-negative-sampler full --save-embed-path none

# inferece for retrieval setting: ppi
WORKSPACE=/industry-gml-benchmarks/primekg
cd $WORKSPACE
# 1. generate GNN embeddings
python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \
--num-samplers 0 \
--ssh-port 2222 \
--part-config $WORKSPACE/4p/primekg_graph_tasks/1_ppi/primekg.json \
--ip-config /data/ip_list_p4_zw.txt \
--cf 1_ppi/frozen_lm_rgcn_lp.yaml \
--batch-size 1024 \
--hidden-size 256 \
--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6 \
--save-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6/embs

# 2. calculate MRR in retrieval setting:
python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \
--num-samplers 0 \
--ssh-port 2222 \
--part-config $WORKSPACE/4p/primekg_graph_tasks/1_ppi/primekg.json \
--ip-config /data/ip_list_p4_zw.txt \
--cf 1_ppi/frozen_lm_rgcn_lp.yaml \
--batch-size 1024 \
--hidden-size 256 \
--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6 \
--restore-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/1_ppi/frozen_lm_rgcn_lp_model-sm/epoch-6/embs \
--eval-negative-sampler full --save-embed-path none

# 1. generate GNN embeddings
python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \
--num-samplers 0 \
--ssh-port 2222 \
--part-config $WORKSPACE/4p/primekg_graph_tasks/2_protein_function_prediction/primekg.json \
--ip-config /data/ip_list_p4_zw.txt \
--cf 2_protein_function_prediction/frozen_lm_rgcn_lp.yaml \
--batch-size 256 \
--hidden-size 256 \
--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49 \
--save-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49/embs

# 2. calculate MRR in retrieval setting:
python3 -m graphstorm.run.gs_link_prediction --inference --num-trainers 8 --num-servers 4 \
--num-samplers 0 \
--ssh-port 2222 \
--part-config $WORKSPACE/4p/primekg_graph_tasks/2_protein_function_prediction/primekg.json \
--ip-config /data/ip_list_p4_zw.txt \
--cf 2_protein_function_prediction/frozen_lm_rgcn_lp.yaml \
--batch-size 256 \
--hidden-size 256 \
--restore-model-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49 \
--restore-embed-path /industry-gml-benchmarks/primekg/4p/primekg_graph_tasks/2_protein_function_prediction/frozen_lm_rgcn_lp_model-sm/epoch-49/embs \
--eval-negative-sampler full --save-embed-path none
6 changes: 6 additions & 0 deletions tests/end2end-tests/graphstorm-lp/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ then
echo "Dot product inference does not output edge embedding"
exit -1
fi

echo "**************dataset: Movielens, do inference on saved model, decoder: dot, retrieval setting:"
python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --preserve-input True --eval-negative-sampler full --save-embed-path none

error_and_exit $?

rm -fr /data/gsgnn_lp_ml_dot/infer-emb/

echo "**************dataset: Movielens, do inference on saved model, decoder: dot, remap without shared file system"
Expand Down
Loading
Loading