Skip to content

Commit

Permalink
support task auto recognition
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 18, 2024
1 parent 3833000 commit 2e367f6
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 2 deletions.
4 changes: 4 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
TokenKMerHead,
TokenPredictionHead,
)
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters

Expand Down Expand Up @@ -255,6 +256,9 @@
"SinusoidalEmbedding",
"Criterion",
"count_parameters",
"Task",
"TaskLevel",
"TaskType",
"SEQUENCE_COL_NAMES",
"LABEL_COL_NAMES",
"SEQUENCE_COL_NAME",
Expand Down
17 changes: 17 additions & 0 deletions multimolecule/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

LABLE_TYPE_THRESHOLD = 0.5
24 changes: 24 additions & 0 deletions multimolecule/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, List

import danling as dl
import datasets
import torch
from chanfig import NestedDict
from danling import NestedTensor
from torch import Tensor
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from multimolecule import defaults
from multimolecule.tasks import Task

from .utils import infer_task


class Dataset(datasets.Dataset):
Expand Down Expand Up @@ -202,6 +207,25 @@ def post(
else:
self.set_transform(self.tokenize_transform)

@cached_property
def tasks(self) -> NestedDict:
return self.infer_tasks()

def infer_tasks(self, sequence_col: str | None = None) -> NestedDict:
return NestedDict({col: self.infer_task(col, sequence_col) for col in self.label_cols})

def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task:
if sequence_col is None:
if len(self.sequence_cols) != 1:
raise ValueError("sequence_col must be specified if there are multiple sequence columns.")
sequence_col = self.sequence_cols[0]
sequence = self._data.column(sequence_col)
column = self._data.column(label_col)
# is_nested = isinstance(column.type, ListType)
# if is_nested:
# column = column.combine_chunks().flatten()
return infer_task(sequence, column)

def update(self, dataset: datasets.Dataset):
# pylint: disable=W0212
# Why datasets won't support in-place changes?
Expand Down
70 changes: 69 additions & 1 deletion multimolecule/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,76 @@

from __future__ import annotations

from typing import Any
from typing import Any, Tuple

import pyarrow as pa
from pyarrow import Array, ChunkedArray, ListArray, StringArray

from multimolecule import defaults
from multimolecule.tasks import Task, TaskLevel, TaskType


def no_collate(batch: Any) -> Any:
return batch


def infer_task(sequence: ChunkedArray | ListArray, column: Array | ChunkedArray | ListArray) -> Task:
if isinstance(sequence, ChunkedArray) and sequence.num_chunks == 1:
sequence = sequence.chunks[0]
if isinstance(column, ChunkedArray) and column.num_chunks == 1:
column = column.chunks[0]
flattened, levels = flatten_column(column)
dtype = flattened.type
unique = flattened.unique()
num_elem = len(sequence)
num_tokens, num_contacts = get_num_tokens(sequence)

if levels == 0 or (levels == 1 and len(flattened) % len(column) == 0):
level = TaskLevel.Sequence
num_labels = len(flattened) // num_elem
else:
num_rows = defaults.TASK_INFERENCE_NUM_ROWS
sequence, column = sequence[:num_rows], column[:num_rows]
if len(flattened) % num_contacts == 0:
level = TaskLevel.Contact
num_labels = len(flattened) // num_contacts
elif len(flattened) % num_tokens == 0:
level = TaskLevel.Token
num_labels = len(flattened) // num_tokens
else:
raise ValueError("Unable to infer task: inconsistent number of values in sequence and column")

if dtype in (pa.float16(), pa.float32(), pa.float64()):
return Task(TaskType.Regression, level=level, num_labels=num_labels)
if dtype in (pa.int8(), pa.int16(), pa.int32(), pa.int64()):
if len(unique) == 2:
if len(flattened) in (num_elem, num_tokens, num_contacts):
return Task(TaskType.Binary, level=level, num_labels=1)
return Task(TaskType.MultiLabel, level=level, num_labels=num_labels)
if len(unique) / len(column) > defaults.LABLE_TYPE_THRESHOLD:
return Task(TaskType.Regression, level=level, num_labels=num_labels)
return Task(TaskType.MultiClass, level=level, num_labels=len(unique))
raise ValueError(f"Unable to infer task: unsupported dtype {dtype}")


def flatten_column(column: Array | ChunkedArray | ListArray) -> Tuple[Array, int]:
levels = 0
while isinstance(column, (ChunkedArray, ListArray)):
if isinstance(column, ChunkedArray):
column = column.combine_chunks()
elif isinstance(column, ListArray):
column = column.flatten()
levels += 1
return column, levels


def get_num_tokens(sequence: Array | ListArray) -> Tuple[int, int]:
if isinstance(sequence, StringArray):
return sum(len(i.as_py()) for i in sequence), sum(len(i.as_py()) ** 2 for i in sequence)
# remove <bos> and <eos> tokens in length calculation
offset = 0
if len({i[0] for i in sequence}) == 1:
offset += 1
if len({i[-1] for i in sequence}) == 1:
offset += 1
return sum((len(i) - offset) for i in sequence), sum((len(i) - offset) ** 2 for i in sequence)
2 changes: 2 additions & 0 deletions multimolecule/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
LABEL_COL_NAMES = ["label", "labels"]
SEQUENCE_COL_NAME = "input_ids"
LABEL_COL_NAME = "labels"
LABLE_TYPE_THRESHOLD = 0.5
TASK_INFERENCE_NUM_ROWS = 100
19 changes: 19 additions & 0 deletions multimolecule/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .task import Task, TaskLevel, TaskType

__all__ = ["Task", "TaskType", "TaskLevel"]
49 changes: 49 additions & 0 deletions multimolecule/tasks/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from chanfig import NestedDict


class TaskType(str, Enum):
Binary = "binary"
MultiClass = "multiclass"
MultiLabel = "multilabel"
Regression = "regression"


class TaskLevel(str, Enum):
Sequence = "sequence"
Token = "token"
Contact = "contact"


@dataclass
class Task(NestedDict):
type: TaskType
level: TaskLevel
num_labels: int = 1

def __post_init__(self):
if self.type in (TaskType.Binary) and self.num_labels != 1:
raise ValueError(f"num_labels must be 1 for {self.type} task")
if self.type in (TaskType.MultiClass, TaskType.MultiLabel) and self.num_labels == 1:
raise ValueError(f"num_labels must not be 1 for {self.type} task")
super().__post_init__()
3 changes: 3 additions & 0 deletions tests/data/datasets/synthetic/rna.json
Git LFS file not shown
57 changes: 56 additions & 1 deletion tests/data/test_pandas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest
import torch

from multimolecule import PandasDataset
from multimolecule import PandasDataset, Task, TaskLevel, TaskType


@pytest.mark.lfs
Expand All @@ -36,25 +36,29 @@ def test_5utr(self, preprocess: bool):
dataset = PandasDataset(
file, split="train", pretrained=self.pretrained, preprocess=preprocess, auto_rename_cols=True
)
task = Task(type=TaskType.Regression, level=TaskLevel.Sequence)
elem = dataset[0]
assert isinstance(elem["input_ids"], dl.PNTensor)
assert isinstance(elem["labels"], torch.FloatTensor)
batch = dataset[list(range(3))]
assert isinstance(batch["input_ids"], dl.NestedTensor)
assert isinstance(batch["labels"], torch.FloatTensor)
assert dataset.tasks["labels"] == task

@pytest.mark.parametrize("preprocess", [True, False])
def test_ncrna(self, preprocess: bool):
file = os.path.join(self.root, "ncrna.csv")
dataset = PandasDataset(
file, split="train", pretrained=self.pretrained, preprocess=preprocess, auto_rename_cols=True
)
task = Task(type=TaskType.MultiClass, level=TaskLevel.Sequence, num_labels=13)
elem = dataset[0]
assert isinstance(elem["input_ids"], dl.PNTensor)
assert isinstance(elem["labels"], torch.LongTensor)
batch = dataset[list(range(3))]
assert isinstance(batch["input_ids"], dl.NestedTensor)
assert isinstance(batch["labels"], torch.LongTensor)
assert dataset.tasks["labels"] == task

@pytest.mark.parametrize("preprocess", [True, False])
def test_rnaswitches(self, preprocess: bool):
Expand All @@ -63,24 +67,29 @@ def test_rnaswitches(self, preprocess: bool):
dataset = PandasDataset(
file, split="train", pretrained=self.pretrained, preprocess=preprocess, label_cols=label_cols
)
task = Task(type=TaskType.Regression, level=TaskLevel.Sequence)
elem = dataset[0]
assert isinstance(elem["sequence"], dl.PNTensor)
assert isinstance(elem["ON"], torch.FloatTensor)
assert isinstance(elem["OFF"], torch.FloatTensor)
batch = dataset[list(range(3))]
assert isinstance(batch["sequence"], dl.NestedTensor)
assert isinstance(batch["ON_OFF"], torch.FloatTensor)
for t in dataset.tasks.values():
assert t == task

@pytest.mark.parametrize("preprocess", [True, False])
def test_modifications(self, preprocess: bool):
file = os.path.join(self.root, "modifications.json")
dataset = PandasDataset(file, split="train", pretrained=self.pretrained, preprocess=preprocess)
task = Task(type=TaskType.MultiLabel, level=TaskLevel.Sequence, num_labels=12)
elem = dataset[0]
assert isinstance(elem["sequence"], dl.PNTensor)
assert isinstance(elem["label"], torch.LongTensor)
batch = dataset[list(range(3))]
assert isinstance(batch["sequence"], dl.NestedTensor)
assert isinstance(batch["label"], torch.LongTensor)
assert dataset.tasks["label"] == task

@pytest.mark.parametrize("preprocess", [True, False])
def test_degradation(self, preprocess: bool):
Expand All @@ -95,13 +104,16 @@ def test_degradation(self, preprocess: bool):
feature_cols=feature_cols,
label_cols=label_cols,
)
task = Task(type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=68)
elem = dataset[0]
assert isinstance(elem["sequence"], dl.PNTensor)
assert isinstance(elem["deg_pH10"], torch.FloatTensor)
assert isinstance(elem["deg_50C"], torch.FloatTensor)
batch = dataset[list(range(3))]
assert isinstance(batch["sequence"], dl.NestedTensor)
assert isinstance(batch["reactivity"], torch.FloatTensor)
for t in dataset.tasks.values():
assert t == task


@pytest.mark.lfs
Expand All @@ -125,3 +137,46 @@ def test_null(self):
assert dataset[0]["label"] == 1
dataset = dataset_factory(nan_process="drop")
assert len(dataset) == 61

def test_rna_task_recognition_json(self):
file = os.path.join(self.root, "rna.json")
dataset = PandasDataset(file, split="train", pretrained=self.pretrained)
assert dataset.tasks["sequence_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Sequence, num_labels=1)
assert dataset.tasks["sequence_multiclass"] == Task(
type=TaskType.MultiClass, level=TaskLevel.Sequence, num_labels=7
)
assert dataset.tasks["sequence_multilabel"] == Task(
type=TaskType.MultiLabel, level=TaskLevel.Sequence, num_labels=7
)
assert dataset.tasks["sequence_multireg"] == Task(
type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=7
)
assert dataset.tasks["sequence_regression"] == Task(
type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=1
)
assert dataset.tasks["nucleotide_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1)
assert dataset.tasks["nucleotide_multiclass"] == Task(
type=TaskType.MultiClass, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_multilabel"] == Task(
type=TaskType.MultiLabel, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_multireg"] == Task(
type=TaskType.Regression, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_regression"] == Task(
type=TaskType.Regression, level=TaskLevel.Token, num_labels=1
)
assert dataset.tasks["contact_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
assert dataset.tasks["contact_multiclass"] == Task(
type=TaskType.MultiClass, level=TaskLevel.Contact, num_labels=3
)
assert dataset.tasks["contact_multilabel"] == Task(
type=TaskType.MultiLabel, level=TaskLevel.Contact, num_labels=3
)
assert dataset.tasks["contact_multireg"] == Task(
type=TaskType.Regression, level=TaskLevel.Contact, num_labels=3
)
assert dataset.tasks["contact_regression"] == Task(
type=TaskType.Regression, level=TaskLevel.Contact, num_labels=1
)

0 comments on commit 2e367f6

Please sign in to comment.