diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py index 7e92fff65..04310aeab 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py @@ -259,7 +259,7 @@ def forward( return loss_for_microbatch * cp_size, {"avg": reduced_loss} -def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = True) -> Tensor: +def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = False) -> Tensor: """Computes the unreduced token loss given the logits and labels without regard to the loss mask. WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.