From ed0b2090ad56387b1952ce47e5d2a9c0ceb6c334 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sat, 7 Sep 2024 02:14:26 +0800 Subject: [PATCH] add data Signed-off-by: Zhiyuan Chen --- .github/workflows/push.yaml | 9 +- .gitmodules | 3 + data | 1 + demo/data/datasets.py | 19 ++ demo/data/local-file.py | 19 ++ demo/{ => models}/direct-access.py | 0 demo/{ => models}/multimolecule-automodel.py | 0 demo/{ => models}/transformers-automodel.py | 0 demo/{ => models}/vanilla.py | 0 docs/docs/data/dataset.md | 9 + docs/docs/data/index.md | 9 + docs/mkdocs.yml | 5 + multimolecule/__init__.py | 4 +- multimolecule/data/README.md | 27 ++ multimolecule/data/README.zh.md | 27 ++ multimolecule/data/__init__.py | 20 ++ multimolecule/data/dataset.py | 338 +++++++++++++++++++ multimolecule/data/utils.py | 23 ++ multimolecule/defaults.py | 20 ++ multimolecule/models/README.md | 8 +- multimolecule/models/README.zh.md | 8 +- pyproject.toml | 1 + requirements.txt | 4 + tests/data/test_dataset.py | 121 +++++++ 24 files changed, 662 insertions(+), 13 deletions(-) create mode 100644 .gitmodules create mode 160000 data create mode 100644 demo/data/datasets.py create mode 100644 demo/data/local-file.py rename demo/{ => models}/direct-access.py (100%) rename demo/{ => models}/multimolecule-automodel.py (100%) rename demo/{ => models}/transformers-automodel.py (100%) rename demo/{ => models}/vanilla.py (100%) create mode 100644 docs/docs/data/dataset.md create mode 100644 docs/docs/data/index.md create mode 100644 multimolecule/data/README.md create mode 100644 multimolecule/data/README.zh.md create mode 100644 multimolecule/data/__init__.py create mode 100644 multimolecule/data/dataset.py create mode 100644 multimolecule/data/utils.py create mode 100644 multimolecule/defaults.py create mode 100644 requirements.txt create mode 100644 tests/data/test_dataset.py diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml index 707f80b2..b591ed92 100644 --- a/.github/workflows/push.yaml +++ b/.github/workflows/push.yaml @@ -18,14 +18,16 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 + with: + submodules: true - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} cache: "pip" - name: Install dependencies for testing - run: pip install pytest pytest-cov torch torchvision + run: pip install pytest pytest-cov - name: Install module - run: pip install -e . + run: pip install -r requirements.txt && pip install -e . - name: pytest run: pytest --cov=materialx --cov-report=xml --cov-report=html . - name: Upload coverage report for documentation @@ -83,11 +85,11 @@ jobs: release: if: startsWith(github.event.ref, 'refs/tags/v') needs: [lint, test] + environment: pypi permissions: contents: write id-token: write runs-on: ubuntu-latest - environment: pypi steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -110,6 +112,7 @@ jobs: develop: if: contains(fromJson('["refs/heads/master", "refs/heads/main"]'), github.ref) needs: [lint, test] + environment: pypi permissions: contents: write runs-on: ubuntu-latest diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..aa89b5f4 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "data"] + path = data + url = git@github.com:MultiMolecule/data.git diff --git a/data b/data new file mode 160000 index 00000000..a7c7a848 --- /dev/null +++ b/data @@ -0,0 +1 @@ +Subproject commit a7c7a84835bf406b3ed9c99384b544841b0cfaa3 diff --git a/demo/data/datasets.py b/demo/data/datasets.py new file mode 100644 index 00000000..69396a88 --- /dev/null +++ b/demo/data/datasets.py @@ -0,0 +1,19 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from multimolecule.data import Dataset + +data = Dataset("multimolecule/bprna-new", split="train") diff --git a/demo/data/local-file.py b/demo/data/local-file.py new file mode 100644 index 00000000..6831cc22 --- /dev/null +++ b/demo/data/local-file.py @@ -0,0 +1,19 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from multimolecule.data import Dataset + +data = Dataset("data/rna/5utr.csv", split="train") diff --git a/demo/direct-access.py b/demo/models/direct-access.py similarity index 100% rename from demo/direct-access.py rename to demo/models/direct-access.py diff --git a/demo/multimolecule-automodel.py b/demo/models/multimolecule-automodel.py similarity index 100% rename from demo/multimolecule-automodel.py rename to demo/models/multimolecule-automodel.py diff --git a/demo/transformers-automodel.py b/demo/models/transformers-automodel.py similarity index 100% rename from demo/transformers-automodel.py rename to demo/models/transformers-automodel.py diff --git a/demo/vanilla.py b/demo/models/vanilla.py similarity index 100% rename from demo/vanilla.py rename to demo/models/vanilla.py diff --git a/docs/docs/data/dataset.md b/docs/docs/data/dataset.md new file mode 100644 index 00000000..58508f35 --- /dev/null +++ b/docs/docs/data/dataset.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# Dataset + +::: multimolecule.data.Dataset diff --git a/docs/docs/data/index.md b/docs/docs/data/index.md new file mode 100644 index 00000000..c84872ac --- /dev/null +++ b/docs/docs/data/index.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# data + +--8<-- "multimolecule/data/README.md:8:" diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 57532e19..ead43c12 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -9,6 +9,9 @@ repo_url: https://github.com/DLS5-Omics/multimolecule nav: - index.md + - data: + - data/index.md + - Dataset: data/dataset.md - module: - module/index.md - heads: module/heads.md @@ -182,6 +185,8 @@ plugins: - https://docs.python.org/3/objects.inv - https://pytorch.org/docs/stable/objects.inv - https://huggingface.co/docs/transformers/master/en/objects.inv + - https://huggingface.co/docs/datasets/master/en/objects.inv + - https://pandas.pydata.org/docs/objects.inv - https://danling.org/objects.inv - https://chanfig.danling.org/objects.inv - section-index diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index 5f0dc995..fa80236c 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from . import models, tokenisers +from .data import Dataset from .models import ( AutoModelForContactPrediction, AutoModelForNucleotidePrediction, @@ -136,11 +136,11 @@ __all__ = [ "modeling_auto", "modeling_outputs", + "Dataset", "PreTrainedConfig", "HeadConfig", "BaseHeadConfig", "MaskedLMHeadConfig", - "tokenisers", "DnaTokenizer", "RnaTokenizer", "ProteinTokenizer", diff --git a/multimolecule/data/README.md b/multimolecule/data/README.md new file mode 100644 index 00000000..5edc5375 --- /dev/null +++ b/multimolecule/data/README.md @@ -0,0 +1,27 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# data + +`data` provides a collection of data processing utilities for handling data. + +While :hugs: [`datasets`](https://huggingface.co/docs/datasets) is a powerful library for managing datasets, it is a general-purpose tool that may not cover all the specific functionalities of scientific applications. + +The `data` package is designed to complement [`datasets`](https://huggingface.co/docs/datasets) by offering additional data processing utilities that are commonly used in scientific tasks. + +## Usage + +### Load from local data file + +```python +--8<-- "demo/data/local-file.py:17:" +``` + +### Load from :hugs: [`datasets`](https://huggingface.co/docs/datasets) + +```python +--8<-- "demo/data/datasets.py:17:" +``` diff --git a/multimolecule/data/README.zh.md b/multimolecule/data/README.zh.md new file mode 100644 index 00000000..7036507a --- /dev/null +++ b/multimolecule/data/README.zh.md @@ -0,0 +1,27 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# data + +`data` 提供了一系列用于处理数据的实用工具。 + +尽管 :hugs: [`datasets`](https://huggingface.co/docs/datasets) 是一个强大的管理数据集的库,但它是一个通用工具,可能无法涵盖科学应用程序的所有特定功能。 + +`data` 包旨在通过提供在科学任务中常用的数据处理实用程序来补充 [`datasets`](https://huggingface.co/docs/datasets)。 + +## Usage + +### 从本地数据文件加载 + +```python +--8<-- "demo/data/local-file.py:17:" +``` + +### 从:hugs: [`datasets`](https://huggingface.co/docs/datasets)加载 + +```python +--8<-- "demo/data/datasets.py:17:" +``` diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py new file mode 100644 index 00000000..62196c10 --- /dev/null +++ b/multimolecule/data/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .dataset import Dataset +from .utils import no_collate + +__all__ = ["Dataset", "no_collate"] diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py new file mode 100644 index 00000000..080b6421 --- /dev/null +++ b/multimolecule/data/dataset.py @@ -0,0 +1,338 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, List +from warnings import warn + +import danling as dl +import datasets +import pyarrow as pa +import torch +from danling import NestedTensor +from datasets.table import Table +from pandas import DataFrame +from torch import Tensor +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from multimolecule import defaults +from multimolecule.tokenisers.dot_bracket.utils import STANDARD_ALPHABET as DOT_BRACKET_ALPHABET + + +class Dataset(datasets.Dataset): + r""" + The base class for all datasets. + + Dataset is a subclass of [`datasets.Dataset`][] that provides additional functionality for handling structured data. + + Attributes: + tokenizer: The pretrained tokenizer to use for tokenization. + truncation: Whether to truncate sequences that exceed the maximum length of the tokenizer. + max_length: The maximum length of the input sequences. + data_cols: The names of all columns in the dataset. + feature_cols: The names of the feature columns in the dataset. + label_cols: The names of the label columns in the dataset. + sequence_cols: The names of the sequence columns in the dataset. + structure_cols: The names of the structure columns in the dataset. + column_names_map: A mapping of column names to new column names. + preprocess: Whether to preprocess the dataset. + + Args: + data: The dataset. This can be a path to a file, a tag on the Hugging Face Hub, a pyarrow.Table, + a [dict][], a [list][], or a [pandas.DataFrame][]. + split: The split of the dataset. + tokenizer: A pretrained tokenizer to use for tokenization. + Either `tokenizer` or `pretrained` must be specified. + pretrained: The name of a pretrained tokenizer to use for tokenization. + Either `tokenizer` or `pretrained` must be specified. + feature_cols: The names of the feature columns in the dataset. + Will be inferred automatically if not specified. + label_cols: The names of the label columns in the dataset. + Will be inferred automatically if not specified. + preprocess: Whether to preprocess the dataset. + Preprocessing involves pre-tokenizing the sequences using the tokenizer. + Defaults to `True`. + auto_rename_cols: Whether to automatically rename columns to standard names. + Only works when there is exactly one feature column / one label column. + You can control the naming through `multimolecule.defaults.SEQUENCE_COL_NAME` and + `multimolecule.defaults.LABEL_COL_NAME`. + For more refined control, use `column_names_map`. + column_names_map: A mapping of column names to new column names. + This is useful for renaming columns to inputs that are expected by a model. + Defaults to `None`. + truncation: Whether to truncate sequences that exceed the maximum length of the tokenizer. + Defaults to `False`. + max_length: The maximum length of the input sequences. + Defaults to the `model_max_length` of the tokenizer. + info: The dataset info. + indices_table: The indices table. + fingerprint: The fingerprint of the dataset. + """ + + tokenizer: PreTrainedTokenizerBase + truncation: bool = False + max_length: int + + feature_cols: List + label_cols: List + + data_cols: List + sequence_cols: List + structure_cols: List + + preprocess: bool = True + auto_rename_cols: bool = False + column_names_map: Mapping[str, str] | None = None + + def __init__( + self, + data: Table | DataFrame | dict | list | str, + split: datasets.NamedSplit, + tokenizer: PreTrainedTokenizerBase | None = None, + pretrained: str | None = None, + feature_cols: List | None = None, + label_cols: List | None = None, + preprocess: bool | None = None, + auto_rename_cols: bool | None = None, + column_names_map: Mapping[str, str] | None = None, + truncation: bool | None = None, + max_length: int | None = None, + info: datasets.DatasetInfo | None = None, + indices_table: Table | None = None, + fingerprint: str | None = None, + nan_process: str = "ignore", + fill_value: str | int | float = 0, + ): + arrow_table = self.build_table( + data, split, feature_cols, label_cols, nan_process=nan_process, fill_value=fill_value + ) + super().__init__( + arrow_table=arrow_table, split=split, info=info, indices_table=indices_table, fingerprint=fingerprint + ) + if tokenizer is None: + if pretrained is None: + raise ValueError("tokenizer and pretrained can not be both None.") + tokenizer = AutoTokenizer.from_pretrained(pretrained) + if max_length is None: + max_length = tokenizer.model_max_length + else: + tokenizer.model_max_length = max_length + self.max_length = max_length + if truncation is not None: + self.truncation = truncation + self.tokenizer = tokenizer + if preprocess is not None: + self.preprocess = preprocess + self.post( + feature_cols=feature_cols, + label_cols=label_cols, + auto_rename_cols=auto_rename_cols, + column_names_map=column_names_map, + ) + + def build_table( + self, + data: Table | DataFrame | dict | str, + split: datasets.NamedSplit, + feature_cols: List | None = None, + label_cols: List | None = None, + nan_process: str | None = "ignore", + fill_value: str | int | float = 0, + ) -> datasets.table.Table: + if isinstance(data, str): + try: + data = datasets.load_dataset(data, split=split).data + except FileNotFoundError: + data = dl.load_pandas(data) + if isinstance(data, DataFrame): + data = data.loc[:, ~data.columns.str.contains("^Unnamed")] + data = pa.Table.from_pandas(data) + elif isinstance(data, dict): + data = pa.Table.from_pydict(data) + elif isinstance(data, list): + data = pa.Table.from_pylist(data) + elif isinstance(data, DataFrame): + data = pa.Table.from_pandas(data) + if feature_cols is not None and label_cols is not None: + data = data.select(feature_cols + label_cols) + data = self.process_nan(data, nan_process=nan_process, fill_value=fill_value) + return data + + def post( + self, + feature_cols: List | None = None, + label_cols: List | None = None, + auto_rename_cols: bool | None = None, + column_names_map: Mapping[str, str] | None = None, + ) -> None: + r""" + Perform pre-processing steps after initialization. + + It first identifies the special columns (sequence and structure columns) in the dataset. + Then it sets the feature and label columns based on the input arguments. + If `auto_rename_cols` is `True`, it will automatically rename the columns to model inputs. + Finally, it sets the [`transform`][datasets.Dataset.set_transform] function based on the `preprocess` flag. + """ + self.identify_special_cols() + data_cols = list(self._info.features.keys()) + if label_cols is None: + if feature_cols is None: + feature_cols = [i for i in data_cols if i in defaults.SEQUENCE_COL_NAMES] + label_cols = [i for i in data_cols if i not in feature_cols] + if feature_cols is None: + feature_cols = [i for i in data_cols if i not in label_cols] + missing_feature_cols = set(feature_cols).difference(data_cols) + if missing_feature_cols: + raise ValueError(f"{missing_feature_cols} are specified in feature_cols, but not found in dataset.") + missing_label_cols = set(label_cols).difference(data_cols) + if missing_label_cols: + raise ValueError(f"{missing_label_cols} are specified in label_cols, but not found in dataset.") + self.feature_cols = list(feature_cols) + self.label_cols = list(label_cols) + self.data_cols = self.feature_cols + self.label_cols + + if auto_rename_cols is not None: + self.auto_rename_cols = auto_rename_cols + if self.auto_rename_cols: + if column_names_map is not None: + raise ValueError("auto_rename_cols and column_names_map are mutually exclusive.") + column_names_map = {} + if len(self.feature_cols) == 1: + column_names_map[self.feature_cols[0]] = defaults.SEQUENCE_COL_NAME + if len(self.label_cols) == 1: + column_names_map[self.label_cols[0]] = defaults.LABEL_COL_NAME + self.column_names_map = column_names_map + if self.column_names_map: + self.rename_columns(self.column_names_map) + + if self.preprocess: + self.update(self.map(self.tokenization)) + self.set_transform(self.torch_transform) + else: + self.set_transform(self.tokenize_transform) + + def torch_transform(self, batch: Mapping) -> Mapping: + r""" + Default [`transform`][datasets.Dataset.set_transform] function when `preprocess` is `True`. + + See Also: + [`collate`](multimolecule.Dataset.collate) + """ + return {k: self.collate(k, v) for k, v in batch.items()} + + def tokenize_transform(self, batch: Mapping) -> Mapping: + r""" + Default [`transform`][datasets.Dataset.set_transform] function when `preprocess` is `False`. + + See Also: + [`collate`](multimolecule.Dataset.collate) + """ + return {k: self.collate(k, v, tokenize=True) for k, v in batch.items()} + + def collate(self, col: str, data: Any, tokenize: bool = False) -> Tensor | NestedTensor | None: + r""" + Collate the data for a column. + + If the column is a sequence column, it will tokenize the data if `tokenize` is `True`. + Otherwise, it will return a tensor or nested tensor. + """ + if col in self.sequence_cols: + if tokenize: + data = self.tokenize(data) + return dl.tensor(data) if len(data) == 1 else NestedTensor(data) + try: + return torch.tensor(data) + except ValueError: + return NestedTensor(data) + + def __getitems__(self, key: int | slice | Iterable[int]) -> Any: + return self.__getitem__(key) + + def identify_special_cols(self) -> Sequence: + self.sequence_cols, self.structure_cols = [], [] + string_cols = [k for k, v in self.features.items() if v.dtype == "string"] + for col in string_cols: + unique_values = set() + for chunk in self._data.column(col): + unique_values.update(chunk.as_py()) + if not unique_values.difference(DOT_BRACKET_ALPHABET): + self.structure_cols.append(col) + else: + self.sequence_cols.append(col) + return string_cols + + def tokenization(self, data: Mapping[str, str]) -> Mapping[str, Tensor]: + return {col: self.tokenize(data[col]) for col in self.sequence_cols} + + def tokenize(self, string: str) -> Tensor: + return self.tokenizer(string, return_attention_mask=False, truncation=self.truncation)["input_ids"] + + def update(self, dataset: datasets.Dataset): + r""" + Perform an in-place update of the dataset. + + This method is used to update the dataset after changes have been made to the underlying data. + It updates the format columns, data, info, and fingerprint of the dataset. + """ + # pylint: disable=W0212 + # Why datasets won't support in-place changes? + # It's just impossible to extend. + self._format_columns = dataset._format_columns + self._data = dataset._data + self._info = dataset._info + self._fingerprint = dataset._fingerprint + + def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str | None = None) -> datasets.Dataset: + self.update(super().rename_columns(column_mapping, new_fingerprint=new_fingerprint)) + self.feature_cols = [column_mapping.get(i, i) for i in self.feature_cols] + self.label_cols = [column_mapping.get(i, i) for i in self.label_cols] + self.sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols] + self.structure_cols = [column_mapping.get(i, i) for i in self.structure_cols] + return self + + def rename_column( + self, original_column_name: str, new_column_name: str, new_fingerprint: str | None = None + ) -> datasets.Dataset: + self.update(super().rename_column(original_column_name, new_column_name, new_fingerprint)) + self.feature_cols = [new_column_name if i == original_column_name else i for i in self.feature_cols] + self.label_cols = [new_column_name if i == original_column_name else i for i in self.label_cols] + self.sequence_cols = [new_column_name if i == original_column_name else i for i in self.sequence_cols] + self.structure_cols = [new_column_name if i == original_column_name else i for i in self.structure_cols] + return self + + def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table: + if nan_process == "ignore": + return data + data = data.to_pandas() + data = data.replace([float("inf"), -float("inf")], float("nan")) + if data.isnull().values.any(): + if nan_process is None or nan_process == "error": + raise ValueError("NaN / inf values have been found in the dataset.") + warn( + "NaN / inf values have been found in the dataset.\n" + "While we can handle them, the data type of the corresponding column may be set to float, " + "which can and very likely will disrupt the auto task recognition.\n" + "It is recommended to address these values before loading the dataset." + ) + if nan_process == "drop": + data = data.dropna() + elif nan_process == "fill": + data = data.fillna(fill_value) + else: + raise ValueError(f"Invalid nan_process: {nan_process}") + return pa.Table.from_pandas(data) diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py new file mode 100644 index 00000000..4a9be6b7 --- /dev/null +++ b/multimolecule/data/utils.py @@ -0,0 +1,23 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import Any + + +def no_collate(batch: Any) -> Any: + return batch diff --git a/multimolecule/defaults.py b/multimolecule/defaults.py new file mode 100644 index 00000000..7eb1959a --- /dev/null +++ b/multimolecule/defaults.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +SEQUENCE_COL_NAMES = ["input_ids", "sequence", "seq"] +LABEL_COL_NAMES = ["label", "labels"] +SEQUENCE_COL_NAME = "input_ids" +LABEL_COL_NAME = "labels" diff --git a/multimolecule/models/README.md b/multimolecule/models/README.md index c7109808..9ef2f7fe 100644 --- a/multimolecule/models/README.md +++ b/multimolecule/models/README.md @@ -41,7 +41,7 @@ Similar to [Token Classification](https://huggingface.co/docs/transformers/en/ta ### Build with `multimolecule.AutoModel`s ```python ---8<-- "demo/multimolecule-automodel.py:17:" +--8<-- "demo/models/multimolecule-automodel.py:17:" ``` ### Direct Access @@ -49,7 +49,7 @@ Similar to [Token Classification](https://huggingface.co/docs/transformers/en/ta All models can be directly loaded with the `from_pretrained` method. ```python ---8<-- "demo/direct-access.py:17:" +--8<-- "demo/models/direct-access.py:17:" ``` ### Build with [`transformers.AutoModel`][]s @@ -57,7 +57,7 @@ All models can be directly loaded with the `from_pretrained` method. While we use a different naming convention for model classes, the models are still registered to corresponding [`transformers.AutoModel`][]s. ```python ---8<-- "demo/transformers-automodel.py:17:" +--8<-- "demo/models/transformers-automodel.py:17:" ``` !!! danger "`import multimolecule` before use" @@ -76,7 +76,7 @@ While we use a different naming convention for model classes, the models are sti You can also initialize a vanilla model using the model class. ```python ---8<-- "demo/vanilla.py:17:" +--8<-- "demo/models/vanilla.py:17:" ``` ## Available Models diff --git a/multimolecule/models/README.zh.md b/multimolecule/models/README.zh.md index 1e632f43..70ee7a28 100644 --- a/multimolecule/models/README.zh.md +++ b/multimolecule/models/README.zh.md @@ -41,7 +41,7 @@ date: 2024-05-04 ### 使用 `multimolecule.AutoModel` 构建 ```python ---8<-- "demo/multimolecule-automodel.py:17:" +--8<-- "demo/models/multimolecule-automodel.py:17:" ``` ### 直接访问 @@ -49,7 +49,7 @@ date: 2024-05-04 所有模型可以通过 `from_pretrained` 方法直接加载。 ```python ---8<-- "demo/direct-access.py:17:" +--8<-- "demo/models/direct-access.py:17:" ``` ### 使用 [`transformers.AutoModel`][] 构建 @@ -57,7 +57,7 @@ date: 2024-05-04 虽然我们为模型类使用了不同的命名约定,但模型仍然注册到相应的 [`transformers.AutoModel`][] 中。 ```python ---8<-- "demo/transformers-automodel.py:17:" +--8<-- "demo/models/transformers-automodel.py:17:" ``` !!! danger "使用前先 `import multimolecule`" @@ -76,7 +76,7 @@ date: 2024-05-04 你也可以使用模型类初始化一个基础模型。 ```python ---8<-- "demo/vanilla.py:17:" +--8<-- "demo/models/vanilla.py:17:" ``` ## 可用模型 diff --git a/pyproject.toml b/pyproject.toml index 748b911b..4abca002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "accelerate", "chanfig>=0.0.99", "danling>=0.3.6", + "datasets", "torch", "transformers", ] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..a6e42ae7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +biopython +pandas +psycopg2 +torch diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py new file mode 100644 index 00000000..4ed8f5f0 --- /dev/null +++ b/tests/data/test_dataset.py @@ -0,0 +1,121 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +from functools import partial + +import danling as dl +import pytest +import torch + +from multimolecule import Dataset + + +@pytest.mark.lfs +class TestRNADataset: + + pretrained = "multimolecule/rna" + root = os.path.join("data", "rna") + + @pytest.mark.parametrize("preprocess", [True, False]) + def test_5utr(self, preprocess: bool): + file = os.path.join(self.root, "5utr.csv") + dataset = Dataset(file, split="train", pretrained=self.pretrained, preprocess=preprocess, auto_rename_cols=True) + elem = dataset[0] + assert isinstance(elem["input_ids"], dl.PNTensor) + assert isinstance(elem["labels"], torch.FloatTensor) + batch = dataset[list(range(3))] + assert isinstance(batch["input_ids"], dl.NestedTensor) + assert isinstance(batch["labels"], torch.FloatTensor) + + @pytest.mark.parametrize("preprocess", [True, False]) + def test_ncrna(self, preprocess: bool): + file = os.path.join(self.root, "ncrna.csv") + dataset = Dataset(file, split="train", pretrained=self.pretrained, preprocess=preprocess, auto_rename_cols=True) + elem = dataset[0] + assert isinstance(elem["input_ids"], dl.PNTensor) + assert isinstance(elem["labels"], torch.LongTensor) + batch = dataset[list(range(3))] + assert isinstance(batch["input_ids"], dl.NestedTensor) + assert isinstance(batch["labels"], torch.LongTensor) + + @pytest.mark.parametrize("preprocess", [True, False]) + def test_rnaswitches(self, preprocess: bool): + file = os.path.join(self.root, "rnaswitches.csv") + label_cols = ["ON", "OFF", "ON_OFF"] + dataset = Dataset(file, split="train", pretrained=self.pretrained, preprocess=preprocess, label_cols=label_cols) + elem = dataset[0] + assert isinstance(elem["sequence"], dl.PNTensor) + assert isinstance(elem["ON"], torch.FloatTensor) + assert isinstance(elem["OFF"], torch.FloatTensor) + batch = dataset[list(range(3))] + assert isinstance(batch["sequence"], dl.NestedTensor) + assert isinstance(batch["ON_OFF"], torch.FloatTensor) + + @pytest.mark.parametrize("preprocess", [True, False]) + def test_modifications(self, preprocess: bool): + file = os.path.join(self.root, "modifications.json") + dataset = Dataset(file, split="train", pretrained=self.pretrained, preprocess=preprocess) + elem = dataset[0] + assert isinstance(elem["sequence"], dl.PNTensor) + assert isinstance(elem["label"], torch.LongTensor) + batch = dataset[list(range(3))] + assert isinstance(batch["sequence"], dl.NestedTensor) + assert isinstance(batch["label"], torch.LongTensor) + + @pytest.mark.parametrize("preprocess", [True, False]) + def test_degradation(self, preprocess: bool): + file = os.path.join(self.root, "degradation.json") + feature_cols = ["sequence"] # , "structure", "predicted_loop_type"] + label_cols = ["reactivity", "deg_Mg_pH10", "deg_Mg_50C", "deg_pH10", "deg_50C"] + dataset = Dataset( + file, + split="train", + pretrained=self.pretrained, + preprocess=preprocess, + feature_cols=feature_cols, + label_cols=label_cols, + ) + elem = dataset[0] + assert isinstance(elem["sequence"], dl.PNTensor) + assert isinstance(elem["deg_pH10"], torch.FloatTensor) + assert isinstance(elem["deg_50C"], torch.FloatTensor) + batch = dataset[list(range(3))] + assert isinstance(batch["sequence"], dl.NestedTensor) + assert isinstance(batch["reactivity"], torch.FloatTensor) + + +@pytest.mark.lfs +class TestSyntheticDataset: + + pretrained = "multimolecule/rna" + root = os.path.join("data", "synthetic") + + def test_null(self): + file = os.path.join(self.root, "null.csv") + dataset_factory = partial(Dataset, file, split="train", pretrained=self.pretrained) + dataset = dataset_factory(nan_process="ignore") + assert len(dataset) == 67 + with pytest.raises(RuntimeError): + dataset[0] + with pytest.raises(ValueError): + dataset = dataset_factory(nan_process="raise") + dataset = dataset_factory(nan_process="fill", fill_value=0) + assert dataset[0]["label"] == 0 + dataset = dataset_factory(nan_process="fill", fill_value=1) + assert dataset[0]["label"] == 1 + dataset = dataset_factory(nan_process="drop") + assert len(dataset) == 61