-
Notifications
You must be signed in to change notification settings - Fork 152
/
hgemm_mma_stage.cu
2582 lines (2395 loc) · 111 KB
/
hgemm_mma_stage.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <mma.h>
#include <torch/types.h>
#include <torch/extension.h>
using namespace nvcuda;
#define WARP_SIZE 32
#define DEVICE_INLINE __device__ inline
#define HOST_DEVICE_INLINE __device__ __host__ inline
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST32BITS(value) (reinterpret_cast<half2*>(&(value))[0])
#define LDST64BITS(value) (reinterpret_cast<float2*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
// gmem -> smem
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
// smem -> gmem: requires sm_90 or higher.
#define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::)
#define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::)
#define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n))
#define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
// ldmatrix
#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
// stmatrix: requires sm_90 or higher.
#define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R))
#define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1))
#define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3))
#define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R))
#define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1))
#define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3))
// mma m16n8k16
#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
HOST_DEVICE_INLINE
int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle
template<const int MMA_M=16,
const int MMA_N=8,
const int MMA_K=16,
const int MMA_TILE_M=2,
const int MMA_TILE_N=4,
const int WARP_TILE_M=4,
const int WARP_TILE_N=4,
const int A_PAD=0,
const int B_PAD=0,
const int K_STAGE=2,
const bool BLOCK_SWIZZLE=true>
__global__ void __launch_bounds__(256)
hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// BLOCK_SWIZZLE 0/1 control use block swizzle or not.
const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, MMA_K);
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
constexpr int BK = MMA_K; // 16
__shared__ half s_a[K_STAGE][BM][BK+A_PAD]; // 128*16*2=4KB
__shared__ half s_b[K_STAGE][BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB
constexpr int s_a_stage_offset = BM * (BK + A_PAD);
constexpr int s_b_stage_offset = BK * (BN + B_PAD);
const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id % 2; // 0,1
const int warp_n = warp_id / 2; // 0,1,2,3
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
RC[i][j][0] = 0;
RC[i][j][1] = 0;
}
}
uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a);
uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b);
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1
// k * WMMA_K, WMMA_K=16 -> (k << 4)
int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (k * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (k * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2
__syncthreads();
#pragma unroll
for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) {
// gmem -> smem
// s2/4 can use bitwise ops but s3 can not, so, we use mod
// ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3
// s3: (k + 1) % 3
int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2...
int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1...
int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (smem_sel_next * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (smem_sel_next * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
// ldmatrix for s_a, ldmatrix.trans for s_b.
uint32_t RA[WARP_TILE_M][4];
uint32_t RB[WARP_TILE_N][2];
// smem -> reg
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = __cvta_generic_to_shared(
&s_a[smem_sel][lane_smem_a_m][lane_smem_a_k]);
LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = __cvta_generic_to_shared(
&s_b[smem_sel][lane_smem_b_k][lane_smem_b_n]);
LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
}
// MMA compute
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
HMMA16816(RC[i][j][0], RC[i][j][1],
RA[i][0], RA[i][1], RA[i][2], RA[i][3],
RB[j][0], RB[j][1],
RC[i][j][0], RC[i][j][1]);
}
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2);
__syncthreads();
}
// make sure all memory issues ready.
if ((K_STAGE - 2) > 0) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// processing last (K_STAGE-1) k iters.
{
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); k++) {
int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE);
// ldmatrix for s_a, ldmatrix.trans for s_b.
uint32_t RA[WARP_TILE_M][4];
uint32_t RB[WARP_TILE_N][2];
// smem -> reg
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = __cvta_generic_to_shared(
&s_a[stage_sel][lane_smem_a_m][lane_smem_a_k]);
LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = __cvta_generic_to_shared(
&s_b[stage_sel][lane_smem_b_k][lane_smem_b_n]);
LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
}
// MMA compute
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
HMMA16816(RC[i][j][0], RC[i][j][1],
RA[i][0], RA[i][1], RA[i][2], RA[i][3],
RB[j][0], RB[j][1],
RC[i][j][0], RC[i][j][1]);
}
}
}
}
// reg -> gmem, MMA_MxMMA_N=16x8
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
// mapping lane smem index -> global index.
// [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
// #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
// [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2;
int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]);
LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]);
}
}
}
// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem
template<const int MMA_M=16,
const int MMA_N=8,
const int MMA_K=16,
const int MMA_TILE_M=2,
const int MMA_TILE_N=4,
const int WARP_TILE_M=4,
const int WARP_TILE_N=4,
const int A_PAD=0,
const int B_PAD=0,
const int K_STAGE=2,
const bool BLOCK_SWIZZLE=true,
const bool COLLECTIVE_STORE=false>
__global__ void __launch_bounds__(256)
hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// BLOCK_SWIZZLE 0/1 control use block swizzle or not.
// COLLECTIVE_STORE true/false control use stmatrix or not.
const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, MMA_K);
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
constexpr int BK = MMA_K; // 16
extern __shared__ half smem[];
half* s_a = smem;
half* s_b = smem + K_STAGE * BM * (BK + A_PAD);
constexpr int s_a_stage_offset = BM * (BK + A_PAD);
constexpr int s_b_stage_offset = BK * (BN + B_PAD);
const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id % 2; // 0,1
const int warp_n = warp_id / 2; // 0,1,2,3
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
RC[i][j][0] = 0;
RC[i][j][1] = 0;
}
}
uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a);
uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b);
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1
// k * WMMA_K, WMMA_K=16 -> (k << 4)
int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (k * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (k * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2
__syncthreads();
#pragma unroll
for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) {
// gmem -> smem
// s2/4 can use bitwise ops but s3 can not, so, we use mod
// ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3
// s3: (k + 1) % 3
int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2...
int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1...
int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (smem_sel_next * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (smem_sel_next * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
uint32_t RA[WARP_TILE_M][4];
uint32_t RB[WARP_TILE_N][2];
// ldmatrix for s_a, ldmatrix.trans for s_b.
// smem -> reg
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + (smem_sel * s_a_stage_offset +
lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + (smem_sel * s_b_stage_offset +
lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
}
// MMA compute
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
HMMA16816(RC[i][j][0], RC[i][j][1],
RA[i][0], RA[i][1], RA[i][2], RA[i][3],
RB[j][0], RB[j][1],
RC[i][j][0], RC[i][j][1]);
}
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2);
__syncthreads();
}
// make sure all memory issues ready.
if ((K_STAGE - 2) > 0) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// processing last (K_STAGE-1) k iters.
{
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); k++) {
uint32_t RA[WARP_TILE_M][4];
uint32_t RB[WARP_TILE_N][2];
int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE);
// ldmatrix for s_a, ldmatrix.trans for s_b.
// smem -> reg
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + (stage_sel * s_a_stage_offset +
lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + (stage_sel * s_b_stage_offset +
lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
}
// MMA compute
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
HMMA16816(RC[i][j][0], RC[i][j][1],
RA[i][0], RA[i][1], RA[i][2], RA[i][3],
RB[j][0], RB[j][1],
RC[i][j][0], RC[i][j][1]);
}
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 90)
if (COLLECTIVE_STORE) {
// The following code has not been tested because I do not have a GPU with sm>=90
// reg -> smem(stmatrix) -> gmem(cp.async.bulk), MMA_MxMMA_N=16x8
// NOTE: need [MMA_M][MMA_N] per warp to avoid overlap between warps.
__shared__ half s_c[MMA_TILE_M][MMA_TILE_N][MMA_M][MMA_N]; // (2*4)*16*8*2=2KB
uint32_t smem_c_base_ptr = __cvta_generic_to_shared(&s_c[warp_m][warp_n]);
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// store (i,j) warp tile -> smem c, 16x8
uint32_t lane_smem_c_ptr = (
smem_c_base_ptr + (lane_id % 16) * MMA_N * sizeof(half)); // (0~15)*8
STMATRIX_X2(lane_smem_c_ptr, RC[i][j][0], RC[i][j][1]);
// smem -> gmem, may use cp.async.bulk.global.share::cta?
int store_warp_gmem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int store_warp_gmem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int store_lane_gmem_c_m = by * BM + store_warp_gmem_c_m;
int store_lane_gmem_c_n = bx * BN + store_warp_gmem_c_n;
// send 16 memory issues with 128 bits within lower half lanes.
// TODO: use cp.async.bulk and wait outside the inner loop.
if (lane_id < 16) {
int store_gmem_c_addr = (store_lane_gmem_c_m + lane_id) * N + store_lane_gmem_c_n;
LDST128BITS(C[store_gmem_c_addr]) = LDST128BITS(
s_c[warp_m][warp_n][lane_id][0]);
}
__syncwarp();
}
}
} else {
// reg -> gmem, MMA_MxMMA_N=16x8
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2;
int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]);
LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]);
}
}
}
#else
#warning "stmatrix need sm>=90, force use __shfl_sync for collective store!"
{
for (int i = 0; i < WARP_TILE_M; ++i) {
// How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half.
// thus, we only need 8 memory issues with 128 bits after shfl_sync.
// may reuse RA[4][4] as RC0 ? only new RC1[4][4].
uint32_t RC0[WARP_TILE_N][4];
uint32_t RC1[WARP_TILE_N][4];
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half.
// thus, we only need 8 memory issues with 128 bits after shfl_sync.
RC0[j][0] = RC[i][j][0];
RC1[j][0] = RC[i][j][1];
RC0[j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1);
RC0[j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2);
RC0[j][3] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 3);
RC1[j][1] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 1);
RC1[j][2] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 2);
RC1[j][3] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 3);
}
if (lane_id % 4 == 0) {
int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n;
int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
LDST128BITS(C[store_gmem_c_addr_0]) = LDST128BITS(RC0[j][0]);
LDST128BITS(C[store_gmem_c_addr_1]) = LDST128BITS(RC1[j][0]);
}
}
}
}
#endif
}
// In order to reduce bank conflicts, we will save the K(16x2=32)
// dimension by half according to the stage dimension. For example,
// stages=3, warp_tile_k=2, it will be saved as [3*2][BM][16].
// 128x128, mma2x4, warp4x4(64,32,32), stages, block swizzle, dsmem,
// k32 with reg double buffers
template<const int MMA_M=16,
const int MMA_N=8,
const int MMA_K=16,
const int MMA_TILE_M=2,
const int MMA_TILE_N=4,
const int WARP_TILE_M=4,
const int WARP_TILE_N=4,
const int WARP_TILE_K=2,
const int A_PAD=0,
const int B_PAD=0,
const int K_STAGE=2,
const bool BLOCK_SWIZZLE=true,
const bool WARP_SWIZZLE=true>
__global__ void __launch_bounds__(256)
hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
const half* __restrict__ A, const half* __restrict__ B, half* __restrict__ C,
int M, int N, int K) {
// BLOCK_SWIZZLE 0/1 control use block swizzle or not.
const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, MMA_K * WARP_TILE_K);
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
constexpr int BK = MMA_K; // 16x2=32
extern __shared__ half smem[];
half* s_a = smem;
half* s_b = smem + K_STAGE * BM * (BK + A_PAD) * WARP_TILE_K;
constexpr int s_a_stage_offset = BM * (BK + A_PAD); // 128x16
constexpr int s_b_stage_offset = BK * (BN + B_PAD); // 16x128
constexpr int s_a_mma_k_store_offset = K_STAGE * BM * (BK + A_PAD);
constexpr int s_b_mma_k_store_offset = K_STAGE * BK * (BN + B_PAD);
const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id % 2; // 0,1
const int warp_n = warp_id / 2; // 0,1,2,3
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,16,...
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
RC[i][j][0] = 0;
RC[i][j][1] = 0;
}
}
uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a);
uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b);
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1
// k * WMMA_K, WMMA_K=16 -> (k << 4)
int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (k * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0
uint32_t load_smem_a_mma_k_ptr = (
smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
(k * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (k * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k;
int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n;
uint32_t load_smem_b_mma_k_ptr = (
smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
(k * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16);
CP_ASYNC_COMMIT_GROUP();
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2
__syncthreads();
uint32_t RA[2][WARP_TILE_M][4];
uint32_t RB[2][WARP_TILE_N][2];
int reg_store_idx = 0;
int reg_load_idx = 1;
{
// ldmatrix for s_a, ldmatrix.trans for s_b.
// smem -> reg buffers 0, first MMA_K, 0~15
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr +
(0 * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr +
(0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
// may use .x4.trans to load 4 matrix for reg double buffers at once?
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
lane_smem_b_ptr);
}
}
#pragma unroll
for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) {
reg_store_idx ^= 1; // 0->1
reg_load_idx ^= 1; // 1->0
int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2...
int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1...
// stage gmem -> smem
int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = (
smem_a_base_ptr + (smem_sel_next * s_a_stage_offset +
load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0
uint32_t load_smem_a_mma_k_ptr = (
smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
(smem_sel_next * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) +
load_smem_a_k) * sizeof(half)
);
CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1
uint32_t load_smem_b_ptr = (
smem_b_base_ptr + (smem_sel_next * s_b_stage_offset +
load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k;
int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n;
uint32_t load_smem_b_mma_k_ptr = (
smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
(smem_sel_next * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) +
load_smem_b_n) * sizeof(half)
);
CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16);
CP_ASYNC_COMMIT_GROUP();
// ldmatrix for s_a, ldmatrix.trans for s_b.
// smem -> reg buffers 1, second MMA_K, 16~31
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
(smem_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
(smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
// may use .x4.trans to load 4 matrix for reg double buffers at once?
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
lane_smem_b_ptr);
}
// MMA compute, first MMA_K
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// Warp swizzle: Right -> Left -> Right -> Left
int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
RC[i][j_s][0], RC[i][j_s][1]);
}
}
reg_store_idx ^= 1; // 1 -> 0
reg_load_idx ^= 1; // 0 -> 1
// MMA compute, second MMA_K
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// Warp swizzle: Right -> Left -> Right -> Left
int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
RC[i][j_s][0], RC[i][j_s][1]);
}
}
CP_ASYNC_WAIT_GROUP(K_STAGE-2);
__syncthreads();
// load next k iters to reg buffers.
// smem -> reg buffers 0, first MMA_K, 0~15
// int smem_sel_reg = (k + 2) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2
int smem_sel_reg = (smem_sel + 1) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + (smem_sel_reg * s_a_stage_offset +
lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + (smem_sel_reg * s_b_stage_offset +
lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
// may use .x4.trans to load 4 matrix for reg double buffers at once?
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
lane_smem_b_ptr);
}
}
// make sure all memory issues ready.
if ((K_STAGE - 2) > 0) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// processing last (K_STAGE-1) k iters.
{
#pragma unroll
for (int k = 0; k < (K_STAGE - 1); k++) {
reg_store_idx ^= 1; // 0->1
reg_load_idx ^= 1; // 1->0
int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE);
// ldmatrix for s_a, ldmatrix.trans for s_b.
// smem -> reg buffers 1, second MMA_K
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
(stage_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
(stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
lane_smem_b_ptr);
}
// MMA compute, first MMA_K
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// Warp swizzle: Right -> Left -> Right -> Left
int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
RC[i][j_s][0], RC[i][j_s][1]);
}
}
reg_store_idx ^= 1; // 1 -> 0
reg_load_idx ^= 1; // 0 -> 1
// MMA compute, second MMA_K
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// Warp swizzle: Right -> Left -> Right -> Left
int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
RC[i][j_s][0], RC[i][j_s][1]);
}
}
// load next k iters to reg buffers.
// smem -> reg buffers 0, first MMA_K, 0~15
// int stage_sel_reg = ((NUM_K_TILES - K_STAGE + k) % K_STAGE);
int stage_sel_reg = (stage_sel + 1) % K_STAGE;
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_a_ptr = (
smem_a_base_ptr + (stage_sel_reg * s_a_stage_offset +
lane_smem_a_m * (BK + A_PAD) +
lane_smem_a_k) * sizeof(half)
);
LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
lane_smem_a_ptr);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
uint32_t lane_smem_b_ptr = (
smem_b_base_ptr + (stage_sel_reg * s_b_stage_offset +
lane_smem_b_k * (BN + B_PAD) +
lane_smem_b_n) * sizeof(half)
);
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
lane_smem_b_ptr);
}
}
}
// collective store with reg reuse & warp shuffle
for (int i = 0; i < WARP_TILE_M; ++i) {
// reuse RA[2][4][4] reg here, this may boost 0.3~0.5 TFLOPS up.
// may not put 'if' in N loop, it will crash the 'pragma unroll' hint ?
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half.
// thus, we only need 8 memory issues with 128 bits after shfl_sync.
RA[0][j][0] = RC[i][j][0];
RA[1][j][0] = RC[i][j][1];
RA[0][j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1);
RA[0][j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2);