Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Fix a bug in general matmul ops with zero #79

Merged
merged 13 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 3 additions & 204 deletions bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,204 +15,6 @@
_tir_packed_to_unsigned_convert_with_zeros,
)

# TODO: The following code should be refactored.
class MatMulNTDequantizeEmitter:
def __init__(
self,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
bit=4,
storage_dtype="int8",
source_format="uint",
with_scaling=False,
with_zeros=False,
group_size=-1,
fast_decoding=False,
with_bias=False,
zeros_mode="original",
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform,
):
self.M = self._validate_dimension(M, "M")
self.N = N
self.K = K
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.accum_dtype = accum_dtype
self.bit = bit
self.storage_dtype = storage_dtype
self.source_format = source_format
self.with_scaling = with_scaling
self.with_zeros = with_zeros
self.group_size = group_size if group_size != -1 else K
self.fast_decoding = fast_decoding
self.with_bias = with_bias
self.zeros_mode = zeros_mode
self.propagate_a = propagate_a
self.propagate_b = propagate_b

self._validate_bit()
self._validate_layout()

@staticmethod
def _validate_dimension(dim, name):
if not isinstance(dim, int):
return tvm.te.var(name.lower())
return dim

def _validate_bit(self):
if self.bit not in [1, 2, 4, 8]:
raise ValueError(f"Unsupported bit: {self.bit}")

def _validate_layout(self):
if self.layout not in ["nt"]:
raise ValueError(f"Unsupported layout: {self.layout}")

def _create_placeholders(self):
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))
n_float_per_elem = storage_nbit // self.bit

A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype)
B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype)
LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype)
Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype)
Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype)
QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit),
name="QZeros",
dtype=self.storage_dtype)
Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype)
return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem

def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem):
w = None
def decode(n, k):
if self.with_zeros and self.zeros_mode == "quantized":
qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit,
QZeros[k, n // n_float_per_elem],
n % n_float_per_elem,
dtype=self.storage_dtype,
)
w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)(
self.bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
qzeros_dequantize,
dtype=self.in_dtype,
)
elif self.source_format == "uint":
if self.bit == 8:
w = B[n, k].astype(self.in_dtype)
w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "int":
if self.bit == 1:
w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
if self.bit == 8:
w = B[n, k].astype(self.in_dtype)
w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "fp":
w = _tir_u32_to_f4_to_f16(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype)
elif self.source_format == "nf":
index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype="int32",
)
w = LUT[index]
else:
raise ValueError(f"Unsupported source_format: {self.source_format}")

group_size = self.group_size
zeros_mode = self.zeros_mode

if not self.with_scaling:
return w

if not self.with_zeros:
return w * Scale[n, k // group_size]

if zeros_mode == "original":
w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size]
elif zeros_mode == "rescale":
w = w * Scale[n, k // group_size] - Zeros[n, k // group_size]
elif zeros_mode == "quantized":
w = w * Scale[n, k // group_size]
else:
raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode))

return w

return te.compute((self.N, self.K), decode, name="B_decode")

def _compute_matmul(self, A, B_decode):
k = te.reduce_axis((0, self.K), name="k")
C = te.compute(
(self.M, self.N),
lambda i, j: te.sum(
A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k),
name="C",
)
return C

def _convert_dtype(self, tensor):
if self.accum_dtype != self.out_dtype:
return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D")
return tensor

def _apply_bias(self, tensor, Bias):
if self.with_bias:
return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E")
return tensor

def emit(self):
A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders()
B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem)
C = self._compute_matmul(A, B_decode)
D = self._convert_dtype(C)
last_output = self._apply_bias(D, Bias)

args = [A, B]
if self.source_format == "nf":
args.append(LUT)
if self.with_scaling:
args.append(Scale)
if self.with_zeros:
args.append(QZeros if self.zeros_mode == "quantized" else Zeros)
if self.with_bias:
args.append(Bias)
args.append(last_output)

func = te.create_prim_func(args).with_attr(
"dequantize_info",
{
"B_decode": {
"decode_block": "B_decode",
"fast_decoding": self.fast_decoding,
"source_format": {
"bits": self.bit,
"format": self.source_format,
},
"storage_dtype": self.storage_dtype,
"target_format": self.in_dtype,
"with_zeros": self.with_zeros,
"zeros_mode": self.zeros_mode,
"with_scaling": self.with_scaling,
"group_size": self.group_size,
}
},
)
return tvm.IRModule.from_expr(func)

# TODO: The following code should be refactored.
class MatMulNTDequantizeEmitter:
Expand Down Expand Up @@ -671,8 +473,7 @@ def decode_func(n, k):
else:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down Expand Up @@ -852,8 +653,7 @@ def decode_func(n, k):
if with_zeros:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down Expand Up @@ -1052,8 +852,7 @@ def decode_func(n, k):
if with_zeros:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down
41 changes: 18 additions & 23 deletions testing/python/module/test_bitblas_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,7 @@
torch.manual_seed(0)
bitblas.set_log_level("DEBUG")

@pytest.mark.parametrize(
"m, in_features, out_features, bias",
[
(1, 1024, 1024, False),
(1, 1024, 1024, True),
(1024, 1024, 1024, True),
([1, 1024], 1024, 1024, True),
],
)
def test_correctness_consistent(m, in_features, out_features, bias):
def correctness_consistent(m, in_features, out_features, bias):
linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
linear_bitblas = BitBLASLinear(
in_features,
Expand Down Expand Up @@ -48,19 +39,13 @@ def test_correctness_consistent(m, in_features, out_features, bias):
torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)


@pytest.mark.parametrize(
"m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode",
[
(1, 1024, 1024, False, "uint4", -1, False, False, None),
(1, 1024, 1024, False, "uint4", -1, False, False, None),
(1024, 1024, 1024, True, "uint4", -1, False, False, None),
(1, 1024, 1024, True, "uint2", -1, True, False, None),
(1, 1024, 1024, True, "uint2", 128, True, True, "original"),
(1024, 1024, 1024, True, "uint2", 128, True, True, "original"),
(1, 1024, 1024, True, "uint2", 128, True, True, "rescale"),
],
)
def test_correctness_weight_only_dequantize(
def test_correctness_consistent():
correctness_consistent(1, 1024, 1024, False)
correctness_consistent(1, 1024, 1024, True)
correctness_consistent(1024, 1024, 1024, True)
correctness_consistent([1, 1024], 1024, 1024, True)

def correctness_weight_only_dequantize(
m,
in_features,
out_features,
Expand Down Expand Up @@ -169,6 +154,16 @@ def test_correctness_weight_only_dequantize(
torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0)


def test_correctness_weight_only_dequantize():
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")


def profile(model, input_data):
model = model.cuda()
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion testing/python/operators/test_general_matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
if with_bias:
permuted_inputs.append(bias)
permuted_inputs.append(inputs[2])
matmul(*permuted_inputs[:2], output=permuted_inputs[-1])
matmul(*permuted_inputs[:-1], output=permuted_inputs[-1])
if zeros_mode == "rescale":
torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0)
else:
Expand Down
Loading