diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 8f7b046a6d..3ebf19f068 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -279,6 +279,7 @@ def main( seq_length=seq_length, bias_dropout_fusion=False, bias_activation_fusion=False, + defer_embedding_wgrad_compute=pipeline_model_parallel_size > 1, params_dtype=get_autocast_dtype(precision), pipeline_dtype=get_autocast_dtype(precision), autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index 283a43f73b..68eac670d4 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -199,7 +199,6 @@ def __init__( # noqa: D107 # megatron core pipelining currently depends on model type self.model_type = ModelType.encoder_or_decoder - # Embeddings. if self.pre_process: self.register_buffer( @@ -235,6 +234,21 @@ def __init__( # noqa: D107 # Output if post_process: # TODO: Make sure you are passing in the mpu_vocab_size properly + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + self.lm_head = BertLMHead( config.hidden_size, config, @@ -250,6 +264,8 @@ def __init__( # noqa: D107 skip_bias_add=False, gather_output=not self.parallel_output, skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, ) self.binary_head = None @@ -342,6 +358,7 @@ def forward( tokentype_ids: Optional[Tensor] = None, lm_labels: Optional[Tensor] = None, inference_params: Any | None = None, + runtime_gather_output: Optional[bool] = None, ) -> BioBertOutput | Tensor: """Forward function of BERT model @@ -420,6 +437,7 @@ def forward( hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states) if not self.skip_logits: + # TODO add , runtime_gather_output=runtime_gather_output once supported in ColumnParallelLinear logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight) else: logits = None