diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 2fa9e4a68f..646f88feaa 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -25,6 +25,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Type, get_args +import torch from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl @@ -91,6 +92,7 @@ def main( log_every_n_steps: int = 50, gc_interval: int = 0, aligned_megatron_ddp: bool = False, + recompilation_check: bool = False, # TODO add datamodule class, and ability to change data step to get full support for pretraining workflows ) -> None: """Train a Geneformer model on single cell data. @@ -143,6 +145,8 @@ def main( at this requested interval of train/val steps. This will likely slow down single GPU runs. aligned_megatron_ddp (bool): if activated, this will activate a number of communication optimizations that are good for clusters. This will likely slow down single node runs though. + recompilation_check (bool): enable a recompilation check (only do on a small run) to verify that fused gpu + kernels are not being regularly recompiled, which is very expensive, with a particular model/settings. """ # Create the result directory if it does not exist. result_dir.mkdir(parents=True, exist_ok=True) @@ -273,6 +277,8 @@ def main( ffn_hidden_size=512, num_attention_heads=4, seq_length=seq_length, + bias_dropout_fusion=False, + bias_activation_fusion=False, params_dtype=get_autocast_dtype(precision), pipeline_dtype=get_autocast_dtype(precision), autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot @@ -325,6 +331,11 @@ def main( wandb_config=wandb_options, ckpt_callback=checkpoint_callback, ) + if recompilation_check: + """This is _very_ useful for debugging slow forward passes. Check that your fused kernels are not + getting recompiled. Once verified, turn this off again. + """ + torch._dynamo.config.error_on_recompile = True llm.train( model=model, data=data, @@ -599,6 +610,13 @@ def config_class_type(desc: str) -> Type[BioBertConfig]: help="By default param overlap/etc is disabled in megatron, this enables all of those settings. This is probably " "good for cluster performance.", ) + parser.add_argument( + "--recompilation-check", + action="store_true", + default=False, + help="Activate this and make sure a small training loop runs, this tells you that your settings are not " + "triggering regular recompilations which can be very expensive for fused gpu kernels.", + ) return parser @@ -648,6 +666,7 @@ def entrypoint(): log_every_n_steps=args.log_every_n_steps, gc_interval=args.gc_interval, aligned_megatron_ddp=args.aligned_megatron_ddp, + recompilation_check=args.recompilation_check, ) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py index 7e92fff658..04310aeabf 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py @@ -259,7 +259,7 @@ def forward( return loss_for_microbatch * cp_size, {"avg": reduced_loss} -def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = True) -> Tensor: +def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = False) -> Tensor: """Computes the unreduced token loss given the logits and labels without regard to the loss mask. WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.