Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add config to enable padding on inner dims for scaled_mm inputs #145

Closed
wants to merge 7 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Nov 16, 2023

Summary

This adds simple utilities that can be used in order to enable scaled_mm to work with non multiple of 16 matrices. This is done by padding the inputs to the function.

This also adds a script that can be used to explore the performance of different shapes. By running
python benchmarks/bench_padding.py
You can produce an example like this:

**************************************TOPs**************************************
Shape               Ref Dtype         Ref Tops    FP8 Tops    Ref % Peak    FP8 % Peak
------------------  --------------  ----------  ----------  ------------  ------------
(8193x2501x5008)    torch.bfloat16    5.1e+14     8.17e+14         0.258         0.206
(65x253x4096)       torch.bfloat16    1.07e+13    8.21e+12         0.005         0.002
(1023x1029x2512)    torch.bfloat16    7.08e+13    1.98e+14         0.036         0.05
(4095x511x10000)    torch.bfloat16    9.4e+13     5.52e+14         0.047         0.139
(2047x3073x8192)    torch.bfloat16    1.14e+14    6.16e+14         0.058         0.156
(511x769x7504)      torch.bfloat16    8.37e+13    1.68e+14         0.042         0.043
(127x4097x12288)    torch.bfloat16    8.61e+13    8.55e+13         0.043         0.022
(32769x15x15024)    torch.bfloat16    1.48e+13    3.27e+13         0.007         0.008
(9217x8191x20480)   torch.bfloat16    1.2e+14     1.07e+15         0.061         0.271
(16385x1025x25008)  torch.bfloat16    1.05e+14    8.11e+14         0.053         0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+------------+-----------+
| Shape                | Ref Dtype      |   Ref Time |   FP8 Time |   Speedup |
+======================+================+============+============+===========+
| (8193, 2501, 5008)   | torch.bfloat16 |   402.215  |   251.246  |  1.60088  |
+----------------------+----------------+------------+------------+-----------+
| (65, 253, 4096)      | torch.bfloat16 |    12.5471 |    16.4149 |  0.764373 |
+----------------------+----------------+------------+------------+-----------+
| (1023, 1029, 2512)   | torch.bfloat16 |    74.7011 |    26.6719 |  2.80074  |
+----------------------+----------------+------------+------------+-----------+
| (4095, 511, 10000)   | torch.bfloat16 |   445.42   |    75.8169 |  5.87494  |
+----------------------+----------------+------------+------------+-----------+
| (2047, 3073, 8192)   | torch.bfloat16 |   901.602  |   167.263  |  5.39033  |
+----------------------+----------------+------------+------------+-----------+
| (511, 769, 7504)     | torch.bfloat16 |    70.5006 |    35.0095 |  2.01376  |
+----------------------+----------------+------------+------------+-----------+
| (127, 4097, 12288)   | torch.bfloat16 |   148.589  |   149.542  |  0.993628 |
+----------------------+----------------+------------+------------+-----------+
| (32769, 15, 15024)   | torch.bfloat16 |   996.979  |   451.53   |  2.208    |
+----------------------+----------------+------------+------------+-----------+
| (9217, 8191, 20480)  | torch.bfloat16 | 25781.6    |  2886.31   |  8.93238  |
+----------------------+----------------+------------+------------+-----------+
| (16385, 1025, 25008) | torch.bfloat16 |  8037.08   |  1036.24   |  7.75598  |
+----------------------+----------------+------------+------------+-----------+

Example workflows that this really helps

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda")

# Convert all torch.nn.Linear modules to Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

import float8_experimental
float8_experimental.config.pad_inner_dim = True

swap_linear_with_float8_linear(model, Float8DynamicLinear)

# Wrap model with Fully Sharded Data Parallel (FSDP)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'

dist.init_process_group(backend='nccl', init_method='env://')

# model = FSDP(model, use_orig_params=True)

# optionally compile the model
# model = torch.compile(model)

# Prepare your dataset and dataloader (customize this part as needed)
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer):
        self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# Example text data
texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)

# Set up the optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4)

# Training loop
model.train()
for epoch in range(3):  # Loop over the dataset multiple times
    for i, batch in enumerate(dataloader):
        inputs = {k: v.to(model.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}')

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")

print("Training complete!")

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 16, 2023
@drisspg drisspg force-pushed the pad_helper branch 2 times, most recently from b4473ed to 8f50785 Compare June 16, 2024 22:52
@drisspg drisspg requested review from vkuzo and msaroufim June 16, 2024 22:52
@vkuzo
Copy link
Contributor

vkuzo commented Jun 18, 2024

do we have a sense of how promising this approach is in terms of performance? Are there shapes where float8 with padding is going to outperform bfloat16?

@drisspg
Copy link
Contributor Author

drisspg commented Jun 23, 2024

I updated the benchmark to perform a bigger sweep, The shapes are pretty big but

And these should all be cases where ,M%16 and N%16 == 0 and the semantic preserving padding of K

**************************************TOPs**************************************
Shape               Ref Dtype         Ref Tops    FP8 Tops    Ref % Peak    FP8 % Peak
------------------  --------------  ----------  ----------  ------------  ------------
(8193x2501x5008)    torch.bfloat16    5.1e+14     8.17e+14         0.258         0.206
(65x253x4096)       torch.bfloat16    1.07e+13    8.21e+12         0.005         0.002
(1023x1029x2512)    torch.bfloat16    7.08e+13    1.98e+14         0.036         0.05
(4095x511x10000)    torch.bfloat16    9.4e+13     5.52e+14         0.047         0.139
(2047x3073x8192)    torch.bfloat16    1.14e+14    6.16e+14         0.058         0.156
(511x769x7504)      torch.bfloat16    8.37e+13    1.68e+14         0.042         0.043
(127x4097x12288)    torch.bfloat16    8.61e+13    8.55e+13         0.043         0.022
(32769x15x15024)    torch.bfloat16    1.48e+13    3.27e+13         0.007         0.008
(9217x8191x20480)   torch.bfloat16    1.2e+14     1.07e+15         0.061         0.271
(16385x1025x25008)  torch.bfloat16    1.05e+14    8.11e+14         0.053         0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+------------+-----------+
| Shape                | Ref Dtype      |   Ref Time |   FP8 Time |   Speedup |
+======================+================+============+============+===========+
| (8193, 2501, 5008)   | torch.bfloat16 |   402.215  |   251.246  |  1.60088  |
+----------------------+----------------+------------+------------+-----------+
| (65, 253, 4096)      | torch.bfloat16 |    12.5471 |    16.4149 |  0.764373 |
+----------------------+----------------+------------+------------+-----------+
| (1023, 1029, 2512)   | torch.bfloat16 |    74.7011 |    26.6719 |  2.80074  |
+----------------------+----------------+------------+------------+-----------+
| (4095, 511, 10000)   | torch.bfloat16 |   445.42   |    75.8169 |  5.87494  |
+----------------------+----------------+------------+------------+-----------+
| (2047, 3073, 8192)   | torch.bfloat16 |   901.602  |   167.263  |  5.39033  |
+----------------------+----------------+------------+------------+-----------+
| (511, 769, 7504)     | torch.bfloat16 |    70.5006 |    35.0095 |  2.01376  |
+----------------------+----------------+------------+------------+-----------+
| (127, 4097, 12288)   | torch.bfloat16 |   148.589  |   149.542  |  0.993628 |
+----------------------+----------------+------------+------------+-----------+
| (32769, 15, 15024)   | torch.bfloat16 |   996.979  |   451.53   |  2.208    |
+----------------------+----------------+------------+------------+-----------+
| (9217, 8191, 20480)  | torch.bfloat16 | 25781.6    |  2886.31   |  8.93238  |
+----------------------+----------------+------------+------------+-----------+
| (16385, 1025, 25008) | torch.bfloat16 |  8037.08   |  1036.24   |  7.75598  |
+----------------------+----------------+------------+------------+-----------+

@drisspg drisspg changed the title Add utilities for padding and add to bench_padding.py Add config to enable padding on inner dims for scaled_mm inputs Jun 24, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Jun 24, 2024

+----------------------+----------------+------------+------------+-----------+
| Shape                | Ref Dtype      |   Ref Time |   FP8 Time |   Speedup |
+======================+================+============+============+===========+
| (8193, 2501, 5008)   | torch.bfloat16 |   402.215  |   251.246  |  1.60088  |

Definitely optional, would be also interesting to break out the performance if the user manually modified the model to have the right shapes versus using this PR to do it dynamically. This could help the user estimate if manually modifying the shapes is worth their time.

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

@drisspg
Copy link
Contributor Author

drisspg commented Jun 24, 2024

Comparing against the aligned version:

**************************************TOPs**************************************
Shape               Ref Dtype         Ref Tops    Aligned BF16 Tops    FP8 Tops    Ref % Peak    Aligned BF16 % Peak    FP8 % Peak
------------------  --------------  ----------  -------------------  ----------  ------------  ---------------------  ------------
(8193x2501x5008)    torch.bfloat16    5.09e+14             5.02e+14    8.13e+14         0.257                  0.253         0.205
(65x253x4096)       torch.bfloat16    1.08e+13             1.24e+13    8.2e+12          0.005                  0.006         0.002
(1023x1029x2512)    torch.bfloat16    7.09e+13             1.82e+14    1.99e+14         0.036                  0.092         0.05
(4095x511x10000)    torch.bfloat16    9.41e+13             4.58e+14    5.52e+14         0.048                  0.231         0.14
(2047x3073x8192)    torch.bfloat16    1.15e+14             4.45e+14    6.18e+14         0.058                  0.225         0.156
(511x769x7504)      torch.bfloat16    8.35e+13             1.86e+14    1.68e+14         0.042                  0.094         0.043
(127x4097x12288)    torch.bfloat16    8.62e+13             7.72e+13    8.58e+13         0.044                  0.039         0.022
(32769x15x15024)    torch.bfloat16    1.48e+13             3.35e+13    3.27e+13         0.007                  0.017         0.008
(9217x8191x20480)   torch.bfloat16    1.2e+14              5.84e+14    1.07e+15         0.061                  0.295         0.271
(16385x1025x25008)  torch.bfloat16    1.04e+14             5.22e+14    8.11e+14         0.053                  0.264         0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| Shape                | Ref Dtype      |   Ref Time |   Aligned BF16 Time |   FP8 Time |   Aligned BF16 Speedup |   FP8 Speedup |
+======================+================+============+=====================+============+========================+===============+
| (8193, 2501, 5008)   | torch.bfloat16 |   403.449  |            409.152  |   252.41   |               0.986061 |      1.59839  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (65, 253, 4096)      | torch.bfloat16 |    12.4869 |             10.8703 |    16.4201 |               1.14871  |      0.760461 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (1023, 1029, 2512)   | torch.bfloat16 |    74.5544 |             29.0095 |    26.5931 |               2.57     |      2.80352  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (4095, 511, 10000)   | torch.bfloat16 |   444.699  |             91.3756 |    75.7621 |               4.86672  |      5.86968  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (2047, 3073, 8192)   | torch.bfloat16 |   899.915  |            231.424  |   166.831  |               3.88861  |      5.39418  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (511, 769, 7504)     | torch.bfloat16 |    70.6476 |             31.6989 |    35.0197 |               2.22871  |      2.01737  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (127, 4097, 12288)   | torch.bfloat16 |   148.296  |            165.574  |   149.074  |               0.895648 |      0.994782 |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (32769, 15, 15024)   | torch.bfloat16 |   996.017  |            440.95   |   451.679  |               2.2588   |      2.20514  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (9217, 8191, 20480)  | torch.bfloat16 | 25802.7    |           5294.79   |  2884.43   |               4.87324  |      8.94554  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+
| (16385, 1025, 25008) | torch.bfloat16 |  8038.73   |           1609.56   |  1036.38   |               4.99436  |      7.75651  |
+----------------------+----------------+------------+---------------------+------------+------------------------+---------------+

Aligning does really help, I think that inductor has some passes somewhere to do this. That being said still some net benefit from float8 for certain shapes, but less impressive forsure

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 57136bd.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants