From 2be95bfaa8b65bddc5c44e98426f8a5d268749e2 Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 7 Oct 2023 11:42:52 +0800 Subject: [PATCH] wip --- src/layer/x86/convolution_3x3_winograd_int8.h | 213 +++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h index eb78d18c835..5168e93b6f7 100644 --- a/src/layer/x86/convolution_3x3_winograd_int8.h +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -899,6 +899,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { const short* pA = pAT; +#if __AVX512F__ + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; +#else __m256i _sum0; __m256i _sum1; __m256i _sum2; @@ -907,9 +913,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _sum5; __m256i _sum6; __m256i _sum7; +#endif // __AVX512F__ if (k == 0) { +#if __AVX512F__ + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); +#else _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); _sum2 = _mm256_setzero_si256(); @@ -918,9 +931,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum5 = _mm256_setzero_si256(); _sum6 = _mm256_setzero_si256(); _sum7 = _mm256_setzero_si256(); +#endif // __AVX512F__ } else { +#if __AVX512F__ + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); +#else _sum0 = _mm256_load_si256((const __m256i*)outptr); _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); @@ -929,6 +949,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum5 = _mm256_load_si256((const __m256i*)(outptr + 40)); _sum6 = _mm256_load_si256((const __m256i*)(outptr + 48)); _sum7 = _mm256_load_si256((const __m256i*)(outptr + 56)); +#endif // __AVX512F__ } int kk = 0; @@ -936,6 +957,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); +#endif // __AVX512VNNI__ +#else // __AVX512F__ __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); @@ -960,6 +1000,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); #endif // __AVXVNNI__ +#endif // __AVX512F__ pA += 16; pB += 16; @@ -968,9 +1009,26 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA = _mm_load_si128((const __m128i*)pA); __m128i _pB = _mm_load_si128((const __m128i*)pB); + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); - __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_i32x4(_pA00, _pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_permutex_epi64(_pB01, _MM_SHUFFLE(2, 3, 0, 1)); + + __m512i _s01 = _mm512_mullo_epi32(_pA00, _pB01); + __m512i _s23 = _mm512_mullo_epi32(_pA00, _pB23); + __m512i _s45 = _mm512_mullo_epi32(_pA11, _pB01); + __m512i _s67 = _mm512_mullo_epi32(_pA11, _pB23); + _sum0 = _mm512_add_epi32(_sum0, _s01); + _sum1 = _mm512_add_epi32(_sum1, _s23); + _sum2 = _mm512_add_epi32(_sum2, _s45); + _sum3 = _mm512_add_epi32(_sum3, _s67); +#else + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); @@ -991,11 +1049,61 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum5 = _mm256_add_epi32(_sum5, _s5); _sum6 = _mm256_add_epi32(_sum6, _s6); _sum7 = _mm256_add_epi32(_sum7, _s7); +#endif // __AVX512F__ pA += 8; pB += 8; } +#if __AVX512F__ + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 04 14 24 34 40 50 60 70 + // 01 11 21 31 45 55 65 75 05 15 25 35 41 51 61 71 + // 02 12 22 32 46 56 66 76 06 16 26 36 42 52 62 72 + // 03 13 23 33 47 57 67 77 07 17 27 37 43 53 63 73 + { + __m512i _s0 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s1 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(2, 3, 3, 2)); + __m512i _s2 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s3 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(2, 3, 3, 2)); + _s1 = _mm512_shuffle_epi32(_s1, _MM_PERM_ADCB); + _s2 = _mm512_shuffle_epi32(_s2, _MM_PERM_BADC); + _s3 = _mm512_shuffle_epi32(_s3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_s0, _s1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_s0, _s1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_s2, _s3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_s2, _s3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 2, 1, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 2, 1, 2)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; +#else if (k_end) { // from @@ -1072,6 +1180,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _mm256_store_si256((__m256i*)(outptr + 8 * 6), _sum6); _mm256_store_si256((__m256i*)(outptr + 8 * 7), _sum7); outptr += 8 * 8; +#endif // __AVX512F__ } for (; jj + 3 < max_jj; jj += 4) { @@ -1301,6 +1410,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { const short* pA = pAT; +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else __m128i _sum0; __m128i _sum1; __m128i _sum2; @@ -1309,9 +1424,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _sum5; __m128i _sum6; __m128i _sum7; +#endif if (k == 0) { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); _sum2 = _mm_setzero_si128(); @@ -1320,9 +1442,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum5 = _mm_setzero_si128(); _sum6 = _mm_setzero_si128(); _sum7 = _mm_setzero_si128(); +#endif } else { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_loadu_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_loadu_si256((const __m256i*)(outptr + 24)); +#else _sum0 = _mm_load_si128((const __m128i*)outptr); _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); @@ -1331,11 +1460,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif } int kk = 0; for (; kk + 1 < max_kk; kk += 2) { +#if __AVX2__ + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif +#else // __AVX2__ __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); @@ -1362,6 +1511,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); #endif +#endif // __AVX2__ pA += 8; pB += 16; @@ -1371,6 +1521,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); __m128i _pB = _mm_loadu_si128((const __m128i*)pB); +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ #if __XOP__ __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); @@ -1416,11 +1582,55 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum6 = _mm_add_epi32(_sum6, _s6); _sum7 = _mm_add_epi32(_sum7, _s7); #endif +#endif // __AVX2__ pA += 4; pB += 8; } +#if __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + _tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + _mm256_storeu_si256((__m256i*)(outptr + 16), _sum2); + _mm256_storeu_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; +#else // __AVX2__ if (k_end) { // from @@ -1470,6 +1680,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _mm_store_si128((__m128i*)(outptr + 24), _sum6); _mm_store_si128((__m128i*)(outptr + 28), _sum7); outptr += 32; +#endif // __AVX2__ } for (; jj + 3 < max_jj; jj += 4) {