Skip to content

Commit

Permalink
Merge branch 'memory_fixes' into 'main'
Browse files Browse the repository at this point in the history
Clone output of view in _split_along_first_dim

See merge request ADLR/megatron-lm!937
  • Loading branch information
ericharper committed Nov 18, 2023
2 parents 480433f + 19afb90 commit 9290c73
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
7 changes: 5 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ def _add_training_args(parser):
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
dest='clone_scatter_output_in_embedding')
group.add_argument('--profile', action='store_true',
help='Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
Expand All @@ -821,9 +824,9 @@ def _add_training_args(parser):
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.')
group.add_argument('--profile-step-start', type=int, default=10,
help='Gloable step to start profiling.')
help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Gloable step to stop profiling.')
help='Global step to stop profiling.')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int =
# Dropout.
if self.config.sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
Expand Down
7 changes: 7 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class TransformerConfig(ModelParallelConfig):
fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision.
Defaults to True.
# Miscellaneous
clone_scatter_output_in_embedding (bool): When set to true, clone the output of scatter_to_sequence_parallel_region
in embedding layer to facilitate garbage collection of input.
# Experimental
normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily
used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`.
Expand Down Expand Up @@ -181,6 +185,9 @@ class TransformerConfig(ModelParallelConfig):
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True

# miscellaneous
clone_scatter_output_in_embedding: bool = True

# experimental section (TODO: move to apt. section above once stable)
normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm"

Expand Down
6 changes: 6 additions & 0 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self,

self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

Expand Down Expand Up @@ -234,6 +235,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
# Dropout.
if self.sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.clone_scatter_output_in_embedding:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
Expand Down

0 comments on commit 9290c73

Please sign in to comment.