From c0b807a04d0557cfd7611763bdb0d3665391b7e8 Mon Sep 17 00:00:00 2001 From: NishantPrabhuFujitsu Date: Thu, 2 Jan 2025 14:07:45 +0530 Subject: [PATCH] sve veclen compilation fixes --- .../cross_compile/cross_compiled_func.cmake | 2 +- src/plugins/intel_cpu/CMakeLists.txt | 10 +- .../src/nodes/kernels/scaled_attn/common.hpp | 10 +- .../kernels/scaled_attn/mha_single_token.cpp | 95 ++++++++++--------- .../kernels/scaled_attn/softmax_kernel.hpp | 27 +++--- 5 files changed, 76 insertions(+), 68 deletions(-) diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index b3e2c60fb5c936..d83450e6d238bd 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -11,7 +11,7 @@ set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$") set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$") set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$") set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$") -set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$") +set(_ACCEPTED_ARCHS_SVE "^(ANY|NEON_FP16|SVE)$") ## Arch specific definitions set(_DEFINE_ANY "") diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index aa6ce49a051e00..c88ec8fa2a1331 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -283,16 +283,16 @@ target_include_directories(${TARGET_NAME} PRIVATE $::value && std::is_same::value) { # if defined(HAVE_SVE) - size_t inc = vec_len_f16_sve; + size_t inc = vec_len_f16_sve(); svbool_t pg = svptrue_b16(); while (i < n) { - if (n - i < vec_len_f16_sve) { + if (n - i < vec_len_f16_sve()) { inc = n - i; pg = svwhilelt_b16(0, static_cast(inc)); } @@ -88,11 +88,11 @@ void cvt_copy(TA* dst, TB* src, size_t n) { # else # if defined(HAVE_SVE) auto _dst = reinterpret_cast(dst); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); svbool_t pg = svptrue_b32(); while (i < n) { - if (n - i < vec_len_f32_sve) { + if (n - i < vec_len_f32_sve()) { inc = n - i; pg = svwhilelt_b32(0, static_cast(inc)); } @@ -138,11 +138,11 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal # if defined(HAVE_SVE) auto _v = reinterpret_cast(v); svfloat32_t attn_w_vec_fp32 = svdup_n_f32(weight); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); svbool_t pg = svptrue_b32(); while (i < S) { - if (S - i < vec_len_f32_sve) { + if (S - i < vec_len_f32_sve()) { inc = S - i; pg = svwhilelt_b32(0, static_cast(inc)); } @@ -180,10 +180,10 @@ static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S, # if defined(HAVE_SVE) svfloat16_t attn_w_vec_fp16 = svdup_n_f16(weight); svbool_t pg = svptrue_b16(); - size_t inc = vec_len_f16_sve; + size_t inc = vec_len_f16_sve(); while (i < S) { - if (S - i < vec_len_f16_sve) { + if (S - i < vec_len_f16_sve()) { inc = S - i; pg = svwhilelt_b16(0, static_cast(inc)); } @@ -440,30 +440,31 @@ static float sum_q_head(T* a, size_t n) { svfloat32_t sum2 = svdup_n_f32(0.0f); svfloat32_t sum3 = svdup_n_f32(0.0f); svbool_t pg = svptrue_b32(); + auto vec_len = vec_len_f32_sve(); - for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + for (; i + 4 * vec_len <= n; i += 4 * vec_len) { svfloat32_t a0 = svld1_f32(pg, a + i); - svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); - svfloat32_t a2 = svld1_f32(pg, a + i + vec_len_f32_sve * 2); - svfloat32_t a3 = svld1_f32(pg, a + i + vec_len_f32_sve * 3); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len); + svfloat32_t a2 = svld1_f32(pg, a + i + vec_len * 2); + svfloat32_t a3 = svld1_f32(pg, a + i + vec_len * 3); sum0 = svadd_f32_z(pg, a0, sum0); sum1 = svadd_f32_z(pg, a1, sum1); sum2 = svadd_f32_z(pg, a2, sum2); sum3 = svadd_f32_z(pg, a3, sum3); } - if (i + 2 * vec_len_f32_sve <= n) { + if (i + 2 * vec_len <= n) { svfloat32_t a0 = svld1_f32(pg, a + i); - svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len); sum0 = svadd_f32_z(pg, a0, sum0); sum1 = svadd_f32_z(pg, a1, sum1); - i += 2 * vec_len_f32_sve; + i += 2 * vec_len; } - if (i + vec_len_f32_sve <= n) { + if (i + vec_len <= n) { svfloat32_t a0 = svld1_f32(pg, a + i); sum0 = svadd_f32_z(pg, a0, sum0); - i += vec_len_f32_sve; + i += vec_len; } // Process tail elements parallely as well (if any) if (i != n) { @@ -623,42 +624,43 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* svfloat32_t sum1 = svdup_n_f32(0.0f); svfloat32_t sum2 = svdup_n_f32(0.0f); svfloat32_t sum3 = svdup_n_f32(0.0f); + auto vec_len = vec_len_f32_sve(); auto _a = reinterpret_cast(a); auto _b = reinterpret_cast(b); - for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + for (; i + 4 * vec_len <= n; i += 4 * vec_len) { svfloat32_t a0 = svld1_f32(pg, _a + i); - svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); - svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len_f32_sve * 2); - svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len_f32_sve * 3); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); + svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len * 2); + svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len * 3); svfloat32_t b0 = svld1_f32(pg, _b + i); - svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); - svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len_f32_sve * 2); - svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len_f32_sve * 3); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); + svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len * 2); + svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len * 3); sum0 = svmla_f32_z(pg, sum0, a0, b0); sum1 = svmla_f32_z(pg, sum1, a1, b1); sum2 = svmla_f32_z(pg, sum2, a2, b2); sum3 = svmla_f32_z(pg, sum3, a3, b3); } - if (i + 2 * vec_len_f32_sve <= n) { + if (i + 2 * vec_len <= n) { svfloat32_t a0 = svld1_f32(pg, _a + i); - svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len); svfloat32_t b0 = svld1_f32(pg, _b + i); - svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len); sum0 = svmla_f32_z(pg, sum0, a0, b0); sum1 = svmla_f32_z(pg, sum1, a1, b1); - i += 2 * vec_len_f32_sve; + i += 2 * vec_len; } - if (i + vec_len_f32_sve <= n) { + if (i + vec_len <= n) { svfloat32_t a0 = svld1_f32(pg, _a + i); svfloat32_t b0 = svld1_f32(pg, _b + i); sum0 = svmla_f32_z(pg, sum0, a0, b0); - i += vec_len_f32_sve; + i += vec_len; } // Process the tail elements parallely as well (if any) if (i != n) { @@ -746,39 +748,40 @@ static ov::float16 dot_product_fp16(ov::float16* a, svfloat16_t sum1 = svdup_n_f16(0.0f); svfloat16_t sum2 = svdup_n_f16(0.0f); svfloat16_t sum3 = svdup_n_f16(0.0f); + auto vec_len = vec_len_f16_sve(); - for (; i + 4 * vec_len_f16_sve <= n; i += 4 * vec_len_f16_sve) { + for (; i + 4 * vec_len <= n; i += 4 * vec_len) { svfloat16_t a0 = svld1_f16(pg, _a + i); - svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve); - svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len_f16_sve * 2); - svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len_f16_sve * 3); + svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len); + svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len * 2); + svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len * 3); svfloat16_t b0 = svld1_f16(pg, _b + i); - svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve); - svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len_f16_sve * 2); - svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len_f16_sve * 3); + svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len); + svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len * 2); + svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len * 3); sum0 = svmla_f16_z(pg, sum0, a0, b0); sum1 = svmla_f16_z(pg, sum1, a1, b1); sum2 = svmla_f16_z(pg, sum2, a2, b2); sum3 = svmla_f16_z(pg, sum3, a3, b3); } - if (i + 2 * vec_len_f16_sve <= n) { + if (i + 2 * vec_len <= n) { svfloat16_t a0 = svld1_f16(pg, _a + i); - svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve); + svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len); svfloat16_t b0 = svld1_f16(pg, _b + i); - svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve); + svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len); sum0 = svmla_f16_z(pg, sum0, a0, b0); sum1 = svmla_f16_z(pg, sum1, a1, b1); - i += 2 * vec_len_f16_sve; + i += 2 * vec_len; } - if (i + vec_len_f16_sve <= n) { + if (i + vec_len <= n) { svfloat16_t a0 = svld1_f16(pg, _a + i); svfloat16_t b0 = svld1_f16(pg, _b + i); sum0 = svmla_f16_z(pg, sum0, a0, b0); - i += vec_len_f16_sve; + i += vec_len; } // Process the tail elements parallely as well (if any) if (i != n) { @@ -1029,11 +1032,11 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str #elif defined(OPENVINO_ARCH_ARM64) # if defined(HAVE_SVE) auto _dst = reinterpret_cast(dst); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); svbool_t pg = svptrue_b32(); while (i < S) { - if (S - i < vec_len_f32_sve) { + if (S - i < vec_len_f32_sve()) { inc = S - i; pg = svwhilelt_b32(0, static_cast(inc)); } @@ -1079,7 +1082,7 @@ static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S, # if defined(HAVE_SVE) svbool_t pg = svptrue_b16(); - for (; i + vec_len_f16_sve <= S; i += vec_len_f16_sve) { + for (; i + vec_len_f16_sve() <= S; i += vec_len_f16_sve()) { auto* src = temp + i; auto result_vec_fp16 = svdup_n_f16(0.0f); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index b31cdc9ceb3b43..d8eaa7889b9502 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -481,10 +481,10 @@ inline void scale_add2_reduce_max(ov::float16* a, svbool_t pg_f16 = svptrue_b16(); svbool_t pg_u8 = svptrue_b8(); svbool_t pg_u16 = svptrue_b16(); - size_t inc = vec_len_f16_sve; + size_t inc = vec_len_f16_sve(); while (i < size) { - if (size - i < vec_len_f16_sve) { + if (size - i < vec_len_f16_sve()) { inc = size - i; pg_f16 = svwhilelt_b16(0, static_cast(inc)); pg_u8 = svwhilelt_b8(0, static_cast(inc)); @@ -717,12 +717,11 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& svfloat32_t v_a; svfloat32_t v_max = svdup_n_f32(max); svfloat32_t v_sum = svdup_n_f32(0.0f); - size_t vec_len_f32_sve = svcntw(); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); svbool_t pg = svptrue_b32(); while (i < size) { - if (size - i < vec_len_f32_sve) { + if (size - i < vec_len_f32_sve()) { inc = size - i; pg = svwhilelt_b32(0, static_cast(inc)); } @@ -767,12 +766,12 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size svbool_t pg_f32 = svptrue_b32(); svbool_t pg_f16 = svptrue_b16(); svfloat16_t zero = svdup_n_f16(0.0); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); while (i < size) { - if (size - i < vec_len_f16_sve) + if (size - i < vec_len_f16_sve()) pg_f16 = svwhilelt_b16(0, static_cast(size - i)); - if (size - i < vec_len_f32_sve) { + if (size - i < vec_len_f32_sve()) { pg_f32 = svwhilelt_b32(0, static_cast(size - i)); inc = size - i; } @@ -838,10 +837,10 @@ inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t s svfloat16_t v_max = svdup_n_f16(max); svfloat16_t v_sum = svdup_n_f16(0.0f); svbool_t pg = svptrue_b16(); - size_t inc = vec_len_f16_sve; + size_t inc = vec_len_f16_sve(); while (i < size) { - if (size - i < vec_len_f16_sve) { + if (size - i < vec_len_f16_sve()) { inc = size - i; pg = svwhilelt_b16(0, static_cast(inc)); } @@ -919,11 +918,11 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ #elif defined(OPENVINO_ARCH_ARM64) # if defined(HAVE_SVE) svfloat32_t v_scale = svdup_n_f32(val); - size_t inc = vec_len_f32_sve; + size_t inc = vec_len_f32_sve(); svbool_t pg = svptrue_b32(); while (i < size) { - if (size - i < vec_len_f32_sve) { + if (size - i < vec_len_f32_sve()) { inc = size - i; pg = svwhilelt_b32(0, static_cast(inc)); } @@ -1020,11 +1019,11 @@ inline void multiply_scalar(ov::float16* a, ov::float16* a_dst, const ov::float1 size_t i = 0; # if defined(HAVE_SVE) svfloat16_t v_scale = svdup_n_f16(val); - size_t inc = vec_len_f16_sve; + size_t inc = vec_len_f16_sve(); svbool_t pg = svptrue_b16(); while (i < size) { - if (size - i < vec_len_f16_sve) { + if (size - i < vec_len_f16_sve()) { inc = size - i; pg = svwhilelt_b16(0, static_cast(inc)); }