Skip to content

Commit

Permalink
support multi-dataset multi-task
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 21, 2024
1 parent 061bd31 commit 48acc23
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 13 deletions.
9 changes: 9 additions & 0 deletions docs/docs/data/multitask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
authors:
- Zhiyuan Chen
date: 2024-05-04
---

# MultiTask

::: multimolecule.data.multitask
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- data.md
- Dataset: data/dataset.md
- PandasDataset: data/pandas.md
- multitask: data/multitask.md
- datasets:
- RNAcentral: datasets/rnacentral.md
- module:
Expand Down
10 changes: 9 additions & 1 deletion multimolecule/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .dataset import Dataset
from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .pandas import PandasDataset
from .utils import no_collate

__all__ = ["Dataset", "PandasDataset", "no_collate"]
__all__ = [
"Dataset",
"PandasDataset",
"MultiTaskDataset",
"MultiTaskSampler",
"DistributedMultiTaskSampler",
"no_collate",
]
166 changes: 166 additions & 0 deletions multimolecule/data/multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

from collections.abc import Mapping
from copy import deepcopy
from functools import cached_property
from random import choices

from chanfig import NestedDict
from torch import distributed as dist
from torch.utils import data

from .dataset import Dataset


class MultiTaskDataset(data.ConcatDataset):
def __init__(self, datasets: Mapping) -> None:
for key, dataset in datasets.items():
if not isinstance(dataset, Dataset):
raise TypeError(f"Dataset {key} should be an instance of Dataset")
super().__init__(datasets.values())

@cached_property
def tasks(self) -> NestedDict:
tasks = self.datasets[0].tasks
for dataset in self.datasets[1:]:
for n, t in dataset.tasks.items():
if n not in tasks:
tasks[n] = t
elif tasks[n] != t:
raise ValueError(f"Task {n} has different configurations across datasets")
return tasks

def __repr__(self) -> str:
return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"


class MultiTaskSampler(data.BatchSampler):
r"""
Ensure all items in a batch comes from the same dataset.
Arguments:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
"""

def __init__( # pylint: disable=super-init-not-called
self,
dataset: data.ConcatDataset,
batch_size: int,
shuffle: bool = True,
drop_last: bool = False,
sampler_cls: type[data.Sampler] | None = None,
weights: list[int] | None = None,
) -> None:
self.datasets = dataset.datasets
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
if sampler_cls is None:
sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore
self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore
self.cumulative_sizes = dataset.cumulative_sizes
self.num_datasets = len(self.datasets)
self.weights = weights if weights is not None else self.dataset_sizes

def __iter__(self):
sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
sampler_weights = deepcopy(self.weights)
sampler_idx = 0
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
while sampler_iters:
if self.shuffle:
sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
sampler_id, sampler_iter = sampler_iters[sampler_idx]
cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
try:
batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
yield batch
except StopIteration:
sampler_iters.pop(sampler_idx)
sampler_weights.pop(sampler_idx)
else:
while sampler_iters:
if self.shuffle:
sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
sampler_id, sampler_iter = sampler_iters[sampler_idx]
cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
batch = [0] * self.batch_size
idx_in_batch = 0
try:
for _ in range(self.batch_size):
batch[idx_in_batch] = next(sampler_iter) + cumulative_size
idx_in_batch += 1
yield batch
idx_in_batch = 0 # noqa: SIM113
batch = [0] * self.batch_size
except StopIteration:
sampler_iters.pop(sampler_idx)
sampler_weights.pop(sampler_idx)
if idx_in_batch > 0:
yield batch[:idx_in_batch]

def __len__(self):
batch_size = self.batch_size
if self.drop_last:
return sum(len(d) // batch_size for d in self.datasets)
return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)


class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods
r"""
Distributed version of MultiTaskSampler, which ensures that each batch contains
data from only one dataset.
See Also:
[MultiTaskSampler][MultiTaskSampler]
"""

def __init__(
self,
dataset: data.ConcatDataset,
batch_size: int,
shuffle: bool = True,
drop_last: bool = False,
sampler_cls: type[data.Sampler] = data.RandomSampler,
weights: list[int] | None = None,
) -> None:
super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]

def set_epoch(self, epoch):
for s in self.samplers:
s.set_epoch(epoch)

def __len__(self):
batch_size = self.batch_size * self.world_size
if self.drop_last:
return sum(len(d) // batch_size for d in self.datasets)
return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)

@cached_property
def world_size(self) -> int:
r"""Return the number of processes in the current process group."""
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
return 1
7 changes: 6 additions & 1 deletion multimolecule/runners/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ class MultiMoleculeConfig(Config):
save_interval: int = 10

seed: int = 1013
data: DataConfig

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data = DataConfig()
self.datas = Config(default_factory=DataConfig)
self.dataloader.batch_size = 32
self.optim.name = "AdamW"
self.optim.lr = 1e-3
self.optim.weight_decay = 1e-2
self.sched.final_lr = 0

def post(self):
if "data" in self:
if self.datas:
raise ValueError("Only one of `data` or `datas` can be specified, but not both")
del self.datas
self.network.backbone.sequence.name = self.pretrained
self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}"
60 changes: 49 additions & 11 deletions multimolecule/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from chanfig import NestedDict
from danling import MultiTaskMetrics, TorchRunner
from torch import optim
from torch.utils import data
from transformers import AutoTokenizer

from multimolecule.data import PandasDataset
from multimolecule.data import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler, PandasDataset
from multimolecule.module import HeadConfig, ModelRegistry

from .config import MultiMoleculeConfig
Expand Down Expand Up @@ -90,21 +91,58 @@ def network(self):
return self.config.network

def build_datasets(self) -> NestedDict:
datasets = NestedDict()
if "data" in self.config:
return self.build_dataset(self.config.data)
if "datas" in self.config:
datasets = {name: self.build_dataset(config) for name, config in self.config.datas.items()}
datasets = {
subkey: {key: subdict[subkey] for key, subdict in datasets.items() if subkey in subdict}
for subkey in {k for v in datasets.values() for k in v}
}
return NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
raise ValueError("No data configuration found")

def build_dataset(self, config):
dataset = NestedDict()
dataset_factory = partial(
PandasDataset,
tokenizer=self.tokenizer,
**{k: v for k, v in self.config.data.items() if k not in ("train", "val", "test", "root")},
**{k: v for k, v in config.items() if k not in ("train", "val", "test", "root")},
)
if self.config.data.train:
datasets.train = dataset_factory(self.config.data.train, split="train")
if self.config.data.val:
datasets.val = dataset_factory(self.config.data.val, split="val")
if self.config.data.test:
datasets.test = dataset_factory(self.config.data.test, split="test")
if not datasets:
if "train" in config:
dataset.train = dataset_factory(config.train, split="train")
if "val" in config:
dataset.val = dataset_factory(config.val, split="val")
if "test" in config:
dataset.test = dataset_factory(config.test, split="test")
if not dataset:
raise ValueError("No datasets built. This is likely due to missing data paths in Config.")
return datasets
return dataset

def build_dataloaders(self):
datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
default_kwargs = self.config.get("dataloader", NestedDict())
dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
for k, d in datasets.items():
dataloader_kwargs.setdefault(k, NestedDict())
dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
if isinstance(d, MultiTaskDataset):
sampler = (
DistributedMultiTaskSampler(d, self.config.dataloader.batch_size, shuffle=shuffle)
if self.distributed
else MultiTaskSampler(d, self.config.dataloader.batch_size, shuffle=shuffle)
)
else:
sampler = (
data.distributed.DistributedSampler(d, shuffle=shuffle)
if self.distributed
else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
)
dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True))
self.dataloaders[k] = data.DataLoader(
d, sampler=sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
)

def build_metrics(self) -> MultiTaskMetrics:
return MultiTaskMetrics(
Expand Down

0 comments on commit 48acc23

Please sign in to comment.