From c6f3ca5a838380784ca5cf340b0db6fb3e92ae39 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 8 Jul 2024 14:21:50 +0900 Subject: [PATCH] [Dev] Append Efficient CUDA test for low precision batch decoding (#80) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability --- THIRDPARTYNOTICES.txt | 204 +++++ testing/cpp/CMakeLists.txt | 1 + .../cpp/efficient_i4_cuda_impl/CMakeLists.txt | 20 + .../efficient_i4_cuda_impl/efficient_i4.cu | 391 +++++++++ .../cpp/efficient_i4_cuda_impl/i4matmul.hpp | 826 ++++++++++++++++++ .../param_permutate.cpp | 89 ++ testing/python/module/test_bitblas_linear.py | 3 +- .../test_int4b_fp16_convert.py | 87 +- 8 files changed, 1563 insertions(+), 58 deletions(-) create mode 100644 testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt create mode 100644 testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu create mode 100644 testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp create mode 100644 testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index f377e67bb..d959effbb 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -206,3 +206,207 @@ Notice for apache/tvm limitations under the License. ------------------------------------------------------------------------------------ +Notice for IST-DASLab/marlin/ +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +------------------------------------------------------------------------------------ diff --git a/testing/cpp/CMakeLists.txt b/testing/cpp/CMakeLists.txt index cf8eb0d3a..b92fa8da7 100644 --- a/testing/cpp/CMakeLists.txt +++ b/testing/cpp/CMakeLists.txt @@ -12,4 +12,5 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) +add_subdirectory(efficient_i4_cuda_impl) add_subdirectory(lop3_type_conversion) diff --git a/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt new file mode 100644 index 000000000..36ffdf548 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +function (ADD_CUDA_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cu) + set_target_properties(${name} PROPERTIES CUDA_ARCHITECTURES 80) + # add flags + target_compile_options(${name} PRIVATE --expt-relaxed-constexpr) + set_target_properties(${name} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${name} gtest gtest_main) +endfunction(ADD_CUDA_TEST_EXECUTABLE) + +ADD_CUDA_TEST_EXECUTABLE(efficient_i4) + +function (ADD_CPP_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cpp) + target_link_libraries(${name} gtest gtest_main pthread) +endfunction(ADD_CPP_TEST_EXECUTABLE) + +ADD_CPP_TEST_EXECUTABLE(param_permutate) diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu new file mode 100644 index 000000000..257f49a31 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "i4matmul.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +void general_compress(const int8_t *lowbit, int8_t *compressed, const int nbit, const int N, bool isSigned = false) +{ + int zero_point = isSigned ? ((1 << (nbit - 1)) - 1) : 0; + const int nbit_per_byte = 8 / nbit; + + for (int i = 0; i < N / nbit_per_byte; i++) + { + compressed[i] = 0; + for (int j = 0; j < nbit_per_byte; j++) + { + compressed[i] |= ((lowbit[nbit_per_byte * i + j] + zero_point) << (nbit * j)); + } + } +} + + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +void weight_pre_process(const int8_t *lowbit, int8_t *compressed, const int nbit, const int K, const int N) +{ + int8_t* tmp1 = new int8_t[K * N]; + const int maxq = 15; + auto [perm, scale_perm, scale_perm_single] = get_perms(); + const int tile_size = 16; + // transform the lowbit matrix to the compressed matrix + for (int i = 0; i < (K / tile_size); i += 1) + { + for (int j = 0; j < (N / tile_size); j += 1) + { + for (int k = 0; k < tile_size; k++) + { + for (int l = 0; l < tile_size; l++) + { + int idx_target = i * N * tile_size + j * tile_size * tile_size + k * tile_size + l; + int idx_source = (i * tile_size + k) * N + j * tile_size + l; + tmp1[idx_target] = lowbit[idx_source] + (maxq + 1) / 2; + } + } + } + } + // print the first 10 of tmp2 + printf("tmp1\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp1[i]); + } + printf(" ... "); + for (int i = K * N - 10; i < K * N; i++) + { + printf("%d ", tmp1[i]); + } + printf("\n"); + // permute the matrix + int32_t* tmp2 = new int32_t[K * N]; + const int perm_size = perm.size(); + for (int i = 0; i < (N * K / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + perm[j]; + tmp2[idx_target] = (int32_t)tmp1[idx_source]; + } + } + // print the first 10 of tmp2 + printf("tmp2\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp2[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp2[i]); + } + printf("\n"); + // compress + int32_t* tmp3 = new int32_t[K * N / (32 / nbit)]; + // set zero + for (int i = 0; i < K * N / (32 / nbit); i++) + { + tmp3[i] = 0; + } + for (int i = 0; i < (K / tile_size); i++) + { + for (int j = 0; j < (N * tile_size / 8); j++) + { + for (int k = 0; k < 8; k++) + { + int idx_target = i * N * tile_size / 8 + j; + int idx_source = i * N * tile_size + j * 8 + k; + tmp3[idx_target] |= (tmp2[idx_source] << (nbit * (k % 8))); + } + } + } + // print the first 10 of tmp3 + printf("tmp3\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp3[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp3[i]); + } + printf("\n"); + // copy tmp3 to compressed + for (int i = 0; i < K * N / (32 / nbit); i++) + { + ((int32_t *)(compressed))[i] = tmp3[i]; + } +} + +void scale_pre_process(const half *scale, half *scale_perm, const int K, const int N, int group_size) +{ + auto [perm, scale_perm_group, scale_perm_single] = get_perms(); + if (group_size == -1) + group_size = K; + if (group_size == K){ + const int perm_size = scale_perm_single.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_single[j]; + if (idx_target < 10){ + printf("idx_target = %d, idx_source = %d\n", idx_target, idx_source); + } + scale_perm[idx_target] = scale[idx_source]; + } + } + } + else{ + const int perm_size = scale_perm_group.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_group[j]; + scale_perm[idx_target] = scale[idx_source]; + } + } + } + // print the first 10 of tmp2 + printf("scale_perm\n"); + for (int i = 0; i < 10; i++) + { + printf("%f ", (float)scale_perm[i]); + } + printf(" ... "); + for (int i = K * N / group_size - 10; i < K * N / group_size; i++) + { + printf("%f ", (float)scale_perm[i]); + } +} + +TEST(EfficientI4MatmulTest, GEMVTest) +{ + const int prom_m = 1; + const int prom_n = 256; + const int prom_k = 256; + const int bits = 4; + const int group_size = prom_k; + + half* A = new half[prom_m * prom_k]; + int8_t* B = new int8_t[prom_k * prom_n]; + int8_t* qB_interleave = new int8_t[prom_k * prom_n / (8 / bits)]; + half* C = new half[prom_m * prom_n]; + half* s = new half[prom_n * (prom_k / group_size)]; + half* s_perm = new half[prom_n * (prom_k / group_size)]; + + // Initialize A and B + for (int i = 0; i < prom_m * prom_k; i++) + { + A[i] = __float2half(rand() / (float)RAND_MAX); + } + for (int i = 0; i < prom_k * prom_n; i++) + { + B[i] = rand() % 4 - 2; + } + for (int i = 0; i < prom_k * prom_n / group_size; i++) + { + // s[i] = __float2half(0.1); + s[i] = __float2half(rand() / (float)RAND_MAX); + } + + weight_pre_process(B, qB_interleave, bits, prom_k, prom_n); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%d ", B[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + // print interleave of B + for (int i = 0; i < 10; i++) + { + printf("%d ", qB_interleave[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of qb_interleave + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of B + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + printf("\n"); + // print last 10 of s + for (int i = prom_n * (prom_k / group_size) - 10; i < prom_n * (prom_k / group_size); i++) + { + printf("%f ", __half2float(s[i])); + } + printf("\n"); + scale_pre_process(s, s_perm, prom_k, prom_n, group_size); + // define cuda variables + float* d_workspace = nullptr; + cudaCheckLastError(cudaMalloc((void**)&d_workspace, prom_n * prom_k * 16 * sizeof(float))); + + half* d_A; + int8_t* d_qB; + half* d_C; + half* d_s; + cudaCheckLastError(cudaMalloc((void**)&d_A, prom_m * prom_k * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_qB, prom_k * prom_n / (8 / bits) * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void**)&d_C, prom_m * prom_n * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_s, prom_n * (prom_k / group_size) * sizeof(half))); + // copy A and B to device + cudaCheckLastError(cudaMemcpy(d_A, A, prom_m * prom_k * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_qB, qB_interleave, prom_n * prom_k / (8 / bits) * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_s, s_perm, prom_n * (prom_k / group_size) * sizeof(half), cudaMemcpyHostToDevice)); + + // allocate workspace + // call the kernel + int ret = marlin_cuda(d_A, d_qB, d_C, d_s, prom_m, prom_n, prom_k, d_workspace, group_size == prom_k? -1: group_size); + printf("ret = %d\n", ret); + + // copy C back to host + cudaCheckLastError(cudaMemcpy(C, d_C, prom_m * prom_n * sizeof(half), cudaMemcpyDeviceToHost)); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(C[i])); + } + printf("\n"); + + // ref calculation + float* ref_C = new float[prom_m * prom_n]; + // zero fill + for (int i = 0; i < prom_m * prom_n; i++) + { + ref_C[i] = __float2half(0.0); + } + // + for (int i = 0; i < prom_m; i++) + { + for (int j = 0; j < prom_n; j++) + { + ref_C[i * prom_n + j] = __float2half(0.0); + for (int k = 0; k < prom_k; k++) + { + ref_C[i * prom_n + j] += float(A[i * prom_k + k]) * (float(B[k * prom_n + j]) * float(s[(k / group_size) * prom_n + j])); + } + } + } + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf("\n"); + + // check the result + for (int i = 0; i < prom_m * prom_n; i++) + { + EXPECT_NEAR(__half2float(C[i]), __half2float(ref_C[i]), 1e-1); + } + + // free memory + delete[] A; + delete[] B; + delete[] C; + cudaCheckLastError(cudaFree(d_A)); + cudaCheckLastError(cudaFree(d_qB)); + cudaCheckLastError(cudaFree(d_C)); +} diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp new file mode 100644 index 000000000..a12a57dcd --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -0,0 +1,826 @@ +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif diff --git a/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp new file mode 100644 index 000000000..64248b3d1 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +TEST(EfficientI4MatmulTest, ParamPermutate) +{ + auto [perm, scale_perm, scale_perm_single] = get_perms(); + + std::cout << "perm: "; + for (int i = 0; i < 10; ++i) { + std::cout << perm[i] << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm: "; + for (const auto& val : scale_perm) { + std::cout << val << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm_single: "; + for (const auto& val : scale_perm_single) { + std::cout << val << " "; + } + std::cout << std::endl; +} diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eee08c93c..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -6,11 +6,11 @@ import time import numpy as np import torch.nn as nn -import pytest torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -45,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 92b0e0788..3a58a47e1 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -5,7 +5,6 @@ import torch import numpy as np from tvm.script import tir as T -import numpy as np def general_compress_to_int8(lowprecision_weight, source_bits=4): @@ -21,9 +20,7 @@ def general_compress_to_int8(lowprecision_weight, source_bits=4): ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( - source_bits * k - ) + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) return int8_weight @@ -44,25 +41,25 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): if nbits == 1 and target_dtype == "int8": # special handling for 1b interleave - n16_weight = new_qweight & np.int32(0xF0F00F0F) - n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 - n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 - n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 - n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + n16_weight = new_qweight & np.int32(np.uint32(0xF0F00F0F)) + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 12 return n16_weight.view(np.int8) elif nbits == 2 and target_dtype == "float16": - n8_weight = new_qweight & np.int32(0xFF0000FF) - n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 - n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + n8_weight = new_qweight & np.int32(np.uint32(0xFF0000FF)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000FF00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00FF0000))) >> 16) << 8 return n8_weight.view(np.int8) elif nbits == 1 and target_dtype == "float16": - n8_weight = new_qweight & 0xF000000F - n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 - n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 - n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 - n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 - n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 - n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + n8_weight = new_qweight & np.int32(np.uint32(0xF000000F)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 8 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00000F00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00F00000))) >> 20) << 12 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 20 return new_qweight.view(np.int8) @@ -80,17 +77,11 @@ def interleave_weight(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32 with T.block("B"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) @T.prim_func - def interleave_weight_f16_2b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_2b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -98,12 +89,8 @@ def interleave_weight_f16_2b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -114,9 +101,7 @@ def interleave_weight_f16_2b( B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] @T.prim_func - def interleave_weight_f16_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -128,12 +113,8 @@ def interleave_weight_f16_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -152,13 +133,10 @@ def interleave_weight_f16_1b( | B_tmp_4[v0, v1] | B_tmp_5[v0, v1] | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1] - ) + | B_tmp_7[v0, v1]) @T.prim_func - def interleave_weight_int8_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_int8_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -168,12 +146,8 @@ def interleave_weight_int8_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -188,8 +162,7 @@ def interleave_weight_int8_1b( | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - ) + | B_tmp_5[v0, v1]) if target_dtype == "float16" and bits == 2: return interleave_weight_f16_2b @@ -207,7 +180,7 @@ def test_lop3_interleave_weight(): K = 16 target_dtype = "float16" torch.manual_seed(0) - uint_max = 2 ** (source_nbits) - 1 + uint_max = 2**(source_nbits) - 1 raw_data = torch.randint(0, uint_max, (N, K), dtype=torch.int8).cpu().numpy() compressed_b = general_compress_to_int8(raw_data, source_nbits) interleaved_weight = interleave_weight(compressed_b, source_nbits, target_dtype)