Skip to content

Commit

Permalink
avxvnniint8 without wshift
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 16, 2024
1 parent 2c08968 commit 9d759d9
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/layer/x86/gemm_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7640,10 +7640,18 @@ static int gemm_x86_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob

Mat ATX;
#if NCNN_AVX512VNNI || NCNN_AVXVNNI
if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni()))
bool has_w_shift = false;
if (TILE_K >= 4)
{
has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni();
#if NCNN_AVXVNNIINT8
if (ncnn::cpu_support_x86_avx_vnni_int8())
has_w_shift = false;
#endif // NCNN_AVXVNNIINT8
}
if (has_w_shift)
{
int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1;
// NCNN_LOGE("w_shift_count = %d", w_shift_count);
ATX.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator);
}
else
Expand Down Expand Up @@ -7905,7 +7913,16 @@ static int gemm_BT_x86_int8(const Mat& A, const Mat& BT, float B_int8_scale, con

Mat ATX;
#if NCNN_AVX512VNNI || NCNN_AVXVNNI
if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni()))
bool has_w_shift = false;
if (TILE_K >= 4)
{
has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni();
#if NCNN_AVXVNNIINT8
if (ncnn::cpu_support_x86_avx_vnni_int8())
has_w_shift = false;
#endif // NCNN_AVXVNNIINT8
}
if (has_w_shift)
{
int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1;
ATX.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator);
Expand Down Expand Up @@ -8077,7 +8094,16 @@ int Gemm_x86::create_pipeline_int8(const Option& opt)
const int nn_M = (M + TILE_M - 1) / TILE_M;

#if NCNN_AVX512VNNI || NCNN_AVXVNNI
if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni()))
bool has_w_shift = false;
if (TILE_K >= 4)
{
has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni();
#if NCNN_AVXVNNIINT8
if (ncnn::cpu_support_x86_avx_vnni_int8())
has_w_shift = false;
#endif // NCNN_AVXVNNIINT8
}
if (has_w_shift)
{
int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1;
AT_data.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0);
Expand Down

0 comments on commit 9d759d9

Please sign in to comment.