diff --git a/README.md b/README.md index c765027b8..51a26927a 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ BitBLAS achieves exceptional performance across a variety of computational patte + - TensorCore FP16/INT8 GEMM Performance Vs. Vendor Library on A100 and RTX4090
@@ -78,7 +79,6 @@ We are continuously expanding the support matrix. If you have any specific requi - [Customization](./docs/ExtendOperatorsWithDSL.md): BitBLAS supports implementing customized mixed-precision DNN operations rather than matrix multiplication with the flexible DSL (TIR Script). - ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. @@ -90,4 +90,3 @@ This project has adopted the Microsoft Open Source Code of Conduct. For more inf ## Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. - diff --git a/integration/pytorch/test_bitblas_quant_linear.py b/integration/pytorch/test_bitblas_quant_linear.py index 1db9faa00..b071fb412 100644 --- a/integration/pytorch/test_bitblas_quant_linear.py +++ b/integration/pytorch/test_bitblas_quant_linear.py @@ -17,7 +17,7 @@ def gen_quant4(k, n, groupsize=-1): - maxq = 2**4 - 1 + maxq = 2**4 w = torch.randn((k, n), dtype=torch.half, device="cpu") original_w = w.clone() @@ -75,7 +75,7 @@ def test_quantization_accuracy(m, in_features, out_features, bits, group_size, b if group_size == -1: group_size = in_features - zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32) + zeros = torch.full((in_features // group_size, out_features), 8, dtype=torch.int32) bitblas_zeros = zeros.clone().T cuda_old_linear = CudaOldQuantLinear( diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index c9d7e5fd1..bad04ceb3 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -5,6 +5,7 @@ from tvm.script import tir as T from typing import Dict, Literal from bitblas.quantization import ( + _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_packed_to_unsigned_convert_with_zeros, @@ -19,7 +20,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x000f000f; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); #pragma unroll for (int i = 0; i < (N / 2); i++) @@ -55,7 +56,7 @@ static constexpr uint BOTTOM_MASK = 0x000f000f; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); T3 const scale_r = *scale; uint const packed_scales = __pack_half2(scale_r, scale_r); @@ -97,7 +98,7 @@ static constexpr uint BOTTOM_MASK = 0x000f000f; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); T3 const scale_r = *scale; uint const packed_scales = __pack_half2(scale_r, scale_r); @@ -139,7 +140,7 @@ static constexpr uint BOTTOM_MASK = 0x000f000f; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); T3 const scale_r = *scale; uint const packed_scales = __pack_half2(scale_r, scale_r); @@ -217,7 +218,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00030003; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; int16_t const i2s_i16 = *reinterpret_cast(_i2s); // decode 2 elems at one time. // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} @@ -258,7 +259,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00030003; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; int16_t const i2s_i16 = *reinterpret_cast(_i2s); // decode 2 elems at one time. // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} @@ -300,7 +301,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00030003; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; int16_t const i2s_i16 = *reinterpret_cast(_i2s); // decode 2 elems at one time. // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} @@ -337,7 +338,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00030003; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; int16_t const i2s_i16 = *reinterpret_cast(_i2s); // decode 2 elems at one time. // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} @@ -366,15 +367,15 @@ """ decode_i1_to_f16 = """ -template -__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +template +__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) { uint *h = reinterpret_cast(B_local_decode); static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; int8_t const i1s_i16 = *reinterpret_cast(_i1s); int i1s = (i1s_i16 & 0x0f); i1s |= ((i1s_i16 & 0xf0) << 12); @@ -392,26 +393,41 @@ template __device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) { - decode_i1b_to_f16(_i1s, B_local_decode, N); -} + uint *h = reinterpret_cast(B_local_decode); -template -__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) -{ - decode_i1b_to_f16(_i1u, B_local_decode, N); + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } } """ decode_i1_to_f16_scale = """ -template -__device__ void decode_i1b_to_f16_scale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr) +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { uint *h = reinterpret_cast(B_local_decode); static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 int8_t const i1s_i16 = *reinterpret_cast(_i1s); @@ -431,17 +447,41 @@ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); } } + template __device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i1b_to_f16_scale(_i1s, B_local_decode, N, scale); -} -template -__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - decode_i1b_to_f16_scale(_i1u, B_local_decode, N, scale); + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } } """ + decode_i1_to_f16_scale_zeros_original = """ template __device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) @@ -451,7 +491,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 int8_t const i1s_i16 = *reinterpret_cast(_i1s); @@ -491,7 +531,7 @@ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 int8_t const i1s_i16 = *reinterpret_cast(_i1s); @@ -538,12 +578,14 @@ static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 static constexpr uint I8s_MAGIC_NUM = 0x00000000; static constexpr uint MEDIAN_NUM = 0x00000000; + static constexpr uint TRANSFORM_SUBTRACT = 0x01010101; for (int i = 0; i < N / 4; i++) { asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" : "=r"(i8s[i]) : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(__vaddss4(i8s[i], i8s[i]), TRANSFORM_SUBTRACT); } } @@ -709,7 +751,10 @@ def get_fast_decode_intrin( if with_zeros and zeros_mode == "quantized": decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) elif source_format == "int": - decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) + if source_bit == 1: + decode_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) + else: + decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) elif source_format == "uint": decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) else: @@ -1379,7 +1424,7 @@ def fast_decode_impl( TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), + source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), ) LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") @@ -1389,6 +1434,14 @@ def fast_decode_impl( source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), ) +LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), +) + + LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index fbdf8058e..78e0cbe95 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -371,7 +371,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): if source_format == "int": assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2**(bit - 1) - 1 + maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust weight = torch.clamp(weight, -maxq, maxq).int() + maxq else: diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index f0f59e035..28e9ae42b 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -7,6 +7,7 @@ from bitblas.ops.operator import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import ( + _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, @@ -76,8 +77,13 @@ def decode_func(n, k): w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "int": - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp": w = _tir_u32_to_f4_to_f16( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) @@ -91,6 +97,8 @@ def decode_func(n, k): else: raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: return w @@ -236,12 +244,17 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "int": - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "fp": w = _tir_u32_to_f4_to_f16( bit, @@ -417,12 +430,17 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "int": - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "fp": w = _tir_u32_to_f4_to_f16( bit, diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py index 227cf61a4..0ca9ab377 100644 --- a/python/bitblas/quantization/__init__.py +++ b/python/bitblas/quantization/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .quantization import ( - _tir_packed_int_to_int_to_float, # noqa: F401 - _tir_packed_uint_to_uint_to_float, # noqa: F401 + _tir_packed_int_to_int_convert, # noqa: F401 _tir_packed_to_signed_convert, # noqa: F401 _tir_packed_to_unsigned_convert, # noqa: F401 _tir_u32_to_f4_to_f16, # noqa: F401 diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index aeecfb874..e390fa640 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -144,7 +144,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - max_int_value = (1 << (nbit - 1)) - 1 + max_int_value = (1 << (nbit - 1)) return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) @@ -173,5 +173,16 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm return f_convert +def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert # fmt: on diff --git a/python/bitblas/quantization/utils.py b/python/bitblas/quantization/utils.py index 3d369afe4..45890c3d8 100644 --- a/python/bitblas/quantization/utils.py +++ b/python/bitblas/quantization/utils.py @@ -6,7 +6,7 @@ def gen_quant4(k, n, groupsize=-1): - maxq = 2**4 - 1 + maxq = 2**4 w = torch.randn((k, n), dtype=torch.half, device="cpu") original_w = w.clone() diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index f86ca2cf0..ac24d40a3 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -401,6 +401,7 @@ __device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 int8_t const i1s_i16 = *reinterpret_cast(_i1s); @@ -415,6 +416,11 @@ __device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, : "=r"(h[i]) : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (isSigned) + { + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } if constexpr (withZeros && ZerosKind == 0) { asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); @@ -718,8 +724,8 @@ __device__ void decode_i1b_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 static constexpr uint I8s_MAGIC_NUM = 0x00000000; - static constexpr uint MEDIAN_NUM = isSigned ? 0x00000000 : 0x00000000; - + static constexpr uint TRANSFORM_SUBTRACT = 0x01010101; + for (int i = 0; i < N / 4; i++) { asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" @@ -728,7 +734,7 @@ __device__ void decode_i1b_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) if constexpr (isSigned) { - i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + i8s[i] = __vsubss4(__vaddss4(i8s[i], i8s[i]), TRANSFORM_SUBTRACT); } } } diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index eda2be206..0d0ebf7d2 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -300,7 +300,7 @@ TEST(DecodeTest, DecodeInt1ToFloat16) cudaCheckLastError(cudaFree(decoded_gpu)); for (int i = 0; i < N; i++) { - EXPECT_EQ(in_data[i], int(decoded[i])); + EXPECT_EQ(2 * in_data[i] - 1, int(decoded[i])); } free(ins); free(interleaved); diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu index fe1b1dd71..0a3b45a77 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu @@ -284,7 +284,7 @@ TEST(DecodeTest, DecodeInt1ToINT8) cudaCheckLastError(cudaFree(decoded_gpu)); for (int i = 0; i < N; i++) { - EXPECT_EQ(in_data[i], int(decoded[i])); + EXPECT_EQ(2 * in_data[i] - 1, int(decoded[i])); } free(ins); free(interleaved); diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 7d1694f03..edab105ab 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -100,7 +100,7 @@ def test_correctness_weight_only_dequantize( linear_bitblas.bitblas_matmul.bit, ) - maxq = 2**(bit - 1) - 1 + maxq = 2**(bit - 1) zeros = maxq if source_format == "uint": inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 6eb588d9e..bb0b06719 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -161,7 +161,7 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, inputs = [] inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) source_format, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] - maxq = 2**(bit - 1) - 1 + maxq = 2**(bit - 1) zeros = maxq if source_format == "uint": inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) @@ -261,7 +261,7 @@ def test_matmul_transform_weight( output_shape = (M, N) _, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] - maxq = 2**(bit - 1) - 1 + maxq = 2**(bit - 1) input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() intweight_tensor = torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda() diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py index 018ad8256..a4a48f267 100644 --- a/testing/python/operators/test_matmul_dequantize_ops.py +++ b/testing/python/operators/test_matmul_dequantize_ops.py @@ -29,7 +29,7 @@ def get_codegen_result(ops, target): (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, False, "nt", False, False, "original"), (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, -1, True, - True, "nt", True, True, "original"), + True, "nt", False, True, "original"), ], ) def test_matmul_dequantize_codegen_default(M, N, K, in_dtype, out_dtype, accum_dtype, bit, @@ -466,7 +466,7 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty output_shape = (M, N) inputs = [] inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - maxq = 2**(bit - 1) - 1 + maxq = 2 ** (bit - 1) zeros = maxq if source_format == "uint": inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())