Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Changes on top of upstream to get rid of type errors #248

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather = False

# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
use_fnuz_dtype = False
12 changes: 5 additions & 7 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
from torch._prims_common import suggest_memory_format


Expand All @@ -46,9 +46,9 @@ def forward(
def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
)
return fp8_tensor, None

Expand Down Expand Up @@ -105,10 +105,8 @@ def cast_to_float8_e4m3fn(
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
)
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def cast_to_float8_e5m2_bw(
Expand Down
19 changes: 12 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
from float8_experimental.float8_utils import (
amax_history_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -89,15 +94,15 @@ def backward(ctx, go):
fp8_amax_history_dL_dY,
fp8_scale_dL_dY,
scale_fn_name,
torch.float8_e5m2,
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -236,14 +241,14 @@ def cast_x_to_float8(
self.fp8_amax_history_x,
self.fp8_scale_x,
scale_fn_name,
torch.float8_e4m3fn,
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
)
Expand All @@ -259,15 +264,15 @@ def cast_w_to_float8(
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
torch.float8_e4m3fn,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
Expand Down
12 changes: 8 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale_stack
from float8_experimental.float8_utils import (
amax_history_to_scale_stack,
e4m3_dtype,
e5m2_dtype,
)
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -298,13 +302,13 @@ def inner_func():

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
)
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import torch

import torch.distributed._functional_collectives as funcol
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
from float8_experimental.float8_utils import (
e4m3_dtype,
tensor_to_amax,
to_fp8_saturated,
)
from torch.distributed._tensor import DTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -125,7 +129,7 @@ def forward(
ctx,
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=torch.float8_e4m3fn,
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
mm_config: Optional[ScaledMMConfig] = None,
):
Expand Down
11 changes: 9 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import Literal, Tuple

import float8_experimental.config as config

import torch
import torch.distributed as dist

Expand All @@ -16,7 +18,7 @@
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand All @@ -25,6 +27,11 @@
}


# User defined type for using the individual F8 type based on config
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
Expand Down Expand Up @@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):


def fp8_tensor_statistics(
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
tensor: torch.Tensor, float8_dtype=e4m3_dtype
) -> Tuple[int, ...]:
"""Calculate FP8 tensor stats

Expand Down
26 changes: 18 additions & 8 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from float8_experimental.float8_utils import (
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
Expand All @@ -51,7 +53,7 @@ class TestFloat8Tensor(unittest.TestCase):
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = FP8_TYPES
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
Expand All @@ -60,7 +62,7 @@ def test_preserves_dtype(self) -> None:
self.assertTrue(x3_hp.dtype == hp_dtype)

def test_differentiable_casts(self) -> None:
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = (e4m3_dtype, e5m2_dtype)
for f8_dtype in lp_dtypes:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
Expand All @@ -73,8 +75,8 @@ def test_differentiable_casts(self) -> None:

def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
scale = tensor_to_scale(a, e4m3_dtype)
fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
Expand Down Expand Up @@ -313,7 +315,7 @@ class TestScaledMM:
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = torch.float8_e4m3fn
input_dtype = e4m3_dtype
output_dtype = base_dtype
compare_type = torch.float32

Expand Down Expand Up @@ -352,7 +354,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
b = Float8Tensor.to_float8(
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
Expand Down Expand Up @@ -387,7 +389,15 @@ def test_merge_configs(self):


class TestNumerics:
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize(
"float8_dtype",
[
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_small_amax_float16(self, float8_dtype):
# If we calculate scale naively with FP8_MAX_POS / amax,
Expand Down Expand Up @@ -508,7 +518,7 @@ def __init__(self, dim: int):

def test_fp8_tensor_statistics(self):
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = (e4m3_dtype, e5m2_dtype)
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.ones(4, 4, dtype=hp_dtype)
tensor_len = x1_hp.numel()
Expand Down
5 changes: 4 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM

from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend
Expand Down Expand Up @@ -116,7 +117,7 @@ def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_x,
ScaledMMConfig(),
)
Expand All @@ -127,12 +128,14 @@ def forward(self, x):
return x_fp8

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=True).cuda()
compiled_mod = copy.deepcopy(mod)
compiled_mod = torch.compile(compiled_mod, backend=cnts)
torch.manual_seed(0)
x = torch.randn(16, 16, device="cuda")
y_eager = mod(x)
y_compiled = compiled_mod(x)
Expand Down
10 changes: 5 additions & 5 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
Expand Down Expand Up @@ -64,7 +64,7 @@ def forward(self, x):

def test_scaled_mm(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
world_size = mesh.size()

x_fp32 = torch.rand(size, size, device=device)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):

def test_fp8_redistribute(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
world_size = mesh.size()

x_fp32 = torch.rand(size, size, device=device)
Expand All @@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):

def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype

x_fp32 = torch.rand(size, size, device=device)
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
Expand All @@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):

def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype

x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)
Expand Down
6 changes: 6 additions & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

# terminate script on first error
set -e
IS_ROCM=$(rocm-smi --version || true)

pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py

# These tests do not work on ROCm yet
if [ -z "$IS_ROCM" ]
then
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_dtensor.sh
pytest test/test_fsdp2/test_fsdp2_eager.py
fi

echo "all tests successful"
3 changes: 2 additions & 1 deletion test/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_utils import compute_error
from float8_experimental.float8_utils import compute_error, IS_ROCM
from transformers import SamModel

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Expand All @@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest:
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw(self, data_dtype, linear_type):
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
# print(model)
Expand Down
Loading