diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 59ed1a5f372..193696f8e04 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -2485,26 +2485,29 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float scale0 = scales[i + ii]; const float scale1 = scales[i + ii + 1]; +#if __SSE2__ + __m128 _scales0 = _mm_set1_ps(scale0); + __m128 _scales1 = _mm_set1_ps(scale1); + __m128 _scales0011 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_scales0), _mm_castps_pd(_scales1))); +#endif // __SSE2__ // if (elempack == 1) { int kk = 0; #if __SSE2__ -#if __AVX512VNNI__ || __AVXVNNI__ -#if !__AVXVNNIINT8__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) int w_shift0 = 0; int w_shift1 = 0; -#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[A_hstep] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep + 2] * scale1); - pp[7] = float2int8(p0[A_hstep + 3] * scale1); + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + A_hstep); + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; #if !__AVXVNNIINT8__ w_shift0 += pp[0]; w_shift0 += pp[1]; @@ -2515,24 +2518,31 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift1 += pp[6]; w_shift1 += pp[7]; #endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ pp += 8; p0 += 4; } -#if !__AVXVNNIINT8__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) if (max_kk >= 4) { ((int*)pp)[0] = w_shift0 * 127; ((int*)pp)[1] = w_shift1 * 127; pp += 8; } -#endif // !__AVXVNNIINT8__ -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[A_hstep] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + _p = _mm_mul_ps(_p, _scales0011); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; pp += 4; p0 += 2; } @@ -2551,37 +2561,40 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float* p0 = (const float*)A + (i + ii) * A_hstep + k; const float scale = scales[i + ii]; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ // if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ -#if !__AVXVNNIINT8__ +#if __SSE2__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) int w_shift = 0; -#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); -#if !__AVXVNNIINT8__ + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) w_shift += pp[0]; w_shift += pp[1]; w_shift += pp[2]; w_shift += pp[3]; -#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) pp += 4; p0 += 4; } -#if !__AVXVNNIINT8__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) if (max_kk >= 4) { ((int*)pp)[0] = w_shift * 127; pp += 4; } -#endif // !__AVXVNNIINT8__ -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) +#endif // __SSE2__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -4703,21 +4716,28 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __SSE2__ -#if __AVX512VNNI__ || __AVXVNNI__ + __m128 _scales0 = _mm_set1_ps(scale0); + __m128 _scales1 = _mm_set1_ps(scale1); + __m128 _scales0011 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_scales0), _mm_castps_pd(_scales1))); #if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) int w_shift0 = 0; int w_shift1 = 0; #endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep + 0] * scale0); - pp[2] = float2int8(p0[A_hstep * 2 + 0] * scale0); - pp[3] = float2int8(p0[A_hstep * 3 + 0] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); - pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p2 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 2))); + __m128 _p3 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 3))); + __m128 _p01 = _mm_unpacklo_ps(_p0, _p1); + __m128 _p23 = _mm_unpacklo_ps(_p2, _p3); + _p0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; #if !__AVXVNNIINT8__ w_shift0 += pp[0]; w_shift0 += pp[1]; @@ -4728,6 +4748,12 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int w_shift1 += pp[6]; w_shift1 += pp[7]; #endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ pp += 8; p0 += A_hstep * 4; } @@ -4739,13 +4765,14 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 8; } #endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) -#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep + 0] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p = _mm_unpacklo_ps(_p0, _p1); + _p = _mm_mul_ps(_p, _scales0011); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; pp += 4; p0 += A_hstep * 2; } @@ -4873,22 +4900,30 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scales[i + ii]); #if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) int w_shift = 0; #endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[A_hstep] * scale); - pp[2] = float2int8(p0[A_hstep * 2] * scale); - pp[3] = float2int8(p0[A_hstep * 3] * scale); -#if !__AVXVNNIINT8__ +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) w_shift += pp[0]; w_shift += pp[1]; w_shift += pp[2]; w_shift += pp[3]; -#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) pp += 4; p0 += A_hstep * 4; } @@ -4899,7 +4934,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 4; } #endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __SSE2__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -5953,42 +5988,50 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ + // if (elempack == 1) { int kk = 0; #if __SSE2__ -#if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { -#if __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[B_hstep] * scale); - pp[5] = float2int8(p0[B_hstep + 1] * scale); - pp[6] = float2int8(p0[B_hstep + 2] * scale); - pp[7] = float2int8(p0[B_hstep + 3] * scale); -#else // __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[B_hstep] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; - pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; -#endif // __AVXVNNIINT8__ + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + B_hstep); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ pp += 8; p0 += 4; } -#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[B_hstep] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; pp += 4; p0 += 2; } @@ -6006,27 +6049,30 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ + // if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ +#if __SSE2__ for (; kk + 3 < max_kk; kk += 4) { -#if __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); -#else // __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; -#endif // __AVXVNNIINT8__ + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) pp += 4; p0 += 4; } -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __SSE2__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -7272,38 +7318,49 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int { int kk = 0; #if __SSE2__ -#if __AVX512VNNI__ || __AVXVNNI__ + __m128 _scale = _mm_set1_ps(scale); for (; kk + 3 < max_kk; kk += 4) { -#if __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep + 0] * scale); - pp[2] = float2int8(p0[B_hstep * 2 + 0] * scale); - pp[3] = float2int8(p0[B_hstep * 3 + 0] * scale); - pp[4] = float2int8(p0[1] * scale); - pp[5] = float2int8(p0[B_hstep + 1] * scale); - pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale); - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale); -#else // __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[B_hstep + 0] * scale) + 127; - pp[2] = float2int8(p0[B_hstep * 2 + 0] * scale) + 127; - pp[3] = float2int8(p0[B_hstep * 3 + 0] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; -#endif // __AVXVNNIINT8__ + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p2 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 2))); + __m128 _p3 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 3))); + __m128 _p01 = _mm_unpacklo_ps(_p0, _p1); + __m128 _p23 = _mm_unpacklo_ps(_p2, _p3); + _p0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ pp += 8; p0 += B_hstep * 4; } -#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep + 0] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p = _mm_unpacklo_ps(_p0, _p1); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; pp += 4; p0 += B_hstep * 2; } @@ -7403,24 +7460,31 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); for (; kk + 3 < max_kk; kk += 4) { -#if __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep] * scale); - pp[2] = float2int8(p0[B_hstep * 2] * scale); - pp[3] = float2int8(p0[B_hstep * 3] * scale); -#else // __AVXVNNIINT8__ - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[B_hstep] * scale) + 127; - pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; - pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; -#endif // __AVXVNNIINT8__ +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) pp += 4; p0 += B_hstep * 4; } -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __SSE2__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale);