Skip to content

we should ensure activation checkpointing with Float8Linear behaves optimally #893

Open
@vkuzo

Description

@vkuzo

When AC is on for Float8Linear, what I would expect is:

  1. the forward gemm is recomputed in the backward (it is not being recomputed now)
  2. max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)

Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require #892

bfloat16 linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter bfloat16 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 2 50 54 PM

we see 1 gemm in the forward and 3 in the backward, as expected

Float8Linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter float8 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 3 05 37 PM

issue 1: there are only two gemms in the backward instead of three
issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions