Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement megatron-aware perplexity in torchmetrics #525

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f1b0736
update ddp config
sichu2023 Dec 11, 2024
97bbfa3
update MegatronMixedPrecision
sichu2023 Dec 11, 2024
34e12e9
change DistributedDataParallelConfig import
sichu2023 Dec 11, 2024
97a8af4
add grad_reduce_in_fp32=False in MegatronMixedPrecision
sichu2023 Dec 11, 2024
e4f8b54
transfer torchmetrics changes from pstjohn repo
sichu2023 Dec 12, 2024
86bcc8a
add pp last stage in logging
sichu2023 Dec 12, 2024
b2e5f1d
add tp-aware update method
sichu2023 Dec 12, 2024
edf1da3
drop comment on tp-aware normalization
sichu2023 Dec 12, 2024
35b5abf
update comment
sichu2023 Dec 12, 2024
564626c
fix cp error
sichu2023 Dec 12, 2024
bf16c01
drop process_group
sichu2023 Dec 12, 2024
6234c70
add MegatronPerplexityMetric testing
sichu2023 Dec 16, 2024
5db215d
fix metric device
sichu2023 Dec 17, 2024
0a3ae38
fix MegatronPerplexityMetric.update
sichu2023 Dec 17, 2024
0603b4d
clean up test_megatron_perplexity_metric_with_single_microbatch_golde…
sichu2023 Dec 17, 2024
c43d9e4
fix get_random_microbatch
sichu2023 Dec 17, 2024
71f9dc0
add variable length microbatch test
sichu2023 Dec 17, 2024
561d3c1
ruff
sichu2023 Dec 17, 2024
f3cc5d6
add back self.log_{train,val}_ppl
sichu2023 Dec 18, 2024
d03a94e
add back return in {train,validation}_step
sichu2023 Dec 18, 2024
eac8b18
add argparse
sichu2023 Dec 18, 2024
da76178
disable async ckpt save
sichu2023 Dec 19, 2024
bd6ebd3
drop pp support
sichu2023 Dec 19, 2024
7886e86
Revert "update ddp config"
sichu2023 Dec 19, 2024
f2ea013
Revert "update MegatronMixedPrecision"
sichu2023 Dec 19, 2024
80403ae
disable training ppl logging by default
sichu2023 Dec 19, 2024
dcc84ec
remove ppl callback
sichu2023 Dec 19, 2024
fb5ad07
move pp check to train_esm2.py
sichu2023 Dec 19, 2024
78aa7a3
ruff
sichu2023 Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -81,6 +80,8 @@ def main(
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
metric_to_monitor_for_checkpoints: str = "val_loss",
log_train_ppl: bool = False,
log_val_ppl: bool = True,
save_top_k: int = 2,
nsys_profiling: bool = False,
nsys_start_step: int = 0,
Expand Down Expand Up @@ -136,6 +137,8 @@ def main(
save_best_checkpoint (bool): whether to save the best checkpoint
save_last_checkpoint (bool): whether to save the last checkpoint
metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
log_train_ppl (bool): log training perplexity
log_val_ppl (bool): log validation perplexity
save_top_k (int): number of top checkpoints to save
nsys_profiling (bool): whether to enable nsys profiling
nsys_start_step (int): start step for nsys profiling
Expand Down Expand Up @@ -189,7 +192,6 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
Expand Down Expand Up @@ -252,6 +254,9 @@ def main(
if scheduler_num_steps is None:
scheduler_num_steps = num_steps

if (log_train_ppl or log_val_ppl) and pipeline_model_parallel_size > 1:
raise NotImplementedError("Perplexity logging does not support pipeline parallelism yet.")

model = biobert_lightning_module(
esm2_config,
tokenizer=tokenizer,
Expand All @@ -272,6 +277,9 @@ def main(
anneal_percentage=0.10,
),
),
# perplexity logging
log_train_ppl=log_train_ppl,
log_val_ppl=log_val_ppl,
)

# Configure our custom Checkpointer
Expand Down Expand Up @@ -350,6 +358,8 @@ def train_esm2_entrypoint():
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
log_train_ppl=args.log_train_ppl,
log_val_ppl=args.log_val_ppl,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
Expand Down Expand Up @@ -586,6 +596,25 @@ def get_parser():
default="val_loss",
help="The metric to monitor for checkpointing.",
)
parser.add_argument(
"--log-train-ppl",
action="store_true",
default=False,
help="Log perplexity during training.",
)
parser.add_argument(
"--log-val-ppl",
action="store_true",
default=False,
help="Log perplexity during validation.",
)
parser.add_argument(
"--no-log-val-ppl",
action="store_false",
dest="log_val_ppl",
default=True,
help="Disable logging perplexity during validation.",
)
parser.add_argument(
"--save-top-k",
type=int,
Expand Down
58 changes: 56 additions & 2 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import lightning.pytorch as pl
import torch.distributed
import torchmetrics.text
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
Expand Down Expand Up @@ -210,6 +211,33 @@ def predict_loss_reduction(self) -> PassthroughLossReduction:
"""


class MegatronPerplexityMetric(torchmetrics.text.Perplexity):
def __init__(self, *args, **kwargs):
if parallel_state.get_context_parallel_world_size() > 1:
raise NotImplementedError(f"{self.__class__} does not support context parallelism yet.")

self.cross_entropy_loss_fusion = kwargs.pop("cross_entropy_loss_fusion", False)
sichu2023 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update state with predictions and targets under tensor parallelism."""
unreduced_token_loss = unreduced_token_loss_fn( # TP-aware log prob function
preds.clone().transpose(0, 1).contiguous(),
target.clone(),
cross_entropy_loss_fusion=self.cross_entropy_loss_fusion,
sichu2023 marked this conversation as resolved.
Show resolved Hide resolved
) # (b, s)

if self.ignore_index is not None:
mask = target.ne(self.ignore_index)
target = target.where(target != self.ignore_index, torch.tensor(0, device=target.device))
else:
mask = torch.ones_like(target, dtype=torch.bool)
unreduced_token_loss = unreduced_token_loss[mask]

self.total_log_probs += unreduced_token_loss.sum()
self.count += mask.sum()


class BionemoLightningModule(
Generic[MegatronModelType, MegatronLossType],
pl.LightningModule,
Expand All @@ -227,6 +255,8 @@ def __init__(
# TODO: Add transformer_layer_spec when we update mcore
optimizer: MegatronOptimizerModule,
model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
log_train_ppl: bool = False,
log_val_ppl: bool = False,
**model_construct_args,
) -> None:
"""Constructor.
Expand All @@ -242,6 +272,8 @@ def __init__(
model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
`configure_model` method.
model_transform: Optional. The model transform function.
log_train_ppl (bool): Log training perplexity.
log_val_ppl (bool): Log validation perplexity.
**model_construct_args: Optional. Arguments necessary for the supplied model configuration's
`configure_model` method, which will make an instance of the model.
"""
Expand All @@ -258,6 +290,14 @@ def __init__(
self._forward_step = forward_step
self.model_transform = model_transform

# all scaling on the internal states are cancelled out in the formula "exp(total_log_probs / count)" so we can safely sum across all devices
self.log_train_ppl = log_train_ppl
self.log_val_ppl = log_val_ppl
if log_train_ppl:
self.train_ppl = MegatronPerplexityMetric(ignore_index=-100)
if log_val_ppl:
self.valid_ppl = MegatronPerplexityMetric(ignore_index=-100)

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

Expand Down Expand Up @@ -304,11 +344,25 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]

if self.log_train_ppl and parallel_state.is_pipeline_last_stage():
self.train_ppl(logits, batch["labels"])
self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False)

return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]

if self.log_val_ppl and parallel_state.is_pipeline_last_stage():
self.valid_ppl(logits, batch["labels"])
self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True)

return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand Down
75 changes: 74 additions & 1 deletion sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchmetrics.text import Perplexity

from bionemo.llm import lightning as bnptl
from bionemo.llm.lightning import PerplexityLoggingCallback, batch_collator, get_dtype_device
from bionemo.llm.lightning import MegatronPerplexityMetric, PerplexityLoggingCallback, batch_collator, get_dtype_device
from bionemo.testing import megatron_parallel_state_utils
from bionemo.testing.lightning import get_random_microbatch

Expand Down Expand Up @@ -186,6 +186,79 @@ def test_mixin_strategy_contract_get_loss_reduction():
assert isinstance(strategy_reduction_function(mixin), bnptl.PassthroughLossReduction)


def test_megatron_perplexity_metric_with_single_microbatch_golden_value_without_parallelism(seed: int = 42):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
# setup test input
microbatch_size, max_sequence_length, vocab_size = 1, 1024, 2
microbatch_outputs = [get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed)]

# setup metric
megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device())
metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device())

# compute values
for microbatch_output in microbatch_outputs:
megatron_ppl_metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
ppl_value = megatron_ppl_metric.compute()
ppl_golden_value = metric.compute()

torch.testing.assert_close(
ppl_value,
ppl_golden_value,
)


def test_megatron_perplexity_metric_with_with_variable_length_microbatches_golden_value_without_parallelism(
seed: int = 42,
):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
# setup test input
microbatch_size, max_sequence_length, vocab_size = 2, 1024, 2
microbatch_outputs = [
get_random_microbatch(microbatch_size, max_sequence_length // 2, vocab_size, seed),
get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed),
]

# setup metric
megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device())
metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device())

# compute values
for microbatch_output in microbatch_outputs:
megatron_ppl_metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
ppl_value = megatron_ppl_metric.compute()
ppl_golden_value = metric.compute()

torch.testing.assert_close(
ppl_value,
ppl_golden_value,
)


def test_perplexity_logging_callback_with_single_microbatch_golden_value_without_parallelism(seed: int = 42):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
Expand Down
10 changes: 7 additions & 3 deletions sub-packages/bionemo-testing/src/bionemo/testing/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@


def get_random_microbatch(
microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
microbatch_size: int,
max_sequence_length: int,
vocab_size: int,
seed: int,
mask_index: int = -100,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Generate random microbatches for testing.

Expand All @@ -35,7 +39,7 @@ def get_random_microbatch(
device=torch.cuda.current_device(),
) # [b s]
loss_mask = torch.randint(
low=1,
low=0,
high=1 + 1,
size=(microbatch_size, max_sequence_length),
dtype=torch.long,
Expand All @@ -45,7 +49,7 @@ def get_random_microbatch(
token_logits = torch.rand(
max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
) # [s b v]
labels[loss_mask == 0] = -100 # propagate masking to labels
labels[loss_mask == 0] = mask_index # propagate masking to labels
microbatch_output = {
"batch": {"labels": labels, "loss_mask": loss_mask},
"forward_out": {"token_logits": token_logits},
Expand Down
Loading