Skip to content

Commit

Permalink
adds f16 sve functions for MHASingleToken
Browse files Browse the repository at this point in the history
  • Loading branch information
NishantPrabhuFujitsu committed Dec 23, 2024
1 parent 9d4c1ab commit b758384
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ endfunction()
# Return currently requested ARCH id
#
function(_currently_requested_top_arch VAR)
if(ENABLE_NEON_FP16)
set(RES NEON_FP16)
elseif(ENABLE_SVE)
if(ENABLE_SVE)
set(RES SVE)
elseif(ENABLE_NEON_FP16)
set(RES NEON_FP16)
elseif(ENABLE_AVX512F)
set(RES AVX512F)
elseif(ENABLE_AVX2)
Expand Down
26 changes: 25 additions & 1 deletion src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float);
static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);

#if defined(HAVE_SVE)
static constexpr size_t vec_len_f32_sve = svcntw();
static size_t vec_len_f32_sve = svcntw();
static size_t vec_len_f16_sve = svcnth();
#endif

#ifdef HAVE_AVX512F
Expand Down Expand Up @@ -403,6 +404,28 @@ inline void __vst1q_f32(ov::bfloat16* a, float32x4_t b) {
#endif

#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
# if defined(HAVE_SVE)
inline svfloat16_t exp_ps_sve_f16(svbool_t& pg, svfloat16_t& src) {
svbool_t pg_f32 = svtrn1_b16(pg, svpfalse());

// Extract lower and upper halves of src into two separate vecs and convert
svfloat16_t zero = svdup_n_f16(0.0);
svfloat16_t low_f16 = svtrn1_f16(src, zero);
svfloat16_t high_f16 = svtrn2_f16(src, zero);
svfloat32_t low_f32 = svcvt_f32_f16_z(pg, low_f16);
svfloat32_t high_f32 = svcvt_f32_f16_z(pg, high_f16);

// Perform exp and convert back to f16
svfloat32_t low_exp_f32 = exp_ps_sve(pg_f32, low_f32);
svfloat32_t high_exp_f32 = exp_ps_sve(pg_f32, high_f32);
svfloat16_t low_exp_f16 = svcvt_f16_f32_z(pg_f32, low_exp_f32);
svfloat16_t high_exp_f16 = svcvt_f16_f32_z(pg_f32, high_exp_f32);

// Interleave both to get final result
svfloat16_t res = svtrn1_f16(low_exp_f16, high_exp_f16);
return res;
}
# else
inline float16x8_t exp_ps_neon_f16(float16x8_t x) {
const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x));
const float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x));
Expand All @@ -411,6 +434,7 @@ inline float16x8_t exp_ps_neon_f16(float16x8_t x) {
const float16x8_t res = vcombine_f16(vcvt_f16_f32(exp_ps_neon_f32(x_low)), vcvt_f16_f32(exp_ps_neon_f32(x_high)));
return res;
}
# endif
inline float16_t hsum(float16x8_t vec) {
float16x4_t sum1 = vpadd_f16(vget_low_f16(vec), vget_high_f16(vec));
float16x4_t sum2 = vpadd_f16(sum1, sum1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,28 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
mm256_uni_storeu_ps(dst + i, vb);
}
#elif defined(OPENVINO_ARCH_ARM64)
if (std::is_same<TA, ov::float16>::value && std::is_same<TB, ov::float16>::value) {
# if defined(HAVE_SVE)
size_t inc = vec_len_f16_sve;
svbool_t pg = svptrue_b16();

while (i < n) {
if (n - i < vec_len_f16_sve) {
inc = n - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
svfloat16_t b1 = svld1_f16(pg, reinterpret_cast<const float16_t*>(src + i));
svst1_f16(pg, reinterpret_cast<float16_t*>(dst + i), b1);
i += inc;
}
# else
for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) {
auto vb1 = vld1q_f16(reinterpret_cast<const float16_t*>(src + i));
vst1q_f16(reinterpret_cast<float16_t*>(dst + i), vb1);
}
# endif
}
#else
# if defined(HAVE_SVE)
auto _dst = reinterpret_cast<float32_t*>(dst);
size_t inc = vec_len_f32_sve;
Expand Down Expand Up @@ -158,15 +180,35 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal
template <typename T>
static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S, float* scale, float* zp) {
size_t i = 0;
auto attn_w_vec_fp16 = vdupq_n_f16(weight);
auto _v = reinterpret_cast<float16_t*>(v);
auto _out = reinterpret_cast<float16_t*>(out);

# if defined(HAVE_SVE)
svfloat16_t attn_w_vec_fp16 = svdup_n_f16(weight);
svbool_t pg = svptrue_b16();
size_t inc = vec_len_f16_sve;

while (i < S) {
if (S - i < vec_len_f16_sve) {
inc = S - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
svfloat16_t v_value = svld1_f16(pg, _v + i);
svfloat16_t v_out = svld1_f16(pg, _out + i);

v_out = svmla_f16_m(pg, v_out, attn_w_vec_fp16, v_value);
svst1_f16(pg, _out + i, v_out);
i += inc;
}
# else
auto attn_w_vec_fp16 = vdupq_n_f16(weight);
for (; i + vec_len_f16_neon <= S; i += vec_len_f16_neon) {
auto v_value = vld1q_f16(_v + i);
auto v_out = vld1q_f16(_out + i);
v_out = vfmaq_f16(v_out, attn_w_vec_fp16, v_value);
vst1q_f16(_out + i, v_out);
}
# endif
for (; i < S; i++) {
out[i] += weight * v[i];
}
Expand Down Expand Up @@ -701,12 +743,67 @@ static ov::float16 dot_product_fp16(ov::float16* a,
float* head_sum) {
size_t i = 0;
ov::float16 sum = 0.0f;
auto _a = reinterpret_cast<float16_t*>(a);
auto _b = reinterpret_cast<float16_t*>(b);

# if defined(HAVE_SVE)
svbool_t pg = svptrue_b16();
svfloat16_t sum0 = svdup_n_f16(0.0f);
svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);

for (; i + 4 * vec_len_f16_sve <= n; i += 4 * vec_len_f16_sve) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve);
svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len_f16_sve * 2);
svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len_f16_sve * 3);

svfloat16_t b0 = svld1_f16(pg, _b + i);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve);
svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len_f16_sve * 2);
svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len_f16_sve * 3);

sum0 = svmla_f16_z(pg, sum0, a0, b0);
sum1 = svmla_f16_z(pg, sum1, a1, b1);
sum2 = svmla_f16_z(pg, sum2, a2, b2);
sum3 = svmla_f16_z(pg, sum3, a3, b3);
}
if (i + 2 * vec_len_f16_sve <= n) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve);

svfloat16_t b0 = svld1_f16(pg, _b + i);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve);

sum0 = svmla_f16_z(pg, sum0, a0, b0);
sum1 = svmla_f16_z(pg, sum1, a1, b1);
i += 2 * vec_len_f16_sve;
}
if (i + vec_len_f16_sve <= n) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t b0 = svld1_f16(pg, _b + i);
sum0 = svmla_f16_z(pg, sum0, a0, b0);
i += vec_len_f16_sve;
}
// Process the tail elements parallely as well (if any)
if (i != n) {
svbool_t pg_rem = svwhilelt_b16(0, static_cast<int>(n - i));
svfloat16_t a0 = svld1_f16(pg_rem, _a + i);
svfloat16_t b0 = svld1_f16(pg_rem, _b + i);
sum0 = svmla_f16_m(pg_rem, sum0, a0, b0);
i = n;
}
float16_t sum_0 = svaddv_f16(pg, sum0);
float16_t sum_1 = svaddv_f16(pg, sum1);
float16_t sum_2 = svaddv_f16(pg, sum2);
float16_t sum_3 = svaddv_f16(pg, sum3);
sum = static_cast<float>(sum_0 + sum_1 + sum_2 + sum_3);
# else
auto vsum0 = vdupq_n_f16(0.0f);
auto vsum1 = vdupq_n_f16(0.0f);
auto vsum2 = vdupq_n_f16(0.0f);
auto vsum3 = vdupq_n_f16(0.0f);
auto _a = reinterpret_cast<float16_t*>(a);
auto _b = reinterpret_cast<float16_t*>(b);

for (; i + 4 * vec_len_f16_neon <= n; i += vec_len_f16_neon * 4) {
auto va0 = vld1q_f16(_a + i);
Expand Down Expand Up @@ -747,7 +844,7 @@ static ov::float16 dot_product_fp16(ov::float16* a,
vsum0 = vaddq_f16(vsum0, vsum2);

sum = hsum(vsum0);

# endif
for (; i < n; i++) {
sum += a[i] * b[i];
}
Expand Down Expand Up @@ -985,6 +1082,21 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S, size_t temp_stride) {
size_t i = 0;
# if defined(HAVE_SVE)
svbool_t pg = svptrue_b16();

for (; i + vec_len_f16_sve <= S; i += vec_len_f16_sve) {
auto* src = temp + i;
auto result_vec_fp16 = svdup_n_f16(0.0f);

for (size_t m = 0; m < M; m++) {
auto o_vec_fp16 = svld1_f16(pg, reinterpret_cast<float16_t*>(src));
result_vec_fp16 = svadd_f16_m(pg, result_vec_fp16, o_vec_fp16);
src += temp_stride;
}
svst1_f16(pg, reinterpret_cast<float16_t*>(dst + i), result_vec_fp16);
}
# else
for (; i + vec_len_f16_neon <= S; i += vec_len_f16_neon) {
auto* src = temp + i;
auto result_vec_fp16 = vdupq_n_f16(0.0f);
Expand All @@ -995,6 +1107,7 @@ static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S,
}
vst1q_f16(reinterpret_cast<float16_t*>(dst + i), result_vec_fp16);
}
# endif
for (; i < S; i++) {
auto* src = temp + i;
float sum = 0.0f;
Expand Down
Loading

0 comments on commit b758384

Please sign in to comment.