-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmatrix_kernel16_6_block.c
106 lines (96 loc) · 4.95 KB
/
matrix_kernel16_6_block.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// 6x16 kernel with blocking
// Requires AVX-2 and FMA
// See a full description at: http://www.cs.utexas.edu/users/flame/pubs/blis3_ipdps14.pdf
inline void matrix_kernel16_6_block( const afloat * __restrict__ A,
const afloat * __restrict__ B, afloat * C,
const int M, const int N, const int K,
const int jc, const int nc,
const int pc, const int kc,
const int ic, const int mc,
const int jr, const int nr,
const int ir, const int mr
)
{
__m256 mB0; // __m256 means 256-bit wide. This is introduced in AVX2 (AVX-512, in 2015, has 512, etc.)
__m256 mB1;
__m256 mA0;
__m256 mA1;
// Chose kernel size 6x16
// - 16 because SIMD width is 8*32 (so must be multiple of 8)
// - Also overall 16 registers
// - Number of registers depends on AVX, AVX-2 or AVX-512
// - So having 6x16 means 6x2 registers used for C block
// - This leaves 4 for sections of A and B (needed to do fma)
// - To use SIMD, need to store in registers
// - Note: Intel paper uses 30x8, not 6x16
__m256 result0_0 = _mm256_set1_ps(0); // Broadcast 32-bit (SP) 0 to all 8 elements
__m256 result1_0 = _mm256_set1_ps(0);
__m256 result2_0 = _mm256_set1_ps(0);
__m256 result3_0 = _mm256_set1_ps(0);
__m256 result4_0 = _mm256_set1_ps(0);
__m256 result5_0 = _mm256_set1_ps(0);
__m256 result0_1 = _mm256_set1_ps(0);
__m256 result1_1 = _mm256_set1_ps(0);
__m256 result2_1 = _mm256_set1_ps(0);
__m256 result3_1 = _mm256_set1_ps(0);
__m256 result4_1 = _mm256_set1_ps(0);
__m256 result5_1 = _mm256_set1_ps(0);
// This is the same for loop as in naive implementation, except now instead of the k indexing
// a single dot product of 2 vectors of size k (a row of A and a col of B),
// the k is indexing 6 rows of A and 16 cols of B
// Since the SIMD width is 8 (256 bits), need to do 12 fmas here
for(int k=0; k<kc; ++k)
{
// Prefetch k+1'th row of B. Gives ~10% speedup
__builtin_prefetch(&B[N*(k+1+pc)+jc+jr+8*0]);
__builtin_prefetch(&B[N*(k+1+pc)+jc+jr+8*1]);
// Load the k'th row of the B block (load twice since in total, it's 16 floats)
mB0 = _mm256_load_ps(&B[N*(k+pc)+jc+jr+8*0]);
mB1 = _mm256_load_ps(&B[N*(k+pc)+jc+jr+8*1]);
// Load a single value for the k'th col of A
// In total, we need to do this 6 times (col of A has height 6)
// Note: the addresses below must be aligned on a 32-byte boundary
mA0 = _mm256_set1_ps(A[k+pc+(ic+ir+0)*K]); // Load float @ A's col k, row m+0 into reg
mA1 = _mm256_set1_ps(A[k+pc+(ic+ir+1)*K]); // Load float @ A's col k, row m+1
// Now we have the 16 floats of B in mB0|mB1, and the 2 floats
// of A broadcast in mA0 and mA1.
result0_0 = _mm256_fmadd_ps(mB0,mA0,result0_0); // result = arg1 .* arg2 .+ arg3
result0_1 = _mm256_fmadd_ps(mB1,mA0,result0_1);
result1_0 = _mm256_fmadd_ps(mB0,mA1,result1_0);
result1_1 = _mm256_fmadd_ps(mB1,mA1,result1_1);
// result0_0 now contains the final result, for this k,
// of row 0 and cols 0-7.
// result0_1 now contains the final result, for this k,
// of row 0 and cols 8-15.
// result1_0 now contains the final result, for this k,
// of row 1 and cols 0-7.
// result1_1 now contains the final result, for this k,
// of row 1 and cols 8-15.
// Repeat for the other 4
mA0 = _mm256_set1_ps(A[k+pc+(ic+ir+2)*K]);
mA1 = _mm256_set1_ps(A[k+pc+(ic+ir+3)*K]);
result2_0 = _mm256_fmadd_ps(mB0,mA0,result2_0);
result2_1 = _mm256_fmadd_ps(mB1,mA0,result2_1);
result3_0 = _mm256_fmadd_ps(mB0,mA1,result3_0);
result3_1 = _mm256_fmadd_ps(mB1,mA1,result3_1);
mA0 = _mm256_set1_ps(A[k+pc+(ic+ir+4)*K]);
mA1 = _mm256_set1_ps(A[k+pc+(ic+ir+5)*K]);
result4_0 = _mm256_fmadd_ps(mB0,mA0,result4_0);
result4_1 = _mm256_fmadd_ps(mB1,mA0,result4_1);
result5_0 = _mm256_fmadd_ps(mB0,mA1,result5_0);
result5_1 = _mm256_fmadd_ps(mB1,mA1,result5_1);
}
// Write registers back to C
*((__m256*) (&C[(ic+ir+0)*N+jc+jr+0*8])) += result0_0;
*((__m256*) (&C[(ic+ir+0)*N+jc+jr+1*8])) += result0_1;
*((__m256*) (&C[(ic+ir+1)*N+jc+jr+0*8])) += result1_0;
*((__m256*) (&C[(ic+ir+1)*N+jc+jr+1*8])) += result1_1;
*((__m256*) (&C[(ic+ir+2)*N+jc+jr+0*8])) += result2_0;
*((__m256*) (&C[(ic+ir+2)*N+jc+jr+1*8])) += result2_1;
*((__m256*) (&C[(ic+ir+3)*N+jc+jr+0*8])) += result3_0;
*((__m256*) (&C[(ic+ir+3)*N+jc+jr+1*8])) += result3_1;
*((__m256*) (&C[(ic+ir+4)*N+jc+jr+0*8])) += result4_0;
*((__m256*) (&C[(ic+ir+4)*N+jc+jr+1*8])) += result4_1;
*((__m256*) (&C[(ic+ir+5)*N+jc+jr+0*8])) += result5_0;
*((__m256*) (&C[(ic+ir+5)*N+jc+jr+1*8])) += result5_1;
}