From 48726bfeb2987b727dff8f52bedc1cff52c33959 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:48:19 -0500 Subject: [PATCH] Gradlib torch extension cmake (#282) * Converted gradlib into a cmake project whilke using TORCH_LIBRARY binding rather than pybind11 * Made gradlib a vllm _gradlib_C module * Reusing binding includes from core vllm * The extension is created by the wrapper * Remove gradlib mentions from the dockerfile --- .github/workflows/publish.yml | 14 - .github/workflows/scripts/build.sh | 3 - CMakeLists.txt | 18 + Dockerfile.rocm | 12 +- {gradlib/csrc => csrc/gradlib}/hipbsolgemm.cu | 45 +- csrc/gradlib/ops.h | 27 + {gradlib/csrc => csrc/gradlib}/rocsolgemm.cu | 464 +++++++++--------- csrc/gradlib/torch_bindings.cpp | 18 + gradlib/{gradlib => }/GemmTuner.py | 44 +- gradlib/csrc/grad_funcs.cu | 413 ---------------- gradlib/{gradlib => }/gemm_runner.py | 13 +- gradlib/{gradlib => }/gemm_tuner.py | 7 +- gradlib/gradlib/mm_test.py | 253 ---------- gradlib/setup.py | 164 ------- pyproject.toml | 2 +- setup.py | 1 + vllm/model_executor/layers/tuned_gemm.py | 28 +- 17 files changed, 363 insertions(+), 1163 deletions(-) rename {gradlib/csrc => csrc/gradlib}/hipbsolgemm.cu (92%) create mode 100644 csrc/gradlib/ops.h rename {gradlib/csrc => csrc/gradlib}/rocsolgemm.cu (51%) create mode 100644 csrc/gradlib/torch_bindings.cpp rename gradlib/{gradlib => }/GemmTuner.py (89%) delete mode 100644 gradlib/csrc/grad_funcs.cu rename gradlib/{gradlib => }/gemm_runner.py (85%) rename gradlib/{gradlib => }/gemm_tuner.py (97%) delete mode 100644 gradlib/gradlib/mm_test.py delete mode 100644 gradlib/setup.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f2fcd72b579cb..f3dda4c25c790 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -68,12 +68,8 @@ jobs: bash -x .github/workflows/scripts/build.sh wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) asset_name=${wheel_name//"linux"/"manylinux1"} - gradlib_wheel_name=$(find gradlib/dist -name "*whl" -print0 | xargs -0 -n 1 basename) - gradlib_asset_name=${gradlib_wheel_name//"linux"/"manylinux1"} echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - echo "gradlib_wheel_name=${gradlib_wheel_name}" >> "$GITHUB_ENV" - echo "gradlib_asset_name=${gradlib_asset_name}" >> "$GITHUB_ENV" - name: Upload vllm Release Asset uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 @@ -84,13 +80,3 @@ jobs: asset_path: ./dist/${{ env.wheel_name }} asset_name: ${{ env.asset_name }} asset_content_type: application/* - - name: Upload gradlib Release Asset - uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ needs.release.outputs.upload_url }} - asset_path: ./gradlib/dist/${{ env.gradlib_wheel_name }} - asset_name: ${{ env.gradlib_asset_name }} - asset_content_type: application/* - diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 88a735de40056..f0a4e4baf1ae2 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -18,6 +18,3 @@ export MAX_JOBS=32 # Build $python_executable setup.py bdist_wheel --dist-dir=dist -cd gradlib -$python_executable setup.py bdist_wheel --dist-dir=dist -cd .. \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e6e52ea06579..440588241d6a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -508,6 +508,24 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) + + # + # _gradlib_C extension + # + set(VLLM_GRADLIB_EXT_SRC + "csrc/gradlib/torch_bindings.cpp" + "csrc/gradlib/hipbsolgemm.cu" + "csrc/gradlib/rocsolgemm.cu") + + define_gpu_extension_target( + _gradlib_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_GRADLIB_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) endif() # vllm-flash-attn currently only supported on CUDA diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 661f10fcd9b2b..66cbe61905f84 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -148,7 +148,7 @@ FROM scratch AS export_flash_attn_0 FROM export_flash_attn_${BUILD_FA} AS export_flash_attn # ----------------------- -# vLLM (and gradlib) fetch stages +# vLLM fetch stages FROM base AS fetch_vllm_0 ONBUILD COPY ./ vllm/ FROM base AS fetch_vllm_1 @@ -160,7 +160,7 @@ ONBUILD RUN git clone ${VLLM_REPO} \ FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm # ----------------------- -# vLLM (and gradlib) build stages +# vLLM build stages FROM fetch_vllm AS build_vllm ARG COMMON_WORKDIR ARG USE_CYTHON @@ -184,13 +184,9 @@ RUN cd vllm \ && python3 setup.py clean --all \ && if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \ && python3 setup.py bdist_wheel --dist-dir=dist -# Build gradlib -RUN cd vllm/gradlib \ - && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist FROM scratch AS export_vllm ARG COMMON_WORKDIR COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl / -COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/gradlib/dist/*.whl / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/rocm_patch /rocm_patch COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks @@ -265,7 +261,7 @@ RUN if [ ${BUILD_RPD} -eq "1" ]; then \ && make && make install \ && cd hipMarker && python setup.py install ; fi -# Install vLLM (and gradlib) +# Install vLLM # Make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ @@ -277,7 +273,7 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ *"rocm-6.1"*) \ cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6;; \ *) ;; esac \ - && pip uninstall -y vllm gradlib \ + && pip uninstall -y vllm \ && pip install *.whl # Copy over the benchmark scripts as well diff --git a/gradlib/csrc/hipbsolgemm.cu b/csrc/gradlib/hipbsolgemm.cu similarity index 92% rename from gradlib/csrc/hipbsolgemm.cu rename to csrc/gradlib/hipbsolgemm.cu index d5aed5deedc42..f1d0a4c301e76 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/csrc/gradlib/hipbsolgemm.cu @@ -6,7 +6,6 @@ // __HIP_NO_HALF_CONVERSIONS__ #endif #include -#include #include #include #include @@ -119,7 +118,7 @@ std::map dtype_map{ } // namespace // find all hipblaslt solutions for given gemm problem -std::vector hipblasLtMatmul_findallsols_wrapper( +std::vector hipblasLtMatmul_findallsols_wrapper( hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, int m, int n, int k, const void* alpha, const void* a, int lda, const void* b, int ldb, const void* beta, void* c, int ldc, @@ -163,7 +162,7 @@ std::vector hipblasLtMatmul_findallsols_wrapper( handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); - std::vector algoIndex; + std::vector algoIndex; int returned_algo_count = heuristicResult.size(); // for (int i = 0; i < returnedAlgoCount; i++) { for (int i = 0; i < returned_algo_count; i++) { @@ -290,12 +289,12 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( } ///////////////////////////////////////////////////////////////////////////////////////////////////////// torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, - const int solution_index, - at::optional bias = at::nullopt, - at::optional out_dtype = at::nullopt, - at::optional scale1 = at::nullopt, - at::optional scale2 = at::nullopt, - at::optional scaleOut = at::nullopt) { + const int64_t solution_index, + at::optional bias, + at::optional out_dtype, + at::optional scale1, + at::optional scale2, + at::optional scaleOut) { auto mat1_strides{mat1.strides()}; auto mat2_strides{mat2.strides()}; auto mat1_sizes{mat1.sizes()}; @@ -309,10 +308,7 @@ torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, "mat1 dim 1 must match mat2 dim 0"); auto inDtype{mat1.options().dtype().toScalarType()}; - auto outDtype{ - out_dtype.has_value() - ? torch::python::detail::py_object_to_dtype(out_dtype.value()) - : inDtype}; + auto outDtype{out_dtype.has_value() ? out_dtype.value() : inDtype}; auto options{at::TensorOptions().dtype(outDtype).device(at::kCUDA)}; auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; @@ -392,10 +388,10 @@ torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, } // find all hipblas solutions and return them to python land -std::vector hipb_findallsols( +std::vector hipb_findallsols( const torch::Tensor& mat1, const torch::Tensor& mat2, at::optional bias = at::nullopt, - at::optional out_dtype = at::nullopt) { + at::optional out_dtype = at::nullopt) { auto mat1_strides{mat1.strides()}; auto mat2_strides{mat2.strides()}; auto mat1_sizes{mat1.sizes()}; @@ -408,10 +404,7 @@ std::vector hipb_findallsols( "mat1 dim 1 must match mat2 dim 0"); auto inType{mat1.options().dtype().toScalarType()}; - auto outType{ - out_dtype.has_value() - ? torch::python::detail::py_object_to_dtype(out_dtype.value()) - : inType}; + auto outType{out_dtype.has_value() ? out_dtype.value() : inType}; auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; @@ -504,17 +497,3 @@ void hipb_destroy_extension() { // CHECK_HIP_ERROR(hipEventDestroy(start)); // CHECK_HIP_ERROR(hipEventDestroy(stop)); } - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); - m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &hipb_mm, "hipb_mm", py::arg("mat1"), py::arg("mat2"), - py::arg("solution_index"), py::arg("bias") = at::nullopt, - py::arg("out_dtype") = at::nullopt, py::arg("scale1") = at::nullopt, - py::arg("scale2") = at::nullopt, py::arg("scaleOut") = at::nullopt); - m.def("hipb_findallsols", &hipb_findallsols, "hipb_findallsols", - py::arg("mat1"), py::arg("mat2"), py::arg("bias") = at::nullopt, - py::arg("out_dtype") = at::nullopt); -} \ No newline at end of file diff --git a/csrc/gradlib/ops.h b/csrc/gradlib/ops.h new file mode 100644 index 0000000000000..43107ec3c098a --- /dev/null +++ b/csrc/gradlib/ops.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +void hipb_create_extension(); +void hipb_destroy_extension(); +torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, + const int64_t solution_index, + at::optional bias = at::nullopt, + at::optional out_dtype = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt); + +std::vector hipb_findallsols(const torch::Tensor& mat1, + const torch::Tensor& mat2, + at::optional bias, + at::optional out_dtype); + +void rocb_create_extension(); +void rocb_destroy_extension(); +torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int64_t solution_index); + +std::vector RocFindAllSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2); \ No newline at end of file diff --git a/gradlib/csrc/rocsolgemm.cu b/csrc/gradlib/rocsolgemm.cu similarity index 51% rename from gradlib/csrc/rocsolgemm.cu rename to csrc/gradlib/rocsolgemm.cu index d691fcac416a6..81c3775cc55b3 100644 --- a/gradlib/csrc/rocsolgemm.cu +++ b/csrc/gradlib/rocsolgemm.cu @@ -1,15 +1,14 @@ // #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif #define ROCBLAS_NO_DEPRECATED_WARNINGS #define ROCBLAS_BETA_FEATURES_API #include -#include #include #include #include @@ -23,7 +22,7 @@ #include #include -//#include +// #include #include #include @@ -36,95 +35,86 @@ #include - // #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif // #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL // #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif namespace { - rocblas_handle r_handle; - - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; - - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - uint64_t workspace_size = 32*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipblasDatatype_t dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; +rocblas_handle r_handle; + +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +uint64_t workspace_size = 32 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipblasDatatype_t dtype; + + friend auto operator<(const MatMulConfig& left, + const MatMulConfig& right) -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; - // std::map, std::vector> heuristic_map; - std::map heuristic_map; +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; - bool cout_print = true; -} +bool cout_print = true; +} // namespace ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call -*/ + */ /* hipblasStatus_t hipblasLtMatmul_wrapper( hipblasLtHandle_t handle, @@ -146,7 +136,8 @@ hipblasStatus_t hipblasLtMatmul_wrapper( int flag { 0 }; if (dtype == HIPBLAS_R_16F) { // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + // +https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices flag = rocblas_gemm_flags_fp16_alt_impl; } @@ -164,28 +155,31 @@ hipblasStatus_t hipblasLtMatmul_wrapper( CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); } CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, +HIPBLAS_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, +HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); nvtxRangePop(); // if heuristic does not exist in the map, do search and push into the map - auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; if (heuristic_map.count(gemm_key) <= 0) { nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? +"N" : "T") << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc +<< "), " << std::endl; } - std::vector heuristicResult(request_solutions); + std::vector +heuristicResult(request_solutions); CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { - std::cout << "less solution found! request: " << request_solutions + preference, request_solutions, heuristicResult.data(), +&returnedAlgoCount)); if((returnedAlgoCount != request_solutions) && cout_print) +{ std::cout << "less solution found! request: " << request_solutions << ", found: " << returnedAlgoCount << std::endl; } @@ -204,7 +198,8 @@ hipblasStatus_t hipblasLtMatmul_wrapper( b, matB, beta, c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c + c, matC, // In case beta != 0, these runs can overwrite the values +in c // since c and d are the same // TODO: allocates separate d memory for these runs &heuristicResult[sol].algo, @@ -221,7 +216,8 @@ hipblasStatus_t hipblasLtMatmul_wrapper( b, matB, beta, c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c + c, matC, // In case beta != 0, these runs can overwrite the values +in c // since c and d are the same // TODO: allocates separate d memory for these runs &heuristicResult[sol].algo, @@ -236,7 +232,8 @@ hipblasStatus_t hipblasLtMatmul_wrapper( eventMs /= bench_iters; if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; + std::cout << " Sol " << sol << ": average time per iter " << +std::to_string(eventMs) << " ms"; } if (bestMs > eventMs) { bestMs = eventMs; @@ -277,43 +274,46 @@ hipblasStatus_t hipblasLtMatmul_wrapper( } */ ///////////////////////////////////////////////////////////////////////////////////////////////////////// -std::vector RocFindAllSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2 - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; +std::vector RocFindAllSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; @@ -324,8 +324,8 @@ std::vector RocFindAllSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - float one { 1.0f }; - float zero { 0.0f }; + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; @@ -333,13 +333,13 @@ std::vector RocFindAllSolIdxBlas( int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; rocblas_set_stream(r_handle, current_stream); - uint32_t flags { 0 }; + uint32_t flags{0}; rocblas_datatype abcRtype; if (abcType == at::kHalf) { abcRtype = rocblas_datatype_f16_r; @@ -351,80 +351,91 @@ std::vector RocFindAllSolIdxBlas( assert(false && "Wrong datatype!"); } - #define GEMM_EX_ARGS \ - r_handle, transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, \ - m, n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, ptrC, \ - abcRtype, result_ld, ptrC, abcRtype, result_ld, rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index - - rocblas_int sizeSolve; - //CHECK_ROCBLAS_ERROR( - rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &sizeSolve); - - // Fill array with list of solutions that match type - // Note: some of these may be invalid - std::vector solutionsSolve(sizeSolve); - //CHECK_ROCBLAS_ERROR( - rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, solutionsSolve.data(), &sizeSolve); - - std::vector validSolutions; - for(auto sol : solutionsSolve) { - auto status = rocblas_gemm_ex(r_handle, - transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, - transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, - m, n, k, - &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, - &zero, ptrC, abcRtype, result_ld, - ptrC, abcRtype, result_ld, - rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index, sol, rocblas_gemm_flags_none); - if (status == rocblas_status_success) { - validSolutions.push_back(sol); - } - } +#define GEMM_EX_ARGS \ + r_handle, \ + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, \ + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, \ + m, n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, \ + ptrC, abcRtype, result_ld, ptrC, abcRtype, result_ld, \ + rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index + + rocblas_int sizeSolve; + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, + &sizeSolve); + + // Fill array with list of solutions that match type + // Note: some of these may be invalid + std::vector solutionsSolve(sizeSolve); + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, + solutionsSolve.data(), &sizeSolve); + + std::vector validSolutions; + for (auto sol : solutionsSolve) { + auto status = rocblas_gemm_ex( + r_handle, + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, + m, n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, + ptrC, abcRtype, result_ld, ptrC, abcRtype, result_ld, + rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index, sol, + rocblas_gemm_flags_none); + if (status == rocblas_status_success) { + validSolutions.push_back(sol); + } + } - return validSolutions; + return validSolutions; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -torch::Tensor RocSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - const int32_t solution_index=0 - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; +torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int64_t solution_index = 0) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { @@ -436,14 +447,19 @@ torch::Tensor RocSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; @@ -451,8 +467,9 @@ torch::Tensor RocSolIdxBlas( int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; - + // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " + // << result_ld << std::endl; + /* int flag { 0 }; hipblasDatatype_t hipblasType; @@ -466,11 +483,11 @@ torch::Tensor RocSolIdxBlas( assert(false && "Wrong datatype!"); } */ - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - /* + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + /* CHECK_HIPBLAS_ERROR(hipblasLtMatmul_wrapper( hipblaslt_handle, @@ -486,8 +503,8 @@ torch::Tensor RocSolIdxBlas( current_stream)); */ rocblas_set_stream(r_handle, current_stream); - uint32_t flags { 0 }; - //int32_t solution_index {0}; + uint32_t flags{0}; + // int32_t solution_index {0}; rocblas_datatype abcRtype; if (abcType == at::kHalf) { abcRtype = rocblas_datatype_f16_r; @@ -499,25 +516,22 @@ torch::Tensor RocSolIdxBlas( assert(false && "Wrong datatype!"); } - //CHECK_ROCBLAS_ERROR( - rocblas_gemm_ex(r_handle, - transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, - transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, - m, n, k, - &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, - &zero, ptrC, abcRtype, result_ld, - ptrC, abcRtype, result_ld, - rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index, solution_index, flags); + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex( + r_handle, + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, m, + n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, ptrC, + abcRtype, result_ld, ptrC, abcRtype, result_ld, rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, solution_index, flags); //); - return result; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void rocb_create_extension() -{ +void rocb_create_extension() { /* CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); @@ -527,8 +541,9 @@ void rocb_create_extension() CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + CHECK_HIP_ERROR(hipEventCreate(&start)); CHECK_HIP_ERROR(hipEventCreate(&stop)); */ rocblas_create_handle(&r_handle); @@ -536,28 +551,17 @@ void rocb_create_extension() ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void rocb_destroy_extension() -{ - /* - CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - CHECK_HIP_ERROR(hipEventDestroy(event)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); - - CHECK_HIP_ERROR(hipEventDestroy(start)); - CHECK_HIP_ERROR(hipEventDestroy(stop)); */ - rocblas_destroy_handle(r_handle); -} +void rocb_destroy_extension() { + /* + CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + CHECK_HIP_ERROR(hipEventDestroy(event)); -///////////////////////////////////////////////////////////////////////////////////////////////////////// + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("rocb_create_extension", &rocb_create_extension, "create_extension"); - m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); - m.def("rocb_mm", &RocSolIdxBlas, "mm"); - m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols"); + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); */ + rocblas_destroy_handle(r_handle); } diff --git a/csrc/gradlib/torch_bindings.cpp b/csrc/gradlib/torch_bindings.cpp new file mode 100644 index 0000000000000..1818df584c5d6 --- /dev/null +++ b/csrc/gradlib/torch_bindings.cpp @@ -0,0 +1,18 @@ +#include "core/registration.h" +#include "gradlib/ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, gradlib_ops) { + // Gradlib custom ops + + gradlib_ops.def("hipb_create_extension", &hipb_create_extension); + gradlib_ops.def("hipb_destroy_extension", &hipb_destroy_extension); + gradlib_ops.def("hipb_mm", &hipb_mm); + gradlib_ops.def("hipb_findallsols", &hipb_findallsols); + + gradlib_ops.def("rocb_create_extension", &rocb_create_extension); + gradlib_ops.def("rocb_destroy_extension", &rocb_destroy_extension); + gradlib_ops.def("rocb_mm", &RocSolIdxBlas); + gradlib_ops.def("rocb_findallsols", &RocFindAllSolIdxBlas); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/GemmTuner.py similarity index 89% rename from gradlib/gradlib/GemmTuner.py rename to gradlib/GemmTuner.py index 60572b3eecd53..b982f4b13fb9b 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/GemmTuner.py @@ -2,14 +2,11 @@ import random from pathlib import Path -import hipbsolidxgemm import pandas as pd -import rocsolidxgemm import torch import torch.nn.functional as F -rocsolidxgemm.rocb_create_extension() -hipbsolidxgemm.hipb_create_extension() +import vllm._gradlib_C # noqa: F401 rtol = 1e-5 atol = 1 @@ -54,10 +51,9 @@ def __init__(self, m, n, k, bias, indtype, outdtype, rocblas_decode=False): self.rocblas_decode = rocblas_decode def find_hipblas_sols(self): - sols = hipbsolidxgemm.hipb_findallsols(self.inp, - self.weights.t(), - bias=self.bias, - out_dtype=self.outdtype) + sols = torch.ops._gradlib_C.hipb_findallsols(self.inp, + self.weights.t(), + self.bias, self.outdtype) print('M N K bias dtype', self.m, self.n, @@ -82,13 +78,12 @@ def check_gemm_ref(self, libtype, solidx): else: ref = F.linear(self.inp, self.weights, self.bias) if libtype == 'hipblaslt': - c = hipbsolidxgemm.hipb_mm(self.inp, - self.weights.t(), - solidx, - bias=self.bias, - out_dtype=self.outdtype) + c = torch.ops._gradlib_C.hipb_mm(self.inp, self.weights.t(), + solidx, self.bias, self.outdtype, + None, None, None) elif libtype == 'rocblas': - c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx) + c = torch.ops._gradlib_C.rocb_mm(self.inp, self.weights.t(), + solidx) if self.bias is not None: c += self.bias if torch.allclose(c.to(self.outdtype), @@ -110,17 +105,13 @@ def check_gemm_ref(self, libtype, solidx): def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): #print('>>>hipbtime',solidx) for i in range(cold_iters): - hipbsolidxgemm.hipb_mm(self.inp, - self.weights.t(), - solidx, - out_dtype=self.outdtype) + torch.ops._gradlib_C.hipb_mm(self.inp, self.weights.t(), solidx, + None, self.outdtype, None, None, None) self.start.record() for i in range(warm_iters): - hipbsolidxgemm.hipb_mm(self.inp, - self.weights2[random.randint( - 0, self.nb - 1)].t(), - solidx, - out_dtype=self.outdtype) + torch.ops._gradlib_C.hipb_mm( + self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), + solidx, None, self.outdtype, None, None, None) self.end.record() torch.cuda.synchronize() gtime = self.start.elapsed_time(self.end) / warm_iters @@ -151,10 +142,10 @@ def hipb_time_all_sols(self, fast_mode=0, top_sols=0): def rocb_time_sol(self, solidx, cold_iters=2, warm_iters=10): def rocb_mm_bias(inp, w, solidx, bias): - return rocsolidxgemm.rocb_mm(inp, w, solidx) + bias + return torch.ops._gradlib_C.rocb_mm(inp, w, solidx) + bias def rocb_mm_nobias(inp, w, solidx, _): - return rocsolidxgemm.rocb_mm(inp, w, solidx) + return torch.ops._gradlib_C.rocb_mm(inp, w, solidx) rocb_fun = rocb_mm_bias if self.bias is not None else rocb_mm_nobias for _ in range(cold_iters): @@ -173,7 +164,8 @@ def rocb_mm_nobias(inp, w, solidx, _): return gtime def find_rocblas_sols(self): - sols = rocsolidxgemm.rocb_findallsols(self.inp, self.weights.t()) + sols = torch.ops._gradlib_C.rocb_findallsols(self.inp, + self.weights.t()) print('M N K dtype', self.m, self.n, diff --git a/gradlib/csrc/grad_funcs.cu b/gradlib/csrc/grad_funcs.cu deleted file mode 100644 index f6498fb2a3ba7..0000000000000 --- a/gradlib/csrc/grad_funcs.cu +++ /dev/null @@ -1,413 +0,0 @@ -// #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif - -#include -#include -#include -#include -#include -#include -#include -#include -// #include -#include -#include -#include -#include - -#include -//#include -#include - -#include -#include -#include -#include -#include -#include -#include "nvToolsExt.h" - -// #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif - -// #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL -// #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif - -#ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } -#endif - -#ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } -#endif - -namespace { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; - - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - uint64_t workspace_size = 32*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipblasDatatype_t dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; - - // std::map, std::vector> heuristic_map; - std::map heuristic_map; - - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; - - bool cout_print = true; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * hipBLASLt GEMM call -*/ -hipblasStatus_t hipblasLtMatmul_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipblasDatatype_t dtype, - hipStream_t &stream) -{ - // TODO: flag is not supported for hipblasLt yet - int flag { 0 }; - if (dtype == HIPBLAS_R_16F) { - // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - flag = rocblas_gemm_flags_fp16_alt_impl; - } - - nvtxRangePushA("hipBLASLt variables creation"); - hipblasLtMatrixLayout_t matA, matB, matC; - hipblasLtMatmulDesc_t matmul; - if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); - } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); - } - if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); - } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); - } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - nvtxRangePop(); - - // if heuristic does not exist in the map, do search and push into the map - auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; - if (heuristic_map.count(gemm_key) <= 0) { - nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; - } - std::vector heuristicResult(request_solutions); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { - std::cout << "less solution found! request: " << request_solutions - << ", found: " << returnedAlgoCount << std::endl; - } - - if (returnedAlgoCount == 1) { - heuristic_map[gemm_key] = heuristicResult[0]; - } else { - // benchmark requested solutions and pick best one - int bestIndex { -1 }; - double bestMs { std::numeric_limits::max() }; - for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { - // warm up - for (int iter { 0 }; iter < warmup_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter { 0 }; iter < bench_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; - } - if (bestMs > eventMs) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) { - std::cout << " *" << std::endl; - } - } else { - if (cout_print) { - std::cout << std::endl; - } - } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; - } - nvtxRangePop(); - } - - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristic_map[gemm_key].algo, - d_workspace, workspace_size, - stream); - - nvtxRangePushA("hipBLASLt variables deletion"); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - nvtxRangePop(); - - return status; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// -torch::Tensor hipBLASLtMm_( - const torch::Tensor& mat1, - const torch::Tensor& mat2) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; - - bool transpose_result = true; - bool transpose_mat1; - bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { - transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { - transpose_mat2 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { - transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { - transpose_mat1 = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_mat1; - transpose_mat1 = !transpose_mat2; - transpose_mat2 = !tmp; - mat1_strides = mat2.strides(); - mat2_strides = mat1.strides(); - mat1_sizes = mat2.sizes(); - mat2_sizes = mat1.sizes(); - } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; - int64_t m = mat1_sizes[transpose_result ? 1 : 0]; - int64_t k = mat1_sizes[transpose_result ? 0 : 1]; - int64_t n = mat2_sizes[transpose_result ? 0 : 1]; - int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; - int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; - int64_t result_ld = result.stride(transpose_result ? 0 : 1); - // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; - - int flag { 0 }; - hipblasDatatype_t hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIPBLAS_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIPBLAS_R_16B; - } else if (abcType == at::kFloat) { - hipblasType = HIPBLAS_R_32F; - } else { - assert(false && "Wrong datatype!"); - } - - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - - CHECK_HIPBLAS_ERROR(hipblasLtMatmul_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream)); - - return result; -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -void create_extension() -{ - CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); - CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - CHECK_HIP_ERROR(hipEventCreate(&start)); - CHECK_HIP_ERROR(hipEventCreate(&stop)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -void destroy_extension() -{ - CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - CHECK_HIP_ERROR(hipEventDestroy(event)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); - - CHECK_HIP_ERROR(hipEventDestroy(start)); - CHECK_HIP_ERROR(hipEventDestroy(stop)); -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("create_extension", &create_extension, "create_extension"); - m.def("destroy_extension", &destroy_extension, "destroy_extension"); - m.def("mm", &hipBLASLtMm_, "mm"); -} diff --git a/gradlib/gradlib/gemm_runner.py b/gradlib/gemm_runner.py similarity index 85% rename from gradlib/gradlib/gemm_runner.py rename to gradlib/gemm_runner.py index eb49d446bef3f..a38724bfc1e5f 100644 --- a/gradlib/gradlib/gemm_runner.py +++ b/gradlib/gemm_runner.py @@ -1,13 +1,13 @@ import sys -import hipbsolidxgemm import pandas as pd -import rocsolidxgemm import torch import torch.nn.functional as F -rocsolidxgemm.rocb_create_extension() -hipbsolidxgemm.hipb_create_extension() +import vllm._gradlib_C # noqa: F401 + +torch.ops._gradlib_C.rocb_create_extension() +torch.ops._gradlib_C.hipb_create_extension() class TunedGemm: @@ -38,9 +38,10 @@ def mm(self, inp, weights): n=inp.shape[0], k=inp.shape[1]) if soltype == 1: - out = hipbsolidxgemm.hipb_mm(inp, weights.t(), solidx) + out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None, + None, None, None, None) elif soltype == 2: - out = rocsolidxgemm.rocb_mm(inp, weights.t(), solidx) + out = torch.ops._gradlib_C.rocb_mm(inp, weights.t(), solidx) else: out = F.linear(inp, weights) return out diff --git a/gradlib/gradlib/gemm_tuner.py b/gradlib/gemm_tuner.py similarity index 97% rename from gradlib/gradlib/gemm_tuner.py rename to gradlib/gemm_tuner.py index 60ceedbec468f..0c0b42ba48ade 100644 --- a/gradlib/gradlib/gemm_tuner.py +++ b/gradlib/gemm_tuner.py @@ -4,14 +4,13 @@ from pathlib import Path import torch # isort: split -import hipbsolidxgemm import pandas as pd -import rocsolidxgemm +import vllm._gradlib_C # noqa: F401 from gradlib.GemmTuner import GemmTuner -rocsolidxgemm.rocb_create_extension() -hipbsolidxgemm.hipb_create_extension() +torch.ops._gradlib_C.rocb_create_extension() +torch.ops._gradlib_C.hipb_create_extension() def generate_mk_sets(model_dir, tp=1): diff --git a/gradlib/gradlib/mm_test.py b/gradlib/gradlib/mm_test.py deleted file mode 100644 index c06c8a4f18c0c..0000000000000 --- a/gradlib/gradlib/mm_test.py +++ /dev/null @@ -1,253 +0,0 @@ -import sys - -import hipbsolidxgemm -import pandas as pd -#import gradlib -import rocsolidxgemm -import torch -import torch.nn.functional as F - -#gradlib.create_extension() -rocsolidxgemm.rocb_create_extension() -hipbsolidxgemm.hipb_create_extension() - -#m = 128; n = 192 ;k = 256 -#m = 7168; k = 4096*2; n = 256 -#m = int(1024*1.25); k = int(1024*8); n = 1 -#m = 1; k = int(1024*8); n = int(1024*7) -#m=22016; k=4096 ; n=1 -#m=int(27648/1);k=5120;n=8 -#m=5120;k=13824;n=1 -m = 3 * 5120 -k = 5120 -n = 1 - -rtol = 1e-5 -atol = 1 -dtype = torch.float16 - - -class Gemm: - - def __init__(self, m, n, k, dtype=torch.float16): - self.m = m - self.k = k - self.n = n - self.dtype = dtype - self.inp = torch.randn((self.n, self.k), - dtype=self.dtype, - device='cuda') - self.weights = torch.randn((self.m, self.k), - dtype=self.dtype, - device='cuda') - self.hipb_sols = [] - self.rtol = 1e-5 - self.atol = 1 - self.cold_iters = 2 - self.warm_iters = 10 - - def find_hipblas_sols(self): - sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t()) - print('M N K', self.m, self.n, self.k, '>>> Total hipb solutions', - len(sols)) - #print(sols) - self.hipb_sols = sols - - def hipb_check_gemm_ref(self, user_solidxs=None): - ref = F.linear(self.inp, self.weights) - solidxs = user_solidxs if user_solidxs is not None else self.hipb_sols - if len(solidxs) > 0: - for solidx in solidxs: - c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx) - if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): - print('>>> Hipb solidx', solidx, 'passed reference test') - else: - print('>>> Hipb solidx', solidx, 'FAILED reference test') - print(ref) - print(c) - - def hipb_time_sol(self, solidx): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - for i in range(self.cold_iters): - hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx) - start.record() - for i in range(self.warm_iters): - hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx) - end.record() - torch.cuda.synchronize() - gtime = start.elapsed_time(end) / self.warm_iters - #print('>>> Solidx GTime',solidx,gtime,'ms') - return gtime - - def hipb_time_all_sols(self): - gtimes = {} - for solidx in self.hipb_sols: - gtimes[solidx] = self.hipb_time_sol(solidx) - self.gtimedf = pd.DataFrame.from_dict( - gtimes, orient='index', - columns=['gtimems']).sort_values(by='gtimems') - self.gtimedf.to_csv('/tmp/gtimedf.csv') - print(self.gtimedf.head(10)) - - -gemmobj = Gemm(m=3 * 5120, n=1, k=5120) -gemmobj.find_hipblas_sols() -#gemmobj.hipb_check_gemm_ref() -#gemmobj.hipb_check_gemm_ref(user_solidxs=[131,8190]) -#gemmobj.hipb_time_sol(gemmobj.hipb_sols[0]) -gemmobj.hipb_time_all_sols() -gemmobj.hipb_check_gemm_ref(user_solidxs=gemmobj.gtimedf.head(5).index.values) - -sys.exit() - - -def splitk_linear(inp, w, splitk=2): - wsp = torch.chunk(w, splitk, dim=1) - isp = torch.chunk(inp, splitk, dim=1) - print('>>>', isp[0].shape, wsp[1].shape) - cnew = [] - for i in range(splitk): - cnew.append(F.linear(isp[i], wsp[i])) - #cnew1 = F.linear(isp[1],wsp[1]) - c = cnew[0] - for i in range(1, splitk): - c.add_(cnew[i]) - #c = torch.add(cnew0,cnew1) - - return c - - -def splitm_linear(inp, w, splitm=2, splits=None, splitk=1): - outputp = [] - #wsp = torch.chunk(F.pad(weights,(0,0,0,padm)),splitm) - if splits is not None: - wsp = torch.split(w, splits) - else: - wsp = torch.chunk(w, splitm) - #cout = torch.empty(inp.shape[0], w.shape[0], - # dtype=inp.dtype,device=inp.device) - #csp = torch.chunk(cout,splitm,dim=1) - - for i, _ in enumerate(wsp): - #print('>>>wspi',wsp[i].shape) - if splitk == 1: - outputp.append(F.linear(inp, wsp[i])) - #cout[:,i*wsp[i].shape[0]: - # (i+1)*wsp[i].shape[0]] = F.linear(inp, wsp[i]) - #csp[i].copy_(F.linear(inp, wsp[i])) - else: - outputp.append(splitk_linear(inp, wsp[i], splitk)) - c = torch.cat((outputp), dim=1) - #print('>>>',c.shape,cout.shape) - return c - - -def splitn_linear(inp, w, splitn=2, splits=None): - outputp = [] - if splits is not None: - isp = torch.split(inp, splits) - else: - isp = torch.chunk(inp, splitn) - torch.empty(inp.shape[0], w.shape[0], dtype=inp.dtype, device=inp.device) - for i, _ in enumerate(isp): - outputp.append(F.linear(isp[i], w)) - #cout[i*isp[i].shape[0]: - # (i+1)*isp[i].shape[0],:] = F.linear(isp[i], w) - c = torch.cat((outputp), dim=0) - #print('>>>',c.shape,cout.shape) - return c - - -nncount = 0 -for _ in range(10): - #a = torch.randn((m, k), dtype=dtype, device='cuda') - #b = torch.randn((k, n), dtype=dtype, device='cuda') - inp = torch.randn((n, k), dtype=dtype, device='cuda') - weights = torch.randn((m, k), dtype=dtype, device='cuda') - #c = gradlib.mm(inp, weights.t()) - c = hipbsolidxgemm.hipb_mm(inp, weights.t(), 20053) - c = hipbsolidxgemm.hipb_mm(inp, weights.t(), 20053) - c = rocsolidxgemm.rocb_mm(inp, weights.t(), 60995) - c = rocsolidxgemm.rocb_mm(inp, weights.t(), 60995) - - splitm = 2 - #padm=2 - outsp = [] - #wsp = torch.chunk(F.pad(weights,(0,0,0,padm)),splitm) - #wsp = torch.chunk(weights,splitm) - #wsp = torch.split(weights,(3*1024,4*1024)) - #c = torch.empty((n,m),dtype=dtype,device='cuda') - #outtup = [] - #for i,_ in enumerate(wsp): - # print('>>>wspi',wsp[i].shape) - # outsp.append(F.linear(inp, wsp[i])) - # #outtup.append(splitk_linear(inp, wsp[i])) - #outsp = [torch.add(a,b) for a,b in outtup] - #c = torch.cat((outsp),dim=1) - #c = c[:,:-padm] - #c = splitm_linear(inp,weights,splitm=4,splits=None,splitk=1) - #c = splitn_linear(inp,weights,splitn=2,splits=None) - - #wsp = torch.chunk(weights,2,dim=1) - #isp = torch.chunk(inp,2,dim=1) - #print('>>>',isp[0].shape,wsp[1].shape) - #cnew0 = F.linear(isp[0],wsp[0]) - #cnew1 = F.linear(isp[1],wsp[1]) - #c = torch.add(cnew0,cnew1) - #c = splitk_linear(inp, weights, splitk=4) - - #torch.cuda.synchronize() - ref = F.linear(inp, weights) - #ref = torch.matmul(a,b) - if torch.allclose(c, ref, atol=atol, rtol=rtol): - nncount += 1 - else: - print(ref) - print(c) -''' -tncount = 0 -for _ in range(10): - a = torch.randn((m, k), dtype=dtype, device='cuda') - b = torch.randn((n, k), dtype=dtype, device='cuda') - c = gradlib.mm(a, b.t()) - #torch.cuda.synchronize() - ref = torch.matmul(a, b.t()) - if torch.allclose(c, ref, atol=atol, rtol=rtol): - tncount += 1 - else: - print(ref) - print(c) - #torch.save(c-ref, '/tmp/difference.pt') - #np.savetxt('my_file.txt', (c-ref).cpu().numpy()) - dfs = ref - c - nz = torch.nonzero(dfs,as_tuple=True) - print(nz) - print(dfs[nz]) - print(ref[nz]) - print(c[nz]) -''' -''' -ntcount = 0 -for _ in range(10): - a = torch.randn((k, m), dtype=dtype, device='cuda') - b = torch.randn((k, n), dtype=dtype, device='cuda') - c = gradlib.mm(a.t(), b) - #torch.cuda.synchronize() - if torch.allclose(c, torch.matmul(a.t(), b), atol=atol, rtol=rtol): - ntcount += 1 - -ttcount = 0 -for _ in range(10): - a = torch.randn((k, m), dtype=dtype, device='cuda') - b = torch.randn((n, k), dtype=dtype, device='cuda') - c = gradlib.mm(a.t(), b.t()) - torch.cuda.synchronize() - if torch.allclose(c, torch.matmul(a.t(), b.t()), atol=atol, rtol=rtol): - ttcount += 1 -''' -print(f"GEMM (m, n, k) = {n}, {m}, {k}") -print(f"NN GEMMs: pass {nncount}/10, tol={rtol}") -#print(f"TN GEMMs: pass {tncount}/10, tol={rtol}") -#print(f"NT GEMMs: pass {ntcount}/10, tol={rtol}") -#print(f"TT GEMMs: pass {ttcount}/10, tol={rtol}") diff --git a/gradlib/setup.py b/gradlib/setup.py deleted file mode 100644 index e90eacfe2a7c2..0000000000000 --- a/gradlib/setup.py +++ /dev/null @@ -1,164 +0,0 @@ -import os - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -from torch.utils.hipify import hipify_python - -this_dir = os.path.dirname(os.path.abspath(__file__)) -#gpus = subprocess.check_output("/opt/rocm/bin/rocminfo").decode( -# 'UTF-8').split('\n') -#gpus = list(set([re.search('(gfx94.)', g).group(0) -# for g in gpus if 'gfx94' in g])) -gpus = ['gfx90a', 'gfx940', 'gfx941', 'gfx942'] -#gpus = ['gfx90a','gfx940'] -extra_args = ["--offload-arch=" + g for g in gpus] - -#sets_rocm_pytorch = False -maj_ver, min_ver, *_ = torch.__version__.split('.') -if int(maj_ver) > 1 or (int(maj_ver) == 1 and int(min_ver) >= 5): - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = bool(torch.version.hip is not None - and ROCM_HOME is not None) - -ext_modules = [] - -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', - 'CUDAGenerator.h')): - generator_flag = ['-DOLD_GENERATOR'] - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) - -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - -include_dirs = [os.path.join(this_dir, 'csrc')] - -#if is_rocm_pytorch: -# import shutil -# with hipify_python.GeneratedFileCleaner( -# keep_intermediates=True) as clean_ctx: -# hipify_python.hipify(project_directory=this_dir, -# output_directory=this_dir, includes="csrc/*", -# show_detailed=True, -# is_pytorch_extension=True, -# clean_ctx=clean_ctx) - -if not is_rocm_pytorch: - ext_modules.append( - CUDAExtension(name='gradlib', - sources=['grad_funcs.cu'], - extra_compile_args={ - 'cxx': [ - '-O3', - ], - 'nvcc': [ - '-O3', '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - "--expt-relaxed-constexpr", - "-ftemplate-depth=1024", - '-gencode=arch=compute_70,code=sm_70', - '-gencode=arch=compute_80,code=sm_80', - '-gencode=arch=compute_80,code=compute_80' - ] - })) -elif is_rocm_pytorch: - #if torch.__version__ <= '1.8': - hipify_ver = [ - int(x) for x in torch.utils.hipify.__version__.split(".") - ] if hasattr(torch.utils.hipify, "__version__") else [0, 0, 0] - if hipify_ver < [1, 0, 0]: - with hipify_python.GeneratedFileCleaner( - keep_intermediates=True) as clean_ctx: - hipify_python.hipify(project_directory=this_dir, - output_directory=this_dir, - includes="csrc/*", - show_detailed=True, - is_pytorch_extension=True, - clean_ctx=clean_ctx) - - ext_modules.append( - CUDAExtension(name='gradlib', - sources=['./csrc/hip/grad_funcs.hip'], - extra_compile_args={ - 'cxx': [ - '-O3', - ] + version_dependent_macros, - 'nvcc': ['-O3'] + extra_args - })) - else: - #ext_modules.append( - # CUDAExtension( - # name='gradlib', - # sources=['./csrc/grad_funcs.cu'], - # include_dirs=include_dirs, - # # add additional libraries argument for hipblaslt - # libraries=['hipblaslt'], - # extra_compile_args={ - # 'cxx': ['-O3',], - # 'nvcc':['-O3', - # '-U__CUDA_NO_HALF_OPERATORS__', - # '-U__CUDA_NO_HALF_CONVERSIONS__', - # "-ftemplate-depth=1024"] + extra_args - # } - # ) - # ) - ext_modules.append( - CUDAExtension( - name='rocsolidxgemm', - sources=['./csrc/rocsolgemm.cu'], - include_dirs=include_dirs, - # add additional libraries argument for hipblaslt - libraries=['rocblas'], - extra_compile_args={ - 'cxx': [ - '-O3', - '-DLEGACY_HIPBLAS_DIRECT=ON', - ], - 'nvcc': [ - '-O3', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - "-ftemplate-depth=1024", - '-DLEGACY_HIPBLAS_DIRECT=ON', - ] + extra_args - })) - ext_modules.append( - CUDAExtension( - name='hipbsolidxgemm', - sources=['./csrc/hipbsolgemm.cu'], - include_dirs=include_dirs, - # add additional libraries argument for hipblaslt - libraries=['hipblaslt'], - extra_compile_args={ - 'cxx': [ - '-O3', - '-DLEGACY_HIPBLAS_DIRECT=ON', - ], - 'nvcc': [ - '-O3', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - "-ftemplate-depth=1024", - '-DLEGACY_HIPBLAS_DIRECT=ON', - ] + extra_args - })) - -setup(name='gradlib', - packages=['gradlib'], - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension}) - -# python setup.py build && cp build/lib*/gradlib* ../ diff --git a/pyproject.toml b/pyproject.toml index e0ed5556ede02..a8cd65173d8b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies, subtile" -skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./gradlib,./csrc/rocm" +skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./csrc/gradlib,./csrc/rocm" [tool.isort] use_parentheses = true diff --git a/setup.py b/setup.py index b936589869e76..9d4bae6dd265e 100644 --- a/setup.py +++ b/setup.py @@ -503,6 +503,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + ext_modules.append(CMakeExtension(name="vllm._gradlib_C")) if _is_cuda(): ext_modules.append( diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 0595ff83be250..a441ca5def07b 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -4,20 +4,30 @@ import pandas as pd import torch import torch.nn.functional as F -from hipbsolidxgemm import hipb_create_extension, hipb_mm -from rocsolidxgemm import rocb_create_extension, rocb_mm from vllm import _custom_ops as ops from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM from vllm.platforms import current_platform from vllm.utils import is_navi +support_tuned_gemms = False +if current_platform.is_rocm(): + import vllm._gradlib_C # noqa: F401 + support_tuned_gemms = True + + +def hipb_mm(inp, weights, solidx, bias=None): + return torch.ops._gradlib_C.hipb_mm(inp, weights, solidx, bias, None, None, + None, None) + + +def rocb_mm(inp, weights, solidx): + return torch.ops._gradlib_C.rocb_mm(inp, weights, solidx) + class TunedGemm: def __init__(self): - #rocb_create_extension() - #hipb_create_extension() self.extensions_created = False self.save_gemm = int(os.environ.get('VLLM_TUNE_GEMM', 0)) self.untune_path = os.environ.get('VLLM_UNTUNE_FILE', @@ -54,7 +64,7 @@ def create_ds(self): soltype = 2 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds - #print('>>>',solds) + def query_sol(self, m, n, k, bias, dtype): return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) @@ -81,6 +91,8 @@ def apply_skinny(self, m, n, k, inp_view, weights): return None def mm(self, inp, weights, bias=None): + if not support_tuned_gemms: + return F.linear(inp, weights, bias) # F.Linear can take a 3 dimensional input. vllm # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only @@ -94,8 +106,8 @@ def mm(self, inp, weights, bias=None): inp_view = inp batched = False if self.extensions_created is False: - rocb_create_extension() - hipb_create_extension() + torch.ops._gradlib_C.rocb_create_extension() + torch.ops._gradlib_C.hipb_create_extension() self.extensions_created = True m = weights.shape[0] n = inp_view.shape[0] @@ -114,7 +126,7 @@ def mm(self, inp, weights, bias=None): return out + bias return out elif soltype == 1: - out = hipb_mm(inp_view, weights.t(), solidx, bias=bias) + out = hipb_mm(inp_view, weights.t(), solidx, bias) elif soltype == 2: out = rocb_mm(inp_view, weights.t(), solidx) if bias is not None: