Skip to content

Commit

Permalink
Updated defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Nov 7, 2024
1 parent 0664aed commit a7084b5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def main(
seq_length=seq_length,
bias_dropout_fusion=False,
bias_activation_fusion=False,
defer_embedding_wgrad_compute=pipeline_model_parallel_size > 1,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
Expand Down
20 changes: 19 additions & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def __init__( # noqa: D107

# megatron core pipelining currently depends on model type
self.model_type = ModelType.encoder_or_decoder

# Embeddings.
if self.pre_process:
self.register_buffer(
Expand Down Expand Up @@ -235,6 +234,21 @@ def __init__( # noqa: D107
# Output
if post_process:
# TODO: Make sure you are passing in the mpu_vocab_size properly
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None

self.lm_head = BertLMHead(
config.hidden_size,
config,
Expand All @@ -250,6 +264,8 @@ def __init__( # noqa: D107
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)

self.binary_head = None
Expand Down Expand Up @@ -342,6 +358,7 @@ def forward(
tokentype_ids: Optional[Tensor] = None,
lm_labels: Optional[Tensor] = None,
inference_params: Any | None = None,
runtime_gather_output: Optional[bool] = None,
) -> BioBertOutput | Tensor:
"""Forward function of BERT model
Expand Down Expand Up @@ -420,6 +437,7 @@ def forward(

hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
if not self.skip_logits:
# TODO add , runtime_gather_output=runtime_gather_output once supported in ColumnParallelLinear
logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
else:
logits = None
Expand Down

0 comments on commit a7084b5

Please sign in to comment.