Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_matmul.cu produces wrong results by loading the scaling factors wrong #88

Open
RuiWang1998 opened this issue Feb 12, 2025 · 0 comments

Comments

@RuiWang1998
Copy link

When I tested with the following snippet:

M = 128
N = 128
K = 128
slice_ = slice(47, 54)

def to_float8_e4m3fn(x: torch.Tensor):
    scales = x.abs().amax(dim=-1, keepdim=True).float().div(FP8_e4m3_MAX)
    x = x.div(scales).clamp(min=FP8_e4m3_MIN, max=FP8_e4m3_MAX)
    x = x.to(torch.float8_e4m3fn)
    return x, scales

A = torch.randn(M, K, device="cuda").mul(.3)
B = torch.randn(N, K, device="cuda").mul(.3)
C = torch.empty(M, N, device="cuda")

A_fp8, scale_a_inv_s = to_float8_e4m3fn(A)
B_fp8, scale_b_inv_s = to_float8_e4m3fn(B)
tk.fp8_gemm_scaled(A_fp8_tile, B_fp8_tile, C, scale_a_inv_s, scale_b_inv_s)


y = torch._scaled_mm(
    A_fp8,
    B_fp8.T,
    out_dtype=torch.bfloat16,
    scale_a=scale_a_inv_s,  # (16, 1)
    scale_b=scale_b_inv_s.T,  # (1, 16)
    use_fast_accum=True,
)  # bias=bias

y and C produces somewhat similar result and yet y is always better in terms of closeness with torch.mm when the input is not cast to torch.float8_e4m3fn (only 17-24% of the elements are better in C in terms of numerical precision, not the expected 50% compared to torch._scaled_mm).

However, if I set

scale_a_inv_s.fill_(1)
scale_b_inv_s.fill_(2)

right after I cast A and B to float8_e4m3fn, the numerical accuracy matches that of torch._scaled_mm. To further illustrate, I used

scale_a_inv_s.normal_(std=10)
scale_b_inv_s.normal_(std=20)

and the results are wildly different between torch._scaled_mm and TK's version with torch._scaled_mm still closely following bf16 computation results (I cast the scale-adjusted FP8 tensors back to BF16).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant