Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add config to enable padding on inner dims for scaled_mm inputs (#145)
Browse files Browse the repository at this point in the history
Summary:
This adds simple utilities that can be used in order to enable scaled_mm to work with non multiple of 16 matrices. This is done by padding the inputs to the function.

This also adds a script that can be used to explore the performance of different shapes. By running
`python benchmarks/bench_padding.py`
You can produce an example like this:

```Shell
**************************************TOPs**************************************
Shape               Ref Dtype         Ref Tops    FP8 Tops    Ref % Peak    FP8 % Peak
------------------  --------------  ----------  ----------  ------------  ------------
(8193x2501x5008)    torch.bfloat16    5.1e+14     8.17e+14         0.258         0.206
(65x253x4096)       torch.bfloat16    1.07e+13    8.21e+12         0.005         0.002
(1023x1029x2512)    torch.bfloat16    7.08e+13    1.98e+14         0.036         0.05
(4095x511x10000)    torch.bfloat16    9.4e+13     5.52e+14         0.047         0.139
(2047x3073x8192)    torch.bfloat16    1.14e+14    6.16e+14         0.058         0.156
(511x769x7504)      torch.bfloat16    8.37e+13    1.68e+14         0.042         0.043
(127x4097x12288)    torch.bfloat16    8.61e+13    8.55e+13         0.043         0.022
(32769x15x15024)    torch.bfloat16    1.48e+13    3.27e+13         0.007         0.008
(9217x8191x20480)   torch.bfloat16    1.2e+14     1.07e+15         0.061         0.271
(16385x1025x25008)  torch.bfloat16    1.05e+14    8.11e+14         0.053         0.205
*********************************Speed Results**********************************
+----------------------+----------------+------------+------------+-----------+
| Shape                | Ref Dtype      |   Ref Time |   FP8 Time |   Speedup |
+======================+================+============+============+===========+
| (8193, 2501, 5008)   | torch.bfloat16 |   402.215  |   251.246  |  1.60088  |
+----------------------+----------------+------------+------------+-----------+
| (65, 253, 4096)      | torch.bfloat16 |    12.5471 |    16.4149 |  0.764373 |
+----------------------+----------------+------------+------------+-----------+
| (1023, 1029, 2512)   | torch.bfloat16 |    74.7011 |    26.6719 |  2.80074  |
+----------------------+----------------+------------+------------+-----------+
| (4095, 511, 10000)   | torch.bfloat16 |   445.42   |    75.8169 |  5.87494  |
+----------------------+----------------+------------+------------+-----------+
| (2047, 3073, 8192)   | torch.bfloat16 |   901.602  |   167.263  |  5.39033  |
+----------------------+----------------+------------+------------+-----------+
| (511, 769, 7504)     | torch.bfloat16 |    70.5006 |    35.0095 |  2.01376  |
+----------------------+----------------+------------+------------+-----------+
| (127, 4097, 12288)   | torch.bfloat16 |   148.589  |   149.542  |  0.993628 |
+----------------------+----------------+------------+------------+-----------+
| (32769, 15, 15024)   | torch.bfloat16 |   996.979  |   451.53   |  2.208    |
+----------------------+----------------+------------+------------+-----------+
| (9217, 8191, 20480)  | torch.bfloat16 | 25781.6    |  2886.31   |  8.93238  |
+----------------------+----------------+------------+------------+-----------+
| (16385, 1025, 25008) | torch.bfloat16 |  8037.08   |  1036.24   |  7.75598  |
+----------------------+----------------+------------+------------+-----------+
```

## Example workflows that this really helps
``` Python
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda")

# Convert all torch.nn.Linear modules to Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

import float8_experimental
float8_experimental.config.pad_inner_dim = True

swap_linear_with_float8_linear(model, Float8DynamicLinear)

# Wrap model with Fully Sharded Data Parallel (FSDP)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'

dist.init_process_group(backend='nccl', init_method='env://')

# model = FSDP(model, use_orig_params=True)

# optionally compile the model
# model = torch.compile(model)

# Prepare your dataset and dataloader (customize this part as needed)
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer):
        self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# Example text data
texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)

# Set up the optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4)

# Training loop
model.train()
for epoch in range(3):  # Loop over the dataset multiple times
    for i, batch in enumerate(dataloader):
        inputs = {k: v.to(model.device) for k, v in batch.items()}

        # Forward pass
        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}')

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")

print("Training complete!")
```

Pull Request resolved: #145

Reviewed By: vkuzo

Differential Revision: D58958442

Pulled By: drisspg

fbshipit-source-id: 5a4c8661e974699ce3f83748fca1ce1f0ad65d3b
  • Loading branch information
drisspg authored and facebook-github-bot committed Jun 24, 2024
1 parent d4ade87 commit 57136bd
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 9 deletions.
204 changes: 204 additions & 0 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from dataclasses import dataclass
from typing import Optional

import fire

import torch
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_usec(f, *args, **kwargs):
no_args = lambda: f(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def get_tops_info(tops, time, peak_tops):
time_sec = time / 1e6
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return tops_sec, pct_top_peak


def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

a_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)

a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)

return a_fp8 @ b_fp8


def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
# Breaks with compile due to trying to pad on fp8 dtype
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

A_pad = A_pad.to(fp8_dtype) # mem copy
B_pad = B_pad.to(fp8_dtype) # mem copy

B_pad = B_pad.t().contiguous().t() # mem copy

return torch._scaled_mm(
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
)


def do_hp_matmul(A, B):
return torch.matmul(A, B)


def do_aligned_bf16_matmul(A, B):
A_pad = pad_tensor_for_matmul(A, dims=1)
B_pad = pad_tensor_for_matmul(B, dims=0)
return torch.matmul(A_pad, B_pad)


@dataclass
class Experiment_config:
M: int
K: int
N: int
output_dtype: torch.dtype
fp8_dtype: torch.dtype

def __iter__(self):
return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))


def gen_configs():
shapes = shapes = [
(8193, 2501, 5008),
(65, 253, 4096),
(1023, 1029, 2512),
(4095, 511, 10000),
(2047, 3073, 8192),
(511, 769, 7504),
(127, 4097, 12288),
(32769, 15, 15024),
(9217, 8191, 20480),
(16385, 1025, 25008),
]
output_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]


@torch.no_grad()
def run(compile: bool = False, n_limit: Optional[int] = None):
device = "cuda"
experiments = gen_configs()
results = []
tops_table = []
tops_headers = [
"Shape",
"Ref Dtype",
"Ref Tops",
"Aligned BF16 Tops",
"FP8 Tops",
"Ref % Peak",
"Aligned BF16 % Peak",
"FP8 % Peak",
]

for experiment in tqdm(experiments):
M, K, N, output_dtype, fp8_dtype = experiment
tops = 2 * M * N * K

A_base = torch.rand(M, K, device=device, dtype=output_dtype)
B_base = torch.rand(K, N, device=device, dtype=output_dtype)

hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
aligned_bf16_func = (
torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul
)
fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul

ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base)
fp8_time = benchmark_fn_in_usec(
fp8_func, A_base, B_base, fp8_dtype, output_dtype
)

ref_tops_sec, ref_pct_top_peak = get_tops_info(
tops, ref_time, dtype_to_peak_tops[output_dtype]
)
aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info(
tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16]
)
fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
)
tops_table.append(
[
f"({M}x{K}x{N})",
f"{output_dtype}",
f"{ref_tops_sec:.2E}",
f"{aligned_bf16_tops_sec:.2E}",
f"{fp8_tops_sec:.2E}",
f"{ref_pct_top_peak:.3f}",
f"{aligned_bf16_pct_top_peak:.3f}",
f"{fp8_pct_top_peak:.3f}",
]
)
results.append(
[
(M, K, N),
output_dtype,
ref_time,
aligned_bf16_time,
fp8_time,
ref_time / aligned_bf16_time,
ref_time / fp8_time,
]
)

print("TOPs".center(80, "*"))
print(tabulate(tops_table, headers=tops_headers))
print("Speed Results".center(80, "*"))
headers = [
"Shape",
"Ref Dtype",
"Ref Time",
"Aligned BF16 Time",
"FP8 Time",
"Aligned BF16 Speedup",
"FP8 Speedup",
]
print(tabulate(results, headers=headers, tablefmt="grid"))


if __name__ == "__main__":
fire.Fire(run)
6 changes: 6 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
use_fnuz_dtype = False

# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim = False
15 changes: 13 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,19 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"bias": False,
}
new_mod = cls(**super_kwargs)
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
new_mod.backward_config = ScaledMMConfig(emulate, False)

new_mod.forward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=not bool(emulate),
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
new_mod.backward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=False,
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.create_buffers()
# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
new_mod.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
return new_mod
13 changes: 12 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
merge_mm_configs,
ScaledMMConfig,
)
from float8_experimental.float8_utils import is_row_major
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul

from torch.utils._pytree import tree_map

aten = torch.ops.aten
Expand Down Expand Up @@ -121,6 +122,16 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_scale = a._scale
b_data = b._data

if a._mm_config.pad_inner_dim:
assert (
b._mm_config.pad_inner_dim
), "Both mm configs must have pad_inner_dim set to True"
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)

if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
Expand Down
1 change: 0 additions & 1 deletion float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to simplify the product code.
"""


from typing import Optional

import float8_experimental.float8_aten_api # noqa
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
# emulate: whether to emulate the matmuls in fp32
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
# fp8_output: whether to output the result of the scaled_mm in fp8
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
ScaledMMConfig = namedtuple(
"ScaledMMConfig",
["emulate", "use_fast_accum", "fp8_output"],
defaults=[False, False, False],
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
defaults=[False, False, False, False],
)


Expand All @@ -48,6 +49,7 @@ def merge_mm_configs(
emulate=a_mm_config.emulate,
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
)


Expand Down
68 changes: 67 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Literal, Tuple
from typing import Iterable, Literal, Tuple, Union

import float8_experimental.config as config

Expand Down Expand Up @@ -179,3 +179,69 @@ def fp8_tensor_statistics(
def is_row_major(stride):
assert len(stride) == 2, "is_row_major only supports 2D tensors"
return stride[0] > stride[1] and stride[1] == 1


def _get_min_alignment(size: int, alignment_value: int) -> int:
"""
Returns the minimum alignment value that is greater than or equal to the given size.
Args:
size: The size of the data to be aligned.
alignment_value: The alignment value to be used.
Returns:
int: The minimum alignment value that is greater than or equal to the given size.
Usage:
```
>>> _get_min_alignment(10, 8)
16
```
"""
if size % alignment_value == 0:
return size
return (1 + (size // alignment_value)) * alignment_value


def pad_tensor_for_matmul(
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
) -> torch.Tensor:
"""
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm`
Args:
tensor: The tensor to pad.
both: Whether to pad both dimensions or just the second dimension.
Returns:
torch.Tensor: The padded tensor.
Usage:
```
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape
torch.Size([16, 10])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape
torch.Size([10, 16])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape
torch.Size([16, 16])
```
"""
assert tensor.dim() == 2
dim1, dim2 = tensor.shape

if isinstance(dims, int):
dims = (dims,)

# Calculate aligned dimensions based on the specified dims
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2

# Check if padding is needed for either dimension
if dim1 == dim1_aligned and dim2 == dim2_aligned:
return tensor

# Calculate padding values for both dimensions
pad_dim1 = dim1_aligned - dim1
pad_dim2 = dim2_aligned - dim2

return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))
Loading

0 comments on commit 57136bd

Please sign in to comment.