Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-task inference. #861

Merged
merged 135 commits into from
Jun 18, 2024
Merged

Conversation

classicsong
Copy link
Contributor

@classicsong classicsong commented Jun 2, 2024

Issue #, if available:
#789

Description of changes:
This PR adds inference support for multi-task learning. Users can use python3 -m graphstorm.run.gs_multi_task_learning --inference to launch a inference task.

This PR also changes remap_result.py to support remapping prediction results from multi-task learning inference. (The prediction results of each task are stored separately on different folders with the name of the corresponding task id.)

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Xiang Song and others added 30 commits May 2, 2024 22:29
…-task learning (awslabs#828)

*Issue #, if available:*
Support multi-task learning. First PR for awslabs#789

*Description of changes:*
Update GraphStorm input config parsing to support multi-task learning.
Allow user to specify to specify multiple training tasks for a training
job through yaml file. By providing the `multi_task_learning`
configurations in the yaml file, users can define multiple training
tasks. The following config defines two training tasks, one for node
classification and one for edge classification.

```
---
version: 1.0
gsf:
  basic:
    ...
  ...
  multi_task_learning:
    - node_classification:
      target_ntype: "movie"
      label_field: "label"
      mask_fields:
        - "train_mask_field_nc"
        - "val_mask_field_nc"
        - "test_mask_field_nc"
      task_weight: 1.0
    - edge_classification:
      target_etype:
        - "user,rating,movie"
      label_field: "rate"
      mask_fields:
        - "train_mask_field_ec"
        - "val_mask_field_ec"
        - "test_mask_field_ec"
      task_weight: 0.5 # weight of the task
```
Task specific hyperparameters in multi-task learning are same as thoses
in single task learning, except that two new configs are required, i.e.,
mask_fields and task_weight. The mask_fields provides the training,
validation and test masks for the task and the task_weight gives its
loss weight.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
…ng (awslabs#834)

*Issue #, if available:*
awslabs#789 

*Description of changes:*
Add GSgnnMultiTaskDataLoader to support multi-task learning. 

When initializing a GSgnnMultiTaskDataLoader, users need to provide two
inputs: 1) a list of config.TaskInfo objects recording the information
of each task and 2) a list of dataloaders corresponding to each training
task.

During training for each iteration, GSgnnMultiTaskDataLoader will
iteratively call each task-dataloader to generate a mini-batch and
finally return a list of mini-batches to the trainer.

The length of the dataloader (number of batches for an epoch) is
determined by the largest task in the GSgnnMultiTaskDataLoader.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
python/graphstorm/gsf.py Show resolved Hide resolved
python/graphstorm/gsf.py Outdated Show resolved Hide resolved
python/graphstorm/gsf.py Show resolved Hide resolved
python/graphstorm/gsf.py Outdated Show resolved Hide resolved
python/graphstorm/inference/mt_infer.py Outdated Show resolved Hide resolved
python/graphstorm/inference/mt_infer.py Outdated Show resolved Hide resolved
python/graphstorm/inference/mt_infer.py Outdated Show resolved Hide resolved
python/graphstorm/inference/mt_infer.py Show resolved Hide resolved
python/graphstorm/model/gnn_encoder_base.py Outdated Show resolved Hide resolved
python/graphstorm/run/gsgnn_mt/mt_infer_gnn.py Outdated Show resolved Hide resolved
Copy link
Contributor

@zhjwy9343 zhjwy9343 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just finished all reviews except for test. Let's see the test results filrst.

Copy link
Contributor

@zhjwy9343 zhjwy9343 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhjwy9343 zhjwy9343 merged commit 973d228 into awslabs:main Jun 18, 2024
6 checks passed
@classicsong classicsong deleted the multi-task-infer branch June 18, 2024 01:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0.3 ready able to trigger the CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants