You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When AC is on for Float8Linear, what I would expect is:
the forward gemm is recomputed in the backward (it is not being recomputed now)
max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)
Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require #892
bfloat16 linear fwd/bwd with activation checkpointing on
issue 1: there are only two gemms in the backward instead of three
issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))
The text was updated successfully, but these errors were encountered:
the torch._scaled_mm behavior seems fine
the max(abs(tensor)) behavior seems inoptimal and we can do better with custom AC settings. I wrote up pytorch/torchtitan#580 with initial findings, will follow up after the conferences this week with more.
When AC is on for Float8Linear, what I would expect is:
Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require #892
bfloat16 linear fwd/bwd with activation checkpointing on
repro command
trace snippet
we see 1 gemm in the forward and 3 in the backward, as expected
Float8Linear fwd/bwd with activation checkpointing on
repro command
trace snippet
issue 1: there are only two gemms in the backward instead of three
issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))
The text was updated successfully, but these errors were encountered: