Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions: Clarifying the use of FP8 for Training #99

Closed
jon-chuang opened this issue Sep 10, 2023 · 2 comments
Closed

Questions: Clarifying the use of FP8 for Training #99

jon-chuang opened this issue Sep 10, 2023 · 2 comments

Comments

@jon-chuang
Copy link

jon-chuang commented Sep 10, 2023

@tocean @wkcn

In line with the investigation in NVIDIA/TransformerEngine#424, it would be great to get the insights from the team at microsoft for using FP8 in aspects of training besides matmul.

Questions

1. Performance

The repo only mention training accuracy and memory savings. However, the kernels may not be very optimized and majority is implemented in Torch. I guess that performance is still unexplored.

2. Weight Update

  • is the weight update applied while the backward pass is running (on-the-fly)? Or is it applied after the entire backward pass is complete?
    • Seems that one can get memory savings from on-the-fly
    • Is there a possibility to use CUDA graphs API to efficiently schedule the weight updates concurrently even if it does not saturate the GPU? Furthermore, one should be able to batch uneven-sized weight update into a single kernel invocation based on an outstanding_weight_updates_bytes threshold.

3. More Accurate Scaling Factors

Is there a way to maintain more accurate amax by estimating:

  • For e.g. naive SGD case:
    • scaling_factor_weights_t = amax_weights_t-1 + amax_grad_t - this is an accurate upper bound (no necessity of apriori knowledge)
    • amax_weights_t = max(abs(weights_t)) - this is only used for the next iteration
  • For Adam optimizer:
    • Utilizing e5m2 might be able to help with dynamic range for v (same dynamic range as FP16).
    • Storing sqrt_v rather than v may help the precision. Update rule: see appendix
      • Intuition: sqrt will reduce the dynamic range of bits by half (2^16 -> 2^8, 2^-16 -> 2^-8). Hence we perform sqrt in fp32/fp16 and quantize that as fp8, thus preserving the dynamic range
    • A more rigorous analysis is needed here.
  • If it is possible to better estimate scaling_factor_weights_t then it may be possible to use more of the dynamic range. Hence, storing the weights as FP8 (rather than FP16 as in the MS-AMP repo) might be possible.
    • Since Adam optimizer is momentum-based, the effect of deviation of amax on a per-batch basis is more bounded.

4. Adaptive Precision

Has it been explored using lower precision (FP8) at high learning rate (at earlier epochs) and higher precision (e.g. FP32, FP16) at lower learning rate (at later epochs)?

Appendix

Update Rule for sqrt_v_fp8

scaling_factor = amax_sqrt_v_prev / 448. # 2^8 * (1 + 3/4) - use more of fp8e5m2 dynamic range. margin = 7

v_fp32 = pow2(sqrt_v_fp8.to(dtype.fp32) * scaling_factor)
v_new = beta_2 * v_fp32 + (1 - beta_2) * grad_sq
sqrt_v_fp8 = (sqrt(v_new) / scaling_factor).to(dtype.fp8e5m2)

# end of loop
amax_sqrt_v_new = sqrt(max(v_new))

Notes:

  1. If amax_sqrt_v_fp8 = 448.0, then the scaling factor is 1. This is captured in margin bits:
    def compute_scaling_factor(amax, scale, fp_max: float, margin: int):
@wkcn
Copy link
Contributor

wkcn commented Oct 12, 2023

Hi @jon-chuang , I am sorry for late reply.
Thanks for your attention to our work!

1. Performance

The repo only mention training accuracy and memory savings. However, the kernels may not be very optimized and majority is implemented in Torch. I guess that performance is still unexplored.

Yes. Our first step is to apply FP8 format as much as possible to reduce memory footprint while maintaining accuracy, and the second step is to optimize the performance in MS-AMP. MS-AMP can be combined with the TransformerEngine to invoke optimized operators in TE. (Related PR: #98)

2. Weight Update

  • is the weight update applied while the backward pass is running (on-the-fly)? Or is it applied after the entire backward pass is complete?

    a) Seems that one can get memory savings from on-the-fly
    b) Is there a possibility to use CUDA graphs API to efficiently schedule the weight updates concurrently even if it does not saturate the GPU? Furthermore, one should be able to batch uneven-sized weight update into a single kernel invocation based on an outstanding_weight_updates_bytes threshold.

a) It is applied after the entire backward pass is complete. The FP8 weights are updated in the optimizer. (https://github.com/Azure/MS-AMP/blob/main/msamp/optim/adamw.py#L193).

b) Good idea. I had tried using an additional CUDA stream for weight update, but it did not achieve the dessired acceleration, probably due to my implementation not being optimal : ) However, I still believe that it is effective to schedule weight updates concurrently, since weight update does not affect the calculation of backpropagation.

It is available to update multiple FP8 weights in a single CUDA kernel, but it is notable that the FP8 tensor with a scaling factor should be treated as a whole. The maximum value of the entire tensor should be computed before quantization a high-precision tensor to a FP8 tensor.

3. More Accurate Scaling Factors

Is there a way to maintain more accurate amax by estimating:

  • For e.g. naive SGD case:
    • scaling_factor_weights_t = amax_weights_t-1 + amax_grad_t - this is an accurate upper bound (no necessity of apriori knowledge)
    • amax_weights_t = max(abs(weights_t)) - this is only used for the next iteration
  • For Adam optimizer:
    • Utilizing e5m2 might be able to help with dynamic range for v (same dynamic range as FP16).
    • Storing sqrt_v rather than v may help the precision. Update rule: see appendix
      • Intuition: sqrt will reduce the dynamic range of bits by half (2^16 -> 2^8, 2^-16 -> 2^-8). Hence we perform sqrt in fp32/fp16 and quantize that as fp8, thus preserving the dynamic range
    • A more rigorous analysis is needed here.
  • If it is possible to better estimate scaling_factor_weights_t then it may be possible to use more of the dynamic range. Hence, storing the weights as FP8 (rather than FP16 as in the MS-AMP repo) might be possible.
    • Since Adam optimizer is momentum-based, the effect of deviation of amax on a per-batch basis is more bounded.
  • Accurate amax
    In MS-AMP, the update of weights contains two stages. In the first stage, we pre-compute the updated high-precision weight and obtain the amax. In the second stage, we re-compute the updated high-precision weight and quantize it into FP8 with a scaling factor. So it's an accurate amax but it increases the computational overhead. I agree that using amax_weights_t = max(abs(weights_t)) for the next iteration is an effective approach.

  • Storing the weights as FP8
    In MS-AMP, the master weight is FP16 with a scaling factor, and the weight is FP8 with a scaling factor. The choice of FP16 as the master weight is primarily for precision rather than range.

  • Utilizing sqrt(value) and e5m2
    I believe it is a good approach and worth trying.
    Sqrt can reduce the dynamic range, but it can not handle the issue of low precision.

4. Adaptive Precision

Has it been explored using lower precision (FP8) at high learning rate (at earlier epochs) and higher precision (e.g. FP32, FP16) at lower learning rate (at later epochs)?

No. This approach requires preserving enough memory in earlier epochs to store high-precision weights in later stages, which may not be as efficient as using high-precision weights and low-bit computations.

@tocean
Copy link
Contributor

tocean commented Aug 2, 2024

Close this issue since there are no activities more than 9 months

@tocean tocean closed this as completed Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants