@@ -294,6 +294,106 @@ def ref_program(A, B):
294
294
torch .testing .assert_close (C , ref_c , rtol = 1e-2 , atol = 1e-2 )
295
295
296
296
297
+ def tl_matmul_block_all_dynamic (
298
+ block_M ,
299
+ block_N ,
300
+ block_K ,
301
+ trans_A ,
302
+ trans_B ,
303
+ dtypeAB ,
304
+ dtypeC ,
305
+ accum_dtype ,
306
+ num_stages ,
307
+ threads ,
308
+ ):
309
+ M = tvm .te .var ("m" )
310
+ N = tvm .te .var ("n" )
311
+ K = tvm .te .var ("k" )
312
+
313
+ A_shape = (K , M ) if trans_A else (M , K )
314
+ B_shape = (N , K ) if trans_B else (K , N )
315
+ A_shared_shape = (block_K , block_M ) if trans_A else (block_M , block_K )
316
+ B_shared_shape = (block_N , block_K ) if trans_B else (block_K , block_N )
317
+
318
+ import tvm .tl .language as T
319
+
320
+ @T .prim_func
321
+ def main (A : T .Buffer (A_shape , dtypeAB ), B : T .Buffer (B_shape , dtypeAB ), C : T .Buffer ((M , N ),
322
+ dtypeC )):
323
+ with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
324
+ A_shared = T .alloc_shared (A_shared_shape , dtypeAB )
325
+ B_shared = T .alloc_shared (B_shared_shape , dtypeAB )
326
+ C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
327
+ T .clear (C_local )
328
+ for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
329
+ if trans_A :
330
+ T .copy (A [k * block_K , by * block_M ], A_shared )
331
+ else :
332
+ T .copy (A [by * block_M , k * block_K ], A_shared )
333
+ if trans_B :
334
+ T .copy (B [bx * block_N , k * block_K ], B_shared )
335
+ else :
336
+ T .copy (B [k * block_K , bx * block_N ], B_shared )
337
+ T .gemm (A_shared , B_shared , C_local , trans_A , trans_B )
338
+ T .copy (C_local , C [by * block_M , bx * block_N ])
339
+
340
+ return main
341
+
342
+
343
+ def assert_tl_matmul_block_all_dynamic_correctness (
344
+ M ,
345
+ N ,
346
+ K ,
347
+ trans_A ,
348
+ trans_B ,
349
+ dtypeAB ,
350
+ dtypeC ,
351
+ dtypeAccum ,
352
+ block_M ,
353
+ block_N ,
354
+ block_K ,
355
+ num_stages = 3 ,
356
+ num_threads = 128 ,
357
+ ):
358
+ program = tl_matmul_block_all_dynamic (
359
+ block_M ,
360
+ block_N ,
361
+ block_K ,
362
+ trans_A ,
363
+ trans_B ,
364
+ dtypeAB ,
365
+ dtypeC ,
366
+ dtypeAccum ,
367
+ num_stages ,
368
+ num_threads ,
369
+ )
370
+ mod , params = TL .lower (program )
371
+
372
+ A = torch .rand (M , K , device = "cuda" , dtype = getattr (torch , dtypeAB ))
373
+ B = torch .rand (N , K , device = "cuda" , dtype = getattr (torch , dtypeAB ))
374
+ C = torch .zeros (M , N , device = "cuda" , dtype = getattr (torch , dtypeC ))
375
+
376
+ mod = TL .Profiler (mod , params , [], TL .TensorSupplyType .Integer )
377
+ mod (A , B , C )
378
+ print (mod .mod .imported_modules [0 ].get_source ())
379
+
380
+ def ref_program (A , B ):
381
+ import torch
382
+
383
+ if trans_A :
384
+ A = A .T
385
+ if trans_B :
386
+ B = B .T
387
+ C = torch .matmul (A .to (torch .float ), B .to (torch .float ))
388
+ C = C .to (torch .__getattribute__ (dtypeC ))
389
+ return C
390
+
391
+ # Get Reference Result
392
+ ref_c = ref_program (A , B )
393
+
394
+ torch .testing .assert_close (C , ref_c , rtol = 1e-2 , atol = 1e-2 )
395
+
396
+
297
397
def test_assert_tl_matmul_macro ():
298
398
assert_tl_matmul_macro_correctness (128 , 128 , 128 , "float16" , "float16" , "float16" )
299
399
assert_tl_matmul_macro_correctness (66 , 128 , 128 , "float16" , "float16" , "float16" )
@@ -309,5 +409,14 @@ def test_assert_tl_matmul_block():
309
409
64 , 64 , 32 )
310
410
311
411
412
+ def test_assert_tl_matmul_block_all_dynamic ():
413
+ assert_tl_matmul_block_all_dynamic_correctness (128 , 128 , 128 , False , False , "float16" ,
414
+ "float16" , "float16" , 64 , 64 , 32 )
415
+ assert_tl_matmul_block_all_dynamic_correctness (67 , 128 , 128 , False , False , "float16" , "float16" ,
416
+ "float16" , 64 , 64 , 32 )
417
+ assert_tl_matmul_block_all_dynamic_correctness (36 , 128 , 128 , False , False , "float16" , "float16" ,
418
+ "float16" , 64 , 64 , 32 )
419
+
420
+
312
421
if __name__ == "__main__" :
313
422
bitblas .testing .main ()
0 commit comments