Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Updates with new scaled-mm api (#284)
Browse files Browse the repository at this point in the history
Summary:
This updates the calls to _scaled_mm to the new signature from this PR: pytorch/pytorch#128683

This is needed to unblock inductor work on scaled_mm.

```Shell
❯ ./test/test_everything.sh
	.
	.
	.

test/test_fsdp2/test_fsdp2_eager.py .......                                   [100%]

================================ 7 passed in 27.66s =================================
all tests successful
```

Pull Request resolved: #284

Reviewed By: y-sq

Differential Revision: D58709092

Pulled By: drisspg

fbshipit-source-id: ab330506621e9240f495be965748066d494d7b50
  • Loading branch information
drisspg authored and facebook-github-bot committed Jun 18, 2024
1 parent 1e9add3 commit edae9a3
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 30 deletions.
2 changes: 0 additions & 2 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
LinearType,
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def run(n_limit: Optional[int] = None):
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()

def do_matmul(A, B):
return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False)
scale_a = torch.tensor([1], device=device)
scale_b = torch.tensor([1], device=device)
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], do_matmul, A, B
Expand Down
7 changes: 3 additions & 4 deletions float8_experimental/float8_aten_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from float8_experimental.float8_utils import tensor_to_amax
from torch.library import Library


Expand All @@ -26,7 +25,7 @@ def mm_float8_emulated(
m2_fp32 = m2.float() / s2
m3_fp32 = torch.mm(m1_fp32, m2_fp32)

return m3_fp32.to(dtype3), tensor_to_amax(m3_fp32)
return m3_fp32.to(dtype3)


#
Expand All @@ -38,7 +37,7 @@ def mm_float8_emulated(
lib = Library("aten", "FRAGMENT")

lib.define(
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> (Tensor, Tensor)"
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
)
lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")
Expand All @@ -47,4 +46,4 @@ def mm_float8_emulated(
@torch.library.impl(lib, "mm_float8_emulated", "Meta")
def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
out = torch.mm(m1.float(), m2.float()).to(dtype3)
return out, torch.empty(1, device="meta")
return out
8 changes: 4 additions & 4 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def float8_mm(aten_op, args, kwargs=None):
if mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
tensor_out, amax = addmm_float8_unwrapped(
)
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
Expand Down Expand Up @@ -180,9 +180,9 @@ def float8_addmm(aten_op, args, kwargs=None):
if mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
)
return out + bias
tensor_out, amax = addmm_float8_unwrapped(
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
Expand Down
18 changes: 9 additions & 9 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""


from typing import Optional, Tuple
from typing import Optional

import float8_experimental.float8_aten_api # noqa

Expand All @@ -31,7 +31,7 @@ def addmm_float8_unwrapped(
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
as inputs. This is used to standardize the logic between subclassed and non subclassed
Expand All @@ -41,25 +41,25 @@ def addmm_float8_unwrapped(
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output, output_amax = torch._scaled_mm(
output = torch._scaled_mm(
a_data,
b_data,
out_dtype=output_dtype,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
output += bias
return output, output_amax
output, output_amax = torch._scaled_mm(
return output
output = torch._scaled_mm(
a_data,
b_data,
bias=bias,
out_dtype=output_dtype,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
bias=bias,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return output, output_amax
return output
12 changes: 2 additions & 10 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
ScaledMMConfig,
)
from float8_experimental.float8_utils import (
amax_to_scale,
compute_error,
fp8_tensor_statistics,
FP8_TYPES,
Expand Down Expand Up @@ -327,29 +326,22 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)

out_scaled_mm, output_amax_scaled = addmm_float8_unwrapped(
out_scaled_mm = addmm_float8_unwrapped(
a_fp8._data,
a_fp8._scale,
b_fp8._data,
b_fp8._scale,
output_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
out_emulated, output_amax_emulated = torch.ops.aten.mm_float8_emulated(
out_emulated = torch.ops.aten.mm_float8_emulated(
a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
)

if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_emulated = out_emulated.to(compare_type)

out_scaled_mm = out_scaled_mm / amax_to_scale(
output_amax_scaled, input_dtype
)
out_emulated = out_emulated / amax_to_scale(
output_amax_emulated, input_dtype
)

if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
Expand Down

0 comments on commit edae9a3

Please sign in to comment.