diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py index e29c9de0..f353228a 100644 --- a/python/bitblas/module/__init__.py +++ b/python/bitblas/module/__init__.py @@ -31,12 +31,13 @@ def unpack_qzeros(qzeros, bits): device=qzeros.device, requires_grad=False, ) - for col in range(unpacked_zeros.shape[1]): i = col % elems_per_int32 - unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) - return unpacked_zeros + 1 + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) class Linear(nn.Module): @@ -232,18 +233,19 @@ def forward(self, A, output=None): A = A.half() # can be lifted to post init. self.init_params() - + if output is None: output = torch.empty( A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) A = self.bitblas_matmul.transform_input(A) stream = torch.cuda.current_stream() - + A_void = ctypes.c_void_p(A.data_ptr()) stream_handle = ctypes.c_void_p(stream.cuda_stream) # m is the product of the last n - 1 dimensions of A - self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle) + self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, + stream_handle) return output diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py index dddafc98..12fc8364 100644 --- a/testing/python/operators/test_matmul_dequantize_ops.py +++ b/testing/python/operators/test_matmul_dequantize_ops.py @@ -502,6 +502,7 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty if with_scaling: if group_size == -1: group_size = K + # Note that scaling is default to all 1 permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) if with_zeros: permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) @@ -515,6 +516,125 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-1) +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode", + [ + (1, 768, 768, "float16", "float16", "float16", 2, "int8", "uint", True, False, 128, False, + False, False, False, "nt", "quantized"), + (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, 128, False, + False, False, False, "nt", "quantized"), + ], +) +def test_matmul_dequantize_torch_forward_with_asym_quantized_zeros(M, N, K, in_dtype, out_dtype, accum_dtype, bit, + storage_dtype, source_format, with_scaling, with_zeros, + group_size, fast_decoding, with_bias, propagate_a, + propagate_b, layout, zeros_mode): + import torch + import numpy as np + torch.random.manual_seed(0) + from bitblas.quantization.utils import general_compress + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + zeros_mode=zeros_mode) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + if not isinstance(M, int): + M = int(32) + # matmul.hardware_aware_finetune(topk=20) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + scaling_shape = (N, K // group_size) + zeros_shape = (K // group_size, N) + + input_A = torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5 + max_quantization = 2 ** (bit - 1) + scaling_matrix = torch.rand(scaling_shape, dtype=torch.float16).cuda() + zeros_matrix = torch.randint(0, max_quantization, zeros_shape, dtype=torch.int8).cuda() + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + + if source_format == "uint": + input_W = torch.randint(0, max_quantization, weight_shape, dtype=torch.int8).cuda() + elif source_format == "int": + input_W = torch.randint(-max_quantization, max_quantization, weight_shape, dtype=torch.int8).cuda() + else: + raise NotImplementedError + + # Now begin bitblas matmul + input_W_int = input_W.cpu().numpy().astype(np.int8) + if source_format == "int": + input_W_int = input_W_int + max_quantization + qw_np = general_compress(input_W_int, source_bits=bit, storage_dtype=np.int8) + qw_torch = torch.from_numpy(qw_np).cuda() + + permuted_inputs = [] + # input and weight + if matmul.input_transform is not None: + permuted_inputs.append(matmul.input_transform(input_A.cpu()).cuda()) + else: + permuted_inputs.append(input_A) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) + else: + permuted_inputs.append(qw_torch) + # scale + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(scaling_matrix) + # zeros + if with_zeros: + if zeros_mode == "quantized": + original_zeros = zeros_matrix + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + # bias + if with_bias: + permuted_inputs.append(bias) + # output + permuted_inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + matmul(*permuted_inputs) + bitblas_result = permuted_inputs[-1] + + # Now begin torch matmul + if with_scaling and with_zeros and zeros_mode == "quantized": + rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda() + for i in range(K // group_size): + for j in range(group_size): + rescaling_tensor[:, i * group_size + j] = ( + input_W[:, i * group_size + j].to(torch.float16) - zeros_matrix[i, :] + ) * scaling_matrix[:, i] + elif with_scaling: + rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda() + for i in range(K // group_size): + for j in range(group_size): + rescaling_tensor[:, i * group_size + j] = input_W[:, i * group_size + j].to(torch.float16) * scaling_matrix[:, i] + ref_result = torch.matmul(input_A, rescaling_tensor.t().to(torch.float16)) + + torch.testing.assert_close(bitblas_result, ref_result, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,zeros_mode", [