Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debugging checkpoint resumption #469

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.lightning import PerplexityLoggingCallback, StopAfterStepCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -90,6 +90,7 @@ def main(
hidden_size: int = 1280,
num_attention_heads: int = 20,
ffn_hidden_size: int = 1280 * 4,
stop_after_steps: int | None = None,
) -> None:
"""Train an ESM2 model on UR data.

Expand All @@ -104,6 +105,7 @@ def main(
max_seq_length (int): maximum sequence length
result_dir (Path): directory to store results, logs and checkpoints
num_steps (int): number of steps to train the model for
stop_after_steps (int): stop after this many steps. For debugging checkpoint resumption.
warmup_steps (int): number of steps for warmup phase
limit_val_batches (int): limit the number of validation global batches to this many
val_check_interval (int): number of steps to periodically check the validation loss
Expand Down Expand Up @@ -201,6 +203,9 @@ def main(
)
)

if stop_after_steps is not None:
callbacks.append(StopAfterStepCallback(stop_after_steps))

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
Expand Down Expand Up @@ -350,6 +355,7 @@ def train_esm2_entrypoint():
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
ffn_hidden_size=args.ffn_hidden_size,
stop_after_steps=args.stop_after_steps,
)


Expand Down Expand Up @@ -651,6 +657,13 @@ def get_parser():
default=4 * 1280,
help="FFN hidden size of the model. Default is 4 * 1280.",
)
parser.add_argument(
"--stop-after-steps",
type=int,
required=False,
default=None,
help="Stop after N steps.",
)
return parser


Expand Down
9 changes: 9 additions & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,12 @@ def on_megatron_reduce_microbatches_end(
step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
elif self.log_train and step.trainer.training:
step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)


class StopAfterStepCallback(pl.Callback, CallbackMethods):
def __init__(self, stop_after_steps: int):
self.stop_after_steps = stop_after_steps

def on_megatron_step_end(self, step, microbatch_outputs, reduced=None) -> None:
if step.trainer.global_step >= self.stop_after_steps:
raise RuntimeError(f"Stopping after {self.stop_after_steps} steps")
Loading