Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] add axiswise granularity to Float8Tensor
Browse files Browse the repository at this point in the history
Summary:

This PR adds the axiswise scaling granularity to `Float8Tensor` and
ensures that basic ops like transpose and `torch._scaled_mm` work as
expected.

A future PR will add integration with `Float8Linear`.

Test Plan:

TODO

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 163452e97ed6a26fa5dcba01c36f49eb744484a6
Pull Request resolved: #351
  • Loading branch information
vkuzo committed Jul 28, 2024
1 parent 1a543f4 commit b150085
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 21 deletions.
12 changes: 12 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def short_str(self):
return "dyn"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
75 changes: 72 additions & 3 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


def _assert_tensorwise_scale(aten_op, scale):
assert (
# TODO(future PR): figure out why tensorwise scaling can have
# both rank 0 and rank 1
len(scale.shape)
in (0, 1)
), f"{aten_op} with axiswise scaling is not supported yet"


def implements(aten_ops):
"""Register aten ops to the float8 op table"""

Expand All @@ -32,18 +41,16 @@ def decorator(func):

@implements(
[
aten.view.default,
aten._unsafe_view.default,
aten.t.default,
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
aten.slice.Tensor,
aten.transpose.int,
aten.fill_.Scalar,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data,
Expand All @@ -54,8 +61,61 @@ def float8_desugar_op(aten_op, args, kwargs=None):
)


@implements(
[
aten.t.default,
aten.transpose.int,
]
)
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
return Float8Tensor(
new_data,
new_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)


@implements([aten.view.default])
def float8_view(aten_op, args, kwargs=None):
if len(args[0]._scale.shape) < 2:
# tensorwise scaling
return float8_desugar_op(aten_op, args, kwargs)

t, new_shape = args[0], args[1]
# for now, only support reshaping to [-1, dim] or [dim, -1]
if len(new_shape) == 2:
if new_shape == [t.shape[0], -1] and t._scale.shape[0] == 1:
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale = aten_op(t._scale, [1, -1], **kwargs)
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
)
elif new_shape == [-1, t.shape[-1]] and t._scale.shape[-1] == 1:
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale = aten_op(t._scale, [-1, 1], **kwargs)
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
)
raise AssertionError(
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} new_shape {new_shape} is not supported yet."
)


@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)

def make_float8(data):
Expand Down Expand Up @@ -101,6 +161,7 @@ def float8_cat(aten_op, args, kwargs=None):
assert (
chunk._gemm_input_role is gemm_input_role
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
_assert_tensorwise_scale(aten_op, chunk._scale)
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
Expand All @@ -117,6 +178,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
"addmm" -> out
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)

def unwrap(x):
if isinstance(x, Float8Tensor):
Expand Down Expand Up @@ -229,6 +291,7 @@ def float8_addmm(aten_op, args, kwargs=None):

@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
return args[0].shape == args[1].shape


Expand All @@ -238,6 +301,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
Expand Down Expand Up @@ -265,6 +329,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(
fp8_input, Float8Tensor
Expand All @@ -284,6 +349,7 @@ def allgather_fp8(aten_op, args, kwargs=None):

@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
def wait_tensor_fp8(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(fp8_input, Float8Tensor)

Expand All @@ -304,6 +370,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
_assert_tensorwise_scale(fp8_self, args[0]._scale)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype
Expand Down Expand Up @@ -334,8 +401,10 @@ def copy_fp8(aten_op, args, kwargs=None):

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
_assert_tensorwise_scale(aten_op, src._scale)
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
_assert_tensorwise_scale(aten_op, src._scale)
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
Expand Down
14 changes: 13 additions & 1 deletion float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch

from float8_experimental.config import ScalingGranularity

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic(
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
reduce_amax,
scaling_granularity,
axiswise_dim,
)
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down
13 changes: 6 additions & 7 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor):
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
by scale to go from fp32 range to fp8 range, and divide by scale to go
from fp8 range to fp32 range.
from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible
with `_data`. For example:
- if scaling is tensorwise, `_scale` is a scalar tensor
- if scaling is axiswise and _data.shape is [3, 5], `_scale` could have
shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling
axis.
* `_orig_dtype`: the original dtype of the tensor used to create this
tensor.
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
Expand Down Expand Up @@ -279,12 +284,6 @@ def __new__(
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
Expand Down
25 changes: 20 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config

import torch
import torch.distributed as dist
from float8_experimental.config import ScalingGranularity

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand Down Expand Up @@ -100,8 +101,18 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))
def tensor_to_amax(
x: torch.Tensor,
reduce_amax: bool = False,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
if scaling_granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
else:
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
assert axiswise_dim is not None, "unsupported"
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
Loading

0 comments on commit b150085

Please sign in to comment.