From da3e8a1f916d83e22f06f0def115822ed7587b38 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 14 Jun 2024 16:43:32 -0700 Subject: [PATCH] updates with new scaled-mm api --- benchmarks/bench_linear_float8.py | 2 -- benchmarks/bench_matmul.py | 6 +++++- float8_experimental/float8_aten_api.py | 7 +++---- float8_experimental/float8_ops.py | 8 ++++---- float8_experimental/float8_python_api.py | 18 +++++++++--------- test/test_base.py | 12 ++---------- 6 files changed, 23 insertions(+), 30 deletions(-) diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index df21c30f..019cc9e4 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -14,8 +14,6 @@ import torch import torch.utils.benchmark as benchmark -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( get_float8_linear, LinearType, diff --git a/benchmarks/bench_matmul.py b/benchmarks/bench_matmul.py index 8629ef9c..967267d5 100644 --- a/benchmarks/bench_matmul.py +++ b/benchmarks/bench_matmul.py @@ -101,7 +101,11 @@ def run(n_limit: Optional[int] = None): B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() def do_matmul(A, B): - return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False) + scale_a = torch.tensor([1], device=device) + scale_b = torch.tensor([1], device=device) + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( tops, dtype_to_peak_tops[d1], do_matmul, A, B diff --git a/float8_experimental/float8_aten_api.py b/float8_experimental/float8_aten_api.py index 174fcc22..41d5083d 100644 --- a/float8_experimental/float8_aten_api.py +++ b/float8_experimental/float8_aten_api.py @@ -10,7 +10,6 @@ import torch -from float8_experimental.float8_utils import tensor_to_amax from torch.library import Library @@ -26,7 +25,7 @@ def mm_float8_emulated( m2_fp32 = m2.float() / s2 m3_fp32 = torch.mm(m1_fp32, m2_fp32) - return m3_fp32.to(dtype3), tensor_to_amax(m3_fp32) + return m3_fp32.to(dtype3) # @@ -38,7 +37,7 @@ def mm_float8_emulated( lib = Library("aten", "FRAGMENT") lib.define( - "mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> (Tensor, Tensor)" + "mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor" ) lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU") lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA") @@ -47,4 +46,4 @@ def mm_float8_emulated( @torch.library.impl(lib, "mm_float8_emulated", "Meta") def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3): out = torch.mm(m1.float(), m2.float()).to(dtype3) - return out, torch.empty(1, device="meta") + return out diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 9f48d2db..ffe6491a 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -147,8 +147,8 @@ def float8_mm(aten_op, args, kwargs=None): if mm_config.emulate: return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype - )[0] - tensor_out, amax = addmm_float8_unwrapped( + ) + tensor_out = addmm_float8_unwrapped( a_data, a_scale, b_data, @@ -180,9 +180,9 @@ def float8_addmm(aten_op, args, kwargs=None): if mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype - )[0] + ) return out + bias - tensor_out, amax = addmm_float8_unwrapped( + tensor_out = addmm_float8_unwrapped( a_data, a_scale, b_data, diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 6cb406d4..3b752d71 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -10,7 +10,7 @@ """ -from typing import Optional, Tuple +from typing import Optional import float8_experimental.float8_aten_api # noqa @@ -31,7 +31,7 @@ def addmm_float8_unwrapped( output_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_fast_accum: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: """ This is the unwrapped version of addmm_float8, which does not take in Float8Tensors as inputs. This is used to standardize the logic between subclassed and non subclassed @@ -41,25 +41,25 @@ def addmm_float8_unwrapped( b_inverse_scale = b_scale.reciprocal() if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 - output, output_amax = torch._scaled_mm( + output = torch._scaled_mm( a_data, b_data, - out_dtype=output_dtype, scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, + out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) output += bias - return output, output_amax - output, output_amax = torch._scaled_mm( + return output + output = torch._scaled_mm( a_data, b_data, - bias=bias, - out_dtype=output_dtype, scale_a=a_inverse_scale, scale_b=b_inverse_scale, + bias=bias, scale_result=output_scale, + out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) - return output, output_amax + return output diff --git a/test/test_base.py b/test/test_base.py index 6e7a34cc..371e044f 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -29,7 +29,6 @@ ScaledMMConfig, ) from float8_experimental.float8_utils import ( - amax_to_scale, compute_error, fp8_tensor_statistics, FP8_TYPES, @@ -327,7 +326,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) - out_scaled_mm, output_amax_scaled = addmm_float8_unwrapped( + out_scaled_mm = addmm_float8_unwrapped( a_fp8._data, a_fp8._scale, b_fp8._data, @@ -335,7 +334,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): output_dtype=output_dtype, use_fast_accum=use_fast_accum, ) - out_emulated, output_amax_emulated = torch.ops.aten.mm_float8_emulated( + out_emulated = torch.ops.aten.mm_float8_emulated( a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype ) @@ -343,13 +342,6 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): out_scaled_mm = out_scaled_mm.to(compare_type) out_emulated = out_emulated.to(compare_type) - out_scaled_mm = out_scaled_mm / amax_to_scale( - output_amax_scaled, input_dtype - ) - out_emulated = out_emulated / amax_to_scale( - output_amax_emulated, input_dtype - ) - if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: