Skip to content

Commit

Permalink
Support multi-task inference. (#861)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
#789 

*Description of changes:*
This PR adds inference support for multi-task learning. Users can use
`python3 -m graphstorm.run.gs_multi_task_learning --inference ` to
launch a inference task.

This PR also changes remap_result.py to support remapping prediction
results from multi-task learning inference. (The prediction results of
each task are stored separately on different folders with the name of
the corresponding task id.)


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 18, 2024
1 parent 5199149 commit 973d228
Show file tree
Hide file tree
Showing 23 changed files with 2,182 additions and 218 deletions.
1 change: 1 addition & 0 deletions inference_scripts/mt_infer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Inference only example configs for multi-task learning
79 changes: 79 additions & 0 deletions inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
---
version: 1.0
gsf:
basic:
backend: gloo
verbose: false
save_perf_results_path: null
batch_size: 32
node_feat_name:
- user:feat
- movie:title
gnn:
model_encoder_type: rgcn
num_layers: 1
hidden_size: 32
use_mini_batch_infer: true
input:
restore_model_path: null
output:
save_model_path: null
save_embed_path: null
hyperparam:
dropout: 0.
lr: 0.001
no_validation: true
rgcn:
num_bases: -1
use_self_loop: true
use_node_embeddings: false
multi_task_learning:
- node_classification:
target_ntype: "movie"
label_field: "label"
multilabel: false
num_classes: 19
batch_size: 16 # will overwrite the global batch_size
eval_metric:
- "accuracy"
- node_classification:
target_ntype: "movie"
label_field: "label2"
multilabel: false
num_classes: 19
batch_size: 16 # will overwrite the global batch_size
eval_metric:
- "accuracy"
- edge_classification:
target_etype:
- "user,rating,movie"
label_field: "rate_class"
multilabel: false
num_classes: 6
num_decoder_basis: 2
remove_target_edge_type: false
batch_size: 64 # will overwrite the global batch_size
- edge_regression:
target_etype:
- "user,rating,movie"
label_field: "rate"
num_decoder_basis: 32
remove_target_edge_type: false
- link_prediction:
num_negative_edges: 4
num_negative_edges_eval: 100
train_negative_sampler: joint
eval_etype:
- "user,rating,movie"
train_etype:
- "user,rating,movie"
exclude_training_targets: true
reverse_edge_types_map:
- user,rating,rating-rev,movie
batch_size: 128 # will overwrite the global batch_size
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
batch_size: 128
eval_metric:
- "mse"
103 changes: 103 additions & 0 deletions inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
---
version: 1.0
gsf:
basic:
backend: gloo
verbose: false
save_perf_results_path: null
batch_size: 32
node_feat_name:
- user:feat
- movie:title
gnn:
model_encoder_type: rgcn
num_layers: 1
hidden_size: 32
use_mini_batch_infer: true
input:
restore_model_path: null
output:
save_model_path: null
save_embed_path: null
hyperparam:
dropout: 0.
lr: 0.001
no_validation: false
rgcn:
num_bases: -1
use_self_loop: true
use_node_embeddings: false
multi_task_learning:
- node_classification:
target_ntype: "movie"
label_field: "label"
multilabel: false
num_classes: 19
batch_size: 16 # will overwrite the global batch_size
mask_fields:
- "train_mask_c0" # node classification mask 0
- "val_mask_c0"
- "test_mask_c0"
eval_metric:
- "accuracy"
- node_classification:
target_ntype: "movie"
label_field: "label2"
multilabel: false
num_classes: 19
batch_size: 16 # will overwrite the global batch_size
mask_fields:
- "train_mask_c1" # node classification mask 1
- "val_mask_c1"
- "test_mask_c1"
eval_metric:
- "accuracy"
- edge_classification:
target_etype:
- "user,rating,movie"
label_field: "rate_class"
multilabel: false
num_classes: 6
num_decoder_basis: 2
remove_target_edge_type: false
batch_size: 64 # will overwrite the global batch_size
mask_fields:
- "train_mask_field_c" # edge classification mask
- "val_mask_field_c"
- "test_mask_field_c"
- edge_regression:
target_etype:
- "user,rating,movie"
label_field: "rate"
num_decoder_basis: 32
remove_target_edge_type: false
mask_fields:
- "train_mask_field_r" # edge regression mask
- "val_mask_field_r"
- "test_mask_field_r"
- link_prediction:
num_negative_edges: 4
num_negative_edges_eval: 100
train_negative_sampler: joint
eval_etype:
- "user,rating,movie"
train_etype:
- "user,rating,movie"
exclude_training_targets: true
reverse_edge_types_map:
- user,rating,rating-rev,movie
batch_size: 128 # will overwrite the global batch_size
mask_fields:
- "train_mask_field_lp"
- null # empty means there is no validation mask
- "test_mask_field_lp"
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
batch_size: 128
mask_fields:
- "train_mask_c0" # use the same mask as node classification c0
- "val_mask_c0"
- "test_mask_c0"
eval_metric:
- "mse"
3 changes: 2 additions & 1 deletion python/graphstorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from .gsf import create_builtin_lp_model
from .gsf import create_builtin_edge_model
from .gsf import create_builtin_node_model
from .gsf import create_task_decoder
from .gsf import (create_task_decoder,
create_evaluator)

from .gsf import (create_builtin_node_decoder,
create_builtin_edge_decoder,
Expand Down
8 changes: 4 additions & 4 deletions python/graphstorm/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def _check_node_mask(self, ntypes, masks):
# ntypes is a string, convert it into list
ntypes = [ntypes]

if isinstance(masks, str):
# Mask is a string
if masks is None or isinstance(masks, str):
# Mask is a string or None
# All the masks are using the same name
masks = [masks] * len(ntypes)

Expand Down Expand Up @@ -710,8 +710,8 @@ def _check_edge_mask(self, etypes, masks):
# etypes is a tuple of strings, convert it into list
etypes = [etypes]

if isinstance(masks, str):
# Mask is a string
if masks is None or isinstance(masks, str):
# Mask is a string or None
# All the masks are using the same name
masks = [masks] * len(etypes)

Expand Down
6 changes: 4 additions & 2 deletions python/graphstorm/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,8 +1294,10 @@ def evaluate(self, val_results, test_results, total_iters):
task_evaluator = self._task_evaluators[task_id]

if isinstance(task_evaluator, GSgnnPredictionEvalInterface):
val_preds, val_labels = eval_task[0]
test_preds, test_labels = eval_task[1]
val_preds, val_labels = eval_task[0] \
if eval_task[0] is not None else (None, None)
test_preds, test_labels = eval_task[1] \
if eval_task[0] is not None else (None, None)
val_score, test_score = task_evaluator.evaluate(
val_preds, test_preds, val_labels, test_labels, total_iters)
elif isinstance(task_evaluator, GSgnnLPRankingEvalInterface):
Expand Down
Loading

0 comments on commit 973d228

Please sign in to comment.