Open
Description
When AC is on for Float8Linear, what I would expect is:
- the forward gemm is recomputed in the backward (it is not being recomputed now)
- max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)
Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require #892
bfloat16 linear fwd/bwd with activation checkpointing on
repro command
python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter bfloat16 --enable_activation_checkpointing True
trace snippet

we see 1 gemm in the forward and 3 in the backward, as expected
Float8Linear fwd/bwd with activation checkpointing on
repro command
python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter float8 --enable_activation_checkpointing True
trace snippet

issue 1: there are only two gemms in the backward instead of three
issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))
Metadata
Metadata
Assignees
Labels
No labels