diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index 962aa5d373a4db..b3e2c60fb5c936 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -186,10 +186,10 @@ endfunction() # Return currently requested ARCH id # function(_currently_requested_top_arch VAR) - if(ENABLE_NEON_FP16) - set(RES NEON_FP16) - elseif(ENABLE_SVE) + if(ENABLE_SVE) set(RES SVE) + elseif(ENABLE_NEON_FP16) + set(RES NEON_FP16) elseif(ENABLE_AVX512F) set(RES AVX512F) elseif(ENABLE_AVX2) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 63cbbb4464ee92..8d599283e3ce84 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -39,7 +39,8 @@ static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float); static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); #if defined(HAVE_SVE) -static constexpr size_t vec_len_f32_sve = svcntw(); +static size_t vec_len_f32_sve = svcntw(); +static size_t vec_len_f16_sve = svcnth(); #endif #ifdef HAVE_AVX512F @@ -403,6 +404,28 @@ inline void __vst1q_f32(ov::bfloat16* a, float32x4_t b) { #endif #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +# if defined(HAVE_SVE) +inline svfloat16_t exp_ps_sve_f16(svbool_t& pg, svfloat16_t& src) { + svbool_t pg_f32 = svtrn1_b16(pg, svpfalse()); + + // Extract lower and upper halves of src into two separate vecs and convert + svfloat16_t zero = svdup_n_f16(0.0); + svfloat16_t low_f16 = svtrn1_f16(src, zero); + svfloat16_t high_f16 = svtrn2_f16(src, zero); + svfloat32_t low_f32 = svcvt_f32_f16_z(pg, low_f16); + svfloat32_t high_f32 = svcvt_f32_f16_z(pg, high_f16); + + // Perform exp and convert back to f16 + svfloat32_t low_exp_f32 = exp_ps_sve(pg_f32, low_f32); + svfloat32_t high_exp_f32 = exp_ps_sve(pg_f32, high_f32); + svfloat16_t low_exp_f16 = svcvt_f16_f32_z(pg_f32, low_exp_f32); + svfloat16_t high_exp_f16 = svcvt_f16_f32_z(pg_f32, high_exp_f32); + + // Interleave both to get final result + svfloat16_t res = svtrn1_f16(low_exp_f16, high_exp_f16); + return res; +} +# else inline float16x8_t exp_ps_neon_f16(float16x8_t x) { const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x)); const float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x)); @@ -411,6 +434,7 @@ inline float16x8_t exp_ps_neon_f16(float16x8_t x) { const float16x8_t res = vcombine_f16(vcvt_f16_f32(exp_ps_neon_f32(x_low)), vcvt_f16_f32(exp_ps_neon_f32(x_high))); return res; } +# endif inline float16_t hsum(float16x8_t vec) { float16x4_t sum1 = vpadd_f16(vget_low_f16(vec), vget_high_f16(vec)); float16x4_t sum2 = vpadd_f16(sum1, sum1); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 5a6f0d66f1f221..3a4ab829b773c0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -63,6 +63,28 @@ void cvt_copy(TA* dst, TB* src, size_t n) { mm256_uni_storeu_ps(dst + i, vb); } #elif defined(OPENVINO_ARCH_ARM64) + if (std::is_same::value && std::is_same::value) { +# if defined(HAVE_SVE) + size_t inc = vec_len_f16_sve; + svbool_t pg = svptrue_b16(); + + while (i < n) { + if (n - i < vec_len_f16_sve) { + inc = n - i; + pg = svwhilelt_b16(0, static_cast(inc)); + } + svfloat16_t b1 = svld1_f16(pg, reinterpret_cast(src + i)); + svst1_f16(pg, reinterpret_cast(dst + i), b1); + i += inc; + } +# else + for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) { + auto vb1 = vld1q_f16(reinterpret_cast(src + i)); + vst1q_f16(reinterpret_cast(dst + i), vb1); + } +# endif + } +#else # if defined(HAVE_SVE) auto _dst = reinterpret_cast(dst); size_t inc = vec_len_f32_sve; @@ -158,15 +180,35 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal template static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S, float* scale, float* zp) { size_t i = 0; - auto attn_w_vec_fp16 = vdupq_n_f16(weight); auto _v = reinterpret_cast(v); auto _out = reinterpret_cast(out); + +# if defined(HAVE_SVE) + svfloat16_t attn_w_vec_fp16 = svdup_n_f16(weight); + svbool_t pg = svptrue_b16(); + size_t inc = vec_len_f16_sve; + + while (i < S) { + if (S - i < vec_len_f16_sve) { + inc = S - i; + pg = svwhilelt_b16(0, static_cast(inc)); + } + svfloat16_t v_value = svld1_f16(pg, _v + i); + svfloat16_t v_out = svld1_f16(pg, _out + i); + + v_out = svmla_f16_m(pg, v_out, attn_w_vec_fp16, v_value); + svst1_f16(pg, _out + i, v_out); + i += inc; + } +# else + auto attn_w_vec_fp16 = vdupq_n_f16(weight); for (; i + vec_len_f16_neon <= S; i += vec_len_f16_neon) { auto v_value = vld1q_f16(_v + i); auto v_out = vld1q_f16(_out + i); v_out = vfmaq_f16(v_out, attn_w_vec_fp16, v_value); vst1q_f16(_out + i, v_out); } +# endif for (; i < S; i++) { out[i] += weight * v[i]; } @@ -701,12 +743,67 @@ static ov::float16 dot_product_fp16(ov::float16* a, float* head_sum) { size_t i = 0; ov::float16 sum = 0.0f; + auto _a = reinterpret_cast(a); + auto _b = reinterpret_cast(b); + +# if defined(HAVE_SVE) + svbool_t pg = svptrue_b16(); + svfloat16_t sum0 = svdup_n_f16(0.0f); + svfloat16_t sum1 = svdup_n_f16(0.0f); + svfloat16_t sum2 = svdup_n_f16(0.0f); + svfloat16_t sum3 = svdup_n_f16(0.0f); + + for (; i + 4 * vec_len_f16_sve <= n; i += 4 * vec_len_f16_sve) { + svfloat16_t a0 = svld1_f16(pg, _a + i); + svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve); + svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len_f16_sve * 2); + svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len_f16_sve * 3); + + svfloat16_t b0 = svld1_f16(pg, _b + i); + svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve); + svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len_f16_sve * 2); + svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len_f16_sve * 3); + + sum0 = svmla_f16_z(pg, sum0, a0, b0); + sum1 = svmla_f16_z(pg, sum1, a1, b1); + sum2 = svmla_f16_z(pg, sum2, a2, b2); + sum3 = svmla_f16_z(pg, sum3, a3, b3); + } + if (i + 2 * vec_len_f16_sve <= n) { + svfloat16_t a0 = svld1_f16(pg, _a + i); + svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve); + + svfloat16_t b0 = svld1_f16(pg, _b + i); + svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve); + + sum0 = svmla_f16_z(pg, sum0, a0, b0); + sum1 = svmla_f16_z(pg, sum1, a1, b1); + i += 2 * vec_len_f16_sve; + } + if (i + vec_len_f16_sve <= n) { + svfloat16_t a0 = svld1_f16(pg, _a + i); + svfloat16_t b0 = svld1_f16(pg, _b + i); + sum0 = svmla_f16_z(pg, sum0, a0, b0); + i += vec_len_f16_sve; + } + // Process the tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b16(0, static_cast(n - i)); + svfloat16_t a0 = svld1_f16(pg_rem, _a + i); + svfloat16_t b0 = svld1_f16(pg_rem, _b + i); + sum0 = svmla_f16_m(pg_rem, sum0, a0, b0); + i = n; + } + float16_t sum_0 = svaddv_f16(pg, sum0); + float16_t sum_1 = svaddv_f16(pg, sum1); + float16_t sum_2 = svaddv_f16(pg, sum2); + float16_t sum_3 = svaddv_f16(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +# else auto vsum0 = vdupq_n_f16(0.0f); auto vsum1 = vdupq_n_f16(0.0f); auto vsum2 = vdupq_n_f16(0.0f); auto vsum3 = vdupq_n_f16(0.0f); - auto _a = reinterpret_cast(a); - auto _b = reinterpret_cast(b); for (; i + 4 * vec_len_f16_neon <= n; i += vec_len_f16_neon * 4) { auto va0 = vld1q_f16(_a + i); @@ -747,7 +844,7 @@ static ov::float16 dot_product_fp16(ov::float16* a, vsum0 = vaddq_f16(vsum0, vsum2); sum = hsum(vsum0); - +# endif for (; i < n; i++) { sum += a[i] * b[i]; } @@ -985,6 +1082,21 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S, size_t temp_stride) { size_t i = 0; +# if defined(HAVE_SVE) + svbool_t pg = svptrue_b16(); + + for (; i + vec_len_f16_sve <= S; i += vec_len_f16_sve) { + auto* src = temp + i; + auto result_vec_fp16 = svdup_n_f16(0.0f); + + for (size_t m = 0; m < M; m++) { + auto o_vec_fp16 = svld1_f16(pg, reinterpret_cast(src)); + result_vec_fp16 = svadd_f16_m(pg, result_vec_fp16, o_vec_fp16); + src += temp_stride; + } + svst1_f16(pg, reinterpret_cast(dst + i), result_vec_fp16); + } +# else for (; i + vec_len_f16_neon <= S; i += vec_len_f16_neon) { auto* src = temp + i; auto result_vec_fp16 = vdupq_n_f16(0.0f); @@ -995,6 +1107,7 @@ static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S, } vst1q_f16(reinterpret_cast(dst + i), result_vec_fp16); } +# endif for (; i < S; i++) { auto* src = temp + i; float sum = 0.0f; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 35aab5b59c7d0e..b31cdc9ceb3b43 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "common.hpp" @@ -464,10 +465,61 @@ inline void scale_add2_reduce_max(ov::float16* a, size_t size, float alibi_slope, ov::float16& max) { + size_t i = 0; +# if defined(HAVE_SVE) + svfloat16_t v_max = svdup_n_f16(static_cast(-FLT_MAX)); + svfloat16_t v_scale = svdup_n_f16(static_cast(scale)); + svfloat16_t v_a; + svuint16_t v_zeroi16 = svdup_n_u16(0); + svfloat16_t v_nfltmax = svdup_n_f16(static_cast(-FLT_MAX)); + svfloat16_t v_alibi_slope = svdup_n_f16(static_cast(alibi_slope)); + + svbool_t mask_xor = svptrue_b16(); + if (!select_nfltmax_at_0) + mask_xor = svnot_z(svptrue_b16(), mask_xor); + + svbool_t pg_f16 = svptrue_b16(); + svbool_t pg_u8 = svptrue_b8(); + svbool_t pg_u16 = svptrue_b16(); + size_t inc = vec_len_f16_sve; + + while (i < size) { + if (size - i < vec_len_f16_sve) { + inc = size - i; + pg_f16 = svwhilelt_b16(0, static_cast(inc)); + pg_u8 = svwhilelt_b8(0, static_cast(inc)); + pg_u16 = svwhilelt_b16(0, static_cast(inc)); + } + v_a = svld1_f16(pg_f16, reinterpret_cast(a + i)); + v_a = svmul_f16_z(pg_f16, v_a, v_scale); + + if (has_alibi) { + svfloat16_t v_lookup = svld1_f16(pg_f16, reinterpret_cast(alibi_lookup + i)); + v_a = svmla_f16_z(pg_f16, v_a, v_lookup, v_alibi_slope); + } + + if (has_attn_mask) { + svfloat16_t v_mask = svld1_f16(pg_f16, reinterpret_cast(attn_mask + i)); + v_a = svadd_f16_z(pg_f16, v_a, v_mask); + } + + if (has_causal_mask) { + svuint8_t v_maski8 = svld1_u8(pg_u8, causal_mask + i); + svuint16_t v_maski16 = svtrn1_u16(svreinterpret_u16_u8(v_maski8), svdup_n_u16(0)); + svbool_t kmask = svcmpeq_u16(pg_u16, v_maski16, v_zeroi16); + kmask = sveor_z(pg_u16, kmask, mask_xor); + v_a = svsel_f16(kmask, v_nfltmax, v_a); + } + + v_max = svmax_f16_z(pg_f16, v_max, v_a); + svst1_f16(pg_f16, reinterpret_cast(a + i), v_a); + i += inc; + } + max = svmaxv_f16(pg_f16, v_max); +# else float16x8_t v_max = vdupq_n_f16(static_cast(-FLT_MAX)); float16x8_t v_scale = vdupq_n_f16(static_cast(scale)); float16x8_t v_a; - size_t i = 0; uint16x8_t v_zeroi16 = vdupq_n_u16(0); float16x8_t v_nfltmax = vdupq_n_f16(static_cast(-FLT_MAX)); uint16x8_t mask_xor = vdupq_n_u16(select_nfltmax_at_0 ? 0xFFFF : 0); @@ -500,6 +552,7 @@ inline void scale_add2_reduce_max(ov::float16* a, vst1q_f16(reinterpret_cast(a + i), v_a); } max = vmaxvq_f16(v_max); +# endif // process tails for (; i < size; i++) { a[i] *= scale; @@ -705,10 +758,47 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& #if defined(OPENVINO_ARCH_ARM64) inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size_t size, ov::float16& sum) { + size_t i = 0; +# if defined(HAVE_SVE) + svfloat32_t v_a; + svfloat32_t v_max = svdup_n_f32(static_cast(max)); + svfloat32_t v_sum = svdup_n_f32(0.0f); + + svbool_t pg_f32 = svptrue_b32(); + svbool_t pg_f16 = svptrue_b16(); + svfloat16_t zero = svdup_n_f16(0.0); + size_t inc = vec_len_f32_sve; + + while (i < size) { + if (size - i < vec_len_f16_sve) + pg_f16 = svwhilelt_b16(0, static_cast(size - i)); + if (size - i < vec_len_f32_sve) { + pg_f32 = svwhilelt_b32(0, static_cast(size - i)); + inc = size - i; + } + // Load 16 elements and interleave with zeros so we have 8 elements with 0 in high parts + svfloat16_t v_a_f16 = svld1_f16(pg_f16, reinterpret_cast(a + i)); + v_a_f16 = svzip1_f16(v_a_f16, zero); + + // Convert to f32 and perform required operations + v_a = svcvt_f32_f16_z(pg_f16, v_a_f16); + v_a = svsub_f32_z(pg_f32, v_a, v_max); + v_a = exp_ps_sve(pg_f32, v_a); + v_sum = svadd_f32_z(pg_f32, v_sum, v_a); + + // Convert to f16 and compact non-zero elements (even indices) to the low part + // so that we can store them in the result using svwhilelt + svfloat16_t v_result = svcvt_f16_f32_z(pg_f32, v_a); + v_result = svtbl_f16(v_result, svindex_u16(0, 2)); + + svst1_f16(svwhilelt_b16(0, static_cast(inc)), reinterpret_cast(a + i), v_result); + i += inc; + } + float total_sum = svaddv_f32(svptrue_b32(), v_sum); +# else float32x4_t v_a; float32x4_t v_max = vdupq_n_f32(static_cast(max)); float32x4_t v_sum = vdupq_n_f32(0.0f); - size_t i = 0; // Process 4 FP32 elements at a time for (; i + vec_len_f32_neon <= size; i += vec_len_f32_neon) { @@ -728,7 +818,7 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size // Reduce sum float total_sum = vaddvq_f32(v_sum); - +# endif // Handle remaining elements for (; i < size; ++i) { float val = exp(static_cast(a[i] - max)); @@ -742,11 +832,32 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t size, ov::float16& sum) { + size_t i = 0; +# if defined(HAVE_SVE) + svfloat16_t v_a; + svfloat16_t v_max = svdup_n_f16(max); + svfloat16_t v_sum = svdup_n_f16(0.0f); + svbool_t pg = svptrue_b16(); + size_t inc = vec_len_f16_sve; + + while (i < size) { + if (size - i < vec_len_f16_sve) { + inc = size - i; + pg = svwhilelt_b16(0, static_cast(inc)); + } + v_a = svld1_f16(pg, reinterpret_cast(a + i)); + v_a = svsub_f16_z(pg, v_a, v_max); + v_a = exp_ps_sve_f16(pg, v_a); + v_sum = svadd_f16_m(pg, v_sum, v_a); + svst1_f16(pg, reinterpret_cast(a + i), v_a); + i += inc; + } + sum = svaddv_f16(svptrue_b16(), v_sum); +# else const size_t vec_len_f16_neon = 8; float16x8_t v_a; float16x8_t v_max = vdupq_n_f16(max); float16x8_t v_sum = vdupq_n_f16(0.0f); - size_t i = 0; for (; i + vec_len_f16_neon <= size; i += vec_len_f16_neon) { v_a = vld1q_f16(reinterpret_cast(a + i)); @@ -761,7 +872,7 @@ inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t s float16x4_t v_sum_final = vpadd_f16(v_sum_pair, v_sum_pair); sum += vget_lane_f16(v_sum_final, 0); - +# endif for (; i < size; ++i) { a[i] = static_cast(exp(static_cast(a[i] - max))); sum += a[i]; @@ -864,6 +975,7 @@ inline void multiply_scalar(float* a, T* a_dst, const float val, const size_t si } #endif } + #if defined(OPENVINO_ARCH_ARM64) inline void multiply_scalar(ov::float16* a, float* a_dst, const ov::float16 val, const size_t size) { float16x4_t v_a_f16; @@ -902,21 +1014,40 @@ inline void multiply_scalar_f32(ov::float16* a, ov::float16* a_dst, const ov::fl } } #endif + #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) inline void multiply_scalar(ov::float16* a, ov::float16* a_dst, const ov::float16 val, const size_t size) { + size_t i = 0; +# if defined(HAVE_SVE) + svfloat16_t v_scale = svdup_n_f16(val); + size_t inc = vec_len_f16_sve; + svbool_t pg = svptrue_b16(); + + while (i < size) { + if (size - i < vec_len_f16_sve) { + inc = size - i; + pg = svwhilelt_b16(0, static_cast(inc)); + } + svfloat16_t v_a = svld1_f16(pg, reinterpret_cast(a + i)); + v_a = svmul_f16_z(pg, v_a, v_scale); + svst1_f16(pg, reinterpret_cast(a_dst + i), v_a); + i += inc; + } +# else float16x8_t v_a, v_res; float16x8_t v_val = vdupq_n_f16(val); - size_t i = 0; for (; i + vec_len_f16_neon <= size; i += vec_len_f16_neon) { v_a = vld1q_f16(reinterpret_cast(a + i)); v_res = vmulq_f16(v_a, v_val); vst1q_f16(reinterpret_cast(a_dst + i), v_res); } +# endif for (; i < size; ++i) { a_dst[i] = a[i] * val; } } #endif + template inline void attn_softmax_kernel(T* a, void* a_dst,