This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add config to enable padding on inner dims for scaled_mm inputs (#145)
Summary: This adds simple utilities that can be used in order to enable scaled_mm to work with non multiple of 16 matrices. This is done by padding the inputs to the function. This also adds a script that can be used to explore the performance of different shapes. By running `python benchmarks/bench_padding.py` You can produce an example like this: ```Shell **************************************TOPs************************************** Shape Ref Dtype Ref Tops FP8 Tops Ref % Peak FP8 % Peak ------------------ -------------- ---------- ---------- ------------ ------------ (8193x2501x5008) torch.bfloat16 5.1e+14 8.17e+14 0.258 0.206 (65x253x4096) torch.bfloat16 1.07e+13 8.21e+12 0.005 0.002 (1023x1029x2512) torch.bfloat16 7.08e+13 1.98e+14 0.036 0.05 (4095x511x10000) torch.bfloat16 9.4e+13 5.52e+14 0.047 0.139 (2047x3073x8192) torch.bfloat16 1.14e+14 6.16e+14 0.058 0.156 (511x769x7504) torch.bfloat16 8.37e+13 1.68e+14 0.042 0.043 (127x4097x12288) torch.bfloat16 8.61e+13 8.55e+13 0.043 0.022 (32769x15x15024) torch.bfloat16 1.48e+13 3.27e+13 0.007 0.008 (9217x8191x20480) torch.bfloat16 1.2e+14 1.07e+15 0.061 0.271 (16385x1025x25008) torch.bfloat16 1.05e+14 8.11e+14 0.053 0.205 *********************************Speed Results********************************** +----------------------+----------------+------------+------------+-----------+ | Shape | Ref Dtype | Ref Time | FP8 Time | Speedup | +======================+================+============+============+===========+ | (8193, 2501, 5008) | torch.bfloat16 | 402.215 | 251.246 | 1.60088 | +----------------------+----------------+------------+------------+-----------+ | (65, 253, 4096) | torch.bfloat16 | 12.5471 | 16.4149 | 0.764373 | +----------------------+----------------+------------+------------+-----------+ | (1023, 1029, 2512) | torch.bfloat16 | 74.7011 | 26.6719 | 2.80074 | +----------------------+----------------+------------+------------+-----------+ | (4095, 511, 10000) | torch.bfloat16 | 445.42 | 75.8169 | 5.87494 | +----------------------+----------------+------------+------------+-----------+ | (2047, 3073, 8192) | torch.bfloat16 | 901.602 | 167.263 | 5.39033 | +----------------------+----------------+------------+------------+-----------+ | (511, 769, 7504) | torch.bfloat16 | 70.5006 | 35.0095 | 2.01376 | +----------------------+----------------+------------+------------+-----------+ | (127, 4097, 12288) | torch.bfloat16 | 148.589 | 149.542 | 0.993628 | +----------------------+----------------+------------+------------+-----------+ | (32769, 15, 15024) | torch.bfloat16 | 996.979 | 451.53 | 2.208 | +----------------------+----------------+------------+------------+-----------+ | (9217, 8191, 20480) | torch.bfloat16 | 25781.6 | 2886.31 | 8.93238 | +----------------------+----------------+------------+------------+-----------+ | (16385, 1025, 25008) | torch.bfloat16 | 8037.08 | 1036.24 | 7.75598 | +----------------------+----------------+------------+------------+-----------+ ``` ## Example workflows that this really helps ``` Python import torch from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda") # Convert all torch.nn.Linear modules to Float8DynamicLinear from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_dynamic_linear import Float8DynamicLinear import float8_experimental float8_experimental.config.pad_inner_dim = True swap_linear_with_float8_linear(model, Float8DynamicLinear) # Wrap model with Fully Sharded Data Parallel (FSDP) import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP import os os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' os.environ['WORLD_SIZE'] = '1' os.environ['RANK'] = '0' dist.init_process_group(backend='nccl', init_method='env://') # model = FSDP(model, use_orig_params=True) # optionally compile the model # model = torch.compile(model) # Prepare your dataset and dataloader (customize this part as needed) class TextDataset(torch.utils.data.Dataset): def __init__(self, texts, tokenizer): self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512) def __getitem__(self, idx): return {key: val[idx] for key, val in self.encodings.items()} def __len__(self): return len(self.encodings.input_ids) # Example text data texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."] dataset = TextDataset(texts, tokenizer) dataloader = DataLoader(dataset, batch_size=2) # Set up the optimizer # optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) optimizer = torch.optim.SGD(model.parameters(), lr=5e-4) # Training loop model.train() for epoch in range(3): # Loop over the dataset multiple times for i, batch in enumerate(dataloader): inputs = {k: v.to(model.device) for k, v in batch.items()} # Forward pass outputs = model(**inputs, labels=inputs['input_ids']) loss = outputs.loss # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}') # Save the fine-tuned model model.save_pretrained("./fine_tuned_model") print("Training complete!") ``` Pull Request resolved: #145 Reviewed By: vkuzo Differential Revision: D58958442 Pulled By: drisspg fbshipit-source-id: 5a4c8661e974699ce3f83748fca1ce1f0ad65d3b
- Loading branch information