Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add more compile compatibility for Float8Tensor ops (#285)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #285

Reviewed By: vkuzo

Differential Revision: D59068281

Pulled By: drisspg

fbshipit-source-id: 18fa34db74cf60e85ff372ff1091c107119403a0
  • Loading branch information
ani300 authored and facebook-github-bot committed Jun 26, 2024
1 parent 57136bd commit b5a444a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
55 changes: 55 additions & 0 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def decorator(func):
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
aten.slice.Tensor,
aten.transpose.int,
aten.fill_.Scalar,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
Expand Down Expand Up @@ -263,3 +266,55 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)


@implements([aten.index_put_.default])
def index_put_fp8(aten_op, args, kwargs=None):
fp8_self = args[0]
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype

fp8_data = fp8_self._data
fp8_values_data = fp8_values._data
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
return Float8Tensor(
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
)


@implements([aten.copy_.default])
def copy_fp8(aten_op, args, kwargs=None):
# For a copy op with Float8Tensors involved, only the following combinations are allowed:
# 1. self is a high precision (hp) tensor, src is a Float8Tensor:
# in this case src is upcasted and unscaled to go into the hp tensor
# 2. self and src are Float8Tensors:
# the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
# Every other combination is banned as the semantics are not well defined

self = args[0]
src = args[1]

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
assert (
self._scale == src._scale
), "Expecting both Float8Tensors to have thee same scale"
assert (
self._mm_config == src._mm_config
), "Expecting both Float8Tensors to have thee same mm config"
assert (
self._data.dtype == src._data.dtype
), "Expecting both Float8Tensors to be of the same dtypet"
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
else:
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")
38 changes: 38 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,44 @@ def test_split_cat(self):
catted = torch.cat(splits, dim=0)
assert bitwise_identical(fp8_a, catted)

def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
fp8_b[index] = fp8_a

def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
torch.empty(16, dtype=torch.float8_e4m3fn),
scale_a,
torch.bfloat16,
fp8_a._mm_config,
)
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down

0 comments on commit b5a444a

Please sign in to comment.