Skip to content

Commit

Permalink
opt unpack aligned cvt
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 16, 2024
1 parent b69bb0f commit f3f1fb5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 120 deletions.
8 changes: 0 additions & 8 deletions src/layer/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A
const int N = BT_int8.h;
const int K = A_int8.w; // assert A_int8.w == BT_int8.w

// NCNN_LOGE("naive ds %f %f", A_int8_scales[0], BT_int8_scale);

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < M; i++)
{
Expand All @@ -232,8 +230,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A

const float descale = 1.f / (A_int8_scales[i] * BT_int8_scale);

// NCNN_LOGE("descale %f", descale);

for (int j = 0; j < N; j++)
{
const signed char* ptrBT = BT_int8.row<const signed char>(j);
Expand Down Expand Up @@ -503,8 +499,6 @@ int Gemm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax;
A_int8_scales[i] = A_int8_scale;

// NCNN_LOGE("A[%d] absmax %.9f %.9f", i, absmax, A_int8_scale);

signed char* ptrAi = A_int8.row<signed char>(i);

for (int k = 0; k < A_int8.w; k++)
Expand Down Expand Up @@ -533,8 +527,6 @@ int Gemm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
}
}

// NCNN_LOGE("B0 absmax %f", absmax);

B_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax;

for (int i = 0; i < B0_int8.h; i++)
Expand Down
230 changes: 118 additions & 112 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -10872,14 +10872,14 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#endif // __AVX512F__
for (; jj + 7 < max_jj; jj += 8)
{
__m128i _sum0 = _mm_load_si128((const __m128i*)pp);
__m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4));
__m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8));
__m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12));
__m128i _sum4 = _mm_load_si128((const __m128i*)(pp + 16));
__m128i _sum5 = _mm_load_si128((const __m128i*)(pp + 20));
__m128i _sum6 = _mm_load_si128((const __m128i*)(pp + 24));
__m128i _sum7 = _mm_load_si128((const __m128i*)(pp + 28));
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));
__m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4)));
__m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8)));
__m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12)));
__m128 _f4 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 16)));
__m128 _f5 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 20)));
__m128 _f6 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 24)));
__m128 _f7 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 28)));

// from
// 00 11 22 33
Expand All @@ -10900,40 +10900,40 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
// 06 16 26 36
// 07 17 27 37
{
_sum4 = _mm_shuffle_epi32(_sum4, _MM_SHUFFLE(2, 1, 0, 3));
_sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3));
_sum6 = _mm_shuffle_epi32(_sum6, _MM_SHUFFLE(2, 1, 0, 3));
_sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3));
__m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum6);
__m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum6);
__m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum7);
__m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum7);
__m128i _tmp4 = _mm_unpacklo_epi32(_sum2, _sum4);
__m128i _tmp5 = _mm_unpackhi_epi32(_sum2, _sum4);
__m128i _tmp6 = _mm_unpacklo_epi32(_sum3, _sum5);
__m128i _tmp7 = _mm_unpackhi_epi32(_sum3, _sum5);
_sum0 = _mm_unpacklo_epi64(_tmp0, _tmp4);
_sum1 = _mm_unpackhi_epi64(_tmp0, _tmp4);
_sum2 = _mm_unpacklo_epi64(_tmp5, _tmp1);
_sum3 = _mm_unpackhi_epi64(_tmp5, _tmp1);
_sum4 = _mm_unpacklo_epi64(_tmp2, _tmp6);
_sum5 = _mm_unpackhi_epi64(_tmp2, _tmp6);
_sum6 = _mm_unpacklo_epi64(_tmp7, _tmp3);
_sum7 = _mm_unpackhi_epi64(_tmp7, _tmp3);
_sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
_sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3));
_sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3));
}

__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale);
__m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale);
__m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale);
__m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale);
__m128 _f4 = _mm_mul_ps(_mm_cvtepi32_ps(_sum4), _descale);
__m128 _f5 = _mm_mul_ps(_mm_cvtepi32_ps(_sum5), _descale);
__m128 _f6 = _mm_mul_ps(_mm_cvtepi32_ps(_sum6), _descale);
__m128 _f7 = _mm_mul_ps(_mm_cvtepi32_ps(_sum7), _descale);
_f4 = _mm_shuffle_ps(_f4, _f4, _MM_SHUFFLE(2, 1, 0, 3));
_f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3));
_f6 = _mm_shuffle_ps(_f6, _f6, _MM_SHUFFLE(2, 1, 0, 3));
_f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3));
__m128 _tmp0 = _mm_unpacklo_ps(_f0, _f6);
__m128 _tmp1 = _mm_unpackhi_ps(_f0, _f6);
__m128 _tmp2 = _mm_unpacklo_ps(_f1, _f7);
__m128 _tmp3 = _mm_unpackhi_ps(_f1, _f7);
__m128 _tmp4 = _mm_unpacklo_ps(_f2, _f4);
__m128 _tmp5 = _mm_unpackhi_ps(_f2, _f4);
__m128 _tmp6 = _mm_unpacklo_ps(_f3, _f5);
__m128 _tmp7 = _mm_unpackhi_ps(_f3, _f5);
_f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4)));
_f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4)));
_f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1)));
_f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1)));
_f4 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6)));
_f5 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6)));
_f6 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3)));
_f7 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3)));
_f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3));
_f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3));
_f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3));
_f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3));
}

_f0 = _mm_mul_ps(_f0, _descale);
_f1 = _mm_mul_ps(_f1, _descale);
_f2 = _mm_mul_ps(_f2, _descale);
_f3 = _mm_mul_ps(_f3, _descale);
_f4 = _mm_mul_ps(_f4, _descale);
_f5 = _mm_mul_ps(_f5, _descale);
_f6 = _mm_mul_ps(_f6, _descale);
_f7 = _mm_mul_ps(_f7, _descale);

if (pC)
{
Expand Down Expand Up @@ -11143,10 +11143,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
__m128i _sum0 = _mm_load_si128((const __m128i*)pp);
__m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4));
__m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8));
__m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12));
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));
__m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4)));
__m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8)));
__m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12)));

// from
// 00 11 22 33
Expand All @@ -11159,24 +11159,24 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
// 02 12 22 32
// 03 13 23 33
{
_sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
__m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum3);
__m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum3);
__m128i _tmp2 = _mm_unpacklo_epi32(_sum2, _sum1);
__m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum1);
_sum0 = _mm_unpacklo_epi64(_tmp0, _tmp2);
_sum1 = _mm_unpackhi_epi64(_tmp0, _tmp2);
_sum2 = _mm_unpacklo_epi64(_tmp3, _tmp1);
_sum3 = _mm_unpackhi_epi64(_tmp3, _tmp1);
_sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
_sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3));
}

__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale);
__m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale);
__m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale);
__m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale);
_f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3));
_f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3));
__m128 _tmp0 = _mm_unpacklo_ps(_f0, _f3);
__m128 _tmp1 = _mm_unpackhi_ps(_f0, _f3);
__m128 _tmp2 = _mm_unpacklo_ps(_f2, _f1);
__m128 _tmp3 = _mm_unpackhi_ps(_f2, _f1);
_f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2)));
_f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2)));
_f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1)));
_f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1)));
_f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3));
_f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3));
}

_f0 = _mm_mul_ps(_f0, _descale);
_f1 = _mm_mul_ps(_f1, _descale);
_f2 = _mm_mul_ps(_f2, _descale);
_f3 = _mm_mul_ps(_f3, _descale);

if (pC)
{
Expand Down Expand Up @@ -11326,8 +11326,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
}
for (; jj + 1 < max_jj; jj += 2)
{
__m128i _sum0 = _mm_load_si128((const __m128i*)pp);
__m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4));
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));
__m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4)));

// from
// 00 11 20 31
Expand All @@ -11336,15 +11336,15 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
// 00 10 20 30
// 01 11 21 31
{
__m128i _tmp0 = _mm_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0));
__m128i _tmp1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1));
_sum0 = _mm_unpacklo_epi32(_tmp0, _tmp1);
_sum1 = _mm_unpackhi_epi32(_tmp0, _tmp1);
_sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3));
__m128 _tmp0 = _mm_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0));
__m128 _tmp1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1));
_f0 = _mm_unpacklo_ps(_tmp0, _tmp1);
_f1 = _mm_unpackhi_ps(_tmp0, _tmp1);
_f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3));
}

__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale);
__m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale);
_f0 = _mm_mul_ps(_f0, _descale);
_f1 = _mm_mul_ps(_f1, _descale);

if (pC)
{
Expand Down Expand Up @@ -11440,7 +11440,9 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
}
for (; jj < max_jj; jj++)
{
__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)), _descale);
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));

_f0 = _mm_mul_ps(_f0, _descale);

if (pC)
{
Expand Down Expand Up @@ -11581,22 +11583,22 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
__m512i _sum0 = _mm512_loadu_si512((const __m512i*)pp);
__m512i _sum1 = _mm512_loadu_si512((const __m512i*)(pp + 16));
__m512 _f0 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp));
__m512 _f1 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 16)));

// 00 11 02 13 04 15 06 17 08 19 0a 1b 0c 1d 0e 1f
// 01 12 03 10 05 16 07 14 09 1a 0b 18 0d 1e 0f 1c

__m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1);
__m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1);
__m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1);
__m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f1);

_sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp1);
_sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp1);
_f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1)));
_f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1)));

_sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD);
_f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3));

__m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum0), _descale0_avx512);
__m512 _f1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum1), _descale1_avx512);
_f0 = _mm512_mul_ps(_f0, _descale0_avx512);
_f1 = _mm512_mul_ps(_f1, _descale1_avx512);

if (pC)
{
Expand Down Expand Up @@ -11690,35 +11692,35 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#endif // __AVX512F__
for (; jj + 7 < max_jj; jj += 8)
{
__m128i _sum0 = _mm_load_si128((const __m128i*)pp);
__m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4));
__m128i _sum2 = _mm_load_si128((const __m128i*)(pp + 8));
__m128i _sum3 = _mm_load_si128((const __m128i*)(pp + 12));
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));
__m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4)));
__m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8)));
__m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12)));

// 00 11 02 13
// 04 15 06 17
// 10 01 12 03
// 14 05 16 07
_sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 3, 0, 1));
_sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 3, 0, 1));
_f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1));
_f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1));

__m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum2);
__m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum2);
__m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum3);
__m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum3);
__m128 _tmp0 = _mm_unpacklo_ps(_f0, _f2);
__m128 _tmp1 = _mm_unpackhi_ps(_f0, _f2);
__m128 _tmp2 = _mm_unpacklo_ps(_f1, _f3);
__m128 _tmp3 = _mm_unpackhi_ps(_f1, _f3);

_sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum1 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum2 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
_f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1)));
_f1 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3)));
_f2 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1)));
_f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3)));

_sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 3, 0, 1));
_sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 3, 0, 1));
_f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1));
_f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1));

__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale0);
__m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale0);
__m128 _f2 = _mm_mul_ps(_mm_cvtepi32_ps(_sum2), _descale1);
__m128 _f3 = _mm_mul_ps(_mm_cvtepi32_ps(_sum3), _descale1);
_f0 = _mm_mul_ps(_f0, _descale0);
_f1 = _mm_mul_ps(_f1, _descale0);
_f2 = _mm_mul_ps(_f2, _descale1);
_f3 = _mm_mul_ps(_f3, _descale1);

if (pC)
{
Expand Down Expand Up @@ -11845,21 +11847,21 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
__m128i _sum0 = _mm_load_si128((const __m128i*)pp);
__m128i _sum1 = _mm_load_si128((const __m128i*)(pp + 4));
__m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp));
__m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4)));

// 00 11 02 13
// 01 12 03 10
__m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
__m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum1);
__m128 _tmp0 = _mm_unpacklo_ps(_f0, _f1);
__m128 _tmp1 = _mm_unpackhi_ps(_f0, _f1);

_sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum1 = _mm_unpackhi_epi64(_tmp1, _tmp0);
_f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1)));
_f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp1), _mm_castps_pd(_tmp0)));

_sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 3, 2, 1));
_f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 3, 2, 1));

__m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_sum0), _descale0);
__m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_sum1), _descale1);
_f0 = _mm_mul_ps(_f0, _descale0);
_f1 = _mm_mul_ps(_f1, _descale1);

if (pC)
{
Expand Down Expand Up @@ -12479,6 +12481,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,

// NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk);

// actually we only depend the global k==0 condition
(void)i;
(void)j;

const signed char* pAT = AT_tile;
const signed char* pBT = BT_tile;

Expand Down

0 comments on commit f3f1fb5

Please sign in to comment.