diff --git a/README.md b/README.md index ae4d0808c2..266387f1f1 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,19 @@ git tag MY-VERSION-TAG uv build /sub-packages/bionemo-core TWINE_PASSWORD="" TWINE_USERNAME="" uvx twine upload /sub-packages/bionemo-core/dist/* ``` +## Pydantic Configuration +BioNeMo 2 provides two entrypoints for models with both argparse and pydantic. Both documented in the `Models` section below. +Pydantic based configuration is designed to accept a configuration json file as input, along with context specific arguments (e.g., should we resume from existing checkpoints?). These JSON configs go through a Pydantic Validator, in this case referred to as `MainConfig`. This Config is composed of several other Pydantic models, see the class definition for details. To pre-populate a config with reasonable defaults for various standard models, we provide 'recipes.' These are simple methods that instantiate the config object and then serialize it to a JSON configuration file. From this file, you may either submit it directly, or modify the various parameters to meet your usecase. For example, Weights and biases, devices, precision, and dataset options are all extremely useful to modify. Then, you would submit this config for training. + +These two workflows are packaged as executables when esm2 or geneformer are installed with pip. These commands will appear as: + +```bash +bionemo-geneformer-recipe +bionemo-esm2-recipe +bionemo-geneformer-train +bionemo-esm2-train +``` ## Models ### ESM-2 @@ -198,6 +210,62 @@ python \ --restore-from-checkpoint-path ${ESM2_650M_CKPT} ``` +##### Running with Pydantic configs + +Alternatively, we provide a validated and serialized configuration file entrypoint for executing the same workflow. Recipes +are available for 8m, 650m, and 3b ESM2 models. You may select which preset config to use by setting the `--recipe` parameter. + +```bash +# The fastest transformer engine environment variables in testing were the following two +TEST_DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source $MY_DATA_SOURCE); \ +bionemo-esm2-recipe \ +--train-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/train_clusters_sanity.parquet \ +--train-database-path ${TEST_DATA_DIR}/2024_03_sanity/train_sanity.db \ +--valid-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/valid_clusters.parquet \ +--valid-database-path ${TEST_DATA_DIR}/2024_03_sanity/validation.db \ +--result-dir ./results \ +--dest my_config.json \ +--recipe 8m +``` + +> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.json as you see fit + +> NOTE: To pretrain from an existing checkpoint, simply pass in the path --initial-ckpt-path to the recipe command. This will populate the JSON with the correct field to ensure pretraining is initialized from an existing checkpoint. + +To submit a training job with the passed config, first update the json file with any additional execution parameters +of your choosing: number of devices, workers, steps, etc. Second, invoke our training entrypoint. To do this, we need +three things: + +- Configuration file, the JSON produced by the previous step +- Model config type, in this case the pretraining config. This will validate the arguments in the config JSON against + those required for pretraining. Alternatively, things like fine-tuning with custom task heads may be specified here. + This allows for mixing/matching Data Modules with various tasks. +- Data Config type, this specifies how to parse, validate, and prepare the DataModule. This may change depending on task, +for example, pretraining ESM2 uses a protein cluster oriented sampling method. In the case of inference or fine-tuning +a pretrained model, a simple fasta file may be sufficient. There is a one-to-one relationship between DataConfig types +and DataModule types. + +> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details. + +``` +export NVTE_FUSED_ATTN=1 +export NVTE_FLASH_ATTN=0 + +bionemo-esm2-train \ +--data-config-t bionemo.esm2.run.config_models.ESM2DataConfig \ +--model-config-t bionemo.esm2.run.config_models.ExposedESM2PretrainConfig \ +--config my_config.json +``` + +> NOTE: both data-config-t and model-config-t have default values corresponding to ESM2DataConfig and ExposedESM2PretrainingConfig + +DataConfigT and ModelConfigT can also refer to locally defined types by the user. As long as python knows how to import +the specified path, they may be configured. For example, you may have a custom Dataset/DataModule that you would like to +mix with an existing recipe. In this case, you define a DataConfig object with the generic specified as your DataModule +type, and then pass in the config type to the training recipe. + + + ### Geneformer #### Running @@ -221,7 +289,7 @@ train_geneformer \ --micro-batch-size 2 ``` -To fine-tune, you just need to specify a different combination of model and loss. Pass the path to the outputted config file from the previous step as the `--restore-from-checkpoint-path`, and also change +To fine-tune, you to specify a different combination of model and loss. Pass the path to the outputted config file from the previous step as the `--restore-from-checkpoint-path`, and also change `--training-model-config-class` to the newly created model-config-class. While no CLI option currently exists to hot swap in different data modules and processing functions _now_, you could @@ -247,6 +315,54 @@ train_geneformer \ --restore-from-checkpoint-path results/test_experiment/dev/checkpoints/test_experiment--val_loss=4.3506-epoch=1-last ``` +##### Running with Pydantic configs +Alternatively, we provide a validated and serialized configuration file entrypoint for executing the same workflow. Recipes +are available for 10m, and 106m geneformer models. Additionally we provide an example recipe of finetuning, where the objective +is to 'regress' on token IDs rather than the traditional masked language model approach. In practice, you will likely +need to implement your own DataModule, DataConfig, and Finetuning model. You can use the same overall approach, but with +customizations for your task. + + +```bash +TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \ +bionemo-geneformer-recipe \ + --recipe 10m-pretrain \ + --dest my_config.json \ + --data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \ + --result-dir ./results +``` +> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.json as you see fit + +> NOTE: To pretrain from an existing checkpoint, simply pass in the path --initial-ckpt-path to the recipe command. This will populate the JSON with the correct field to ensure pretraining is initialized from an existing checkpoint. + +To submit a training job with the passed config, first update the json file with any additional execution parameters +of your choosing: number of devices, workers, steps, etc. Second, invoke our training entrypoint. To do this, we need +three things: + +- Configuration file, the JSON produced by the previous step +- Model config type, in this case the pretraining config. This will validate the arguments in the config JSON against + those required for pretraining. Alternatively, things like fine-tuning with custom task heads may be specified here. + This allows for mixing/matching Data Modules with various tasks. +- Data Config type, this specifies how to parse, validate, and prepare the DataModule. This may change depending on task, +for example, while fine-tuning you may want to use a custom Dataset/DataModule that includes PERTURB-seq. In this case, +the default pretraining DataConfig and DataModule will be insufficient. See ESM2 for additional example usecases. + +> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details. + +```bash +bionemo-geneformer-train \ +--data-config-t bionemo.geneformer.run.config_models.GeneformerPretrainingDataConfig \ +--model-config-t bionemo.geneformer.run.config_models.ExposedGeneformerPretrainConfig \ +--config my_config.json +``` + +> NOTE: both data-config-t and model-config-t have default values corresponding to GeneformerPretrainingDataConfig and ExposedGeneformerPretrainConfig + +DataConfigT and ModelConfigT can also refer to locally defined types by the user. As long as python knows how to import +the specified path, they may be configured. For example, you may have a custom Dataset/DataModule that you would like to +mix with an existing recipe. In this case, you define a DataConfig object with the generic specified as your DataModule +type, and then pass in the config type to the training recipe. + ## Updating License Header on Python Files diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py index 8fa25ffa1e..25741e7bd2 100644 --- a/scripts/protein/esm2/esm2_pretrain.py +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -30,12 +30,12 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.dataset import RandomMaskStrategy from bionemo.esm2.data.tokenizer import get_tokenizer -from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler from bionemo.llm.lightning import PerplexityLoggingCallback from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BiobertSpecOption +from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size -from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger +from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger __all__: Sequence[str] = ("main", "parser") @@ -147,10 +147,10 @@ def main( # for wandb integration # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html" - wandb_options: Optional[WandbLoggerOptions] = ( + wandb_config: Optional[WandbConfig] = ( None if wandb_project is None - else WandbLoggerOptions( + else WandbConfig( offline=wandb_offline, project=wandb_project, entity=wandb_entity, @@ -203,8 +203,8 @@ def main( max_seq_length=max_seq_length, num_workers=num_dataset_workers, random_mask_strategy=random_mask_strategy, + tokenizer=tokenizer, ) - # Configure the model esm2_config = ESM2Config( seq_length=max_seq_length, @@ -254,7 +254,7 @@ def main( root_dir=result_dir, name=experiment_name, initialize_tensorboard_logger=create_tensorboard_logger, - wandb_kwargs=wandb_options, + wandb_config=wandb_config, ckpt_callback=checkpoint_callback, ) diff --git a/scripts/protein/esm2/test_esm2_pretrain.py b/scripts/protein/esm2/test_esm2_pretrain.py index 8e4303b32b..c2253d46da 100644 --- a/scripts/protein/esm2/test_esm2_pretrain.py +++ b/scripts/protein/esm2/test_esm2_pretrain.py @@ -144,6 +144,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra def test_val_dataloader_in_main_runs_with_limit_val_batches( monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, limit_val_batches ): + # TODO: pydantic. """Ensures doesn't run out of validation samples whenever updating limit_val_batches logic. Args: diff --git a/scripts/protein/esm2/test_pydantic_train.py b/scripts/protein/esm2/test_pydantic_train.py new file mode 100644 index 0000000000..2522e538f8 --- /dev/null +++ b/scripts/protein/esm2/test_pydantic_train.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shlex +import subprocess +from pathlib import Path + +import pytest +from lightning.fabric.plugins.environments.lightning import find_free_network_port + +from bionemo.testing.data.esm2 import create_mock_parquet_train_val_inputs, create_mock_protein_dataset +from bionemo.testing.data.load import load + + +data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" + + +def test_bionemo2_rootdir(): + data_error_str = ( + "Please download test data with:\n" + "`python scripts/download_artifacts.py --models all --model_dir ./models --data all --data_dir ./ --verbose --source pbss`" + ) + assert data_path.exists(), f"Could not find test data directory.\n{data_error_str}" + assert data_path.is_dir(), f"Test data directory is supposed to be a directory.\n{data_error_str}" + + +@pytest.fixture +def dummy_protein_dataset(tmp_path): + """Create a mock protein dataset.""" + db_file = create_mock_protein_dataset(tmp_path) + return db_file + + +@pytest.fixture +def dummy_parquet_train_val_inputs(tmp_path): + """Create a mock protein train and val cluster parquet.""" + train_cluster_path, valid_cluster_path = create_mock_parquet_train_val_inputs(tmp_path) + return train_cluster_path, valid_cluster_path + + +def test_pretrain_pydantic_cli(dummy_protein_dataset, dummy_parquet_train_val_inputs, tmpdir): + result_dir = tmpdir.mkdir("results") + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + open_port = find_free_network_port() + config = f"{result_dir}/test_config.json" + + # Invoke with blocking + cmd_str = f"""bionemo-esm2-recipe --dest {config} --recipe test + --train-database-path {dummy_protein_dataset} + --train-cluster-path {train_cluster_path} + --valid-database-path {dummy_protein_dataset} + --valid-cluster-path {valid_cluster_path} + --result-dir {result_dir}""".strip() + + # continue when finished + env = dict(**os.environ) # a local copy of the environment + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + # Now do pretrain + if result.returncode != 0: + raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + + cmd_str = f"""bionemo-esm2-train --conf {config}""".strip() + env = dict(**os.environ) # a local copy of the environment + open_port = find_free_network_port() + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + if result.returncode != 0: + raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + # NOTE this looks a lot like a magic value. But we also could do json.loads(config)['experiment_config']['experiment_name'] + assert (result_dir / "default_experiment").exists(), "Could not find test experiment directory." diff --git a/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py b/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py index 9520e62f3b..31f213633a 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py +++ b/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py @@ -14,7 +14,7 @@ # limitations under the License. -from typing import Literal, Sequence +from typing import Dict, Literal, Sequence import torch @@ -24,7 +24,23 @@ "PrecisionTypes", ) + +# NOTE(SKH) our precision types are a mess, but we inherit this problem from NeMo and Megatron. PrecisionTypes = Literal["fp16", "bf16", "fp32", "bf16-mixed", "fp32-mixed", "16-mixed", "fp16-mixed", 16, 32] +precision_to_dtype: Dict[PrecisionTypes, torch.dtype] = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + "16-mixed": torch.float16, + "fp16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + "fp32-mixed": torch.float32, + 16: torch.float16, + 32: torch.float32, +} + +# NOTE(SKH) these do not have a perfect 1-1 relationship, but we can use this to serialize/deserialize dtypes in ModelConfigs since its ultimately converted with precision_to_dtype. +dtype_to_precision: Dict[torch.dtype, PrecisionTypes] = {v: k for k, v in precision_to_dtype.items()} def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype: diff --git a/sub-packages/bionemo-esm2/pyproject.toml b/sub-packages/bionemo-esm2/pyproject.toml index 9fb3c782ed..119e7b36cc 100644 --- a/sub-packages/bionemo-esm2/pyproject.toml +++ b/sub-packages/bionemo-esm2/pyproject.toml @@ -21,6 +21,10 @@ dependencies = [ [tool.setuptools.package-data] "bionemo.esm2" = ["data/tokenizer/*.json", "data/tokenizer/*.txt"] +[project.scripts] +bionemo-esm2-train= "bionemo.esm2.run.main:main" +bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main" + [tool.setuptools.packages.find] where = ["src"] include = ["bionemo.*"] diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py index 9ee3206710..e065702d3d 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py @@ -108,6 +108,11 @@ def __init__( rampup_batch_size=rampup_batch_size, ) + @property + def tokenizer(self) -> tokenizer.BioNeMoESMTokenizer: + """Returns the tokenizer.""" + return self._tokenizer + def setup(self, stage: str = "") -> None: """Setup the ESMDataModule. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py index 8baea604a4..40fd6023b1 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py @@ -24,6 +24,7 @@ from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset +from bionemo.llm.lightning import batch_collator from bionemo.llm.model.biobert.lightning import biobert_lightning_module @@ -57,7 +58,7 @@ def infer_model( plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) module = biobert_lightning_module(config=config, tokenizer=tokenizer) - results = trainer.predict(module, datamodule=data_module) + results = batch_collator(trainer.predict(module, datamodule=data_module)) return results diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index b5da3366e2..b5b53b036c 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -323,6 +323,7 @@ class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]): return_only_hidden_states: bool = False # return logits def __post_init__(self): + # TODO, as a validator? """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.""" super().__post_init__() if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec: diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/__init__.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py new file mode 100644 index 0000000000..7437e75f3e --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +from pathlib import Path +from typing import Optional, Type + +import torch +from nemo.utils import logging +from pydantic import field_serializer, field_validator, model_validator + +from bionemo.esm2.data.datamodule import ESMDataModule +from bionemo.esm2.data.dataset import RandomMaskStrategy +from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.attention import ESM2DotProductAttention, ESM2TEDotProductAttention +from bionemo.esm2.model.model import ESM2Config +from bionemo.llm.model.biobert.model import BiobertSpecOption +from bionemo.llm.run.config_models import ( + DataConfig, + ExposedModelConfig, + MainConfig, +) + + +class ESM2DataConfig(DataConfig[ESMDataModule]): + """ESM2DataConfig is a configuration class for setting up the pre-training data module for ESM2. + + The ESM2DataModule implements the cluster oriented sampling method defined in the ESM2 publication. + + Attributes: + train_cluster_path (Path): Path to the training cluster data. + train_database_path (Path): Path to the training database. + valid_cluster_path (Path): Path to the validation cluster data. + valid_database_path (Path): Path to the validation database. + micro_batch_size (int): Size of the micro-batch. Default is 8. + result_dir (str): Directory to store results. Default is "./results". + min_seq_length (int): Minimum sequence length. Default is 128. + max_seq_length (int): Maximum sequence length. Default is 128. + random_mask_strategy (RandomMaskStrategy): Strategy for random masking. Default is RandomMaskStrategy.ALL_TOKENS. + num_dataset_workers (int): Number of workers for the dataset. Default is 0. + + Methods: + construct_data_module(global_batch_size: int) -> ESMDataModule: + Constructs and returns an ESMDataModule instance with the provided global batch size. + """ + + train_cluster_path: Path + train_database_path: Path + valid_cluster_path: Path + valid_database_path: Path + + micro_batch_size: int = 8 + result_dir: str = "./results" + min_seq_length: int = 128 + max_seq_length: int = 128 + random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS + num_dataset_workers: int = 0 + + def construct_data_module(self, global_batch_size: int) -> ESMDataModule: + """Constructs and returns an ESMDataModule instance with the provided global batch size. + + This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be + aquired here. For example, tokenizers, preprocessing, may want to live in this method. + + Args: + global_batch_size (int): Global batch size for the data module. Global batch size must be a function of + parallelism settings and the `micro_batch_size` attribute. Since the DataConfig has no ownership over + parallelism configuration, we expect someone higher up on the ownership chain to provide the value to + this method. + + """ + tokenizer = get_tokenizer() + data = ESMDataModule( + train_cluster_path=self.train_cluster_path, + train_database_path=self.train_database_path, + valid_cluster_path=self.valid_cluster_path, + valid_database_path=self.valid_database_path, + global_batch_size=global_batch_size, + micro_batch_size=self.micro_batch_size, + min_seq_length=self.min_seq_length, + max_seq_length=self.max_seq_length, + num_workers=self.num_dataset_workers, + random_mask_strategy=self.random_mask_strategy, + tokenizer=tokenizer, + ) + return data + + +class ExposedESM2PretrainConfig(ExposedModelConfig[ESM2Config]): + """Configuration class for ESM2 pretraining with select exposed parameters. + + See the inherited ExposedModelConfig for attributes and methods from the base class. Use this class either + as a template or extension for custom configurations. Importantly, these kinds of classes should do two things, + select attributes to expose to the user, and provide validation and serialization any attributes. + + Attributes: + use_esm_attention (bool): Flag to skip ESM2 custom attention for TE acceleration. Defaults to False. + token_dropout (bool): Flag to enable token dropout. Defaults to True. + normalize_attention_scores (bool): Flag to normalize attention scores. Defaults to False. + variable_seq_lengths (bool): Flag to enable variable sequence lengths. Defaults to False. + core_attention_override (Optional[Type[torch.nn.Module]]): Optional override for core attention module. Defaults to None. + + Methods: + restrict_biobert_spec_to_esm2(cls, biobert_spec_option: BiobertSpecOption) -> BiobertSpecOption: + Validates the BiobertSpecOption to ensure it is compatible with ESM2. + serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]: + Serializes the core attention override module to a string. + validate_core_attention_override(cls, value): + Validates the core attention override module, ensuring it is a subclass of torch.nn.Module. + validate_and_set_attention_and_scaling(self): + Validates and sets the attention and scaling parameters based on the biobert_spec_option. + model_validator(self, global_cfg: MainConfig) -> MainConfig: + Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings. + model_class(self) -> Type[ESM2Config]: + Returns the model class associated with this configuration. + """ + + use_esm_attention: bool = False # Skip ESM2 custom attention for TE acceleration. Still passes golden value test. + token_dropout: bool = True + normalize_attention_scores: bool = False + variable_seq_lengths: bool = False + core_attention_override: Type[torch.nn.Module] | None = None + + @field_serializer("core_attention_override") + def serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]: + """Serializes the core attention override module to a string.""" + if value is None: + return None + return f"{value.__module__}.{value.__name__}" + + @field_validator("core_attention_override", mode="before") + def validate_core_attention_override(cls, value): + """Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.""" + if value is None: + return None + if isinstance(value, str): + module_name, class_name = value.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + if not issubclass(cls, torch.nn.Module): + raise ValueError(f"{cls} is not a subclass of torch.nn.Module") + return cls + except (ImportError, AttributeError): + raise ValueError(f"Cannot import {value}") + return value + + @model_validator(mode="after") + def validate_and_set_attention_and_scaling(self): + """Validates and sets the attention and scaling parameters based on the biobert_spec_option.""" + logging.info( + "Mutating apply_query_key_layer_scaling and core_attention_override based on biobert_spec_option.." + ) + if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec: + self.apply_query_key_layer_scaling = False + self.core_attention_override = ESM2TEDotProductAttention + elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec: + logging.warning( + "BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. " + "Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead." + ) + self.apply_query_key_layer_scaling = True + self.core_attention_override = ESM2DotProductAttention + return self + + def model_validator(self, global_cfg: MainConfig) -> MainConfig: + """Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings. + + The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig. + Additionally, it provides validation for sequence length and parallelism settings. + + Args: + global_cfg (MainConfig): The global configuration object. + """ + global_cfg = super().model_validator(global_cfg) + # Need to ensure that at the least we have access to min_seq_length and max_seq_length + if not isinstance(global_cfg.data_config, ESM2DataConfig): + raise TypeError(f"ESM2PretrainConfig requires ESM2DataConfig, got {global_cfg.data_config=}") + + pipeline_model_parallel_size, tensor_model_parallel_size = ( + global_cfg.parallel_config.pipeline_model_parallel_size, + global_cfg.parallel_config.tensor_model_parallel_size, + ) + min_seq_length, max_seq_length = global_cfg.data_config.min_seq_length, global_cfg.data_config.max_seq_length + assert ( + self.variable_seq_lengths + == (pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length) + ), "Must set variable_seq_lengths to True when min_seq_length != max_seq_length under pipeline or tensor parallelism." + return global_cfg + + def model_class(self) -> Type[ESM2Config]: + """Returns the model class associated with this configuration.""" + return ESM2Config diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py new file mode 100644 index 0000000000..5820e48438 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +from typing import Optional + +from bionemo.esm2.run.config_models import ESM2DataConfig, ExposedESM2PretrainConfig +from bionemo.llm.run.config_models import MainConfig +from bionemo.llm.train import NsysConfig, train + + +def main(): # noqa: D103 + def parse_args(): + parser = argparse.ArgumentParser(description="Run ESM2 pretraining") + parser.add_argument("--config", type=str, required=True, help="Path to the JSON configuration file") + parser.add_argument( + "--model-config-t", + default=ExposedESM2PretrainConfig, + required=False, + help="fully resolvable python import path to the ModelConfig object. Builtin options are ExposedESM2PretrainConfig.", + ) + parser.add_argument( + "--data-config-t", + default=ESM2DataConfig, + required=False, + help="fully resolvable python import path to the ModelConfig object.", + ) + parser.add_argument( + "--resume-if-exists", + default=False, + action="store_true", + help="Resume training if a checkpoint exists that matches the current experiment configuration.", + ) + + # Debug options. + parser.add_argument( + "--nsys-profiling", + action="store_true", + default=False, + help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: " + " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", + ) + # start, end, rank + parser.add_argument( + "--nsys-start-step", + type=int, + required=False, + default=0, + help="Start nsys profiling after this step.", + ) + parser.add_argument( + "--nsys-end-step", + type=int, + required=False, + help="End nsys profiling after this step.", + ) + # rank as list of integers + parser.add_argument( + "--nsys-ranks", + type=int, + nargs="+", + required=False, + default=[0], + help="Enable nsys profiling for these ranks.", + ) + return parser.parse_args() + + def string_to_class(path: str): + import importlib + + module_path, class_name = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def load_config(config_path: str, model_config_t: Optional[str], data_config_t: Optional[str]) -> MainConfig: + with open(config_path, "r") as f: + config_dict = json.load(f) + + # model/data_config_t is used to select the parser dynamically. + if model_config_t is None or model_config_t == "ExposedESM2PretrainConfig": + model_config_t = ExposedESM2PretrainConfig + elif model_config_t == "ExposedFineTuneSeqModel": + # Hardcoded path for those who do not know the full path + # model_config_t = ExposedFineTuneSeqLenBioBertConfig + raise NotImplementedError() + elif model_config_t == "ExposedFineTuneTokenModel": + raise NotImplementedError() + elif isinstance(model_config_t, str): + # We assume we get a string to some importable config... e.g. in the sub-package jensen, 'bionemo.jensen.configs.MyConfig' + model_config_t = string_to_class(model_config_t) + + if data_config_t is None: + data_config_t = ESM2DataConfig + elif isinstance(data_config_t, str): + data_config_t = string_to_class(data_config_t) + + return MainConfig[model_config_t, data_config_t](**config_dict) + + args = parse_args() + config = load_config(args.config, args.model_config_t, args.data_config_t) + + if args.nsys_profiling: + nsys_config = NsysConfig( + start_step=args.nsys_start_step, + end_step=args.nsys_end_step, + ranks=args.nsys_ranks, + ) + else: + nsys_config = None + + train( + bionemo_exposed_model_config=config.bionemo_model_config, + data_config=config.data_config, + parallel_config=config.parallel_config, + training_config=config.training_config, + optim_config=config.optim_config, + experiment_config=config.experiment_config, + wandb_config=config.wandb_config, + nsys_config=nsys_config, + resume_if_exists=args.resume_if_exists, + ) + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py new file mode 100644 index 0000000000..9473cc69ce --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from pathlib import Path +from typing import Optional + +from nemo.utils import logging + +from bionemo.core.utils.dtypes import PrecisionTypes +from bionemo.esm2.run.config_models import ESM2DataConfig, ExposedESM2PretrainConfig +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.run.config_models import ( + ExperimentConfig, + MainConfig, + OptimizerSchedulerConfig, + ParallelConfig, + TrainingConfig, +) +from bionemo.llm.utils.logger_utils import WandbConfig + + +def esm2_base_training_config() -> TrainingConfig: + """Base training config for ESM2.""" + return TrainingConfig( + max_steps=500000, + limit_val_batches=1.0, + val_check_interval=10_000, + precision="bf16-mixed", + include_perplexity=True, + ) + + +def esm2_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig: + """Base optimizer scheduler config for ESM2.""" + return OptimizerSchedulerConfig( + optimizer="adam", lr=4e-4, interval="step", monitor="val_loss", lr_scheduler="warmup_anneal", warmup_steps=2000 + ) + + +def esm2_base_parallel_config() -> ParallelConfig: + """Base parallel config for ESM2.""" + return ParallelConfig( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + accumulate_grad_batches=1, + ddp="megatron", + num_devices=1, + num_nodes=1, + ) + + +def esm2_base_data_config(args) -> ESM2DataConfig: + """Base data config for ESM2.""" + data_config = ESM2DataConfig( + min_seq_length=1024, + max_seq_length=1024, + micro_batch_size=1, + num_dataset_workers=8, + train_cluster_path=args.train_cluster_path, + train_database_path=args.train_database_path, + valid_cluster_path=args.valid_cluster_path, + valid_database_path=args.valid_database_path, + ) + return data_config + + +def esm2_8m_wandb_config() -> WandbConfig: + """Wandb config for ESM2 8m.""" + wandb_config = WandbConfig( + entity="esm2-8m_pretraining", + project="esm2-8m_pretraining", + group="esm2-8m", + tags=["esm2", "pretraining"], + offline=True, + anonymous=True, + id="1", + log_model=False, + ) + return wandb_config + + +def esm2_8m_experiment_config(result_dir) -> ExperimentConfig: + """Experiment config for ESM2 8m.""" + return ExperimentConfig( + save_every_n_steps=50, # default set in previous script. + result_dir=result_dir, + experiment_name="esm2-8m-pretraining", + restore_from_checkpoint_path=None, + ) + + +def esm2_8m_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig: + """Model config for ESM2 8m.""" + return ExposedESM2PretrainConfig( + num_layers=6, + hidden_size=320, + ffn_hidden_size=320 * 4, + num_attention_heads=20, + seq_length=1024, + biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, + initial_ckpt_path=initial_ckpt_path, + get_attention_mask_from_fusion=True, + params_dtype="bf16-mixed", + pipeline_dtype="bf16-mixed", + autocast_dtype="bf16-mixed", + ) + + +def esm2_8m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]: + """Recipe for ESM2 8m.""" + return MainConfig( + data_config=esm2_base_data_config(args), + parallel_config=esm2_base_parallel_config(), + training_config=esm2_base_training_config(), # no changes for 8m + bionemo_model_config=esm2_8m_model_config(args.initial_ckpt_path), + optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + experiment_config=esm2_8m_experiment_config(args.result_dir), + wandb_config=esm2_8m_wandb_config(), + ) + + +def esm2_650m_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig: + """Model config for ESM2 650m.""" + return ExposedESM2PretrainConfig( + num_layers=33, + hidden_size=1280, + ffn_hidden_size=1280 * 4, + seq_length=1024, + num_attention_heads=20, + biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, + initial_ckpt_path=initial_ckpt_path, + get_attention_mask_from_fusion=True, + params_dtype="bf16-mixed", + pipeline_dtype="bf16-mixed", + autocast_dtype="bf16-mixed", + ) + + +def esm2_650m_wandb_config() -> WandbConfig: + """Wandb config for ESM2 650m.""" + return WandbConfig( + entity="esm2-650m_pretraining", + project="esm2-650m_pretraining", + group="esm2-650m", + tags=["esm2", "pretraining"], + offline=True, + anonymous=True, + id="1", + log_model=False, + ) + + +def esm2_650m_experiment_config(result_dir) -> ExperimentConfig: + """Experiment config for ESM2 650m.""" + return ExperimentConfig( + save_every_n_steps=50, + result_dir=result_dir, + experiment_name="esm2-650m-pretraining", + # TODO should this be exposed? + restore_from_checkpoint_path=None, + ) + + +def esm2_650m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]: + """Recipe for ESM2 650m.""" + return MainConfig( + data_config=esm2_base_data_config(args), + parallel_config=esm2_base_parallel_config(), + training_config=esm2_base_training_config(), # no changes for 8m + bionemo_model_config=esm2_650m_model_config(args.initial_ckpt_path), + optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + experiment_config=esm2_650m_experiment_config(args.result_dir), + wandb_config=esm2_650m_wandb_config(), + ) + + +def esm2_3b_parallel_config() -> ParallelConfig: + """Parallel config for ESM2 3b.""" + return ParallelConfig( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=1, + # TODO: is this correct? + accumulate_grad_batches=1, + ddp="megatron", + # NOTE assumes 8xGPU node. Can always edit the config. + num_devices=8, + ) + + +def esm2_3b_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig: + """Model config for ESM2 3b.""" + return ExposedESM2PretrainConfig( + num_layers=36, + hidden_size=2560, + ffn_hidden_size=2560 * 4, + num_attention_heads=40, + seq_length=1024, + biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, + initial_ckpt_path=initial_ckpt_path, + get_attention_mask_from_fusion=True, + params_dtype="bf16-mixed", + pipeline_dtype="bf16-mixed", + autocast_dtype="bf16-mixed", + ) + + +def esm2_3b_wandb_config() -> WandbConfig: + """Wandb config for ESM2 3b.""" + return WandbConfig( + entity="esm2-3b_pretraining", + project="esm2-3b_pretraining", + group="esm2-3b", + tags=["esm2-650m"], + offline=True, + anonymous=True, + id="1", + log_model=False, + ) + + +def esm2_3b_experiment_config(result_dir) -> ExperimentConfig: + """Experiment config for ESM2 650m.""" + return ExperimentConfig( + save_every_n_steps=50, + result_dir=result_dir, + experiment_name="esm2-3b-pretraining", + # TODO should this be exposed? + restore_from_checkpoint_path=None, + ) + + +def esm2_3b_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]: + """Recipe for ESM2 3b.""" + return MainConfig( + data_config=esm2_base_data_config(args), + parallel_config=esm2_3b_parallel_config(), + training_config=esm2_base_training_config(), # no changes for 8m + bionemo_model_config=esm2_3b_model_config(args.initial_ckpt_path), + optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + experiment_config=esm2_3b_experiment_config(args.result_dir), + wandb_config=esm2_3b_wandb_config(), + ) + + +def simple_parallel_recipe( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + num_devices: int = 1, + accumulate_grad_batches: int = 1, +) -> ParallelConfig: + """Simple parallel recipe for ESM2.""" + assert ( + num_devices >= tensor_model_parallel_size * pipeline_model_parallel_size + ), "devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size" + return ParallelConfig( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + num_devices=num_devices, + accumulate_grad_batches=accumulate_grad_batches, + ) + + +def tiny_train_config_recipe() -> TrainingConfig: + """Tiny training config for ESM2.""" + return TrainingConfig(max_steps=10, limit_val_batches=2, val_check_interval=2) + + +def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig: + """Default optimizer scheduler config for ESM2.""" + return OptimizerSchedulerConfig() + + +def experiment_config_recipe(result_dir="./results") -> ExperimentConfig: + """Experiment config for ESM2.""" + return ExperimentConfig( + save_every_n_steps=100, + result_dir=result_dir, + experiment_name="default_experiment", + restore_from_checkpoint_path=None, + save_last_checkpoint=True, + metric_to_monitor_for_checkpoints="val_loss", + save_top_k=2, + create_tensorboard_logger=False, + ) + + +def esm2_tiny_model_config( + seq_length: int = 2048, + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, + variable_seq_lengths: bool = False, +) -> ExposedESM2PretrainConfig: + """Model config for ESM2 tiny, used for testing.""" + return ExposedESM2PretrainConfig( + seq_length=seq_length, + num_layers=2, + hidden_size=32, + num_attention_heads=2, + ffn_hidden_size=4 * 32, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + biobert_spec_option=biobert_spec_option, + get_attention_mask_from_fusion=True, + nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None, + # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities + initial_ckpt_path=str(initial_ckpt_path) if initial_ckpt_path is not None else None, + variable_seq_lengths=variable_seq_lengths, + ) + + +def esm2_tiny_test_recipe(args): + """Test recipe for ESM2 tiny, used for testing.""" + parallel_config = simple_parallel_recipe() + training_config = tiny_train_config_recipe() + + data_config = ESM2DataConfig( + min_seq_length=128, + max_seq_length=128, + micro_batch_size=2, + num_dataset_workers=1, + train_cluster_path=args.train_cluster_path, + train_database_path=args.train_database_path, + valid_cluster_path=args.valid_cluster_path, + valid_database_path=args.valid_database_path, + ) + bionemo_model_config = esm2_tiny_model_config( + seq_length=data_config.max_seq_length, initial_ckpt_path=args.initial_ckpt_path + ) + + optim_config = default_adam_optimizer_with_cosine_annealing_recipe() + experiment_config = experiment_config_recipe(args.result_dir) + wandb_config = WandbConfig( + project="bionemo2-demo", + entity="nvidia", + offline=True, + tags=[], + group="dev", + id="dev", + log_model=False, + anonymous=True, + ) + main_config = MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=bionemo_model_config, + optim_config=optim_config, + experiment_config=experiment_config, + wandb_config=wandb_config, + ) + return main_config + + +def main(): # noqa: D103 + def parse_args(): + parser = argparse.ArgumentParser(description="Create ESM2 configuration JSON.") + parser.add_argument( + "--recipe", + type=str, + choices=["test", "8m", "650m", "3b"], + required=True, + help="Use one of the preconfigured recipes to create a template config file.", + ) + + parser.add_argument( + "--dest", + type=str, + default="./esm2-recipe.json", + required=True, + help="Path to the JSON configuration file.", + ) + + parser.add_argument( + "--train-cluster-path", type=Path, required=True, help="Path to the training cluster file." + ) + parser.add_argument( + "--train-database-path", type=Path, required=True, help="Path to the training database file." + ) + parser.add_argument( + "--valid-cluster-path", type=Path, required=True, help="Path to the validation cluster file." + ) + parser.add_argument( + "--valid-database-path", type=Path, required=True, help="Path to the validation database file." + ) + + parser.add_argument("--result-dir", type=Path, required=True, default="results", help="Path to store results") + + # Extra argument. + parser.add_argument( + "--initial-ckpt-path", + type=str, + required=False, + default=None, + help="Path to an existing to a checkpoint directory to restore an existing checkpoint. Not compatible with all recipes.", + ) + + args = parser.parse_args() + return args + + # Simple example for creating a JSON from recipes. + args = parse_args() + + if args.recipe == "8m": + config = esm2_8m_recipe(args) + elif args.recipe == "650m": + config = esm2_650m_recipe(args) + elif args.recipe == "3b": + config = esm2_3b_recipe(args) + elif args.recipe == "test": + # Hardcoded test recipe. + config = esm2_tiny_test_recipe(args) + else: + raise ValueError(f"Invalid recipe choice. {args.recipe=}") + + # Serialize to JSON + json_str = config.model_dump_json(indent=2) + + # Save to file + with open( + args.dest, + "w", + ) as f: + f.write(json_str) + logging.info(f"Saved configuration to {args.dest=}") + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py index 41ba01fbba..18be7eccf3 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py @@ -28,8 +28,8 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.dataset import RandomMaskStrategy from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer -from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler from bionemo.testing.data.load import load from bionemo.testing.harnesses import stop_and_go from bionemo.testing.harnesses.mode import Mode diff --git a/sub-packages/bionemo-geneformer/pyproject.toml b/sub-packages/bionemo-geneformer/pyproject.toml index d6709cc836..5bde43b92b 100644 --- a/sub-packages/bionemo-geneformer/pyproject.toml +++ b/sub-packages/bionemo-geneformer/pyproject.toml @@ -20,6 +20,8 @@ dependencies = [ ] [project.scripts] +bionemo-geneformer-train= "bionemo.geneformer.run.main:main" +bionemo-geneformer-recipe= "bionemo.geneformer.run.recipes:main" sc_memmap = "bionemo.geneformer.scripts.sc_memmap:main_cli" infer_geneformer = "bionemo.geneformer.scripts.infer_geneformer:geneformer_infer_entrypoint" train_geneformer = "bionemo.geneformer.scripts.train_geneformer:entrypoint" diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/__init__.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/config_models.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/config_models.py new file mode 100644 index 0000000000..ff64d45f58 --- /dev/null +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/config_models.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from dataclasses import dataclass, field +from typing import List, Optional, Type + +from nemo.utils import logging +from tokenizers import Tokenizer + +from bionemo.geneformer.api import GeneformerConfig +from bionemo.geneformer.data.singlecell.datamodule import SingleCellDataModule +from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess +from bionemo.geneformer.model.finetune_token_regressor import FineTuneSeqLenBioBertConfig +from bionemo.llm.run.config_models import ( + DataConfig, + ExposedModelConfig, +) + + +@dataclass +class GeneformerDataArtifacts: + """Data artifacts produced by the geneformer preprocess.""" + + tokenizer: Tokenizer + median_dict: dict + + +class GeneformerPretrainingDataConfig(DataConfig[SingleCellDataModule]): + """Configuration class for Geneformer pretraining data. + + Expects train/test/val to be prior split by directory and processed by `sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py`. + + Attributes: + data_dir (str): Directory where the data is stored. + result_dir (str | pathlib.Path): Directory where the results will be stored. Defaults to "./results". + micro_batch_size (int): Size of the micro-batch. Defaults to 8. + seq_length (int): Sequence length for the data. Defaults to 2048. + num_dataset_workers (int): Number of workers for data loading. Defaults to 0. + + Properties: + train_data_path (str): Path to the training data. + val_data_path (str): Path to the validation data. + test_data_path (str): Path to the test data. + + Methods: + geneformer_preprocess() -> GeneformerDataArtifacts: + Preprocesses the data using a legacy preprocessor from BioNeMo 1 and returns the necessary artifacts. + construct_data_module(global_batch_size: int) -> SingleCellDataModule: + Constructs and returns a SingleCellDataModule using the preprocessed data artifacts. + """ + + # Shadow two attributes from the parent for visibility. + data_dir: str + result_dir: str | pathlib.Path = "./results" + micro_batch_size: int = 8 + + seq_length: int = 2048 + num_dataset_workers: int = 0 + + @property + def train_data_path(self) -> str: # noqa: D102 + return self.data_dir + "/train" + + @property + def val_data_path(self) -> str: # noqa: D102 + return self.data_dir + "/val" + + @property + def test_data_path(self) -> str: # noqa: D102 + return self.data_dir + "/test" + + def geneformer_preprocess(self) -> GeneformerDataArtifacts: + """Geneformer datamodule expects certain artifacts to be present in the data directory. + + This method uses a legacy 'preprocessor' from BioNeMo 1 to acquire the associated artifacts. + """ + preprocessor = GeneformerPreprocess( + download_directory=pathlib.Path(self.train_data_path), + medians_file_path=pathlib.Path(self.train_data_path + "/medians.json"), + tokenizer_vocab_path=pathlib.Path(self.train_data_path + "/geneformer.vocab"), + ) + result = preprocessor.preprocess() + if "tokenizer" in result and "median_dict" in result: + logging.info("*************** Preprocessing Finished ************") + return GeneformerDataArtifacts(tokenizer=result["tokenizer"], median_dict=result["median_dict"]) + else: + logging.error("Preprocessing failed.") + raise ValueError("Preprocessing failed to create tokenizer and/or median dictionary.") + + def construct_data_module(self, global_batch_size: int) -> SingleCellDataModule: + """Downloads the requisite data artifacts and instantiates the DataModule.""" + geneformer_data_artifacts: GeneformerDataArtifacts = self.geneformer_preprocess() + data = SingleCellDataModule( + seq_length=self.seq_length, + tokenizer=geneformer_data_artifacts.tokenizer, + train_dataset_path=self.train_data_path, + val_dataset_path=self.val_data_path, + test_dataset_path=self.test_data_path, + random_token_prob=0.02, + median_dict=geneformer_data_artifacts.median_dict, + micro_batch_size=self.micro_batch_size, + global_batch_size=global_batch_size, + persistent_workers=self.num_dataset_workers > 0, + pin_memory=False, + num_workers=self.num_dataset_workers, + ) + return data + + +class ExposedGeneformerPretrainConfig(ExposedModelConfig[GeneformerConfig]): + """Exposes custom parameters for pretraining and binds the class to GeneformerConfig. + + Attributes: + initial_ckpt_path (str): Path to a directory containing checkpoint files for initializing the model. This is only + initial_ckpt_skip_keys_with_these_prefixes (List[str]): Skip any layer that contains this key during restoration. Useful for finetuning, set the names of the task heads so checkpoint restoration does not errorniously try to restore these. + """ + + # Custom parameters for FineTuning + initial_ckpt_path: Optional[str] = None + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list) + + def model_class(self) -> Type[GeneformerConfig]: # noqa: D102 + return GeneformerConfig + + +class ExposedFineTuneSeqLenBioBertConfig(ExposedModelConfig[FineTuneSeqLenBioBertConfig]): + """Config for models that fine-tune a BioBERT model from a pre-trained checkpoint. + + Parameters: + initial_ckpt_path - path to a directory containing checkpoint files for initializing the model. This is only + required on the first execution of the model, any restored checkpoints should skip this step. + initial_ckpt_skip_keys_with_these_prefixes - skip any layer that contains this key during restoration. Useful + for ignoring extra additional layers used for finetuning. Layers with these keys are then randomly initialized. + """ + + # Custom parameters for FineTuning + initial_ckpt_path: Optional[str] = None + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"]) + + def model_class(self) -> Type[FineTuneSeqLenBioBertConfig]: + """Binds the class to FineTuneSeqLenBioBertConfig.""" + return FineTuneSeqLenBioBertConfig diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py new file mode 100644 index 0000000000..24f1682e18 --- /dev/null +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +from typing import Optional + +from bionemo.geneformer.run.config_models import ( + ExposedFineTuneSeqLenBioBertConfig, + ExposedGeneformerPretrainConfig, + GeneformerPretrainingDataConfig, +) +from bionemo.llm.run.config_models import MainConfig +from bionemo.llm.train import NsysConfig, train + + +def main(): # noqa: D103 + def parse_args(): + parser = argparse.ArgumentParser(description="Run Geneformer pretraining") + parser.add_argument("--config", type=str, required=True, help="Path to the JSON configuration file") + parser.add_argument( + "--model-config-t", + default=ExposedGeneformerPretrainConfig, + required=False, + help="fully resolvable python import path to the ModelConfig object. Builtin options are ExposedGeneformerPretrainConfig and ExposedFineTuneSeqLenBioBertConfig.", + ) + parser.add_argument( + "--data-config-t", + default=GeneformerPretrainingDataConfig, + required=False, + help="fully resolvable python import path to the ModelConfig object.", + ) + parser.add_argument( + "--resume-if-exists", + default=False, + action="store_true", + help="Resume training if a checkpoint exists that matches the current experiment configuration.", + ) + + # Debug options. + parser.add_argument( + "--nsys-profiling", + action="store_true", + default=False, + help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: " + " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", + ) + # start, end, rank + parser.add_argument( + "--nsys-start-step", + type=int, + required=False, + default=0, + help="Start nsys profiling after this step.", + ) + parser.add_argument( + "--nsys-end-step", + type=int, + required=False, + help="End nsys profiling after this step.", + ) + # rank as list of integers + parser.add_argument( + "--nsys-ranks", + type=int, + nargs="+", + required=False, + default=[0], + help="Enable nsys profiling for these ranks.", + ) + + return parser.parse_args() + + def string_to_class(path: str): + import importlib + + module_path, class_name = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def load_config(config_path: str, model_config_t: Optional[str], data_config_t: Optional[str]) -> MainConfig: + with open(config_path, "r") as f: + config_dict = json.load(f) + + # model/data_config_t is used to select the parser dynamically. + if model_config_t is None or model_config_t == "ExposedGeneformerPretrainConfig": + model_config_t = ExposedGeneformerPretrainConfig + elif model_config_t == "ExposedFineTuneSeqLenBioBertConfig": + # Hardcoded path for those who do not know the full path + model_config_t = ExposedFineTuneSeqLenBioBertConfig + elif isinstance(model_config_t, str): + # We assume we get a string to some importable config... e.g. in the sub-package jensen, 'bionemo.jensen.configs.MyConfig' + model_config_t = string_to_class(model_config_t) + + if data_config_t is None: + data_config_t = GeneformerPretrainingDataConfig + elif isinstance(data_config_t, str): + data_config_t = string_to_class(data_config_t) + return MainConfig[model_config_t, data_config_t](**config_dict) + + args = parse_args() + config = load_config(args.config, args.model_config_t, args.data_config_t) + + if args.nsys_profiling: + nsys_config = NsysConfig( + start_step=args.nsys_start_step, + end_step=args.nsys_end_step, + ranks=args.nsys_ranks, + ) + else: + nsys_config = None + + train( + bionemo_exposed_model_config=config.bionemo_model_config, + data_config=config.data_config, + parallel_config=config.parallel_config, + training_config=config.training_config, + optim_config=config.optim_config, + experiment_config=config.experiment_config, + wandb_config=config.wandb_config, + resume_if_exists=args.resume_if_exists, + nsys_config=nsys_config, + ) + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/recipes.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/recipes.py new file mode 100644 index 0000000000..2cbc1e3c1b --- /dev/null +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/recipes.py @@ -0,0 +1,609 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from functools import partial +from typing import List, Optional + +from nemo.utils import logging + +from bionemo.core.utils.dtypes import PrecisionTypes +from bionemo.geneformer.run.config_models import ( + ExposedFineTuneSeqLenBioBertConfig, + ExposedGeneformerPretrainConfig, + GeneformerPretrainingDataConfig, +) +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.run.config_models import ( + ExperimentConfig, + MainConfig, + OptimizerSchedulerConfig, + ParallelConfig, + TrainingConfig, +) +from bionemo.llm.utils.logger_utils import WandbConfig + + +def geneformer_base_parallel_config() -> ParallelConfig: + """Base parallel config for Geneformer.""" + return ParallelConfig( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + accumulate_grad_batches=1, + ddp="megatron", + num_devices=1, + num_nodes=1, + ) + + +def geneformer_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig: + """Base optimizer scheduler config for Geneformer.""" + return OptimizerSchedulerConfig(lr=1e-3, lr_scheduler="cosine") # Matches bionemo1 + + +def geneformer_base_training_config() -> TrainingConfig: + """Base training config for Geneformer.""" + return TrainingConfig( + max_steps=400000, limit_val_batches=8, val_check_interval=100, precision="bf16-mixed" + ) # matches bionemo1 + + +def geneformer_data_recipe(data_dir) -> GeneformerPretrainingDataConfig: + """Recipe that produces the base geneformer small data configuration.""" + return GeneformerPretrainingDataConfig(data_dir=data_dir) + + +# 10m definitions +def geneformer_10m_model_config( + seq_length: int = 2048, + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec, +) -> ExposedGeneformerPretrainConfig: + """Geneformer 10m model config settings.""" + geneformer_config = ExposedGeneformerPretrainConfig( + num_layers=6, + hidden_size=256, + ffn_hidden_size=512, + num_attention_heads=4, + seq_length=seq_length, + fp32_residual_connection=False, + hidden_dropout=0.02, + init_method_std=0.02, + kv_channels=None, + apply_query_key_layer_scaling=False, + make_vocab_size_divisible_by=128, + masked_softmax_fusion=True, + fp16_lm_cross_entropy=False, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + gradient_accumulation_fusion=False, + layernorm_zero_centered_gamma=False, + layernorm_epsilon=1.0e-12, + activation_func="gelu", + qk_layernorm=False, + apply_residual_connection_post_layernorm=False, + bias_activation_fusion=True, + bias_dropout_fusion=True, + get_attention_mask_from_fusion=True, + attention_dropout=0.1, + share_embeddings_and_output_weights=True, + enable_autocast=False, + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=nemo1_init_path, + initial_ckpt_path=initial_ckpt_path, + ) + return geneformer_config + + +def geneformer_10m_experiment_config(result_dir) -> ExperimentConfig: + """Experiment config for Geneformer 10m.""" + return ExperimentConfig( + save_every_n_steps=100, + result_dir=result_dir, + experiment_name="geneformer-10m", + restore_from_checkpoint_path=None, + ) + + +def geneformer_10m_wandb_config() -> WandbConfig: + """Wandb config for Geneformer 10m.""" + wandb_config = WandbConfig( + entity="geneformer-10m_pretraining", + project="geneformer-10m_pretraining", + group="geneformer-10m", + tags=["geneformer-10m"], + offline=True, + anonymous=True, + id="1", + log_model=False, + ) + return wandb_config + + +# 106m definition, model, experiment, wandb, parallel +def geneformer_106m_parallel_config() -> ParallelConfig: + """Base parallel config for Geneformer.""" + return ParallelConfig( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + accumulate_grad_batches=1, + ddp="megatron", + num_devices=8, + num_nodes=1, + ) + + +def geneformer_106m_experiment_config(result_dir) -> ExperimentConfig: + """Experiment config for Geneformer 106m.""" + return ExperimentConfig( + save_every_n_steps=100, + result_dir=result_dir, + experiment_name="geneformer-106m", + restore_from_checkpoint_path=None, + ) + + +def geneformer_106m_wandb_config() -> WandbConfig: + """Wandb config for Geneformer 106m.""" + wandb_config = WandbConfig( + entity="geneformer-106m_pretraining", + project="geneformer-106m_pretraining", + group="geneformer-106m", + tags=["geneformer-106m"], + offline=True, + anonymous=True, + id="1", + log_model=False, + ) + return wandb_config + + +def geneformer_106m_model_config( + seq_length: int = 2048, + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec, +) -> ExposedGeneformerPretrainConfig: + """Geneformer 106m model config settings.""" + geneformer_config = ExposedGeneformerPretrainConfig( + num_layers=12, + hidden_size=768, + ffn_hidden_size=3072, + num_attention_heads=12, + seq_length=seq_length, + fp32_residual_connection=False, + hidden_dropout=0.02, + init_method_std=0.02, + kv_channels=None, + apply_query_key_layer_scaling=False, + make_vocab_size_divisible_by=128, + masked_softmax_fusion=True, + fp16_lm_cross_entropy=False, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + gradient_accumulation_fusion=False, + layernorm_zero_centered_gamma=False, + layernorm_epsilon=1.0e-12, + activation_func="gelu", + qk_layernorm=False, + apply_residual_connection_post_layernorm=False, + bias_activation_fusion=True, + bias_dropout_fusion=True, + get_attention_mask_from_fusion=True, + attention_dropout=0.1, + share_embeddings_and_output_weights=True, + enable_autocast=False, + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=nemo1_init_path, + initial_ckpt_path=initial_ckpt_path, + ) + return geneformer_config + + +def simple_parallel_recipe( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + num_devices: int = 1, + accumulate_grad_batches: int = 1, +) -> ParallelConfig: + """Simple parallel config for Geneformer, only used in testing.""" + assert ( + num_devices >= tensor_model_parallel_size * pipeline_model_parallel_size + ), "devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size" + return ParallelConfig( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + accumulate_grad_batches=accumulate_grad_batches, + num_devices=num_devices, + ) + + +def geneformer_finetuning_regression_head_recipe( + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + initial_ckpt_skip_keys_with_these_prefixes: Optional[List[str]] = None, +) -> ExposedFineTuneSeqLenBioBertConfig: + """Recipe for finetuning a regression head on the masked tokens.""" + partial_finetuning_config = partial( + ExposedFineTuneSeqLenBioBertConfig, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + nemo1_ckpt_path=nemo1_init_path, + initial_ckpt_path=initial_ckpt_path, + biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec, + ) + if initial_ckpt_skip_keys_with_these_prefixes: + finetuning_config = partial_finetuning_config( + initial_ckpt_skip_keys_with_these_prefixes=initial_ckpt_skip_keys_with_these_prefixes + ) + else: + # Use the sensible default when None is passed + finetuning_config = partial_finetuning_config() + return finetuning_config + + +def default_trainer_config_recipe() -> TrainingConfig: + """Default trainer config for Geneformer.""" + return TrainingConfig(max_steps=55000, limit_val_batches=2, val_check_interval=100) + + +def geneformer_10m_finetune_config( + seq_length: int = 2048, + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec, +) -> ExposedFineTuneSeqLenBioBertConfig: + """Geneformer 10m finetuning config settings.""" + geneformer_config = ExposedFineTuneSeqLenBioBertConfig( + num_layers=6, + hidden_size=256, + ffn_hidden_size=512, + num_attention_heads=4, + seq_length=seq_length, + fp32_residual_connection=False, + hidden_dropout=0.02, + init_method_std=0.02, + kv_channels=None, + apply_query_key_layer_scaling=False, + make_vocab_size_divisible_by=128, + masked_softmax_fusion=True, + fp16_lm_cross_entropy=False, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + gradient_accumulation_fusion=False, + layernorm_zero_centered_gamma=False, + layernorm_epsilon=1.0e-12, + activation_func="gelu", + qk_layernorm=False, + apply_residual_connection_post_layernorm=False, + bias_activation_fusion=True, + bias_dropout_fusion=True, + get_attention_mask_from_fusion=True, + attention_dropout=0.1, + share_embeddings_and_output_weights=True, + enable_autocast=False, + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=nemo1_init_path, + initial_ckpt_path=initial_ckpt_path, + ) + return geneformer_config + + +def geneformer_tiny_config( + seq_length: int = 2048, + precision: PrecisionTypes = "bf16-mixed", + nemo1_init_path: Optional[str] = None, + initial_ckpt_path: Optional[str] = None, + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec, +) -> ExposedGeneformerPretrainConfig: + """Geneformer tiny model config settings, used in testing.""" + geneformer_config = ExposedGeneformerPretrainConfig( + num_layers=2, + hidden_size=32, + ffn_hidden_size=4 * 32, + num_attention_heads=2, + seq_length=seq_length, + fp32_residual_connection=False, + hidden_dropout=0.02, + init_method_std=0.02, + kv_channels=None, + apply_query_key_layer_scaling=False, + make_vocab_size_divisible_by=128, + masked_softmax_fusion=True, + fp16_lm_cross_entropy=False, + params_dtype=precision, + pipeline_dtype=precision, + autocast_dtype=precision, + gradient_accumulation_fusion=False, + layernorm_zero_centered_gamma=False, + layernorm_epsilon=1.0e-12, + activation_func="gelu", + qk_layernorm=False, + apply_residual_connection_post_layernorm=False, + bias_activation_fusion=True, + bias_dropout_fusion=True, + get_attention_mask_from_fusion=True, + attention_dropout=0.1, + share_embeddings_and_output_weights=True, + enable_autocast=False, + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=nemo1_init_path, + initial_ckpt_path=initial_ckpt_path, + ) + return geneformer_config + + +def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig: + """Default optimizer scheduler config for Geneformer. See OptimizerSchedulerConfig for defaults.""" + return OptimizerSchedulerConfig() + + +def experiment_config_recipe() -> ExperimentConfig: + """Default experiment config for Geneformer. Used in testing.""" + return ExperimentConfig( + save_every_n_steps=100, + result_dir="./results", + experiment_name="default_experiment", + restore_from_checkpoint_path=None, + save_last_checkpoint=True, + metric_to_monitor_for_checkpoints="reduced_train_loss", + save_top_k=2, + create_tensorboard_logger=False, + ) + + +def finetune_test_recipe(args) -> MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig]: + """Recipe for finetuning a regression head on the masked tokens.""" + data_path = args.data_path + result_dir = args.result_dir + + parallel_config = ParallelConfig( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=2 + ) + training_config = TrainingConfig( + max_steps=10, limit_val_batches=2, val_check_interval=2, precision="bf16-mixed", accelerator="gpu" + ) + data_config = GeneformerPretrainingDataConfig( + seq_length=128, + micro_batch_size=2, + num_dataset_workers=0, + data_dir=data_path, + ) + experiment_config = ExperimentConfig( + save_every_n_steps=training_config.val_check_interval, + result_dir=result_dir, + experiment_name="test-experiment", + restore_from_checkpoint_path=None, + save_last_checkpoint=True, + metric_to_monitor_for_checkpoints="reduced_train_loss", + save_top_k=2, + create_tensorboard_logger=False, + ) + + optim_config = OptimizerSchedulerConfig(lr_scheduler="cosine") + geneformer_config = geneformer_10m_finetune_config( + seq_length=data_config.seq_length, initial_ckpt_path=args.initial_ckpt_path + ) + + return MainConfig( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=geneformer_config, + optim_config=optim_config, + experiment_config=experiment_config, + ) + + +def pretrain_tiny_test_recipe(args) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]: + """Recipe for pretraining a tiny model. Used in testing.""" + data_path = args.data_path + result_dir = args.result_dir + + parallel_config = ParallelConfig( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=2 + ) + training_config = TrainingConfig( + max_steps=10, limit_val_batches=2, val_check_interval=2, precision="bf16-mixed", accelerator="gpu" + ) + data_config = GeneformerPretrainingDataConfig( + seq_length=128, + micro_batch_size=2, + num_dataset_workers=0, + data_dir=data_path, + ) + experiment_config = ExperimentConfig( + save_every_n_steps=training_config.val_check_interval, + result_dir=result_dir, + experiment_name="test-experiment", + restore_from_checkpoint_path=None, + save_last_checkpoint=True, + metric_to_monitor_for_checkpoints="reduced_train_loss", + save_top_k=2, + create_tensorboard_logger=False, + ) + + optim_config = OptimizerSchedulerConfig(lr_scheduler="cosine") + geneformer_config = geneformer_tiny_config( + seq_length=data_config.seq_length, initial_ckpt_path=args.initial_ckpt_path + ) + + return MainConfig( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=geneformer_config, + optim_config=optim_config, + experiment_config=experiment_config, + ) + + +def geneformer_10m_pretrain_recipe( + args, +) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]: + """Recipe for pretraining the 10m model.""" + data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path) + parallel_config = simple_parallel_recipe() + training_config = geneformer_base_training_config() + bionemo_model_config = geneformer_10m_model_config(initial_ckpt_path=args.initial_ckpt_path) + optim_config = geneformer_base_optimizer_scheduler_config() + experiment_config = geneformer_10m_experiment_config(result_dir=args.result_dir) + wandb_config = geneformer_10m_wandb_config() + main_config = MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=bionemo_model_config, + optim_config=optim_config, + experiment_config=experiment_config, + wandb_config=wandb_config, + ) + return main_config + + +def geneformer_106m_pretrain_recipe( + args, +) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]: + """Recipe for pretraining the 106m model. Uses 8 GPUs for data parallelism.""" + data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path) + parallel_config = geneformer_106m_parallel_config() + training_config = geneformer_base_training_config() + bionemo_model_config = geneformer_106m_model_config(initial_ckpt_path=args.initial_ckpt_path) + optim_config = geneformer_base_optimizer_scheduler_config() + experiment_config = geneformer_106m_experiment_config(result_dir=args.result_dir) + wandb_config = geneformer_106m_wandb_config() + main_config = MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=bionemo_model_config, + optim_config=optim_config, + experiment_config=experiment_config, + wandb_config=wandb_config, + ) + return main_config + + +def geneformer_10m_finetune_recipe( + args, +) -> MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig]: + """Recipe for finetuning the 10m model on a token regression head. Used as an example and for testing.""" + data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path) + parallel_config = simple_parallel_recipe() + training_config = default_trainer_config_recipe() + bionemo_model_config = geneformer_finetuning_regression_head_recipe(initial_ckpt_path=args.initial_ckpt_path) + optim_config = default_adam_optimizer_with_cosine_annealing_recipe() + experiment_config = experiment_config_recipe() + wandb_config = WandbConfig( + project="bionemo2-demo", + entity="nvidia", + offline=True, + tags=[], + group="dev", + id="dev", + log_model=False, + anonymous=True, + ) + main_config = MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig]( + data_config=data_config, + parallel_config=parallel_config, + training_config=training_config, + bionemo_model_config=bionemo_model_config, + optim_config=optim_config, + experiment_config=experiment_config, + wandb_config=wandb_config, + ) + return main_config + + +def main(): # noqa: D103 + def parse_args(): + parser = argparse.ArgumentParser(description="Create Geneformer configuration JSON.") + parser.add_argument( + "--recipe", + type=str, + choices=["test", "10m-pretrain", "106m-pretrain", "test-finetune", "finetune"], + required=True, + help="Use one of the preconfigured recipes to create a template config file.", + ) + + parser.add_argument( + "--dest", + type=str, + default="./geneformer-recipe.json", + required=True, + help="Path to the JSON configuration file.", + ) + + parser.add_argument( + "--data-path", type=str, required=True, help="Path to the directory containing pretraining data." + ) + parser.add_argument( + "--result-dir", type=str, required=True, help="Path to the directory used to save results." + ) + + parser.add_argument( + "--initial-ckpt-path", + type=str, + required=False, + default=None, + help="Path to an existing to a checkpoint directory to restore an existing checkpoint. Not compatible with all recipes.", + ) + + args = parser.parse_args() + return args + + """Simple example for creating a JSON from recipes.""" + args = parse_args() + + if args.recipe == "test": + config = pretrain_tiny_test_recipe(args) + elif args.recipe == "10m-pretrain": + config = geneformer_10m_pretrain_recipe(args) + elif args.recipe == "106m-pretrain": + config = geneformer_106m_pretrain_recipe(args) + elif args.recipe == "test-finetune": + # Uses a bigger model because we have a pretrained model for it. + config = finetune_test_recipe(args) + elif args.recipe == "finetune": + # NOTE: this recipe finetunes a regression model on the masked tokens, if youre looking to finetune with a custom task, youll need to define your own classes. + config = geneformer_10m_finetune_recipe(args) + else: + raise ValueError("Invalid recipe choice.") + + # Serialize to JSON + json_str = config.model_dump_json(indent=2) + + # Save to file + with open( + args.dest, + "w", + ) as f: + f.write(json_str) + + logging.info(f"Saved configuration to {args.dest=}") + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 131d149fe7..2fa9e4a68f 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -44,7 +44,7 @@ from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BioBertConfig, BiobertSpecOption from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size -from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger +from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger __all__: Sequence[str] = ("main", "get_parser") @@ -190,10 +190,10 @@ def main( # for wandb integration # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html" - wandb_options: Optional[WandbLoggerOptions] = ( + wandb_options: Optional[WandbConfig] = ( None if wandb_project is None - else WandbLoggerOptions( + else WandbConfig( offline=wandb_offline, project=wandb_project, entity=wandb_entity, @@ -322,7 +322,7 @@ def main( root_dir=result_dir, name=experiment_name, initialize_tensorboard_logger=create_tensorboard_logger, - wandb_kwargs=wandb_options, + wandb_config=wandb_options, ckpt_callback=checkpoint_callback, ) llm.train( diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py new file mode 100644 index 0000000000..7eeb47a613 --- /dev/null +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shlex +import subprocess +from pathlib import Path + +from lightning.fabric.plugins.environments.lightning import find_free_network_port + +from bionemo.testing.data.load import load + + +data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" + + +def test_bionemo2_rootdir(): + data_error_str = ( + "Please download test data with:\n" + "`python scripts/download_artifacts.py --models all --model_dir ./models --data all --data_dir ./ --verbose --source pbss`" + ) + assert data_path.exists(), f"Could not find test data directory.\n{data_error_str}" + assert data_path.is_dir(), f"Test data directory is supposed to be a directory.\n{data_error_str}" + + +def test_pretrain_cli_from_ckpt(tmpdir): + # Same as test_pretrain, but includes a checkpoint to initialize from. + data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" + result_dir = Path(tmpdir.mkdir("results")) + + open_port = find_free_network_port() + # NOTE: if this test is ever failing, you may want to put the config somewhere easily accessible. + config = f"{result_dir}/test_config.json" + # Invoke with blocking, continue when finished (and the json config is generated) + checkpoint_path: Path = load("geneformer/10M_240530:2.0") + cmd_str = f"""bionemo-geneformer-recipe --dest {config} --recipe test --data-path {data_path} --result-dir {result_dir} --initial-ckpt-path {checkpoint_path}""".strip() + env = dict(**os.environ) # a local copy of the environment + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + if result.returncode != 0: + raise Exception(f"Pretrain recipe failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + + cmd_str = f"""bionemo-geneformer-train --conf {config}""".strip() + env = dict(**os.environ) # a local copy of the environment + open_port = find_free_network_port() + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + if result.returncode != 0: + # More helpful failure + raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + + # Must match the experiment directory configured. + assert (result_dir / "test-experiment").exists(), "Could not find test experiment directory." + + +def test_pretrain_cli(tmpdir): + """trains from scratch""" + data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" + result_dir = Path(tmpdir.mkdir("results")) + + open_port = find_free_network_port() + config = f"{result_dir}/test_config.json" + # Invoke with blocking + cmd_str = f"""bionemo-geneformer-recipe --dest {config} --recipe test --data-path {data_path} --result-dir {result_dir}""".strip() + # continue when finished + env = dict(**os.environ) # a local copy of the environment + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + # Now do pretrain + if result.returncode != 0: + raise Exception(f"Pretrain recipe failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + + cmd_str = f"""bionemo-geneformer-train --conf {config}""".strip() + env = dict(**os.environ) # a local copy of the environment + open_port = find_free_network_port() + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + if result.returncode != 0: + raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + # NOTE this looks a lot like a magic value. But we also could do json.loads(config)['experiment_config']['experiment_name'] + assert (result_dir / "test-experiment").exists(), "Could not find test experiment directory." + + +def test_finetune_cli(tmpdir): + """Uses CLI to invoke the entrypoint""" + data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" + result_dir = Path(tmpdir.mkdir("results")) + checkpoint_path: Path = load("geneformer/10M_240530:2.0") + + open_port = find_free_network_port() + + config = f"{result_dir}/test_config.json" + + # TODO add initial path + cmd_str = f"""bionemo-geneformer-recipe --dest {config} --recipe test-finetune --data-path {data_path} --result-dir {result_dir} --initial-ckpt-path {checkpoint_path}""".strip() + # continue when finished + env = dict(**os.environ) # a local copy of the environment + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + import sys + + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + # Now do pretrain + if result.returncode != 0: + raise Exception(f"Pretrain recipe failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + + cmd_str = f"""bionemo-geneformer-train --conf {config} """.strip() + env = dict(**os.environ) # a local copy of the environment + open_port = find_free_network_port() + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + if result.returncode != 0: + raise Exception(f"Pretrain script failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") + assert (result_dir / "test-experiment").exists(), "Could not find test experiment directory." diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index 65c1cecc5d..2ee35bb38b 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -479,6 +479,7 @@ class BioBertConfig( nemo1_ckpt_path: Optional[str] = None initial_ckpt_path: Optional[str] = None + # TODO(@jstjohn, @skothenhill) Was this supposed to be only on the child? initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list) # Used if initializing from a checkpoint, set this to any fields you want to override rather than re-set. # by default all fields will be overridden. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/lr_scheduler.py similarity index 100% rename from sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py rename to sub-packages/bionemo-llm/src/bionemo/llm/model/lr_scheduler.py diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/run/__init__.py b/sub-packages/bionemo-llm/src/bionemo/llm/run/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/run/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py new file mode 100644 index 0000000000..c3c2ef292d --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pathlib +from abc import ABC, abstractmethod +from dataclasses import field +from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar + +import pytorch_lightning as pl +import torch +from pydantic import BaseModel, field_serializer, field_validator, model_validator +from torch.nn import functional as F + +from bionemo.core.utils import dtypes +from bionemo.llm.model.biobert.model import BioBertConfig +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.utils.logger_utils import WandbConfig + + +ModelConfigT = TypeVar("ModelConfigT", bound=BioBertConfig) +DataModuleT = TypeVar("DataModuleT", bound=pl.LightningDataModule) + +# Activation functions not available in torch.nn.functional require custom serialization/validation. Add them here with a lookup key. +CUSTOM_ACTIVATION_FNS: Dict[str, Callable[[torch.Tensor, Any], torch.Tensor]] = {} + +# DO NOT use keys that already exist in torch.nn.functional, as the torch.nn.functional functions are selected first. +for key in CUSTOM_ACTIVATION_FNS: + assert key not in dir(torch.nn.functional), f"Key {key} already exists in torch.nn.functional" + +# It does not matter if values are duplicated as the key=>value mapping still does the right thing. Repeat values should be considered aliases. +REVERSE_CUSTOM_ACTIVATION_FNS: Dict[Callable[[torch.Tensor, Any], torch.Tensor], str] = { + v: k for k, v in CUSTOM_ACTIVATION_FNS.items() +} + + +class DataConfig(BaseModel, Generic[DataModuleT], ABC): + """Base class for all data configurations. + + This class is used to define the interface for all data configurations. It is used to define the data module that + will be used in the training loop. + """ + + micro_batch_size: int = 8 + result_dir: str | pathlib.Path = "./results" + num_dataset_workers: int = 0 + seq_length: int = 128 + + @abstractmethod + def construct_data_module(self, global_batch_size: int) -> DataModuleT: + """Construct the data module from the configuration. Cannot be defined generically.""" + ... + + def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig": + """Use custom implementation of this method to define the things inside global_config. + + The following expression will always be true: + + global_cfg.data_config == self + """ + return global_cfg + + +class ExposedModelConfig(BaseModel, Generic[ModelConfigT], ABC): + """BioNeMo model configuration class, wraps TransformerConfig and friends. + + This class is used to define the interface for all model configurations. It is **Exposed** to guard against ill-typed + or poorly defined fields in the underlying configuration objects. `ModelConfigT` declares the associated type of the + underlying config (most commonly a BioBertGenericConfig, but could also be a TransformerConfig or something similar). + Children should try to expose the minimal set of fields necessary for the user to configure the model while keeping + the more esoteric configuration private to the underlying ModelConfigT. + """ + + # Restores weights from a pretrained checkpoint + initial_ckpt_path: Optional[str] = None + # Does not attempt to load keys with these prefixes (useful if you attached extra parameters and still want to load a set of weights) + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list) + + # Pydantic stuff to allow arbitrary types + validators + serializers + class Config: # noqa: D106 + arbitrary_types_allowed = True + + def model_class(self) -> Type[ModelConfigT]: + """Returns the underlying model class that this config wraps.""" + raise NotImplementedError + + def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig": + """Use custom implementation of this method to define the things inside global_config. + + The following expression will always be true: + + global_cfg.bionemo_model_config == self + """ + return global_cfg + + def exposed_to_internal_bionemo_model_config(self) -> ModelConfigT: + """Converts the exposed dataclass to the underlying Transformer config. + + The underlying ModelConfigT may both be incomplete and unserializable. We use this transformation as a way to + hide fields that are either not serializable by Pydantic or that we do not want to expose. + """ + cls: Type[ModelConfigT] = self.model_class() + model_dict = {} + for attr in self.model_fields: + if attr not in model_dict and attr in cls.__dataclass_fields__: + model_dict[attr] = getattr(self, attr) + + # Now set fp16 and bf16 based on the precision for the underlying TransformerConfig=>ParallelConfig + # the only constraint is that both must not be true. + model_dict["bf16"] = self.pipeline_dtype == dtypes.precision_to_dtype["bf16-mixed"] + model_dict["fp16"] = self.pipeline_dtype == dtypes.precision_to_dtype["16-mixed"] + result = cls(**model_dict) + + return result + + # NOTE: See PrecisionTypes for a list of valid literals that may be deserialized. + params_dtype: torch.dtype + pipeline_dtype: torch.dtype + autocast_dtype: torch.dtype + + num_layers: int = 6 + hidden_size: int = 256 + ffn_hidden_size: int = 512 + num_attention_heads: int = 4 + seq_length: int = 512 + fp32_residual_connection: bool = False + hidden_dropout: float = 0.02 + init_method_std: float = 0.02 + kv_channels: Optional[int] = None + apply_query_key_layer_scaling: bool = False + make_vocab_size_divisible_by: int = 128 + masked_softmax_fusion: bool = True + fp16_lm_cross_entropy: bool = False + gradient_accumulation_fusion: bool = False + layernorm_zero_centered_gamma: bool = False + layernorm_epsilon: float = 1.0e-12 + activation_func: Callable[[torch.Tensor, Any], torch.Tensor] = F.gelu + qk_layernorm: bool = False + apply_residual_connection_post_layernorm: bool = False + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + get_attention_mask_from_fusion: bool = False + attention_dropout: float = 0.1 + share_embeddings_and_output_weights: bool = True + enable_autocast: bool = False + nemo1_ckpt_path: Optional[str] = None + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec + + @field_validator("activation_func", mode="before") + @classmethod + def validate_activation_func(cls, activation_func: str) -> Callable: + """Validates the activation function, assumes this function exists in torch.nn.functional. + + For custom activation functions, use the CUSTOM_ACTIVATION_FUNCTIONS dictionary in the module. This method + validates the provided activation function string and returns a callable function based on the validation + context using the provided validator in the base class. + + Args: + activation_func (str): The activation function to be validated. + context (ValidationInfo): The context for validation. + + Returns: + Callable: A callable function after validation. + + See Also: + CUSTOM_ACTIVATION_FNS + """ + func = getattr(torch.nn.functional, activation_func.lower(), None) + if func is None and activation_func in CUSTOM_ACTIVATION_FNS: + func = CUSTOM_ACTIVATION_FNS[activation_func] + return func + elif func is None: + raise ValueError( + f"activation_func must be a valid function in `torch.nn.functional`, got {activation_func=}" + ) + else: + return func + + @field_serializer("activation_func") + def serialize_activation_func(self, v: Callable[[torch.Tensor, Any], torch.Tensor]) -> str: + """Serializes a given activation function to its corresponding string representation. + + By default, all activation functions from `torch.nn.functional` are serialized to their name. User defined + activation functions should also be defined here with a custom mapping in CUSTOM_ACTIVATION_FNS defined at the + top of this file. This allows our Pydantic model to serialize and deserialize the activation function. + + Args: + v (Callable[[torch.Tensor, Any], torch.Tensor]): The activation function to serialize. + + Returns: + str: The name of the activation function if it is a standard PyTorch function, + or the corresponding serialization key if it is a custom activation function. + + Raises: + ValueError: If the activation function is not supported. + """ + func_name = v.__name__ + func = getattr(torch.nn.functional, func_name, None) + if func is not None: + return func_name + elif func in REVERSE_CUSTOM_ACTIVATION_FNS: + return REVERSE_CUSTOM_ACTIVATION_FNS[func] # Get the serialization key + else: + raise ValueError(f"Unsupported activation function: {v}") + + @field_validator("params_dtype", "pipeline_dtype", "autocast_dtype", mode="before") + @classmethod + def precision_validator(cls, v: dtypes.PrecisionTypes) -> torch.dtype: + """Validates the precision type and returns the corresponding torch dtype.""" + return dtypes.get_autocast_dtype(v) + + @field_serializer("params_dtype", "pipeline_dtype", "autocast_dtype") + def serialize_dtypes(self, v: torch.dtype) -> dtypes.PrecisionTypes: + """Serializes the torch dtype to the corresponding precision type.""" + return dtypes.dtype_to_precision[v] + + +class ParallelConfig(BaseModel): + """ParallelConfig is a configuration class for setting up parallelism in model training. + + Attributes: + tensor_model_parallel_size (int): The size of the tensor model parallelism. Default is 1. + pipeline_model_parallel_size (int): The size of the pipeline model parallelism. Default is 1. + accumulate_grad_batches (int): The number of batches to accumulate gradients over. Default is 1. + ddp (Literal["megatron"]): The distributed data parallel method to use. Default is "megatron". + remove_unused_parameters (bool): Whether to remove unused parameters. Default is True. + num_devices (int): The number of devices to use. Default is 1. + num_nodes (int): The number of nodes to use. Default is 1. + + Methods: + validate_devices(): Validates the number of devices based on the tensor and pipeline model parallel sizes. + """ + + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + accumulate_grad_batches: int = 1 + ddp: Literal["megatron"] = "megatron" + remove_unused_parameters: bool = True + num_devices: int = 1 + num_nodes: int = 1 + + @model_validator(mode="after") + def validate_devices(self): + """Validates the number of devices based on the tensor and pipeline model parallel sizes.""" + if self.num_devices < self.tensor_model_parallel_size * self.pipeline_model_parallel_size: + raise ValueError("devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size") + return self + + +class TrainingConfig(BaseModel): + """TrainingConfig is a configuration class for training models. + + Attributes: + max_steps (int): The maximum number of training steps. + limit_val_batches (int | float): The number of validation batches to use. Can be a fraction or a count. + val_check_interval (int): The interval (in steps) at which to check validation. + precision (Literal["32", "bf16-mixed", "16-mixed"], optional): The precision to use for training. Defaults to "bf16-mixed". + accelerator (str, optional): The type of accelerator to use for training. Defaults to "gpu". + gc_interval (int, optional): The interval of global steps at which to run synchronized garbage collection. Useful for synchronizing garbage collection when performing distributed training. Defaults to 0. + include_perplexity (bool, optional): Whether to include perplexity in the validation logs. Defaults to False. + """ + + max_steps: int + limit_val_batches: int | float # Because this can be a fraction or a count... + val_check_interval: int + precision: Literal["32", "bf16-mixed", "16-mixed"] = "bf16-mixed" + accelerator: str = "gpu" + # NOTE: VERY important for distributed training performance. + gc_interval: int = 0 + include_perplexity: bool = False + + +class OptimizerSchedulerConfig(BaseModel): + """Configuration for the optimizer and learning rate scheduler. + + Attributes: + lr (float): Learning rate for the optimizer. Default is 1e-4. + optimizer (str): Type of optimizer to use. Default is "adam". + interval (str): Interval for updating the learning rate scheduler. Default is "step". + monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss". + interval (str): Interval for updating the learning rate scheduler. Default is "step". + monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss". + warmup_steps (int): Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0. + lr_scheduler (Literal['warmup_anneal', 'cosine']): Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change. + """ + + lr: float = 1e-4 + optimizer: str = "adam" + interval: str = "step" + monitor: str = "val_loss" + cosine_rampup_frac: float = 0.01 + cosine_hold_frac: float = 0.05 + warmup_steps: int = 0 + lr_scheduler: Literal["warmup_anneal", "cosine"] = "warmup_anneal" + + +class ExperimentConfig(BaseModel): + """Configuration class for setting up and managing experiment parameters. + + Attributes: + save_every_n_steps (int): Number of steps between saving checkpoints. + result_dir (str | pathlib.Path): Directory where results will be saved. + experiment_name (str): Name of the experiment. + restore_from_checkpoint_path (Optional[str]): Path to restore from a checkpoint. Note: This does not invoke the checkpoint callback as expected. + save_last_checkpoint (bool): Flag to save the last checkpoint. Default is True. + metric_to_monitor_for_checkpoints (str): Metric to monitor for saving top-k checkpoints. Default is "reduced_train_loss". + save_top_k (int): Number of top checkpoints to save based on the monitored metric. Default is 2. + create_tensorboard_logger (bool): Flag to create a TensorBoard logger. Default is False. + """ + + save_every_n_steps: int + result_dir: str | pathlib.Path + experiment_name: str + # NOTE: restore_from_checkpoint_path does not invoke the checkpoint callback in the way we'd like. Avoid using. + restore_from_checkpoint_path: Optional[str] + save_last_checkpoint: bool = True + metric_to_monitor_for_checkpoints: str = "reduced_train_loss" + save_top_k: int = 2 + create_tensorboard_logger: bool = False + + +# DataConfig -> some config that can make a data module (see ABC definition.) +DataConfigT = TypeVar("DataConfigT", bound=DataConfig) +# ExposedModelConfig -> some config that can make a non-exposed model config (see ABC definition.) +ExModelConfigT = TypeVar("ExModelConfigT", bound=ExposedModelConfig) + + +class MainConfig(BaseModel, Generic[ExModelConfigT, DataConfigT]): + """Main configuration class for BioNeMo. All serialized configs that are a valid MainConfig should be Runnable. + + This class is used to define the main configuration for BioNeMo. It defines the minimal pieces of configuration + to execution a training job with the NeMo2 training api. It accepts two generic type parameters which users + must define in their own environment for execution. + + Additionally, this class assumes that the configs for ExposedModelConfig and DataConfig may have custom validators + implemented that operate on the entire MainConfig. This prevents the need from type based conditionals inside this + class while still allowing for custom validation global logic to be implemented in the underlying classes. For example, + some models may want to restrict their Datamodules seq_length to a certain value. + + + Args: + data_config: Generic config type that contains instructions on instantiating the required DataModule. + parallel_config: The parallel configuration for the model. + training_config: The training configuration for the model. + bionemo_model_config: Generic ExposedModelConfig type. This class hides extra configuration parameters in the + underlying model configuration as well as providing + optim_config: The optimizer/scheduler configuration for the model. + experiment_config: The experiment configuration for the model. + wandb_config: Optional, the wandb configuration for the model. + """ + + data_config: DataConfigT + parallel_config: ParallelConfig + training_config: TrainingConfig + bionemo_model_config: ExModelConfigT + optim_config: OptimizerSchedulerConfig + experiment_config: ExperimentConfig + wandb_config: Optional[WandbConfig] = None + + @model_validator(mode="after") + def validate_master_config(self) -> "MainConfig": + """Validates the master configuration object.""" + self.bionemo_model_config.seq_length = self.data_config.seq_length + return self + + @model_validator(mode="after") + def run_bionemo_model_config_model_validators(self) -> "MainConfig": + """Runs the model validators on the bionemo_model_config.""" + return self.bionemo_model_config.custom_model_validator(self) + + @model_validator(mode="after") + def run_data_config_model_validators(self) -> "MainConfig": + """Runs the model validators on the data_config.""" + return self.data_config.custom_model_validator(self) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/train.py b/sub-packages/bionemo-llm/src/bionemo/llm/train.py new file mode 100644 index 0000000000..18ec5b1b83 --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/train.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import pathlib +from dataclasses import field +from typing import Optional + +from megatron.core.optimizer import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning import resume +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.optim import MegatronOptimizerModule +from nemo.lightning.pytorch.optim.lr_scheduler import CosineAnnealingScheduler +from nemo.utils import logging +from pydantic import BaseModel +from pytorch_lightning.callbacks import LearningRateMonitor, RichModelSummary + +from bionemo.llm.lightning import BionemoLightningModule, PerplexityLoggingCallback +from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler +from bionemo.llm.run.config_models import ( + DataConfig, + DataModuleT, + ExperimentConfig, + ExposedModelConfig, + OptimizerSchedulerConfig, + ParallelConfig, + TrainingConfig, +) +from bionemo.llm.utils.datamodule_utils import infer_global_batch_size +from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger + + +class NsysConfig(BaseModel): + """Configuration for nsys profiling.""" + + start_step: int = 0 + end_step: Optional[int] = None + ranks: list[int] = field(default_factory=lambda: [0]) + + +def nemo_logger_factory(experiment_config: ExperimentConfig, wandb_config: Optional[WandbConfig]) -> nl.NeMoLogger: + """Creates and returns a NeMoLogger instance configured based on the provided experiment and wandb configurations. + + Args: + experiment_config (ExperimentConfig): Configuration object containing experiment settings such as + result directory, experiment name, checkpoint settings, and logger preferences. + wandb_config (Optional[WandbConfig]): Optional configuration object for Weights and Biases logging. + + Returns: + nl.NeMoLogger: An instance of NeMoLogger configured with the specified settings. + """ + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=experiment_config.save_last_checkpoint, + monitor=experiment_config.metric_to_monitor_for_checkpoints, + save_top_k=experiment_config.save_top_k, + every_n_train_steps=experiment_config.save_every_n_steps, + always_save_context=True, + ) + + nemo_logger = setup_nemo_lightning_logger( + root_dir=experiment_config.result_dir, + name=experiment_config.experiment_name, + initialize_tensorboard_logger=experiment_config.create_tensorboard_logger, + wandb_config=wandb_config, + ckpt_callback=checkpoint_callback, + ) + return nemo_logger + + +def setup_trainer( + parallel_config: ParallelConfig, + training_config: TrainingConfig, + callbacks=None, + nsys_config: NsysConfig | None = None, +) -> nl.Trainer: + """Set up the trainer for model training using the specified parallel and training configurations. + + Args: + parallel_config (ParallelConfig): Configuration for parallelism, including tensor and pipeline model parallel sizes, + number of devices, and number of nodes. + training_config (TrainingConfig): Configuration for training, including maximum steps, accelerator type, + validation batch limit, validation check interval, and precision. + callbacks (list, optional): List of callback functions to be used during training. Defaults to None, + in which case default callbacks (RichModelSummary and LearningRateMonitor) are used. + nsys_config (NsysConfig, optional): Configuration for nsys profiling. If None, is disabled. + + Returns: + nl.Trainer: Configured trainer object ready for model training. + """ + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, + ddp="megatron", + find_unused_parameters=True, + ckpt_include_optimizer=True, + ) + if callbacks is None: + callbacks = [ + RichModelSummary(max_depth=4), + LearningRateMonitor(), + ] + + if training_config.include_perplexity: + callbacks.append(PerplexityLoggingCallback()) + + if training_config.gc_interval > 0: + callbacks.append( + nl_callbacks.GarbageCollectionCallback( + gc_interval_train=training_config.gc_interval, gc_interval_val=training_config.gc_interval + ) + ) + + if nsys_config: + if nsys_config.end_step is None: + nsys_config.end_step = training_config.max_steps + callbacks.append( + nl_callbacks.NsysCallback( + start_step=nsys_config.start_step, + end_step=nsys_config.end_step, + ranks=nsys_config.ranks, + gen_shape=True, + ) + ) + + trainer = nl.Trainer( + devices=parallel_config.num_devices, + max_steps=training_config.max_steps, + accelerator=training_config.accelerator, + strategy=strategy, + limit_val_batches=training_config.limit_val_batches, + val_check_interval=training_config.val_check_interval, + num_nodes=parallel_config.num_nodes, + callbacks=callbacks, + plugins=nl.MegatronMixedPrecision(precision=training_config.precision), + ) + return trainer + + +def train( + bionemo_exposed_model_config: ExposedModelConfig, + data_config: DataConfig[DataModuleT], + parallel_config: ParallelConfig, + training_config: TrainingConfig, + optim_config: OptimizerSchedulerConfig, + experiment_config: ExperimentConfig, + wandb_config: Optional[WandbConfig], + nsys_config: Optional[NsysConfig] = None, + resume_if_exists: bool = True, +): + """Train a BioNemo model using the provided configurations. Uses the ExposedModelConfig and DataConfig as the primary variants for this method. + + Args: + bionemo_exposed_model_config (ExposedModelConfig): Configuration for the exposed BioNemo model. + data_config (DataConfig[DataModuleT]): Configuration for the data module. + parallel_config (ParallelConfig): Configuration for parallel training. + training_config (TrainingConfig): Configuration for training parameters. + optim_config (OptimizerSchedulerConfig): Configuration for the optimizer and scheduler. + experiment_config (ExperimentConfig): Configuration for the experiment. + wandb_config (Optional[WandbConfig]): Configuration for Weights and Biases logging.n + nsys_config (Optional[NsysConfig], optional): Configuration for nsys profiling. If None, is disabled. + resume_if_exists (bool, optional): Flag to resume training if a checkpoint exists. Defaults to True. + """ + bionemo_model_config = bionemo_exposed_model_config.exposed_to_internal_bionemo_model_config() + pathlib.Path(data_config.result_dir).mkdir(parents=True, exist_ok=True) + + if experiment_config.save_every_n_steps != training_config.val_check_interval: + logging.warning("Mutating training_config.save_every_n_steps to be equal to val_check_interval.") + experiment_config.save_every_n_steps = training_config.val_check_interval + + global_batch_size = infer_global_batch_size( + micro_batch_size=data_config.micro_batch_size, + num_nodes=parallel_config.num_nodes, + devices=parallel_config.num_devices, + accumulate_grad_batches=parallel_config.accumulate_grad_batches, + tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, + ) + + data: DataModuleT = data_config.construct_data_module(global_batch_size) + # TODO BioBertDataModule or BioBertTokenizer abstractions. We know all DataModuleT in this case has data.tokenizer, + # although this constraint is not documented. + + # TODO: need an abstraction for LrSchedulerConfig + if optim_config.lr_scheduler == "cosine": + lr_scheduler = CosineAnnealingScheduler( + max_steps=training_config.max_steps, + min_lr=optim_config.lr / 100, + warmup_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_rampup_frac)), + interval=optim_config.interval, + monitor=optim_config.monitor, + constant_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_hold_frac)), + ) + elif optim_config.lr_scheduler == "warmup_anneal": + lr_scheduler = WarmupAnnealDecayHoldScheduler( + warmup_steps=optim_config.warmup_steps, + max_steps=training_config.max_steps, + max_lr=optim_config.lr, + min_lr=optim_config.lr / 10.0, + anneal_percentage=0.10, + ) + else: + raise NotImplementedError(f"Scheduler {optim_config.lr_scheduler} not implemented.") + + optimizer = MegatronOptimizerModule( + config=OptimizerConfig( + lr=optim_config.lr, + optimizer=optim_config.optimizer, + use_distributed_optimizer=True, + fp16=bionemo_model_config.fp16, + bf16=bionemo_model_config.bf16, + ), + lr_scheduler=lr_scheduler, + ) + + model: BionemoLightningModule = biobert_lightning_module( + config=bionemo_model_config, tokenizer=data.tokenizer, optimizer=optimizer + ) + trainer: nl.Trainer = setup_trainer(parallel_config, training_config, nsys_config=nsys_config) + nemo_logger: nl.NeMoLogger = nemo_logger_factory(experiment_config, wandb_config=wandb_config) + + llm.train( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + resume_if_exists=resume_if_exists, + resume_ignore_no_checkpoint=True, + ), + ) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py index 93ead69e48..5f2f6b1957 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py @@ -13,23 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. import pathlib -from typing import Any, Dict, List, Optional, Sequence, TypedDict +from typing import Any, Dict, List, Optional, Sequence from nemo.lightning.nemo_logger import NeMoLogger from nemo.lightning.pytorch import callbacks as nemo_callbacks from nemo.utils import logging +from pydantic import BaseModel from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger __all__: Sequence[str] = ( - "WandbLoggerOptions", + "WandbConfig", "setup_nemo_lightning_logger", ) -class WandbLoggerOptions(TypedDict): +class WandbConfig(BaseModel): """Note: `name` controls the exp name is handled by the NeMoLogger so it is ommitted here. `directory` is also omitted since it is set by the NeMoLogger. + + Args: + entity: The team posting this run (default: your username or your default team) + project: The name of the project to which this run will belong. + tags: Tags associated with this run. + group: A unique string shared by all runs in a given group + offline: Run offline (data can be streamed later to wandb servers). + id: Sets the version, mainly used to resume a previous run. + anonymous: Enables or explicitly disables anonymous logging. """ # noqa: D205 entity: str # The team posting this run (default: your username or your default team) @@ -48,17 +58,17 @@ def setup_nemo_lightning_logger( name: str = "default-name", root_dir: str | pathlib.Path = "./results", initialize_tensorboard_logger: bool = False, - wandb_kwargs: Optional[WandbLoggerOptions] = None, + wandb_config: Optional[WandbConfig] = None, ckpt_callback: Optional[nemo_callbacks.ModelCheckpoint] = None, **kwargs: Dict[str, Any], ) -> NeMoLogger: """Setup the logger for the experiment. - Args: + Arguments: name: The name of the experiment. Results go into `root_dir`/`name` root_dir: The root directory to create the `name` directory in for saving run results. initialize_tensorboard_logger: Whether to initialize the tensorboard logger. - wandb_kwargs: The kwargs for the wandb logger. + wandb_config: The remaining configuration options for the wandb logger. ckpt_callback: The checkpoint callback to use, must be a child of the pytorch lightning ModelCheckpoint callback. NOTE the type annotation in the underlying NeMoCheckpoint constructor is incorrect. **kwargs: The kwargs for the NeMoLogger. @@ -68,8 +78,8 @@ def setup_nemo_lightning_logger( """ # The directory that the logger will save to save_dir = pathlib.Path(root_dir) / name - if wandb_kwargs is not None: - wandb_logger = WandbLogger(save_dir=save_dir, name=name, **wandb_kwargs) + if wandb_config is not None: + wandb_logger = WandbLogger(save_dir=save_dir, name=name, **wandb_config.model_dump()) else: wandb_logger = None logging.warning("WandB is currently turned off.") diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_lr_scheduler.py b/sub-packages/bionemo-llm/tests/bionemo/llm/model/test_lr_scheduler.py similarity index 96% rename from sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_lr_scheduler.py rename to sub-packages/bionemo-llm/tests/bionemo/llm/model/test_lr_scheduler.py index a0b0883e05..1b5549db00 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_lr_scheduler.py +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/model/test_lr_scheduler.py @@ -16,7 +16,7 @@ import torch -from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHold, WarmupAnnealDecayHoldScheduler +from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHold, WarmupAnnealDecayHoldScheduler def test_warmup_anneal_decay_hold_scheduler_exists():