Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Issue#24: FIx a bug of repack AutoGPTQ quantized parameters #57

Merged
merged 2 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def unpack_qzeros(qzeros, bits):
device=qzeros.device,
requires_grad=False,
)

for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i))

return unpacked_zeros + 1
# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)


class Linear(nn.Module):
Expand Down Expand Up @@ -232,18 +233,19 @@ def forward(self, A, output=None):
A = A.half()
# can be lifted to post init.
self.init_params()

if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1))
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle)
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m,
stream_handle)

return output

Expand Down
120 changes: 120 additions & 0 deletions testing/python/operators/test_matmul_dequantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty
if with_scaling:
if group_size == -1:
group_size = K
# Note that scaling is default to all 1
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
if with_zeros:
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
Expand All @@ -515,6 +516,125 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty
torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-1)


@pytest.mark.parametrize(
"M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode",
[
(1, 768, 768, "float16", "float16", "float16", 2, "int8", "uint", True, False, 128, False,
False, False, False, "nt", "quantized"),
(1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, 128, False,
False, False, False, "nt", "quantized"),
],
)
def test_matmul_dequantize_torch_forward_with_asym_quantized_zeros(M, N, K, in_dtype, out_dtype, accum_dtype, bit,
storage_dtype, source_format, with_scaling, with_zeros,
group_size, fast_decoding, with_bias, propagate_a,
propagate_b, layout, zeros_mode):
import torch
import numpy as np
torch.random.manual_seed(0)
from bitblas.quantization.utils import general_compress
matmul_config = MatmulWeightOnlyDequantizeConfig(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
bit=bit,
storage_dtype=storage_dtype,
source_format=source_format,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
fast_decoding=fast_decoding,
with_bias=with_bias,
propagate_a=propagate_a,
propagate_b=propagate_b,
layout=layout,
zeros_mode=zeros_mode)
matmul = MatmulWeightOnlyDequantize(
config=matmul_config,
target=target,
)
if not isinstance(M, int):
M = int(32)
# matmul.hardware_aware_finetune(topk=20)
input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
output_shape = (M, N)
scaling_shape = (N, K // group_size)
zeros_shape = (K // group_size, N)

input_A = torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5
max_quantization = 2 ** (bit - 1)
scaling_matrix = torch.rand(scaling_shape, dtype=torch.float16).cuda()
zeros_matrix = torch.randint(0, max_quantization, zeros_shape, dtype=torch.int8).cuda()
bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()

if source_format == "uint":
input_W = torch.randint(0, max_quantization, weight_shape, dtype=torch.int8).cuda()
elif source_format == "int":
input_W = torch.randint(-max_quantization, max_quantization, weight_shape, dtype=torch.int8).cuda()
else:
raise NotImplementedError

# Now begin bitblas matmul
input_W_int = input_W.cpu().numpy().astype(np.int8)
if source_format == "int":
input_W_int = input_W_int + max_quantization
qw_np = general_compress(input_W_int, source_bits=bit, storage_dtype=np.int8)
qw_torch = torch.from_numpy(qw_np).cuda()

permuted_inputs = []
# input and weight
if matmul.input_transform is not None:
permuted_inputs.append(matmul.input_transform(input_A.cpu()).cuda())
else:
permuted_inputs.append(input_A)
if matmul.weight_transform is not None:
permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda())
else:
permuted_inputs.append(qw_torch)
# scale
if with_scaling:
if group_size == -1:
group_size = K
permuted_inputs.append(scaling_matrix)
# zeros
if with_zeros:
if zeros_mode == "quantized":
original_zeros = zeros_matrix
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
# bias
if with_bias:
permuted_inputs.append(bias)
# output
permuted_inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
matmul(*permuted_inputs)
bitblas_result = permuted_inputs[-1]

# Now begin torch matmul
if with_scaling and with_zeros and zeros_mode == "quantized":
rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda()
for i in range(K // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = (
input_W[:, i * group_size + j].to(torch.float16) - zeros_matrix[i, :]
) * scaling_matrix[:, i]
elif with_scaling:
rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda()
for i in range(K // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = input_W[:, i * group_size + j].to(torch.float16) * scaling_matrix[:, i]
ref_result = torch.matmul(input_A, rescaling_tensor.t().to(torch.float16))

torch.testing.assert_close(bitblas_result, ref_result, rtol=1e-1, atol=1e-1)


@pytest.mark.parametrize(
"M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,zeros_mode",
[
Expand Down
Loading