-
Notifications
You must be signed in to change notification settings - Fork 20
Add config to enable padding on inner dims for scaled_mm inputs #145
Conversation
b4473ed
to
8f50785
Compare
do we have a sense of how promising this approach is in terms of performance? Are there shapes where float8 with padding is going to outperform bfloat16? |
I updated the benchmark to perform a bigger sweep, The shapes are pretty big but And these should all be cases where ,M%16 and N%16 == 0 and the semantic preserving padding of K **************************************TOPs**************************************
Shape Ref Dtype Ref Tops FP8 Tops Ref % Peak FP8 % Peak
------------------ -------------- ---------- ---------- ------------ ------------
(8193x2501x5008) torch.bfloat16 5.1e+14 8.17e+14 0.258 0.206
(65x253x4096) torch.bfloat16 1.07e+13 8.21e+12 0.005 0.002
(1023x1029x2512) torch.bfloat16 7.08e+13 1.98e+14 0.036 0.05
(4095x511x10000) torch.bfloat16 9.4e+13 5.52e+14 0.047 0.139
(2047x3073x8192) torch.bfloat16 1.14e+14 6.16e+14 0.058 0.156
(511x769x7504) torch.bfloat16 8.37e+13 1.68e+14 0.042 0.043
(127x4097x12288) torch.bfloat16 8.61e+13 8.55e+13 0.043 0.022
(32769x15x15024) torch.bfloat16 1.48e+13 3.27e+13 0.007 0.008
(9217x8191x20480) torch.bfloat16 1.2e+14 1.07e+15 0.061 0.271
(16385x1025x25008) torch.bfloat16 1.05e+14 8.11e+14 0.053 0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+------------+-----------+
| Shape | Ref Dtype | Ref Time | FP8 Time | Speedup |
+======================+================+============+============+===========+
| (8193, 2501, 5008) | torch.bfloat16 | 402.215 | 251.246 | 1.60088 |
+----------------------+----------------+------------+------------+-----------+
| (65, 253, 4096) | torch.bfloat16 | 12.5471 | 16.4149 | 0.764373 |
+----------------------+----------------+------------+------------+-----------+
| (1023, 1029, 2512) | torch.bfloat16 | 74.7011 | 26.6719 | 2.80074 |
+----------------------+----------------+------------+------------+-----------+
| (4095, 511, 10000) | torch.bfloat16 | 445.42 | 75.8169 | 5.87494 |
+----------------------+----------------+------------+------------+-----------+
| (2047, 3073, 8192) | torch.bfloat16 | 901.602 | 167.263 | 5.39033 |
+----------------------+----------------+------------+------------+-----------+
| (511, 769, 7504) | torch.bfloat16 | 70.5006 | 35.0095 | 2.01376 |
+----------------------+----------------+------------+------------+-----------+
| (127, 4097, 12288) | torch.bfloat16 | 148.589 | 149.542 | 0.993628 |
+----------------------+----------------+------------+------------+-----------+
| (32769, 15, 15024) | torch.bfloat16 | 996.979 | 451.53 | 2.208 |
+----------------------+----------------+------------+------------+-----------+
| (9217, 8191, 20480) | torch.bfloat16 | 25781.6 | 2886.31 | 8.93238 |
+----------------------+----------------+------------+------------+-----------+
| (16385, 1025, 25008) | torch.bfloat16 | 8037.08 | 1036.24 | 7.75598 |
+----------------------+----------------+------------+------------+-----------+ |
Definitely optional, would be also interesting to break out the performance if the user manually modified the model to have the right shapes versus using this PR to do it dynamically. This could help the user estimate if manually modifying the shapes is worth their time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!
Comparing against the aligned version: **************************************TOPs**************************************
Shape Ref Dtype Ref Tops Aligned BF16 Tops FP8 Tops Ref % Peak Aligned BF16 % Peak FP8 % Peak
------------------ -------------- ---------- ------------------- ---------- ------------ --------------------- ------------
(8193x2501x5008) torch.bfloat16 5.09e+14 5.02e+14 8.13e+14 0.257 0.253 0.205
(65x253x4096) torch.bfloat16 1.08e+13 1.24e+13 8.2e+12 0.005 0.006 0.002
(1023x1029x2512) torch.bfloat16 7.09e+13 1.82e+14 1.99e+14 0.036 0.092 0.05
(4095x511x10000) torch.bfloat16 9.41e+13 4.58e+14 5.52e+14 0.048 0.231 0.14
(2047x3073x8192) torch.bfloat16 1.15e+14 4.45e+14 6.18e+14 0.058 0.225 0.156
(511x769x7504) torch.bfloat16 8.35e+13 1.86e+14 1.68e+14 0.042 0.094 0.043
(127x4097x12288) torch.bfloat16 8.62e+13 7.72e+13 8.58e+13 0.044 0.039 0.022
(32769x15x15024) torch.bfloat16 1.48e+13 3.35e+13 3.27e+13 0.007 0.017 0.008
(9217x8191x20480) torch.bfloat16 1.2e+14 5.84e+14 1.07e+15 0.061 0.295 0.271
(16385x1025x25008) torch.bfloat16 1.04e+14 5.22e+14 8.11e+14 0.053 0.264 0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| Shape | Ref Dtype | Ref Time | Aligned BF16 Time | FP8 Time | Aligned BF16 Speedup | FP8 Speedup |
+======================+================+============+=====================+============+========================+===============+
| (8193, 2501, 5008) | torch.bfloat16 | 403.449 | 409.152 | 252.41 | 0.986061 | 1.59839 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (65, 253, 4096) | torch.bfloat16 | 12.4869 | 10.8703 | 16.4201 | 1.14871 | 0.760461 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (1023, 1029, 2512) | torch.bfloat16 | 74.5544 | 29.0095 | 26.5931 | 2.57 | 2.80352 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (4095, 511, 10000) | torch.bfloat16 | 444.699 | 91.3756 | 75.7621 | 4.86672 | 5.86968 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (2047, 3073, 8192) | torch.bfloat16 | 899.915 | 231.424 | 166.831 | 3.88861 | 5.39418 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (511, 769, 7504) | torch.bfloat16 | 70.6476 | 31.6989 | 35.0197 | 2.22871 | 2.01737 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (127, 4097, 12288) | torch.bfloat16 | 148.296 | 165.574 | 149.074 | 0.895648 | 0.994782 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (32769, 15, 15024) | torch.bfloat16 | 996.017 | 440.95 | 451.679 | 2.2588 | 2.20514 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (9217, 8191, 20480) | torch.bfloat16 | 25802.7 | 5294.79 | 2884.43 | 4.87324 | 8.94554 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (16385, 1025, 25008) | torch.bfloat16 | 8038.73 | 1609.56 | 1036.38 | 4.99436 | 7.75651 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+ Aligning does really help, I think that inductor has some passes somewhere to do this. That being said still some net benefit from float8 for certain shapes, but less impressive forsure |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This adds simple utilities that can be used in order to enable scaled_mm to work with non multiple of 16 matrices. This is done by padding the inputs to the function.
This also adds a script that can be used to explore the performance of different shapes. By running
python benchmarks/bench_padding.py
You can produce an example like this:
Example workflows that this really helps