From fa24249a614ab093e7040feeee81908f60e7958c Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 25 Nov 2024 15:13:56 -0800 Subject: [PATCH] wip --- fbgemm_gpu/CMakeLists.txt | 2 +- .../ck_extensions/fp8_blockwise_gemm.hip | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 09d57d672e..d8ee96231a 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -195,7 +195,7 @@ if(NOT FBGEMM_CPU_ONLY) add_subdirectory(experimental/gemm) endif() -if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) +if(NOT FBGEMM_CPU_ONLY) # TODO: Re-enable gen_ai for ROCm once ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp # lands into latest ROCm add_subdirectory(experimental/gen_ai) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index 53b8020c6b..48dc6569f4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -30,7 +30,11 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#if (defined(USE_ROCM) && ROCM_VERSION >= 60300) +// NOTE: This source is currently only available in the `develop` branch of CK +// https://github.com/ROCm/composable_kernel #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#endif // Define commonly used types. template @@ -42,6 +46,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; namespace fbgemm_gpu { +#if (defined(USE_ROCM) && ROCM_VERSION >= 60300) template < int BLOCK_SIZE, int MBLOCK, @@ -269,4 +274,20 @@ at::Tensor f8f8bf16_blockwise( } } +#else + +at::Tensor f8f8bf16_blockwise( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + int64_t block_m = 128, + int64_t block_n = 128, + int64_t block_k = 128) { + throw std::runtime_error( + "ROCm version is older than 6.3"); // requires ROCm>=6.3 +} + +#endif + } // namespace fbgemm_gpu