diff --git a/megatron/arguments.py b/megatron/arguments.py index 2ca514f68b..0ca8776eda 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -813,6 +813,9 @@ def _add_training_args(parser): 'uniformly divided recompute unit, ' '2) block: the number of individual Transformer layers ' 'to recompute within each pipeline stage.') + group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', + help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', + dest='clone_scatter_output_in_embedding') group.add_argument('--profile', action='store_true', help='Enable nsys profiling. When using this option, nsys ' 'options should be specified in commandline. An example ' @@ -821,9 +824,9 @@ def _add_training_args(parser): '--capture-range=cudaProfilerApi ' '--capture-range-end=stop`.') group.add_argument('--profile-step-start', type=int, default=10, - help='Gloable step to start profiling.') + help='Global step to start profiling.') group.add_argument('--profile-step-end', type=int, default=12, - help='Gloable step to stop profiling.') + help='Global step to stop profiling.') group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], help='Global ranks to profile.') group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the ' diff --git a/megatron/core/models/common/embeddings/language_model_embedding.py b/megatron/core/models/common/embeddings/language_model_embedding.py index 6fa6efcaf8..40d679d7b1 100644 --- a/megatron/core/models/common/embeddings/language_model_embedding.py +++ b/megatron/core/models/common/embeddings/language_model_embedding.py @@ -119,6 +119,11 @@ def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = # Dropout. if self.config.sequence_parallel: embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # `scatter_to_sequence_parallel_region` returns a view, which prevents + # the original tensor from being garbage collected. Clone to facilitate GC. + # Has a small runtime cost (~0.5%). + if self.config.clone_scatter_output_in_embedding: + embeddings = embeddings.clone() with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6d2dd5f525..adccd4409b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -123,6 +123,10 @@ class TransformerConfig(ModelParallelConfig): fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision. Defaults to True. + # Miscellaneous + clone_scatter_output_in_embedding (bool): When set to true, clone the output of scatter_to_sequence_parallel_region + in embedding layer to facilitate garbage collection of input. + # Experimental normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`. @@ -181,6 +185,9 @@ class TransformerConfig(ModelParallelConfig): fp8_amax_compute_algo: str = "most_recent" fp8_wgrad: bool = True + # miscellaneous + clone_scatter_output_in_embedding: bool = True + # experimental section (TODO: move to apt. section above once stable) normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm" diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 4cbdd2eef5..69bfa2e801 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -178,6 +178,7 @@ def __init__(self, self.fp32_residual_connection = args.fp32_residual_connection self.sequence_parallel = args.sequence_parallel + self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) @@ -234,6 +235,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. if self.sequence_parallel: embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # `scatter_to_sequence_parallel_region` returns a view, which prevents + # the original tensor from being garbage collected. Clone to facilitate GC. + # Has a small runtime cost (~0.5%). + if self.clone_scatter_output_in_embedding: + embeddings = embeddings.clone() with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: