Skip to content

Commit fa24249

Browse files
committed
wip
1 parent 2cac703 commit fa24249

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

fbgemm_gpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ if(NOT FBGEMM_CPU_ONLY)
195195
add_subdirectory(experimental/gemm)
196196
endif()
197197

198-
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
198+
if(NOT FBGEMM_CPU_ONLY)
199199
# TODO: Re-enable gen_ai for ROCm once ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp
200200
# lands into latest ROCm
201201
add_subdirectory(experimental/gen_ai)

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
#include "ck/library/utility/host_tensor_generator.hpp"
3131
#include "ck/library/utility/literals.hpp"
3232

33+
#if (defined(USE_ROCM) && ROCM_VERSION >= 60300)
34+
// NOTE: This source is currently only available in the `develop` branch of CK
35+
// https://github.com/ROCm/composable_kernel
3336
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
37+
#endif
3438

3539
// Define commonly used types.
3640
template <ck::index_t... Is>
@@ -42,6 +46,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
4246

4347
namespace fbgemm_gpu {
4448

49+
#if (defined(USE_ROCM) && ROCM_VERSION >= 60300)
4550
template <
4651
int BLOCK_SIZE,
4752
int MBLOCK,
@@ -269,4 +274,20 @@ at::Tensor f8f8bf16_blockwise(
269274
}
270275
}
271276

277+
#else
278+
279+
at::Tensor f8f8bf16_blockwise(
280+
at::Tensor XQ,
281+
at::Tensor WQ,
282+
at::Tensor x_scale,
283+
at::Tensor w_scale,
284+
int64_t block_m = 128,
285+
int64_t block_n = 128,
286+
int64_t block_k = 128) {
287+
throw std::runtime_error(
288+
"ROCm version is older than 6.3"); // requires ROCm>=6.3
289+
}
290+
291+
#endif
292+
272293
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)