Skip to content

Commit

Permalink
callbacks and bf16 grad (#11985)
Browse files Browse the repository at this point in the history
* callbacks and bf16 grad

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* callbacks list extend

Signed-off-by: Malay Nagda <[email protected]>

* grad dtype and garbage collection callback

Signed-off-by: Malay Nagda <[email protected]>

---------

Signed-off-by: Malay Nagda <[email protected]>
Signed-off-by: malay-nagda <[email protected]>
Co-authored-by: malay-nagda <[email protected]>
  • Loading branch information
malay-nagda and malay-nagda authored Jan 30, 2025
1 parent 7692802 commit 78f445f
Show file tree
Hide file tree
Showing 20 changed files with 238 additions and 246 deletions.
41 changes: 25 additions & 16 deletions nemo/collections/llm/recipes/gpt3_175b.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
)
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -216,22 +217,30 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
It may not be suitable for all hardware configurations or use cases.
"""

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
align_param_gather=True,
)
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
)
mcomm_overlap_callback = run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
# 'overlap_param_gather_with_optimizer_step' is set automatically. Added here for user's knowledge
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe
40 changes: 24 additions & 16 deletions nemo/collections/llm/recipes/llama31_405b.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,31 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
It may not be suitable for all hardware configurations or use cases.
"""

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
align_param_gather=True,
)
garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
)
mcomm_overlap_callback = run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
# 'overlap_param_gather_with_optimizer_step' is set automatically. Added here for user's knowledge
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe

Expand Down Expand Up @@ -360,7 +368,7 @@ def finetune_performance_optimizations(
It may not be suitable for all hardware configurations or use cases.
"""

if not hasattr(recipe.trainer, "callbacks"):
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
Expand Down
46 changes: 30 additions & 16 deletions nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,24 +219,32 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
It may not be suitable for all hardware configurations or use cases.
"""

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing.
align_param_gather=True,
)
garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
)
mcomm_overlap_callback = run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
# 'overlap_param_gather_with_optimizer_step' is set automatically. Added here for user's knowledge
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing.
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe


Expand Down Expand Up @@ -358,7 +366,7 @@ def finetune_performance_optimizations(
It may not be suitable for all hardware configurations or use cases.
"""

if not hasattr(recipe.trainer, "callbacks"):
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
Expand Down Expand Up @@ -387,6 +395,12 @@ def finetune_performance_optimizations(
recipe.trainer.strategy.pipeline_model_parallel_size = 4
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5
recipe.peft.target_modules = ['linear_qkv']
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)

recipe.trainer.strategy.sequence_parallel = True

Expand Down
39 changes: 27 additions & 12 deletions nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,27 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
)
mcomm_overlap_callback = run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe


Expand Down Expand Up @@ -338,7 +353,7 @@ def finetune_performance_optimizations(
"""
recipe.trainer.strategy.tensor_model_parallel_size = 1

if not hasattr(recipe.trainer, "callbacks"):
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
Expand All @@ -351,15 +366,15 @@ def finetune_performance_optimizations(
overlap_param_gather=True,
average_in_collective=True,
)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)
else:
recipe.peft.target_modules = ['linear_qkv']

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)
recipe.trainer.callbacks.append(run.Config(TimingCallback))
recipe.trainer.callbacks.append(
run.Config(
Expand Down
36 changes: 23 additions & 13 deletions nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback
Expand Down Expand Up @@ -213,27 +214,36 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
It may not be suitable for all hardware configurations or use cases.
"""

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.
if not recipe.trainer.callbacks:
recipe.trainer.callbacks = []

garbage_collection_callback = (
run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
),
)
mcomm_overlap_callback = (
run.Config(
MegatronCommOverlapCallback,
# 'overlap_param_gather_with_optimizer_step' is set automatically. Added here for user's knowledge
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing
),
)
recipe.trainer.callbacks.extend(
[
run.Config(
MegatronTokenDropCallback,
),
run.Config(
MegatronCommOverlapCallback,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing
align_param_gather=True,
),
run.Config(MegatronTokenDropCallback),
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.strategy.expert_model_parallel_size = 1
recipe.trainer.strategy.tensor_model_parallel_size = 8
recipe.trainer.strategy.sequence_parallel = True
recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe


Expand Down
27 changes: 16 additions & 11 deletions nemo/collections/llm/recipes/mixtral_8x7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback
Expand Down Expand Up @@ -210,25 +211,29 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
It may not be suitable for all hardware configurations or use cases.
"""

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.

garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=100,
)
mcomm_overlap_callback = run.Config(
MegatronCommOverlapCallback,
# 'overlap_param_gather_with_optimizer_step' is set automatically. Added here for user's knowledge
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing.
)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(
MegatronCommOverlapCallback,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing.
align_param_gather=True,
),
garbage_collection_callback,
mcomm_overlap_callback,
]
)

recipe.trainer.strategy.expert_model_parallel_size = 1
recipe.trainer.strategy.tensor_model_parallel_size = 8
recipe.trainer.strategy.sequence_parallel = True
recipe.trainer.plugins.grad_reduce_in_fp32 = False

return recipe


Expand Down
Loading

0 comments on commit 78f445f

Please sign in to comment.