@@ -98,11 +98,11 @@ layout (constant_id = 12) const uint LOAD_VEC_B_SHIFT = 0;
98
98
#ifdef COOPMAT
99
99
#define SHMEM_STRIDE (BK + 8)
100
100
#else
101
- #define SHMEM_STRIDE (BK + 1)
101
+ #define SHMEM_STRIDE (BK / 2 + 1)
102
102
#endif
103
103
104
- shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
105
- shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104
+ shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
105
+ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
106
106
107
107
#ifdef MUL_MAT_ID
108
108
shared u16vec2 row_ids[3072];
@@ -223,8 +223,8 @@ void main() {
223
223
}
224
224
#else
225
225
ACC_TYPE sums[WMITER * TM * WNITER * TN];
226
- FLOAT_TYPE cache_a[WMITER * TM];
227
- FLOAT_TYPE cache_b[TN];
226
+ FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
227
+ FLOAT_TYPE_VEC2 cache_b[TN];
228
228
229
229
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
230
230
sums[i] = ACC_TYPE(0.0f);
@@ -262,7 +262,7 @@ void main() {
262
262
}
263
263
}
264
264
#else
265
- [[unroll]] for (uint i = 0; i < BK; i++) {
265
+ [[unroll]] for (uint i = 0; i < BK / 2 ; i++) {
266
266
// Load from shared into cache
267
267
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
268
268
[[unroll]] for (uint j = 0; j < TM; j++) {
@@ -278,7 +278,7 @@ void main() {
278
278
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
279
279
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
280
280
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
281
- sums[sums_idx] = fma(ACC_TYPE (cache_a[wsir * TM + cr]), ACC_TYPE (cache_b[cc]), sums[sums_idx] );
281
+ sums[sums_idx] += dot(ACC_TYPE_VEC2 (cache_a[wsir * TM + cr]), ACC_TYPE_VEC2 (cache_b[cc]));
282
282
}
283
283
}
284
284
}
0 commit comments