diff --git a/README.md b/README.md
index c765027b8..51a26927a 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,7 @@ BitBLAS achieves exceptional performance across a variety of computational patte
+
- TensorCore FP16/INT8 GEMM Performance Vs. Vendor Library on A100 and RTX4090
@@ -78,7 +79,6 @@ We are continuously expanding the support matrix. If you have any specific requi
- [Customization](./docs/ExtendOperatorsWithDSL.md): BitBLAS supports implementing customized mixed-precision DNN operations rather than matrix multiplication with the flexible DSL (TIR Script).
-
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
@@ -90,4 +90,3 @@ This project has adopted the Microsoft Open Source Code of Conduct. For more inf
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.
-
diff --git a/integration/pytorch/test_bitblas_quant_linear.py b/integration/pytorch/test_bitblas_quant_linear.py
index 1db9faa00..b071fb412 100644
--- a/integration/pytorch/test_bitblas_quant_linear.py
+++ b/integration/pytorch/test_bitblas_quant_linear.py
@@ -17,7 +17,7 @@
def gen_quant4(k, n, groupsize=-1):
- maxq = 2**4 - 1
+ maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
@@ -75,7 +75,7 @@ def test_quantization_accuracy(m, in_features, out_features, bits, group_size, b
if group_size == -1:
group_size = in_features
- zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32)
+ zeros = torch.full((in_features // group_size, out_features), 8, dtype=torch.int32)
bitblas_zeros = zeros.clone().T
cuda_old_linear = CudaOldQuantLinear(
diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py
index c9d7e5fd1..bad04ceb3 100644
--- a/python/bitblas/gpu/intrin/lop3.py
+++ b/python/bitblas/gpu/intrin/lop3.py
@@ -5,6 +5,7 @@
from tvm.script import tir as T
from typing import Dict, Literal
from bitblas.quantization import (
+ _tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_packed_to_unsigned_convert_with_zeros,
@@ -19,7 +20,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
@@ -55,7 +56,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
@@ -97,7 +98,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
@@ -139,7 +140,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
@@ -217,7 +218,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
@@ -258,7 +259,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
@@ -300,7 +301,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
@@ -337,7 +338,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
@@ -366,15 +367,15 @@
"""
decode_i1_to_f16 = """
-template
-__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
+template
+__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
@@ -392,26 +393,41 @@
template
__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
- decode_i1b_to_f16(_i1s, B_local_decode, N);
-}
+ uint *h = reinterpret_cast(B_local_decode);
-template
-__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8)
-{
- decode_i1b_to_f16(_i1u, B_local_decode, N);
+ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
+ static constexpr uint BOTTOM_MASK = 0x00010001;
+ static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
+ static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
+
+ int8_t const i1s_i16 = *reinterpret_cast(_i1s);
+ int i1s = (i1s_i16 & 0x0f);
+ i1s |= ((i1s_i16 & 0xf0) << 12);
+#pragma unroll
+ // decode 2 elems at one time.
+ for (int i = 0; i < (N / 2); i++)
+ {
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
+ : "=r"(h[i])
+ : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
+ asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
+ asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
+ asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
+ }
}
"""
decode_i1_to_f16_scale = """
-template
-__device__ void decode_i1b_to_f16_scale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr)
+template
+__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
@@ -431,17 +447,41 @@
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
+
template
__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
- decode_i1b_to_f16_scale(_i1s, B_local_decode, N, scale);
-}
-template
-__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
-{
- decode_i1b_to_f16_scale(_i1u, B_local_decode, N, scale);
+ uint *h = reinterpret_cast(B_local_decode);
+
+ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
+ static constexpr uint BOTTOM_MASK = 0x00010001;
+ static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
+ static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
+ // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
+ // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
+
+ int8_t const i1s_i16 = *reinterpret_cast(_i1s);
+ int i1s = (i1s_i16 & 0x0f);
+ i1s |= ((i1s_i16 & 0xf0) << 12);
+ T3 const scale_r = *scale;
+ uint const packed_scales = __pack_half2(scale_r, scale_r);
+#pragma unroll
+ // decode 2 elems at one time.
+ for (int i = 0; i < (N / 2); i++)
+ {
+
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
+ : "=r"(h[i])
+ : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
+ asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
+ asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
+ asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
+ }
}
"""
+
decode_i1_to_f16_scale_zeros_original = """
template
__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
@@ -451,7 +491,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
@@ -491,7 +531,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
+ static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
@@ -538,12 +578,14 @@
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint MEDIAN_NUM = 0x00000000;
+ static constexpr uint TRANSFORM_SUBTRACT = 0x01010101;
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
+ i8s[i] = __vsubss4(__vaddss4(i8s[i], i8s[i]), TRANSFORM_SUBTRACT);
}
}
@@ -709,7 +751,10 @@ def get_fast_decode_intrin(
if with_zeros and zeros_mode == "quantized":
decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)
elif source_format == "int":
- decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit)
+ if source_bit == 1:
+ decode_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit)
+ else:
+ decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit)
elif source_format == "uint":
decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)
else:
@@ -1379,7 +1424,7 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
- source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16),
+ source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
)
LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
@@ -1389,6 +1434,14 @@ def fast_decode_impl(
source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16),
)
+LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_")
+TensorIntrin.register(
+ LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
+ *get_fast_decode_intrin(
+ source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
+)
+
+
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py
index fbdf8058e..78e0cbe95 100644
--- a/python/bitblas/ops/general_matmul.py
+++ b/python/bitblas/ops/general_matmul.py
@@ -371,7 +371,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
- maxq = 2**(bit - 1) - 1
+ maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
else:
diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py
index f0f59e035..28e9ae42b 100644
--- a/python/bitblas/ops/impl/matmul_dequantize_impl.py
+++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py
@@ -7,6 +7,7 @@
from bitblas.ops.operator import TransformKind
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.quantization import (
+ _tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
@@ -76,8 +77,13 @@ def decode_func(n, k):
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "int":
- w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
- bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
+ if bit == 1:
+ # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
+ w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
+ bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
+ else:
+ w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
+ bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
@@ -91,6 +97,8 @@ def decode_func(n, k):
else:
raise ValueError("Unsupported source_format: {}".format(source_format))
+
+
if not with_scaling:
return w
@@ -236,12 +244,17 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "int":
- w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
- bit,
- B_reindex[n, k // n_float_per_elem],
- k % n_float_per_elem,
- dtype=in_dtype,
- )
+ if bit == 1:
+ # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
+ w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
+ bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
+ else:
+ w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
+ bit,
+ B_reindex[n, k // n_float_per_elem],
+ k % n_float_per_elem,
+ dtype=in_dtype,
+ )
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit,
@@ -417,12 +430,17 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "int":
- w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
- bit,
- B_reindex[n, k // n_float_per_elem],
- k % n_float_per_elem,
- dtype=in_dtype,
- )
+ # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
+ if bit == 1:
+ w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
+ bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
+ else:
+ w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
+ bit,
+ B_reindex[n, k // n_float_per_elem],
+ k % n_float_per_elem,
+ dtype=in_dtype,
+ )
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit,
diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py
index 227cf61a4..0ca9ab377 100644
--- a/python/bitblas/quantization/__init__.py
+++ b/python/bitblas/quantization/__init__.py
@@ -1,8 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .quantization import (
- _tir_packed_int_to_int_to_float, # noqa: F401
- _tir_packed_uint_to_uint_to_float, # noqa: F401
+ _tir_packed_int_to_int_convert, # noqa: F401
_tir_packed_to_signed_convert, # noqa: F401
_tir_packed_to_unsigned_convert, # noqa: F401
_tir_u32_to_f4_to_f16, # noqa: F401
diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py
index aeecfb874..e390fa640 100644
--- a/python/bitblas/quantization/quantization.py
+++ b/python/bitblas/quantization/quantization.py
@@ -144,7 +144,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
- max_int_value = (1 << (nbit - 1)) - 1
+ max_int_value = (1 << (nbit - 1))
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
@@ -173,5 +173,16 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm
return f_convert
+def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
+ storage_dtype = storage_type + str(storage_nbit)
+
+ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
+ assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
+ mask = tir.const((1 << nbit) - 1, "int32")
+ unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
+ return tir.Cast(
+ dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
+
+ return f_convert
# fmt: on
diff --git a/python/bitblas/quantization/utils.py b/python/bitblas/quantization/utils.py
index 3d369afe4..45890c3d8 100644
--- a/python/bitblas/quantization/utils.py
+++ b/python/bitblas/quantization/utils.py
@@ -6,7 +6,7 @@
def gen_quant4(k, n, groupsize=-1):
- maxq = 2**4 - 1
+ maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp
index f86ca2cf0..ac24d40a3 100644
--- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp
+++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp
@@ -401,6 +401,7 @@ __device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8,
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
+ static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
@@ -415,6 +416,11 @@ __device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8,
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
+ if constexpr (isSigned)
+ {
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
+ }
if constexpr (withZeros && ZerosKind == 0)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
@@ -718,8 +724,8 @@ __device__ void decode_i1b_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
- static constexpr uint MEDIAN_NUM = isSigned ? 0x00000000 : 0x00000000;
-
+ static constexpr uint TRANSFORM_SUBTRACT = 0x01010101;
+
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
@@ -728,7 +734,7 @@ __device__ void decode_i1b_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
if constexpr (isSigned)
{
- i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM);
+ i8s[i] = __vsubss4(__vaddss4(i8s[i], i8s[i]), TRANSFORM_SUBTRACT);
}
}
}
diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu
index eda2be206..0d0ebf7d2 100644
--- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu
+++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu
@@ -300,7 +300,7 @@ TEST(DecodeTest, DecodeInt1ToFloat16)
cudaCheckLastError(cudaFree(decoded_gpu));
for (int i = 0; i < N; i++)
{
- EXPECT_EQ(in_data[i], int(decoded[i]));
+ EXPECT_EQ(2 * in_data[i] - 1, int(decoded[i]));
}
free(ins);
free(interleaved);
diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu
index fe1b1dd71..0a3b45a77 100644
--- a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu
+++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu
@@ -284,7 +284,7 @@ TEST(DecodeTest, DecodeInt1ToINT8)
cudaCheckLastError(cudaFree(decoded_gpu));
for (int i = 0; i < N; i++)
{
- EXPECT_EQ(in_data[i], int(decoded[i]));
+ EXPECT_EQ(2 * in_data[i] - 1, int(decoded[i]));
}
free(ins);
free(interleaved);
diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py
index 7d1694f03..edab105ab 100644
--- a/testing/python/module/test_bitblas_linear.py
+++ b/testing/python/module/test_bitblas_linear.py
@@ -100,7 +100,7 @@ def test_correctness_weight_only_dequantize(
linear_bitblas.bitblas_matmul.bit,
)
- maxq = 2**(bit - 1) - 1
+ maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py
index 6eb588d9e..bb0b06719 100644
--- a/testing/python/operators/test_general_matmul_ops.py
+++ b/testing/python/operators/test_general_matmul_ops.py
@@ -161,7 +161,7 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
source_format, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype]
- maxq = 2**(bit - 1) - 1
+ maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
@@ -261,7 +261,7 @@ def test_matmul_transform_weight(
output_shape = (M, N)
_, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype]
- maxq = 2**(bit - 1) - 1
+ maxq = 2**(bit - 1)
input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda()
intweight_tensor = torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()
diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py
index 018ad8256..a4a48f267 100644
--- a/testing/python/operators/test_matmul_dequantize_ops.py
+++ b/testing/python/operators/test_matmul_dequantize_ops.py
@@ -29,7 +29,7 @@ def get_codegen_result(ops, target):
(1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True,
False, "nt", False, False, "original"),
(1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, -1, True,
- True, "nt", True, True, "original"),
+ True, "nt", False, True, "original"),
],
)
def test_matmul_dequantize_codegen_default(M, N, K, in_dtype, out_dtype, accum_dtype, bit,
@@ -466,7 +466,7 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty
output_shape = (M, N)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
- maxq = 2**(bit - 1) - 1
+ maxq = 2 ** (bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())