Skip to content

Latest commit

 

History

History
62 lines (40 loc) · 1.39 KB

README.md

File metadata and controls

62 lines (40 loc) · 1.39 KB

Multi-Task-Learning

This is an implementation of exploiting the generalized mean for per-task loss aggregation in multi-task learning. Our code is mainly based on LibMTL.

Getting started

  1. Create a virtual environment

    conda create -n gemtl python=3.8
    conda activate gemtl
    pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
  2. Clone this repository

  3. Install LibMTL

    cd GeMTL
    pip install -e .

Requirements

  • Python >= 3.8
  • Pytorch >= 1.8.1
pip install -r requirements.txt

Dataset

You can download datasets in the following links.

Run

Training and testing codes are in ./examples/{nyusp, office}/main.py.
You can check the results by running the following command.

cd ./examples/{nyusp, office}
bash run.sh

Reference

Our implementation is developed on the following repositories. Thanks to the contributors!

License

This repository is released under the MIT license.