From 0c13ab9989d0830afefeec71ce90f8a923ce8992 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sun, 20 Oct 2024 14:11:50 -0700 Subject: [PATCH] Make sure fake tensor functions return on proper device (#3258) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3258 X-link: https://github.com/facebookresearch/FBGEMM/pull/359 I didnt realize that the device of faketensors matters in abstract functions, but torch.compile will check it in some cases. This small diff adds proper device placement to all fbgemm abstract operators. Reviewed By: jiawenliu64 Differential Revision: D64667681 fbshipit-source-id: 79b36af21cf8ad867d52beeb59f137368f0a48da --- .../gen_ai/gen_ai/quantize_ops.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py index 577a301a1c..3e716affcf 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py @@ -36,6 +36,7 @@ def f8f8bf16_blockwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -51,6 +52,7 @@ def f8f8bf16_tensorwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -69,6 +71,7 @@ def f8f8bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -83,8 +86,8 @@ def quantize_fp8_per_tensor_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -100,8 +103,8 @@ def quantize_fp8_per_row_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -115,8 +118,8 @@ def quantize_fp8_per_col_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -135,6 +138,7 @@ def i8i8bf16_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8f8bf16") @@ -149,6 +153,7 @@ def f8f8bf16_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8f8bf16_cublas") @@ -165,6 +170,7 @@ def f8f8bf16_cublas_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=A.device, ) @torch.library.register_fake("fbgemm::f8f8bf16_rowwise_batched") @@ -182,6 +188,7 @@ def f8f8bf16_rowwise_batched_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8i4bf16_rowwise") @@ -197,6 +204,7 @@ def f8i4bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::bf16i4bf16_rowwise") @@ -211,6 +219,7 @@ def bf16i4bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=X.device, ) @torch.library.register_fake("fbgemm::bf16i4bf16_rowwise_batched") @@ -225,4 +234,5 @@ def bf16i4bf16_rowwise_batched_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=X.device, )