Skip to content

Commit

Permalink
Fix FSDP in transformer4.38 (huggingface#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta authored Mar 17, 2024
1 parent e29aacc commit 6e36e18
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,11 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign
if is_accelerate_available() and self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
grad_norm = (
_grad_norm.item()
if (_grad_norm is not None and self.accelerator.distributed_type != GaudiDistributedType.FSDP)
else None
)

if grad_norm is not None:
logs["grad_norm"] = grad_norm
Expand Down

0 comments on commit 6e36e18

Please sign in to comment.