Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 training with rowwise scaling #889

Open
vkuzo opened this issue Sep 16, 2024 · 9 comments
Open

float8 training with rowwise scaling #889

vkuzo opened this issue Sep 16, 2024 · 9 comments
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Sep 16, 2024

This is a brain dump of what is missing from torchao.float8 to support training with rowwise scaling, to help if someone wants to jump in to build this.

already done

  • torch._scaled_mm supports rowwise scaling
  • inductor supports rowwise scaled gemms, in max-autotune mode (I haven't personally tested this yet)

needed

  1. we need Float8Tensor to work with rowwise scales. We had an unlanded PR on float8_experimental doing that here ([wip] add axiswise granularity to Float8Tensor pytorch-labs/float8_experimental#352), just never got the time to land it. You can reuse that PR or do something similar. Note that [Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819 landed recently adding float8 rowwise scaling to inference, so being consistent with that where applicable would be nice.
  2. we need Float8Linear to be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes to torchao.float8.config.py and torchao.float8.float8_linear.py.
  3. after (1) and (2), we could make each gemm configurable to enable leaving some of them in high precision
  4. performance fixes throughout torchao.float8 and inductor, if needed based on how well inductor generates the scaling code
@gau-nernst
Copy link
Collaborator

I was thinking about similar topics for int8 training too. Just curious, how do you plan to handle backward pass for grad_input = grad_output @ weight? Will you do column-wise scaling for weight instead (or equivalently, row-wise scaling for weight.T)? This is actually my approach for int8 training, which seems to work fine, but there will be an issue if we try to do quantization in FSDP pre-all-gather: the scaling axis is different for forward and backward -> need different behavior for forward and backward, which I think is not possible with the current FSDP2 API?

@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 16, 2024

Just curious, how do you plan to handle backward pass for grad_input = grad_output @ weight?

Well, at a high level, I want torchao.float8 to support all the possible knobs so we can experiment with finding the best recipes in a data driven way.

If I had to guess today what I expect to work well with FSDP, I'd say either tensorwise or blockwise scaling of the weight, so we can scale once and transpose without rescaling. In the future it would be great to have scaled float8 gemm support where one of the arguments is scaled tensorwise/blockwise (the weight), and the other rowwise.

@awgu
Copy link
Contributor

awgu commented Sep 16, 2024

@gau-nernst I will say that supporting row-wise scaling with FSDP2 pre-all-gather is painful, and I am not sure if we should ever do it (at least with FSDP2 -- maybe for some other FSDP implementations, it can work).

I have some more thoughts, but for one, if the user enables activation checkpointing, then the backward all-gather should now all-gather both for weight and weight.T, in which case all-gathering in bf16 is probably simpler.

@lw
Copy link

lw commented Sep 16, 2024

The recipe we've been working on does row-wise scaling of the weights post-all-gather (hence the comms happen in bf16), and tensor-wise scaling (+transposition) of the weights in the backward (leveraging the scaling factor of the forward to avoid recomputing a full amax from scratch). Ideally we would have done real column-wise scaling of weights in the backward but that was too hard.

@lw
Copy link

lw commented Sep 16, 2024

BTW, I plan to tackle some of these points in the next few days. I'll report here with where I get to.

@lw lw self-assigned this Sep 16, 2024
@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 18, 2024

@vkuzo if we enable training for AffineQuantizedTensor (with float8) this will be supported automatically I think? since affine quantized tensor covers all granularities, is this correct? cc @andrewor14

@lw
Copy link

lw commented Sep 18, 2024

@jerryzh168 @vkuzo I'd appreciate clarity on this ASAP to avoid investing too much time in this if it's unnecessary, thanks!

@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 18, 2024

adding rowwise scaling to Float8Tensor

~hours to days of work, IMO

@vkuzo if we enable training for AffineQuantizedTensor (with float8) this will be supported automatically I think?

~weeks of alignment and work IMO, if we include aligning everyone that this is in fact what we want (I personally am not convinced just yet, AffineQuantizedTensor seems a bit too complicated for what float8 needs right now), and getting float8 training to the same level of robustness there as it is in torchao.float8.

Given the timeline estimates above, I would just add rowwise scaling to torchao.float8 now and deal with the potential unification separately. #894 is a great place to discuss unification.

@jerryzh168
Copy link
Contributor

@vkuzo OK sounds good, @lw please feel free to start adding the support for Float8Tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants