Skip to content

Commit a912e02

Browse files
authored
[Test] Add Thread Level Macro Dequantize Gemm Test Cases (#194)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix * gemm_ss * doc fix * lint fix * lint fix * remove debug print * remove debug print * vectorization init * lint fix * prelude update * update tvm * bug fix for reduce_k with shared memory * bug fix * bug fix * Enhance Macro Generation * Lift Layout to reduce load time * lint fix * test fix * red fix * tile lang macro example * tile lang macro example * optimize the marcro generator related items * lint fix * Tile Lang Test with Dynamic Symbolic * more test case with block level programming * all dynamic test case * simplify the test case for dequantize gemm. * dequant gemm updare.
1 parent 3f6d516 commit a912e02

File tree

1 file changed

+318
-10
lines changed

1 file changed

+318
-10
lines changed

testing/python/tilelang/test_tilelang_dequantize_gemm.py

Lines changed: 318 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,35 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
import torch
4+
import torch.backends
35
import bitblas
46
from bitblas import tvm as tvm
5-
from tvm import tl
7+
from tvm import DataType
8+
from tvm import tl as TL
9+
import tvm.tl.language as T
610
from bitblas.quantization import _tir_packed_to_unsigned_convert
11+
from bitblas.tl.utils import get_swizzle_layout
12+
from bitblas.tl.macro_generator import (
13+
TensorCoreIntrinEmitterWithLadderTransform,)
14+
15+
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
16+
17+
torch.manual_seed(0)
18+
19+
20+
def make_swizzle_layout(shared_buf):
21+
dtype = shared_buf.dtype
22+
shape = shared_buf.shape
23+
24+
can_swizzle = shape[-1] * DataType(dtype).bits == 512
25+
if not can_swizzle:
26+
return T.Layout(shape, lambda *args: args)
27+
28+
def transform_func(i, j):
29+
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
30+
return [new_warp_i, new_warp_j]
31+
32+
return T.Layout(shape, transform_func)
733

834

935
def matmul(
@@ -47,13 +73,8 @@ def main(
4773
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
4874
T.copy(A[by * block_M, k * block_K], A_shared)
4975

50-
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)):
51-
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
52-
for v in T.vectorized(0, 16):
53-
vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte)
54-
vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte)
55-
B_shared[vi, vj] = B[bx * block_N + vi,
56-
k * block_K // num_elems_per_byte + vj,]
76+
for i, j in T.Parallel(block_N, block_K // num_elems_per_byte):
77+
B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j]
5778

5879
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)):
5980
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
@@ -106,8 +127,8 @@ def run_gemm(
106127
)
107128
print(program)
108129

109-
mod, params = tl.lower(program)
110-
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
130+
mod, params = TL.lower(program)
131+
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer)
111132

112133
out = mod.run_once()
113134

@@ -129,9 +150,296 @@ def ref_program(A, qB):
129150
mod.assert_allclose(ref_program)
130151

131152

153+
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
154+
M,
155+
N,
156+
K,
157+
dtypeAB,
158+
dtypeC,
159+
accum_dtype,
160+
transform_b,
161+
):
162+
assert dtypeAB in [
163+
"float16",
164+
"int8",
165+
], "Currently only float16 and int8 are supported"
166+
assert dtypeC in [
167+
"float16",
168+
"float32",
169+
"int32",
170+
], "Currently only float16, float32 and int32 are supported"
171+
num_bits = 4
172+
num_elems_per_byte = 8 // num_bits
173+
storage_dtype = "int8"
174+
175+
micro_size_x = micro_size_y = micro_size_k = 16
176+
177+
if dtypeC == "int32":
178+
micro_size_k = 32
179+
180+
# This is a debug config
181+
block_row_warps = 1
182+
block_col_warps = 4
183+
184+
warp_rows = 1
185+
warp_cols = 2
186+
warp_row_tiles = micro_size_x * warp_rows
187+
warp_col_tiles = micro_size_y * warp_cols
188+
shared_scope = "shared.dyn"
189+
190+
# Pipeline Stage
191+
stage = 2
192+
reduce_k = 2
193+
194+
block_M = block_row_warps * warp_row_tiles
195+
block_N = block_col_warps * warp_col_tiles
196+
block_K = 32 if dtypeAB == "float16" else 64
197+
chunk = block_K // reduce_k
198+
199+
is_smooth_a = False
200+
can_swizzle = block_K * DataType(dtypeAB).bits == 512
201+
apply_pad_a = not (is_smooth_a or can_swizzle)
202+
pad_factor = 8
203+
204+
A_shape = (M, K)
205+
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
206+
micro_size_k // num_elems_per_byte)
207+
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
208+
B_shared_shape = (
209+
block_N // micro_size_y,
210+
block_K // micro_size_k,
211+
micro_size_y,
212+
micro_size_k // num_elems_per_byte,
213+
)
214+
C_shared_shape = (
215+
block_M // micro_size_x,
216+
block_N // micro_size_y,
217+
micro_size_x,
218+
micro_size_y,
219+
)
220+
221+
warp_size = 32
222+
threads = warp_size * (block_row_warps * block_col_warps)
223+
local_size = (micro_size_x * micro_size_y) // warp_size
224+
warp_rows = warp_row_tiles // micro_size_x
225+
warp_cols = warp_col_tiles // micro_size_y
226+
227+
# MMA Wrapper to Auto Generate Code for MMA
228+
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
229+
a_dtype=dtypeAB,
230+
b_dtype=dtypeAB,
231+
accum_dtype=accum_dtype,
232+
a_transposed=False,
233+
b_transposed=True,
234+
block_row_warps=block_row_warps,
235+
block_col_warps=block_col_warps,
236+
warp_row_tiles=warp_row_tiles,
237+
warp_col_tiles=warp_col_tiles,
238+
chunk=chunk,
239+
reduce_k=reduce_k,
240+
transform_kind_b=transform_b,
241+
num_elems_per_byte=num_elems_per_byte)
242+
243+
vec_load_qb = 16
244+
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
245+
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
246+
247+
@T.prim_func
248+
def main(
249+
A: T.Buffer(A_shape, dtypeAB),
250+
B: T.Buffer(B_shape, storage_dtype),
251+
C: T.Buffer((M, N), dtypeC),
252+
):
253+
with T.Kernel(
254+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
255+
prelude=decode_i4_to_f16) as (bx, by):
256+
257+
A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope)
258+
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
259+
C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope)
260+
A_local = T.alloc_local((warp_rows * local_size), dtypeAB)
261+
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
262+
B_dequantize_local = T.alloc_local((warp_cols * local_size), dtypeAB)
263+
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
264+
reduced_accum_res = T.alloc_local(0, accum_dtype)
265+
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
266+
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
267+
268+
T.annotate_layout({
269+
A_shared: make_swizzle_layout(A_shared),
270+
})
271+
272+
T.use_swizzle(panel_size=10)
273+
274+
T.clear(C_local)
275+
276+
for ko in T.Pipelined((K // block_K), num_stages=stage):
277+
278+
# Load A into shared memory
279+
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
280+
vk = rk * (block_K // reduce_k) + k
281+
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
282+
283+
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
284+
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
285+
(threads * vec_load_qb)):
286+
for v in T.vectorized(0, vec_load_qb):
287+
t = thread_bindings
288+
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
289+
vkk = idx % (micro_size_k // num_elems_per_byte)
290+
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
291+
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
292+
block_K // micro_size_k)
293+
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
294+
(block_K // micro_size_k)) % (
295+
block_N // micro_size_y)
296+
B_shared[vj, vk, vjj,
297+
vkk] = B[bx * (block_N // micro_size_y) + vj,
298+
ko * (block_K // micro_size_k) + vk, vjj, vkk]
299+
300+
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
301+
302+
# Load A into fragment
303+
mma_emitter.ldmatrix_a(
304+
A_local,
305+
A_shared,
306+
ki,
307+
thread_bindings=thread_bindings,
308+
rk=rk,
309+
)
310+
311+
# Load B into fragment
312+
mma_emitter.ldmatrix_b(
313+
B_local,
314+
B_shared,
315+
ki,
316+
thread_bindings=thread_bindings,
317+
rk=rk,
318+
)
319+
320+
for j in T.serial(warp_cols):
321+
local_size_b = mma_emitter.local_size_b
322+
T.call_extern('handle', 'decode_i4u_to_f16',
323+
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
324+
T.address_of(B_dequantize_local[j * local_size_b]), 8)
325+
326+
mma_emitter.mma(A_local, B_dequantize_local, C_local)
327+
328+
if reduce_k > 1:
329+
for n in T.serial(warp_rows * warp_cols * local_size):
330+
T.attr(
331+
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
332+
"reduce_scope",
333+
T.reinterpret(T.uint64(0), dtype="handle"),
334+
)
335+
T.evaluate(
336+
T.tvm_thread_allreduce(
337+
T.uint32(1),
338+
C_local[n],
339+
True,
340+
reduced_accum_res[0],
341+
rk,
342+
dtype="handle",
343+
))
344+
if rk == 0:
345+
C_local[n] = reduced_accum_res[0]
346+
347+
if rk == 0:
348+
mma_emitter.stmatrix(
349+
C_local,
350+
C_shared,
351+
thread_bindings=thread_bindings,
352+
)
353+
354+
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
355+
vj = rk * (block_N // reduce_k) + j
356+
C[by * block_M + i,
357+
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
358+
i % micro_size_x, vj % micro_size_y]
359+
360+
return main
361+
362+
363+
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
364+
M,
365+
N,
366+
K,
367+
in_dtype,
368+
dtypeC,
369+
accum_dtype,
370+
transform_b,
371+
):
372+
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
373+
M, N, K, in_dtype, dtypeC, accum_dtype, transform_b)
374+
375+
mod, params = TL.lower(matmul)
376+
src_code = mod.imported_modules[0].get_source()
377+
378+
# src_code is the generated cuda source
379+
assert src_code is not None
380+
num_bits = 4
381+
num_elems_per_byte = 8 // num_bits
382+
storage_dtype = "int8"
383+
384+
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
385+
qB = torch.randint(
386+
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
387+
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
388+
389+
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
390+
M=N,
391+
N=K,
392+
transform_kind=transform_b,
393+
transpose_matrix=True,
394+
dequantize_bits=num_bits,
395+
storage_dtype=storage_dtype,
396+
)
397+
398+
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
399+
400+
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
401+
M=N,
402+
N=K,
403+
datatype=in_dtype,
404+
dequantize_bits=num_bits,
405+
storage_dtype=storage_dtype,
406+
)
407+
lop3_permutate = bitblas.ops.LOP3Permutate(
408+
config=lop3_permutate_config,
409+
target=tvm.target.Target("llvm"),
410+
)
411+
QLB = ladder_permutate(qB.cpu()).cuda()
412+
QLB = lop3_permutate(QLB.cpu()).cuda()
413+
414+
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
415+
416+
mod(A, QLB, C)
417+
418+
latency = mod.do_bench(mod.func, warmup=25)
419+
420+
# Ensure that the latency is not None
421+
assert latency is not None
422+
423+
B = (
424+
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
425+
dtype=torch.half).to(torch.half).to(A.device))
426+
for i in range(B.shape[0]):
427+
for j in range(B.shape[1]):
428+
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
429+
430+
# Get Reference Result
431+
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
432+
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
433+
434+
132435
def test_run_dequantize_gemm():
133436
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
134437

135438

439+
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
440+
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
441+
256, 1024, 512, "float16", "float16", "float16", 3)
442+
443+
136444
if __name__ == "__main__":
137445
bitblas.testing.main()

0 commit comments

Comments
 (0)