Skip to content

Commit

Permalink
[PyTorch] Set flags in norm modules for Mcore sequence-parallel suppo…
Browse files Browse the repository at this point in the history
…rt (NVIDIA#1528)

Set flag in norm modules for Mcore sequence-parallel support

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Mar 1, 2025
1 parent d3efaeb commit 4b523d2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def __init__(

# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(

# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel

def reset_rms_norm_parameters(self) -> None:
"""Deprecated"""
Expand Down

0 comments on commit 4b523d2

Please sign in to comment.