From e87f005368070f750766068752727c97a35cdd66 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 13:35:53 -0700 Subject: [PATCH 1/4] [wip] add axiswise granularity to Float8Tensor Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/config.py | 12 +++++ float8_experimental/float8_ops.py | 40 ++++++++++++++++- float8_experimental/float8_python_api.py | 8 ++++ float8_experimental/float8_scaling_utils.py | 14 +++++- float8_experimental/float8_tensor.py | 13 +++--- float8_experimental/float8_utils.py | 30 ++++++++++--- test/test_base.py | 50 ++++++++++++++++++++- 7 files changed, 151 insertions(+), 16 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 5d1bf9f..217fca1 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -21,6 +21,18 @@ def short_str(self): return "dyn" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 2a11726..588d48a 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -34,16 +43,15 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -54,8 +62,27 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) def make_float8(data): @@ -101,6 +128,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -117,6 +145,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -229,6 +258,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -238,6 +268,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -265,6 +296,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -284,6 +316,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -304,6 +337,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -334,8 +368,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index d8aa081..001eff4 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -38,6 +38,14 @@ def addmm_float8_unwrapped( """ a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() + + # TODO: should we change torch._scaled_mm? + # torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank + # 2. Translate to this format. + # TODO: audit if we need to make this more generic for various shapes. + a_inverse_scale = a_inverse_scale.squeeze() + b_inverse_scale = b_inverse_scale.squeeze() + if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index ce6422f..06c93c1 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from float8_experimental.config import ScalingGranularity + from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 641f972..22c2a32 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling + axis. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. * `_emulate`: if true using fp32 emulation for the matmuls, helpful @@ -279,12 +284,6 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568e..26fde8a 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import float8_experimental.config as config import torch import torch.distributed as dist +from float8_experimental.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -100,8 +101,23 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, +) -> torch.Tensor: + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + + # convert from axiswise_dim (dim to keep) to + # dim as the input to the `torch.amax` function (tuple of dims to reduce) + dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) + + amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -114,9 +130,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/test/test_base.py b/test/test_base.py index 4e0c685..38fed52 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,7 +16,12 @@ import torch import torch.nn as nn -from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType +from float8_experimental.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -24,6 +29,7 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_python_api import addmm_float8_unwrapped +from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -143,6 +149,48 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) + def test_axiswise_dynamic_cast(self): + a = torch.randn(16, 32, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + # print(a_fp8) + # print(a_fp8.to_original_precision()) + # print(a_fp8.t()) + b = a_fp8.t() + # TODO check numerical accuracy + + def test_axiswise_gemm(self): + a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + c = torch.mm(a_fp8, b_fp8.t()) + print(c) + # TODO check numerical accuracy + class TestFloat8Linear: def _test_linear_impl( From c4c9ae82a1d5c536a25d505ec67edcc4ccb6b8a4 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 16:30:19 -0700 Subject: [PATCH 2/4] Update on "[wip] add axiswise granularity to Float8Tensor" Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_ops.py | 15 +++++- float8_experimental/float8_python_api.py | 7 --- float8_experimental/float8_utils.py | 4 +- test/test_base.py | 61 ++++++++++++++++++------ 4 files changed, 62 insertions(+), 25 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 588d48a..da5aec4 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -41,7 +41,7 @@ def decorator(func): @implements( [ - aten.view.default, + # aten.view.default, aten._unsafe_view.default, aten.as_strided.default, aten.clone.default, @@ -79,6 +79,19 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None): args[0]._gemm_input_role, ) +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, *args, **kwargs) + print('args', args) + print('kwargs', kwargs) + tensor, new_shape = args[0], args[1] + + # for now, only support reshaping to [-1, *dims] or [*dims, -1] + if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1): + return float8_desugar_data_and_scale(aten_op, *args, **kwargs) + raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.") @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 001eff4..59edd8d 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -39,13 +39,6 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - # TODO: should we change torch._scaled_mm? - # torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank - # 2. Translate to this format. - # TODO: audit if we need to make this more generic for various shapes. - a_inverse_scale = a_inverse_scale.squeeze() - b_inverse_scale = b_inverse_scale.squeeze() - if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 26fde8a..500d05e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -115,9 +115,9 @@ def tensor_to_amax( # convert from axiswise_dim (dim to keep) to # dim as the input to the `torch.amax` function (tuple of dims to reduce) - dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) + # dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) - amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True) + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will diff --git a/test/test_base.py b/test/test_base.py index 38fed52..2825fc5 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -63,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -73,7 +73,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -108,7 +108,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -122,7 +122,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -149,9 +149,33 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) - def test_axiswise_dynamic_cast(self): - a = torch.randn(16, 32, dtype=torch.bfloat16) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("dim_name", ["first", "last"]) + def test_axiswise_dynamic_cast(self, shape, dim_name): + a = torch.randn(*shape, dtype=torch.bfloat16) + + if dim_name == "first": + dim = 0 + elif dim_name == "last": + dim = len(a.shape) - 1 + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + # TODO(next) make this work + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, @@ -159,11 +183,15 @@ def test_axiswise_dynamic_cast(self): scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=0, ) - # print(a_fp8) - # print(a_fp8.to_original_precision()) - # print(a_fp8.t()) - b = a_fp8.t() - # TODO check numerical accuracy + # a_fp8._data.shape is (3, 5, 7) + # a_fp8._scale.shape is (1, 5, 7) + print(a_fp8._scale.shape) + + # reshape to (3, 5 * 7) + # a_fp8._scale.shape should be (1, 5 * 7) + a_fp8_r = a_fp8.reshape(3, -1) + print(a_fp8_r._scale.shape) + def test_axiswise_gemm(self): a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") @@ -177,7 +205,7 @@ def test_axiswise_gemm(self): linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=0, + axiswise_dim=1, ) b_fp8 = hp_tensor_to_float8_dynamic( b, @@ -185,10 +213,13 @@ def test_axiswise_gemm(self): linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=0, + axiswise_dim=1, ) - c = torch.mm(a_fp8, b_fp8.t()) - print(c) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + print(c_fp8_compute) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + print('sqnr', sqnr) # TODO check numerical accuracy From 10520662e4b39516ca881d1a95456bc00236e947 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Sun, 28 Jul 2024 09:20:00 -0700 Subject: [PATCH 3/4] Update on "[wip] add axiswise granularity to Float8Tensor" Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_ops.py | 40 +++++++++++++----- float8_experimental/float8_python_api.py | 1 - test/test_base.py | 54 ++++++++++++++++++------ 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index da5aec4..3f3af10 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -41,7 +41,6 @@ def decorator(func): @implements( [ - # aten.view.default, aten._unsafe_view.default, aten.as_strided.default, aten.clone.default, @@ -79,19 +78,40 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None): args[0]._gemm_input_role, ) + @implements([aten.view.default]) def float8_view(aten_op, args, kwargs=None): if len(args[0]._scale.shape) < 2: # tensorwise scaling - return float8_desugar_op(aten_op, *args, **kwargs) - print('args', args) - print('kwargs', kwargs) - tensor, new_shape = args[0], args[1] - - # for now, only support reshaping to [-1, *dims] or [*dims, -1] - if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1): - return float8_desugar_data_and_scale(aten_op, *args, **kwargs) - raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.") + return float8_desugar_op(aten_op, args, kwargs) + + t, new_shape = args[0], args[1] + # for now, only support reshaping to [-1, dim] or [dim, -1] + if len(new_shape) == 2: + if new_shape == [t.shape[0], -1] and t._scale.shape[0] == 1: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale = aten_op(t._scale, [1, -1], **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + ) + elif new_shape == [-1, t.shape[-1]] and t._scale.shape[-1] == 1: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale = aten_op(t._scale, [-1, 1], **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + ) + raise AssertionError( + f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} new_shape {new_shape} is not supported yet." + ) + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 59edd8d..d8aa081 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -38,7 +38,6 @@ def addmm_float8_unwrapped( """ a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/test/test_base.py b/test/test_base.py index 2825fc5..739d6b0 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -171,27 +171,57 @@ def test_axiswise_dynamic_cast(self, shape, dim_name): sqnr = compute_error(a, a_dq) assert sqnr >= 25.0 - # TODO(next) make this work def test_axiswise_reshape(self): a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") linear_mm_config = LinearMMConfig() - a_fp8 = hp_tensor_to_float8_dynamic( + # if we scale across dim0, we can only reshape to [3, -1] + a_fp8_d0 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=0, ) - # a_fp8._data.shape is (3, 5, 7) - # a_fp8._scale.shape is (1, 5, 7) - print(a_fp8._scale.shape) + assert list(a_fp8_d0._data.shape) == [3, 5, 7] + assert list(a_fp8_d0._scale.shape) == [1, 5, 7] + + a_fp8_d0_r = a_fp8_d0.reshape(3, -1) + assert list(a_fp8_d0_r.shape) == [3, 5 * 7] + assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7] + # verify numerics did not change + assert torch.allclose( + a_fp8_d0.to_original_precision(), + a_fp8_d0_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(AssertionError): + a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) - # reshape to (3, 5 * 7) - # a_fp8._scale.shape should be (1, 5 * 7) - a_fp8_r = a_fp8.reshape(3, -1) - print(a_fp8_r._scale.shape) - + # if we scale across dim2, we can only reshape to [-1, 7] + a_fp8_d2 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=2, + ) + assert list(a_fp8_d2._data.shape) == [3, 5, 7] + assert list(a_fp8_d2._scale.shape) == [3, 5, 1] + + a_fp8_d2_r = a_fp8_d2.reshape(-1, 7) + assert list(a_fp8_d2_r.shape) == [3 * 5, 7] + assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1] + # verify numerics did not change + assert torch.allclose( + a_fp8_d2.to_original_precision(), + a_fp8_d2_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(AssertionError): + a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) def test_axiswise_gemm(self): a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") @@ -216,11 +246,9 @@ def test_axiswise_gemm(self): axiswise_dim=1, ) c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) - print(c_fp8_compute) c_ref = torch.mm(a, b.t()) sqnr = compute_error(c_ref, c_fp8_compute) - print('sqnr', sqnr) - # TODO check numerical accuracy + assert sqnr >= 25.0 class TestFloat8Linear: From 29fdaacb9b0056f1d114525c5cea2e3acee046bb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Sun, 28 Jul 2024 09:21:31 -0700 Subject: [PATCH 4/4] Update on "[wip] add axiswise granularity to Float8Tensor" Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 500d05e..fdd9189 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -112,11 +112,6 @@ def tensor_to_amax( else: assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" assert axiswise_dim is not None, "unsupported" - - # convert from axiswise_dim (dim to keep) to - # dim as the input to the `torch.amax` function (tuple of dims to reduce) - # dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) - amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it.