1
1
# Copyright (c) Microsoft Corporation.
2
2
# Licensed under the MIT License.
3
+ import torch
4
+ import torch .backends
3
5
import bitblas
4
6
from bitblas import tvm as tvm
5
- from tvm import tl
7
+ from tvm import DataType
8
+ from tvm import tl as TL
9
+ import tvm .tl .language as T
6
10
from bitblas .quantization import _tir_packed_to_unsigned_convert
11
+ from bitblas .tl .utils import get_swizzle_layout
12
+ from bitblas .tl .macro_generator import (
13
+ TensorCoreIntrinEmitterWithLadderTransform ,)
14
+
15
+ from bitblas .gpu .intrin .lop3 import decode_i4_to_f16
16
+
17
+ torch .manual_seed (0 )
18
+
19
+
20
+ def make_swizzle_layout (shared_buf ):
21
+ dtype = shared_buf .dtype
22
+ shape = shared_buf .shape
23
+
24
+ can_swizzle = shape [- 1 ] * DataType (dtype ).bits == 512
25
+ if not can_swizzle :
26
+ return T .Layout (shape , lambda * args : args )
27
+
28
+ def transform_func (i , j ):
29
+ new_warp_i , new_warp_j = get_swizzle_layout (i , j , shape [- 1 ], dtype )
30
+ return [new_warp_i , new_warp_j ]
31
+
32
+ return T .Layout (shape , transform_func )
7
33
8
34
9
35
def matmul (
@@ -47,13 +73,8 @@ def main(
47
73
for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = 3 ):
48
74
T .copy (A [by * block_M , k * block_K ], A_shared )
49
75
50
- for i in T .serial (block_N * block_K // num_elems_per_byte // (threads * 16 )):
51
- for t in T .thread_binding (0 , threads , thread = "threadIdx.x" ):
52
- for v in T .vectorized (0 , 16 ):
53
- vi = (i * threads * 16 + t * 16 + v ) // (block_K // num_elems_per_byte )
54
- vj = (i * threads * 16 + t * 16 + v ) % (block_K // num_elems_per_byte )
55
- B_shared [vi , vj ] = B [bx * block_N + vi ,
56
- k * block_K // num_elems_per_byte + vj ,]
76
+ for i , j in T .Parallel (block_N , block_K // num_elems_per_byte ):
77
+ B_shared [i , j ] = B [bx * block_N + i , k * block_K // num_elems_per_byte + j ]
57
78
58
79
for i in T .serial (block_N * block_K // num_elems_per_byte // (threads * 4 )):
59
80
for t in T .thread_binding (0 , threads , thread = "threadIdx.x" ):
@@ -106,8 +127,8 @@ def run_gemm(
106
127
)
107
128
print (program )
108
129
109
- mod , params = tl .lower (program )
110
- mod = tl .Profiler (mod , params , [2 ], tl .TensorSupplyType .Integer )
130
+ mod , params = TL .lower (program )
131
+ mod = TL .Profiler (mod , params , [2 ], TL .TensorSupplyType .Integer )
111
132
112
133
out = mod .run_once ()
113
134
@@ -129,9 +150,296 @@ def ref_program(A, qB):
129
150
mod .assert_allclose (ref_program )
130
151
131
152
153
+ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4 (
154
+ M ,
155
+ N ,
156
+ K ,
157
+ dtypeAB ,
158
+ dtypeC ,
159
+ accum_dtype ,
160
+ transform_b ,
161
+ ):
162
+ assert dtypeAB in [
163
+ "float16" ,
164
+ "int8" ,
165
+ ], "Currently only float16 and int8 are supported"
166
+ assert dtypeC in [
167
+ "float16" ,
168
+ "float32" ,
169
+ "int32" ,
170
+ ], "Currently only float16, float32 and int32 are supported"
171
+ num_bits = 4
172
+ num_elems_per_byte = 8 // num_bits
173
+ storage_dtype = "int8"
174
+
175
+ micro_size_x = micro_size_y = micro_size_k = 16
176
+
177
+ if dtypeC == "int32" :
178
+ micro_size_k = 32
179
+
180
+ # This is a debug config
181
+ block_row_warps = 1
182
+ block_col_warps = 4
183
+
184
+ warp_rows = 1
185
+ warp_cols = 2
186
+ warp_row_tiles = micro_size_x * warp_rows
187
+ warp_col_tiles = micro_size_y * warp_cols
188
+ shared_scope = "shared.dyn"
189
+
190
+ # Pipeline Stage
191
+ stage = 2
192
+ reduce_k = 2
193
+
194
+ block_M = block_row_warps * warp_row_tiles
195
+ block_N = block_col_warps * warp_col_tiles
196
+ block_K = 32 if dtypeAB == "float16" else 64
197
+ chunk = block_K // reduce_k
198
+
199
+ is_smooth_a = False
200
+ can_swizzle = block_K * DataType (dtypeAB ).bits == 512
201
+ apply_pad_a = not (is_smooth_a or can_swizzle )
202
+ pad_factor = 8
203
+
204
+ A_shape = (M , K )
205
+ B_shape = (N // micro_size_y , K // micro_size_k , micro_size_y ,
206
+ micro_size_k // num_elems_per_byte )
207
+ A_shared_shape = (block_M , (block_K + pad_factor ) if apply_pad_a else block_K )
208
+ B_shared_shape = (
209
+ block_N // micro_size_y ,
210
+ block_K // micro_size_k ,
211
+ micro_size_y ,
212
+ micro_size_k // num_elems_per_byte ,
213
+ )
214
+ C_shared_shape = (
215
+ block_M // micro_size_x ,
216
+ block_N // micro_size_y ,
217
+ micro_size_x ,
218
+ micro_size_y ,
219
+ )
220
+
221
+ warp_size = 32
222
+ threads = warp_size * (block_row_warps * block_col_warps )
223
+ local_size = (micro_size_x * micro_size_y ) // warp_size
224
+ warp_rows = warp_row_tiles // micro_size_x
225
+ warp_cols = warp_col_tiles // micro_size_y
226
+
227
+ # MMA Wrapper to Auto Generate Code for MMA
228
+ mma_emitter = TensorCoreIntrinEmitterWithLadderTransform (
229
+ a_dtype = dtypeAB ,
230
+ b_dtype = dtypeAB ,
231
+ accum_dtype = accum_dtype ,
232
+ a_transposed = False ,
233
+ b_transposed = True ,
234
+ block_row_warps = block_row_warps ,
235
+ block_col_warps = block_col_warps ,
236
+ warp_row_tiles = warp_row_tiles ,
237
+ warp_col_tiles = warp_col_tiles ,
238
+ chunk = chunk ,
239
+ reduce_k = reduce_k ,
240
+ transform_kind_b = transform_b ,
241
+ num_elems_per_byte = num_elems_per_byte )
242
+
243
+ vec_load_qb = 16
244
+ if block_N * (block_K // reduce_k ) // num_elems_per_byte // threads < vec_load_qb :
245
+ vec_load_qb = block_N * (block_K // reduce_k ) // num_elems_per_byte // threads
246
+
247
+ @T .prim_func
248
+ def main (
249
+ A : T .Buffer (A_shape , dtypeAB ),
250
+ B : T .Buffer (B_shape , storage_dtype ),
251
+ C : T .Buffer ((M , N ), dtypeC ),
252
+ ):
253
+ with T .Kernel (
254
+ T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ,
255
+ prelude = decode_i4_to_f16 ) as (bx , by ):
256
+
257
+ A_shared = T .alloc_shared (A_shared_shape , dtypeAB , scope = shared_scope )
258
+ B_shared = T .alloc_shared (B_shared_shape , storage_dtype , scope = shared_scope )
259
+ C_shared = T .alloc_shared (C_shared_shape , dtypeC , scope = shared_scope )
260
+ A_local = T .alloc_local ((warp_rows * local_size ), dtypeAB )
261
+ B_local = T .alloc_local ((warp_cols * local_size // num_elems_per_byte ), storage_dtype )
262
+ B_dequantize_local = T .alloc_local ((warp_cols * local_size ), dtypeAB )
263
+ C_local = T .alloc_local ((warp_rows * warp_cols * local_size ), accum_dtype )
264
+ reduced_accum_res = T .alloc_local (0 , accum_dtype )
265
+ thread_bindings = T .thread_binding (0 , threads , "threadIdx.x" )
266
+ rk = T .thread_binding (0 , reduce_k , "threadIdx.y" )
267
+
268
+ T .annotate_layout ({
269
+ A_shared : make_swizzle_layout (A_shared ),
270
+ })
271
+
272
+ T .use_swizzle (panel_size = 10 )
273
+
274
+ T .clear (C_local )
275
+
276
+ for ko in T .Pipelined ((K // block_K ), num_stages = stage ):
277
+
278
+ # Load A into shared memory
279
+ for i , k in T .Parallel (block_M , (block_K // reduce_k )):
280
+ vk = rk * (block_K // reduce_k ) + k
281
+ A_shared [i , vk ] = A [by * block_M + i , ko * block_K + vk ]
282
+
283
+ # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
284
+ for i in T .serial (block_N * (block_K // reduce_k ) // num_elems_per_byte //
285
+ (threads * vec_load_qb )):
286
+ for v in T .vectorized (0 , vec_load_qb ):
287
+ t = thread_bindings
288
+ idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
289
+ vkk = idx % (micro_size_k // num_elems_per_byte )
290
+ vjj = (idx // (micro_size_k // num_elems_per_byte )) % micro_size_y
291
+ vk = (idx // (micro_size_k // num_elems_per_byte ) // micro_size_y ) % (
292
+ block_K // micro_size_k )
293
+ vj = (idx // (micro_size_k // num_elems_per_byte ) // micro_size_y //
294
+ (block_K // micro_size_k )) % (
295
+ block_N // micro_size_y )
296
+ B_shared [vj , vk , vjj ,
297
+ vkk ] = B [bx * (block_N // micro_size_y ) + vj ,
298
+ ko * (block_K // micro_size_k ) + vk , vjj , vkk ]
299
+
300
+ for ki in T .serial (0 , (block_K // (micro_size_k * reduce_k ))):
301
+
302
+ # Load A into fragment
303
+ mma_emitter .ldmatrix_a (
304
+ A_local ,
305
+ A_shared ,
306
+ ki ,
307
+ thread_bindings = thread_bindings ,
308
+ rk = rk ,
309
+ )
310
+
311
+ # Load B into fragment
312
+ mma_emitter .ldmatrix_b (
313
+ B_local ,
314
+ B_shared ,
315
+ ki ,
316
+ thread_bindings = thread_bindings ,
317
+ rk = rk ,
318
+ )
319
+
320
+ for j in T .serial (warp_cols ):
321
+ local_size_b = mma_emitter .local_size_b
322
+ T .call_extern ('handle' , 'decode_i4u_to_f16' ,
323
+ T .address_of (B_local [j * local_size_b // num_elems_per_byte ]),
324
+ T .address_of (B_dequantize_local [j * local_size_b ]), 8 )
325
+
326
+ mma_emitter .mma (A_local , B_dequantize_local , C_local )
327
+
328
+ if reduce_k > 1 :
329
+ for n in T .serial (warp_rows * warp_cols * local_size ):
330
+ T .attr (
331
+ T .comm_reducer (lambda x , y : x + y , [T .float16 (0 )]),
332
+ "reduce_scope" ,
333
+ T .reinterpret (T .uint64 (0 ), dtype = "handle" ),
334
+ )
335
+ T .evaluate (
336
+ T .tvm_thread_allreduce (
337
+ T .uint32 (1 ),
338
+ C_local [n ],
339
+ True ,
340
+ reduced_accum_res [0 ],
341
+ rk ,
342
+ dtype = "handle" ,
343
+ ))
344
+ if rk == 0 :
345
+ C_local [n ] = reduced_accum_res [0 ]
346
+
347
+ if rk == 0 :
348
+ mma_emitter .stmatrix (
349
+ C_local ,
350
+ C_shared ,
351
+ thread_bindings = thread_bindings ,
352
+ )
353
+
354
+ for i , j in T .Parallel (block_M , (block_N // reduce_k )):
355
+ vj = rk * (block_N // reduce_k ) + j
356
+ C [by * block_M + i ,
357
+ bx * block_N + vj ] = C_shared [i // micro_size_x , vj // micro_size_y ,
358
+ i % micro_size_x , vj % micro_size_y ]
359
+
360
+ return main
361
+
362
+
363
+ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness (
364
+ M ,
365
+ N ,
366
+ K ,
367
+ in_dtype ,
368
+ dtypeC ,
369
+ accum_dtype ,
370
+ transform_b ,
371
+ ):
372
+ matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4 (
373
+ M , N , K , in_dtype , dtypeC , accum_dtype , transform_b )
374
+
375
+ mod , params = TL .lower (matmul )
376
+ src_code = mod .imported_modules [0 ].get_source ()
377
+
378
+ # src_code is the generated cuda source
379
+ assert src_code is not None
380
+ num_bits = 4
381
+ num_elems_per_byte = 8 // num_bits
382
+ storage_dtype = "int8"
383
+
384
+ A = torch .rand (M , K , device = "cuda" , dtype = getattr (torch , in_dtype ))
385
+ qB = torch .randint (
386
+ 0 , 127 , (N , K // num_elems_per_byte ), device = "cuda" , dtype = getattr (torch , storage_dtype ))
387
+ C = torch .zeros (M , N , device = "cuda" , dtype = getattr (torch , accum_dtype ))
388
+
389
+ ladder_permutate_config = bitblas .ops .LadderPermutateConfig (
390
+ M = N ,
391
+ N = K ,
392
+ transform_kind = transform_b ,
393
+ transpose_matrix = True ,
394
+ dequantize_bits = num_bits ,
395
+ storage_dtype = storage_dtype ,
396
+ )
397
+
398
+ ladder_permutate = bitblas .ops .LadderPermutate (ladder_permutate_config )
399
+
400
+ lop3_permutate_config = bitblas .ops .LOP3PermutateConfig (
401
+ M = N ,
402
+ N = K ,
403
+ datatype = in_dtype ,
404
+ dequantize_bits = num_bits ,
405
+ storage_dtype = storage_dtype ,
406
+ )
407
+ lop3_permutate = bitblas .ops .LOP3Permutate (
408
+ config = lop3_permutate_config ,
409
+ target = tvm .target .Target ("llvm" ),
410
+ )
411
+ QLB = ladder_permutate (qB .cpu ()).cuda ()
412
+ QLB = lop3_permutate (QLB .cpu ()).cuda ()
413
+
414
+ mod = TL .Profiler (mod , params , [], TL .TensorSupplyType .Integer )
415
+
416
+ mod (A , QLB , C )
417
+
418
+ latency = mod .do_bench (mod .func , warmup = 25 )
419
+
420
+ # Ensure that the latency is not None
421
+ assert latency is not None
422
+
423
+ B = (
424
+ torch .zeros (qB .shape [0 ], qB .shape [1 ] * 8 // 4 ,
425
+ dtype = torch .half ).to (torch .half ).to (A .device ))
426
+ for i in range (B .shape [0 ]):
427
+ for j in range (B .shape [1 ]):
428
+ B [i ][j ] = ((qB [i ][j // 2 ] >> (4 * (j % 2 ))) & 0xF ).to (torch .half )
429
+
430
+ # Get Reference Result
431
+ ref_c = torch .matmul (A , B .T ).to (getattr (torch , accum_dtype ))
432
+ torch .testing .assert_close (C , ref_c , rtol = 1e-2 , atol = 1e-2 )
433
+
434
+
132
435
def test_run_dequantize_gemm ():
133
436
run_gemm (256 , 256 , 256 , "int8" , "int32" , "int32" , 128 , 128 , 32 , num_threads = 128 )
134
437
135
438
439
+ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4 ():
440
+ assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness (
441
+ 256 , 1024 , 512 , "float16" , "float16" , "float16" , 3 )
442
+
443
+
136
444
if __name__ == "__main__" :
137
445
bitblas .testing .main ()
0 commit comments