Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 7, 2023
1 parent ccbdd9a commit 2be95bf
Showing 1 changed file with 212 additions and 1 deletion.
213 changes: 212 additions & 1 deletion src/layer/x86/convolution_3x3_winograd_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
const short* pA = pAT;

#if __AVX512F__
__m512i _sum0;
__m512i _sum1;
__m512i _sum2;
__m512i _sum3;
#else
__m256i _sum0;
__m256i _sum1;
__m256i _sum2;
Expand All @@ -907,9 +913,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m256i _sum5;
__m256i _sum6;
__m256i _sum7;
#endif // __AVX512F__

if (k == 0)
{
#if __AVX512F__
_sum0 = _mm512_setzero_si512();
_sum1 = _mm512_setzero_si512();
_sum2 = _mm512_setzero_si512();
_sum3 = _mm512_setzero_si512();
#else
_sum0 = _mm256_setzero_si256();
_sum1 = _mm256_setzero_si256();
_sum2 = _mm256_setzero_si256();
Expand All @@ -918,9 +931,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum5 = _mm256_setzero_si256();
_sum6 = _mm256_setzero_si256();
_sum7 = _mm256_setzero_si256();
#endif // __AVX512F__
}
else
{
#if __AVX512F__
_sum0 = _mm512_loadu_si512((const __m512i*)outptr);
_sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16));
_sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32));
_sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48));
#else
_sum0 = _mm256_load_si256((const __m256i*)outptr);
_sum1 = _mm256_load_si256((const __m256i*)(outptr + 8));
_sum2 = _mm256_load_si256((const __m256i*)(outptr + 16));
Expand All @@ -929,13 +949,33 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum5 = _mm256_load_si256((const __m256i*)(outptr + 40));
_sum6 = _mm256_load_si256((const __m256i*)(outptr + 48));
_sum7 = _mm256_load_si256((const __m256i*)(outptr + 56));
#endif // __AVX512F__
}

int kk = 0;
for (; kk + 1 < max_kk; kk += 2)
{
__m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA);
__m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB);
#if __AVX512F__
__m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1);
__m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1);
__m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC);

#if __AVX512VNNI__
_sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01);
_sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23);
_sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01);
_sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23);
#else
_sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01));
_sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23));
_sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01));
_sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23));
#endif // __AVX512VNNI__
#else // __AVX512F__
__m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
Expand All @@ -960,6 +1000,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2));
_sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3));
#endif // __AVXVNNI__
#endif // __AVX512F__

pA += 16;
pB += 16;
Expand All @@ -968,9 +1009,26 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m128i _pA = _mm_load_si128((const __m128i*)pA);
__m128i _pB = _mm_load_si128((const __m128i*)pB);

__m256i _pA0 = _mm256_cvtepi16_epi32(_pA);
__m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB0 = _mm256_cvtepi16_epi32(_pB);
#if __AVX512F__
__m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1);
__m512i _pA11 = _mm512_shuffle_i32x4(_pA00, _pA00, _MM_SHUFFLE(2, 3, 0, 1));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1);
__m512i _pB23 = _mm512_permutex_epi64(_pB01, _MM_SHUFFLE(2, 3, 0, 1));

__m512i _s01 = _mm512_mullo_epi32(_pA00, _pB01);
__m512i _s23 = _mm512_mullo_epi32(_pA00, _pB23);
__m512i _s45 = _mm512_mullo_epi32(_pA11, _pB01);
__m512i _s67 = _mm512_mullo_epi32(_pA11, _pB23);
_sum0 = _mm512_add_epi32(_sum0, _s01);
_sum1 = _mm512_add_epi32(_sum1, _s23);
_sum2 = _mm512_add_epi32(_sum2, _s45);
_sum3 = _mm512_add_epi32(_sum3, _s67);
#else
__m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1));
__m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3));
Expand All @@ -991,11 +1049,61 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum5 = _mm256_add_epi32(_sum5, _s5);
_sum6 = _mm256_add_epi32(_sum6, _s6);
_sum7 = _mm256_add_epi32(_sum7, _s7);
#endif // __AVX512F__

pA += 8;
pB += 8;
}

#if __AVX512F__
if (k_end)
{
// from
// 00 11 22 33 44 55 66 77 01 12 23 30 45 56 67 74
// 02 13 20 31 46 57 64 75 03 10 21 32 47 54 65 76
// 40 51 62 73 04 15 26 37 41 52 63 70 05 16 27 34
// 42 53 60 71 06 17 24 35 43 50 61 72 07 14 25 36
// to
// 00 10 20 30 44 54 64 74 04 14 24 34 40 50 60 70
// 01 11 21 31 45 55 65 75 05 15 25 35 41 51 61 71
// 02 12 22 32 46 56 66 76 06 16 26 36 42 52 62 72
// 03 13 23 33 47 57 67 77 07 17 27 37 43 53 63 73
{
__m512i _s0 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(0, 1, 1, 0));
__m512i _s1 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(2, 3, 3, 2));
__m512i _s2 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(0, 1, 1, 0));
__m512i _s3 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(2, 3, 3, 2));
_s1 = _mm512_shuffle_epi32(_s1, _MM_PERM_ADCB);
_s2 = _mm512_shuffle_epi32(_s2, _MM_PERM_BADC);
_s3 = _mm512_shuffle_epi32(_s3, _MM_PERM_CBAD);
__m512i _tmp0 = _mm512_unpacklo_epi32(_s0, _s1);
__m512i _tmp1 = _mm512_unpackhi_epi32(_s0, _s1);
__m512i _tmp2 = _mm512_unpacklo_epi32(_s2, _s3);
__m512i _tmp3 = _mm512_unpackhi_epi32(_s2, _s3);
_sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2);
_sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2);
_sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1);
_sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1);
_sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD);
_sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD);
}

__m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 0, 3, 0));
__m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 0, 3, 0));
__m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 2, 1, 2));
__m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 2, 1, 2));
_sum0 = _tmp0;
_sum1 = _tmp1;
_sum2 = _tmp2;
_sum3 = _tmp3;
}

_mm512_storeu_si512((__m512i*)outptr, _sum0);
_mm512_storeu_si512((__m512i*)(outptr + 16), _sum1);
_mm512_storeu_si512((__m512i*)(outptr + 32), _sum2);
_mm512_storeu_si512((__m512i*)(outptr + 48), _sum3);
outptr += 64;
#else
if (k_end)
{
// from
Expand Down Expand Up @@ -1072,6 +1180,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_mm256_store_si256((__m256i*)(outptr + 8 * 6), _sum6);
_mm256_store_si256((__m256i*)(outptr + 8 * 7), _sum7);
outptr += 8 * 8;
#endif // __AVX512F__
}
for (; jj + 3 < max_jj; jj += 4)
{
Expand Down Expand Up @@ -1301,6 +1410,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
const short* pA = pAT;

#if __AVX2__
__m256i _sum0;
__m256i _sum1;
__m256i _sum2;
__m256i _sum3;
#else
__m128i _sum0;
__m128i _sum1;
__m128i _sum2;
Expand All @@ -1309,9 +1424,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _sum5;
__m128i _sum6;
__m128i _sum7;
#endif

if (k == 0)
{
#if __AVX2__
_sum0 = _mm256_setzero_si256();
_sum1 = _mm256_setzero_si256();
_sum2 = _mm256_setzero_si256();
_sum3 = _mm256_setzero_si256();
#else
_sum0 = _mm_setzero_si128();
_sum1 = _mm_setzero_si128();
_sum2 = _mm_setzero_si128();
Expand All @@ -1320,9 +1442,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum5 = _mm_setzero_si128();
_sum6 = _mm_setzero_si128();
_sum7 = _mm_setzero_si128();
#endif
}
else
{
#if __AVX2__
_sum0 = _mm256_loadu_si256((const __m256i*)outptr);
_sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8));
_sum2 = _mm256_loadu_si256((const __m256i*)(outptr + 16));
_sum3 = _mm256_loadu_si256((const __m256i*)(outptr + 24));
#else
_sum0 = _mm_load_si128((const __m128i*)outptr);
_sum1 = _mm_load_si128((const __m128i*)(outptr + 4));
_sum2 = _mm_load_si128((const __m128i*)(outptr + 8));
Expand All @@ -1331,11 +1460,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum5 = _mm_load_si128((const __m128i*)(outptr + 20));
_sum6 = _mm_load_si128((const __m128i*)(outptr + 24));
_sum7 = _mm_load_si128((const __m128i*)(outptr + 28));
#endif
}

int kk = 0;
for (; kk + 1 < max_kk; kk += 2)
{
#if __AVX2__
__m128i _pA = _mm_loadu_si128((const __m128i*)pA);
__m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB);
__m256i _pA0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1);
__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));

#if __AVXVNNI__ || __AVX512VNNI__
_sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0);
_sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1);
_sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0);
_sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1);
#else
_sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0));
_sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1));
_sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0));
_sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1));
#endif
#else // __AVX2__
__m128i _pA0 = _mm_loadu_si128((const __m128i*)pA);
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
__m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8));
Expand All @@ -1362,6 +1511,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2));
_sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3));
#endif
#endif // __AVX2__

pA += 8;
pB += 16;
Expand All @@ -1371,6 +1521,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA));
__m128i _pB = _mm_loadu_si128((const __m128i*)pB);

#if __AVX2__
__m256i _pA0 = _mm256_cvtepi16_epi32(_pA);
__m256i _pB0 = _mm256_cvtepi16_epi32(_pB);
__m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));

__m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0);
__m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1);
__m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0);
__m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1);

_sum0 = _mm256_add_epi32(_sum0, _s0);
_sum1 = _mm256_add_epi32(_sum1, _s1);
_sum2 = _mm256_add_epi32(_sum2, _s2);
_sum3 = _mm256_add_epi32(_sum3, _s3);
#else // __AVX2__
#if __XOP__
__m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA);
__m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
Expand Down Expand Up @@ -1416,11 +1582,55 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_sum6 = _mm_add_epi32(_sum6, _s6);
_sum7 = _mm_add_epi32(_sum7, _s7);
#endif
#endif // __AVX2__

pA += 4;
pB += 8;
}

#if __AVX2__
if (k_end)
{
// from
// 00 11 22 33 04 15 26 37
// 01 12 23 30 05 16 27 34
// 20 31 02 13 24 35 06 17
// 21 32 03 10 25 36 07 14
// to
// 00 10 20 30 04 14 24 34
// 01 11 21 31 05 15 25 35
// 02 12 22 32 06 16 26 36
// 03 13 23 33 07 17 27 37
{
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3);
__m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3);
__m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1);
__m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1);
_sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2);
_sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2);
_sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1);
_sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1);
_sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
_tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0));
_tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0));
_tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1));
_tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1));
_sum0 = _tmp0;
_sum1 = _tmp1;
_sum2 = _tmp2;
_sum3 = _tmp3;
}
}

_mm256_storeu_si256((__m256i*)outptr, _sum0);
_mm256_storeu_si256((__m256i*)(outptr + 8), _sum1);
_mm256_storeu_si256((__m256i*)(outptr + 16), _sum2);
_mm256_storeu_si256((__m256i*)(outptr + 24), _sum3);
outptr += 32;
#else // __AVX2__
if (k_end)
{
// from
Expand Down Expand Up @@ -1470,6 +1680,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_mm_store_si128((__m128i*)(outptr + 24), _sum6);
_mm_store_si128((__m128i*)(outptr + 28), _sum7);
outptr += 32;
#endif // __AVX2__
}
for (; jj + 3 < max_jj; jj += 4)
{
Expand Down

0 comments on commit 2be95bf

Please sign in to comment.