Skip to content

Commit

Permalink
Merge branch 'main' into jstjohn/geneformer_modelcard
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn authored Nov 6, 2024
2 parents 16bef75 + 6773d00 commit 0ddf406
Show file tree
Hide file tree
Showing 28 changed files with 2,807 additions and 23 deletions.
118 changes: 117 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,19 @@ git tag MY-VERSION-TAG
uv build /sub-packages/bionemo-core
TWINE_PASSWORD="<pypi pass>" TWINE_USERNAME="<pypi user>" uvx twine upload /sub-packages/bionemo-core/dist/*
```
## Pydantic Configuration

BioNeMo 2 provides two entrypoints for models with both argparse and pydantic. Both documented in the `Models` section below.
Pydantic based configuration is designed to accept a configuration json file as input, along with context specific arguments (e.g., should we resume from existing checkpoints?). These JSON configs go through a Pydantic Validator, in this case referred to as `MainConfig`. This Config is composed of several other Pydantic models, see the class definition for details. To pre-populate a config with reasonable defaults for various standard models, we provide 'recipes.' These are simple methods that instantiate the config object and then serialize it to a JSON configuration file. From this file, you may either submit it directly, or modify the various parameters to meet your usecase. For example, Weights and biases, devices, precision, and dataset options are all extremely useful to modify. Then, you would submit this config for training.

These two workflows are packaged as executables when esm2 or geneformer are installed with pip. These commands will appear as:

```bash
bionemo-geneformer-recipe
bionemo-esm2-recipe
bionemo-geneformer-train
bionemo-esm2-train
```

## Models
### ESM-2
Expand Down Expand Up @@ -198,6 +210,62 @@ python \
--restore-from-checkpoint-path ${ESM2_650M_CKPT}
```

##### Running with Pydantic configs

Alternatively, we provide a validated and serialized configuration file entrypoint for executing the same workflow. Recipes
are available for 8m, 650m, and 3b ESM2 models. You may select which preset config to use by setting the `--recipe` parameter.

```bash
# The fastest transformer engine environment variables in testing were the following two
TEST_DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source $MY_DATA_SOURCE); \
bionemo-esm2-recipe \
--train-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/train_clusters_sanity.parquet \
--train-database-path ${TEST_DATA_DIR}/2024_03_sanity/train_sanity.db \
--valid-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/valid_clusters.parquet \
--valid-database-path ${TEST_DATA_DIR}/2024_03_sanity/validation.db \
--result-dir ./results \
--dest my_config.json \
--recipe 8m
```

> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.json as you see fit
> NOTE: To pretrain from an existing checkpoint, simply pass in the path --initial-ckpt-path to the recipe command. This will populate the JSON with the correct field to ensure pretraining is initialized from an existing checkpoint.
To submit a training job with the passed config, first update the json file with any additional execution parameters
of your choosing: number of devices, workers, steps, etc. Second, invoke our training entrypoint. To do this, we need
three things:

- Configuration file, the JSON produced by the previous step
- Model config type, in this case the pretraining config. This will validate the arguments in the config JSON against
those required for pretraining. Alternatively, things like fine-tuning with custom task heads may be specified here.
This allows for mixing/matching Data Modules with various tasks.
- Data Config type, this specifies how to parse, validate, and prepare the DataModule. This may change depending on task,
for example, pretraining ESM2 uses a protein cluster oriented sampling method. In the case of inference or fine-tuning
a pretrained model, a simple fasta file may be sufficient. There is a one-to-one relationship between DataConfig types
and DataModule types.

> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details.
```
export NVTE_FUSED_ATTN=1
export NVTE_FLASH_ATTN=0
bionemo-esm2-train \
--data-config-t bionemo.esm2.run.config_models.ESM2DataConfig \
--model-config-t bionemo.esm2.run.config_models.ExposedESM2PretrainConfig \
--config my_config.json
```

> NOTE: both data-config-t and model-config-t have default values corresponding to ESM2DataConfig and ExposedESM2PretrainingConfig
DataConfigT and ModelConfigT can also refer to locally defined types by the user. As long as python knows how to import
the specified path, they may be configured. For example, you may have a custom Dataset/DataModule that you would like to
mix with an existing recipe. In this case, you define a DataConfig object with the generic specified as your DataModule
type, and then pass in the config type to the training recipe.



### Geneformer
#### Running

Expand All @@ -221,7 +289,7 @@ train_geneformer \
--micro-batch-size 2
```

To fine-tune, you just need to specify a different combination of model and loss. Pass the path to the outputted config file from the previous step as the `--restore-from-checkpoint-path`, and also change
To fine-tune, you to specify a different combination of model and loss. Pass the path to the outputted config file from the previous step as the `--restore-from-checkpoint-path`, and also change
`--training-model-config-class` to the newly created model-config-class.

While no CLI option currently exists to hot swap in different data modules and processing functions _now_, you could
Expand All @@ -247,6 +315,54 @@ train_geneformer \
--restore-from-checkpoint-path results/test_experiment/dev/checkpoints/test_experiment--val_loss=4.3506-epoch=1-last
```

##### Running with Pydantic configs
Alternatively, we provide a validated and serialized configuration file entrypoint for executing the same workflow. Recipes
are available for 10m, and 106m geneformer models. Additionally we provide an example recipe of finetuning, where the objective
is to 'regress' on token IDs rather than the traditional masked language model approach. In practice, you will likely
need to implement your own DataModule, DataConfig, and Finetuning model. You can use the same overall approach, but with
customizations for your task.


```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
bionemo-geneformer-recipe \
--recipe 10m-pretrain \
--dest my_config.json \
--data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--result-dir ./results
```
> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.json as you see fit
> NOTE: To pretrain from an existing checkpoint, simply pass in the path --initial-ckpt-path to the recipe command. This will populate the JSON with the correct field to ensure pretraining is initialized from an existing checkpoint.
To submit a training job with the passed config, first update the json file with any additional execution parameters
of your choosing: number of devices, workers, steps, etc. Second, invoke our training entrypoint. To do this, we need
three things:

- Configuration file, the JSON produced by the previous step
- Model config type, in this case the pretraining config. This will validate the arguments in the config JSON against
those required for pretraining. Alternatively, things like fine-tuning with custom task heads may be specified here.
This allows for mixing/matching Data Modules with various tasks.
- Data Config type, this specifies how to parse, validate, and prepare the DataModule. This may change depending on task,
for example, while fine-tuning you may want to use a custom Dataset/DataModule that includes PERTURB-seq. In this case,
the default pretraining DataConfig and DataModule will be insufficient. See ESM2 for additional example usecases.

> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details.
```bash
bionemo-geneformer-train \
--data-config-t bionemo.geneformer.run.config_models.GeneformerPretrainingDataConfig \
--model-config-t bionemo.geneformer.run.config_models.ExposedGeneformerPretrainConfig \
--config my_config.json
```

> NOTE: both data-config-t and model-config-t have default values corresponding to GeneformerPretrainingDataConfig and ExposedGeneformerPretrainConfig
DataConfigT and ModelConfigT can also refer to locally defined types by the user. As long as python knows how to import
the specified path, they may be configured. For example, you may have a custom Dataset/DataModule that you would like to
mix with an existing recipe. In this case, you define a DataConfig object with the generic specified as your DataModule
type, and then pass in the config type to the training recipe.



## Updating License Header on Python Files
Expand Down
12 changes: 6 additions & 6 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger
from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger


__all__: Sequence[str] = ("main", "parser")
Expand Down Expand Up @@ -147,10 +147,10 @@ def main(

# for wandb integration
# Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html"
wandb_options: Optional[WandbLoggerOptions] = (
wandb_config: Optional[WandbConfig] = (
None
if wandb_project is None
else WandbLoggerOptions(
else WandbConfig(
offline=wandb_offline,
project=wandb_project,
entity=wandb_entity,
Expand Down Expand Up @@ -203,8 +203,8 @@ def main(
max_seq_length=max_seq_length,
num_workers=num_dataset_workers,
random_mask_strategy=random_mask_strategy,
tokenizer=tokenizer,
)

# Configure the model
esm2_config = ESM2Config(
seq_length=max_seq_length,
Expand Down Expand Up @@ -254,7 +254,7 @@ def main(
root_dir=result_dir,
name=experiment_name,
initialize_tensorboard_logger=create_tensorboard_logger,
wandb_kwargs=wandb_options,
wandb_config=wandb_config,
ckpt_callback=checkpoint_callback,
)

Expand Down
1 change: 1 addition & 0 deletions scripts/protein/esm2/test_esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
def test_val_dataloader_in_main_runs_with_limit_val_batches(
monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, limit_val_batches
):
# TODO: pydantic.
"""Ensures doesn't run out of validation samples whenever updating limit_val_batches logic.
Args:
Expand Down
97 changes: 97 additions & 0 deletions scripts/protein/esm2/test_pydantic_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shlex
import subprocess
from pathlib import Path

import pytest
from lightning.fabric.plugins.environments.lightning import find_free_network_port

from bionemo.testing.data.esm2 import create_mock_parquet_train_val_inputs, create_mock_protein_dataset
from bionemo.testing.data.load import load


data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data"


def test_bionemo2_rootdir():
data_error_str = (
"Please download test data with:\n"
"`python scripts/download_artifacts.py --models all --model_dir ./models --data all --data_dir ./ --verbose --source pbss`"
)
assert data_path.exists(), f"Could not find test data directory.\n{data_error_str}"
assert data_path.is_dir(), f"Test data directory is supposed to be a directory.\n{data_error_str}"


@pytest.fixture
def dummy_protein_dataset(tmp_path):
"""Create a mock protein dataset."""
db_file = create_mock_protein_dataset(tmp_path)
return db_file


@pytest.fixture
def dummy_parquet_train_val_inputs(tmp_path):
"""Create a mock protein train and val cluster parquet."""
train_cluster_path, valid_cluster_path = create_mock_parquet_train_val_inputs(tmp_path)
return train_cluster_path, valid_cluster_path


def test_pretrain_pydantic_cli(dummy_protein_dataset, dummy_parquet_train_val_inputs, tmpdir):
result_dir = tmpdir.mkdir("results")
train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

open_port = find_free_network_port()
config = f"{result_dir}/test_config.json"

# Invoke with blocking
cmd_str = f"""bionemo-esm2-recipe --dest {config} --recipe test
--train-database-path {dummy_protein_dataset}
--train-cluster-path {train_cluster_path}
--valid-database-path {dummy_protein_dataset}
--valid-cluster-path {valid_cluster_path}
--result-dir {result_dir}""".strip()

# continue when finished
env = dict(**os.environ) # a local copy of the environment
env["MASTER_PORT"] = str(open_port)
cmd = shlex.split(cmd_str)
result = subprocess.run(
cmd,
cwd=tmpdir,
env=env,
capture_output=True,
)
# Now do pretrain
if result.returncode != 0:
raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}")

cmd_str = f"""bionemo-esm2-train --conf {config}""".strip()
env = dict(**os.environ) # a local copy of the environment
open_port = find_free_network_port()
env["MASTER_PORT"] = str(open_port)
cmd = shlex.split(cmd_str)
result = subprocess.run(
cmd,
cwd=tmpdir,
env=env,
capture_output=True,
)
if result.returncode != 0:
raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}")
# NOTE this looks a lot like a magic value. But we also could do json.loads(config)['experiment_config']['experiment_name']
assert (result_dir / "default_experiment").exists(), "Could not find test experiment directory."
18 changes: 17 additions & 1 deletion sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import Literal, Sequence
from typing import Dict, Literal, Sequence

import torch

Expand All @@ -24,7 +24,23 @@
"PrecisionTypes",
)


# NOTE(SKH) our precision types are a mess, but we inherit this problem from NeMo and Megatron.
PrecisionTypes = Literal["fp16", "bf16", "fp32", "bf16-mixed", "fp32-mixed", "16-mixed", "fp16-mixed", 16, 32]
precision_to_dtype: Dict[PrecisionTypes, torch.dtype] = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
"16-mixed": torch.float16,
"fp16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
"fp32-mixed": torch.float32,
16: torch.float16,
32: torch.float32,
}

# NOTE(SKH) these do not have a perfect 1-1 relationship, but we can use this to serialize/deserialize dtypes in ModelConfigs since its ultimately converted with precision_to_dtype.
dtype_to_precision: Dict[torch.dtype, PrecisionTypes] = {v: k for k, v in precision_to_dtype.items()}


def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype:
Expand Down
4 changes: 4 additions & 0 deletions sub-packages/bionemo-esm2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ dependencies = [
[tool.setuptools.package-data]
"bionemo.esm2" = ["data/tokenizer/*.json", "data/tokenizer/*.txt"]

[project.scripts]
bionemo-esm2-train= "bionemo.esm2.run.main:main"
bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main"

[tool.setuptools.packages.find]
where = ["src"]
include = ["bionemo.*"]
Expand Down
5 changes: 5 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def __init__(
rampup_batch_size=rampup_batch_size,
)

@property
def tokenizer(self) -> tokenizer.BioNeMoESMTokenizer:
"""Returns the tokenizer."""
return self._tokenizer

def setup(self, stage: str = "") -> None:
"""Setup the ESMDataModule.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset
from bionemo.llm.lightning import batch_collator
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


Expand Down Expand Up @@ -57,7 +58,7 @@ def infer_model(
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
module = biobert_lightning_module(config=config, tokenizer=tokenizer)
results = trainer.predict(module, datamodule=data_module)
results = batch_collator(trainer.predict(module, datamodule=data_module))

return results

Expand Down
1 change: 1 addition & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
return_only_hidden_states: bool = False # return logits

def __post_init__(self):
# TODO, as a validator?
"""Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
super().__post_init__()
if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
Expand Down
14 changes: 14 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/run/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 0ddf406

Please sign in to comment.