From f3f1fb5da0a8a74a5fe274b9eefb44e926c84883 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 16 Dec 2024 08:06:16 +0000 Subject: [PATCH] opt unpack aligned cvt --- src/layer/gemm.cpp | 8 -- src/layer/x86/gemm_int8.h | 230 +++++++++++++++++++------------------- 2 files changed, 118 insertions(+), 120 deletions(-) diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index e480183e348..b4a263f9b53 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -220,8 +220,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A const int N = BT_int8.h; const int K = A_int8.w; // assert A_int8.w == BT_int8.w - // NCNN_LOGE("naive ds %f %f", A_int8_scales[0], BT_int8_scale); - #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < M; i++) { @@ -232,8 +230,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A const float descale = 1.f / (A_int8_scales[i] * BT_int8_scale); - // NCNN_LOGE("descale %f", descale); - for (int j = 0; j < N; j++) { const signed char* ptrBT = BT_int8.row(j); @@ -503,8 +499,6 @@ int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& t float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; A_int8_scales[i] = A_int8_scale; - // NCNN_LOGE("A[%d] absmax %.9f %.9f", i, absmax, A_int8_scale); - signed char* ptrAi = A_int8.row(i); for (int k = 0; k < A_int8.w; k++) @@ -533,8 +527,6 @@ int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& t } } - // NCNN_LOGE("B0 absmax %f", absmax); - B_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; for (int i = 0; i < B0_int8.h; i++) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 22c11f7cb53..b094420a9ae 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -10872,14 +10872,14 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { - __m128i _sum0 = _mm_load_si128((const __m128i*)pp); - __m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4)); - __m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8)); - __m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12)); - __m128i _sum4 = _mm_load_si128((const __m128i*)(pp + 16)); - __m128i _sum5 = _mm_load_si128((const __m128i*)(pp + 20)); - __m128i _sum6 = _mm_load_si128((const __m128i*)(pp + 24)); - __m128i _sum7 = _mm_load_si128((const __m128i*)(pp + 28)); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); + __m128 _f4 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 16))); + __m128 _f5 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 20))); + __m128 _f6 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 24))); + __m128 _f7 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 28))); // from // 00 11 22 33 @@ -10900,40 +10900,40 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // 06 16 26 36 // 07 17 27 37 { - _sum4 = _mm_shuffle_epi32(_sum4, _MM_SHUFFLE(2, 1, 0, 3)); - _sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); - _sum6 = _mm_shuffle_epi32(_sum6, _MM_SHUFFLE(2, 1, 0, 3)); - _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); - __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum6); - __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum6); - __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum7); - __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum7); - __m128i _tmp4 = _mm_unpacklo_epi32(_sum2, _sum4); - __m128i _tmp5 = _mm_unpackhi_epi32(_sum2, _sum4); - __m128i _tmp6 = _mm_unpacklo_epi32(_sum3, _sum5); - __m128i _tmp7 = _mm_unpackhi_epi32(_sum3, _sum5); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp4); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp4); - _sum2 = _mm_unpacklo_epi64(_tmp5, _tmp1); - _sum3 = _mm_unpackhi_epi64(_tmp5, _tmp1); - _sum4 = _mm_unpacklo_epi64(_tmp2, _tmp6); - _sum5 = _mm_unpackhi_epi64(_tmp2, _tmp6); - _sum6 = _mm_unpacklo_epi64(_tmp7, _tmp3); - _sum7 = _mm_unpackhi_epi64(_tmp7, _tmp3); - _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); - _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); - _sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); - _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); - } - - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale); - __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale); - __m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale); - __m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale); - __m128 _f4 = _mm_mul_ps(_mm_cvtepi32_ps(_sum4), _descale); - __m128 _f5 = _mm_mul_ps(_mm_cvtepi32_ps(_sum5), _descale); - __m128 _f6 = _mm_mul_ps(_mm_cvtepi32_ps(_sum6), _descale); - __m128 _f7 = _mm_mul_ps(_mm_cvtepi32_ps(_sum7), _descale); + _f4 = _mm_shuffle_ps(_f4, _f4, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f6 = _mm_shuffle_ps(_f6, _f6, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f6); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f6); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f7); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f7); + __m128 _tmp4 = _mm_unpacklo_ps(_f2, _f4); + __m128 _tmp5 = _mm_unpackhi_ps(_f2, _f4); + __m128 _tmp6 = _mm_unpacklo_ps(_f3, _f5); + __m128 _tmp7 = _mm_unpackhi_ps(_f3, _f5); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f4 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f5 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f6 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f7 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); + _f2 = _mm_mul_ps(_f2, _descale); + _f3 = _mm_mul_ps(_f3, _descale); + _f4 = _mm_mul_ps(_f4, _descale); + _f5 = _mm_mul_ps(_f5, _descale); + _f6 = _mm_mul_ps(_f6, _descale); + _f7 = _mm_mul_ps(_f7, _descale); if (pC) { @@ -11143,10 +11143,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #endif // defined(__x86_64__) || defined(_M_X64) for (; jj + 3 < max_jj; jj += 4) { - __m128i _sum0 = _mm_load_si128((const __m128i*)pp); - __m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4)); - __m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8)); - __m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12)); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); // from // 00 11 22 33 @@ -11159,24 +11159,24 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // 02 12 22 32 // 03 13 23 33 { - _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); - _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); - __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum3); - __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum3); - __m128i _tmp2 = _mm_unpacklo_epi32(_sum2, _sum1); - __m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum1); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp2); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp2); - _sum2 = _mm_unpacklo_epi64(_tmp3, _tmp1); - _sum3 = _mm_unpackhi_epi64(_tmp3, _tmp1); - _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); - _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); - } - - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale); - __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale); - __m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale); - __m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f3); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f3); + __m128 _tmp2 = _mm_unpacklo_ps(_f2, _f1); + __m128 _tmp3 = _mm_unpackhi_ps(_f2, _f1); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); + _f2 = _mm_mul_ps(_f2, _descale); + _f3 = _mm_mul_ps(_f3, _descale); if (pC) { @@ -11326,8 +11326,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } for (; jj + 1 < max_jj; jj += 2) { - __m128i _sum0 = _mm_load_si128((const __m128i*)pp); - __m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4)); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); // from // 00 11 20 31 @@ -11336,15 +11336,15 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // 00 10 20 30 // 01 11 21 31 { - __m128i _tmp0 = _mm_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i _tmp1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); - _sum0 = _mm_unpacklo_epi32(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi32(_tmp0, _tmp1); - _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128 _tmp1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); } - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale); - __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale); + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); if (pC) { @@ -11440,7 +11440,9 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } for (; jj < max_jj; jj++) { - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)), _descale); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + + _f0 = _mm_mul_ps(_f0, _descale); if (pC) { @@ -11581,22 +11583,22 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #if __AVX512F__ for (; jj + 15 < max_jj; jj += 16) { - __m512i _sum0 = _mm512_loadu_si512((const __m512i*)pp); - __m512i _sum1 = _mm512_loadu_si512((const __m512i*)(pp + 16)); + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 16))); // 00 11 02 13 04 15 06 17 08 19 0a 1b 0c 1d 0e 1f // 01 12 03 10 05 16 07 14 09 1a 0b 18 0d 1e 0f 1c - __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); - __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f1); - _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); - _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); - __m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum0), _descale0_avx512); - __m512 _f1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum1), _descale1_avx512); + _f0 = _mm512_mul_ps(_f0, _descale0_avx512); + _f1 = _mm512_mul_ps(_f1, _descale1_avx512); if (pC) { @@ -11690,35 +11692,35 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { - __m128i _sum0 = _mm_load_si128((const __m128i*)pp); - __m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4)); - __m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8)); - __m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12)); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); // 00 11 02 13 // 04 15 06 17 // 10 01 12 03 // 14 05 16 07 - _sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 3, 0, 1)); - _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 3, 0, 1)); + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); - __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum2); - __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum2); - __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum3); - __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum3); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f2); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f2); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f3); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum2 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); + _f2 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); - _sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 3, 0, 1)); - _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 3, 0, 1)); + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale0); - __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale0); - __m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale1); - __m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale1); + _f0 = _mm_mul_ps(_f0, _descale0); + _f1 = _mm_mul_ps(_f1, _descale0); + _f2 = _mm_mul_ps(_f2, _descale1); + _f3 = _mm_mul_ps(_f3, _descale1); if (pC) { @@ -11845,21 +11847,21 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #endif // defined(__x86_64__) || defined(_M_X64) for (; jj + 3 < max_jj; jj += 4) { - __m128i _sum0 = _mm_load_si128((const __m128i*)pp); - __m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4)); + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); // 00 11 02 13 // 01 12 03 10 - __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum1); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f1); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f1); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp1, _tmp0); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp1), _mm_castps_pd(_tmp0))); - _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 3, 2, 1)); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 3, 2, 1)); - __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale0); - __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale1); + _f0 = _mm_mul_ps(_f0, _descale0); + _f1 = _mm_mul_ps(_f1, _descale1); if (pC) { @@ -12479,6 +12481,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + // actually we only depend the global k==0 condition + (void)i; + (void)j; + const signed char* pAT = AT_tile; const signed char* pBT = BT_tile;