Skip to content

Commit

Permalink
sve veclen compilation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NishantPrabhuFujitsu committed Jan 2, 2025
1 parent 7638e02 commit c0b807a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$")
set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$")
set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$")
set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$")
set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$")
set(_ACCEPTED_ARCHS_SVE "^(ANY|NEON_FP16|SVE)$")

## Arch specific definitions
set(_DEFINE_ANY "")
Expand Down
10 changes: 5 additions & 5 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,16 @@ target_include_directories(${TARGET_NAME} PRIVATE $<TARGET_PROPERTY:openvino::re
set(SOFTMAX_ARCH_LIST AVX512F AVX2)
set(MHA_SINGLE_TOKEN_ARCH_LIST AVX512F AVX2)

if(ENABLE_NEON_FP16)
list(APPEND SOFTMAX_ARCH_LIST NEON_FP16)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST NEON_FP16)
endif()

if(ENABLE_SVE)
list(APPEND SOFTMAX_ARCH_LIST SVE)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST SVE)
endif()

if(ENABLE_NEON_FP16)
list(APPEND SOFTMAX_ARCH_LIST NEON_FP16)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST NEON_FP16)
endif()

list(APPEND SOFTMAX_ARCH_LIST ANY)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST ANY)

Expand Down
10 changes: 8 additions & 2 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float);
static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);

#if defined(HAVE_SVE)
static size_t vec_len_f32_sve = svcntw();
static size_t vec_len_f16_sve = svcnth();
inline size_t vec_len_f32_sve() {
static size_t len = svcntw();
return len;
}
inline size_t vec_len_f16_sve() {
static size_t len = svcnth();
return len;
}
#endif

#ifdef HAVE_AVX512F
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
if (std::is_same<TA, ov::float16>::value && std::is_same<TB, ov::float16>::value) {
# if defined(HAVE_SVE)
size_t inc = vec_len_f16_sve;
size_t inc = vec_len_f16_sve();
svbool_t pg = svptrue_b16();

while (i < n) {
if (n - i < vec_len_f16_sve) {
if (n - i < vec_len_f16_sve()) {
inc = n - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
Expand All @@ -88,11 +88,11 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
# else
# if defined(HAVE_SVE)
auto _dst = reinterpret_cast<float32_t*>(dst);
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

while (i < n) {
if (n - i < vec_len_f32_sve) {
if (n - i < vec_len_f32_sve()) {
inc = n - i;
pg = svwhilelt_b32(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -138,11 +138,11 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal
# if defined(HAVE_SVE)
auto _v = reinterpret_cast<float32_t*>(v);
svfloat32_t attn_w_vec_fp32 = svdup_n_f32(weight);
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

while (i < S) {
if (S - i < vec_len_f32_sve) {
if (S - i < vec_len_f32_sve()) {
inc = S - i;
pg = svwhilelt_b32(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -180,10 +180,10 @@ static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S,
# if defined(HAVE_SVE)
svfloat16_t attn_w_vec_fp16 = svdup_n_f16(weight);
svbool_t pg = svptrue_b16();
size_t inc = vec_len_f16_sve;
size_t inc = vec_len_f16_sve();

while (i < S) {
if (S - i < vec_len_f16_sve) {
if (S - i < vec_len_f16_sve()) {
inc = S - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -440,30 +440,31 @@ static float sum_q_head(T* a, size_t n) {
svfloat32_t sum2 = svdup_n_f32(0.0f);
svfloat32_t sum3 = svdup_n_f32(0.0f);
svbool_t pg = svptrue_b32();
auto vec_len = vec_len_f32_sve();

for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) {
for (; i + 4 * vec_len <= n; i += 4 * vec_len) {
svfloat32_t a0 = svld1_f32(pg, a + i);
svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve);
svfloat32_t a2 = svld1_f32(pg, a + i + vec_len_f32_sve * 2);
svfloat32_t a3 = svld1_f32(pg, a + i + vec_len_f32_sve * 3);
svfloat32_t a1 = svld1_f32(pg, a + i + vec_len);
svfloat32_t a2 = svld1_f32(pg, a + i + vec_len * 2);
svfloat32_t a3 = svld1_f32(pg, a + i + vec_len * 3);

sum0 = svadd_f32_z(pg, a0, sum0);
sum1 = svadd_f32_z(pg, a1, sum1);
sum2 = svadd_f32_z(pg, a2, sum2);
sum3 = svadd_f32_z(pg, a3, sum3);
}
if (i + 2 * vec_len_f32_sve <= n) {
if (i + 2 * vec_len <= n) {
svfloat32_t a0 = svld1_f32(pg, a + i);
svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve);
svfloat32_t a1 = svld1_f32(pg, a + i + vec_len);

sum0 = svadd_f32_z(pg, a0, sum0);
sum1 = svadd_f32_z(pg, a1, sum1);
i += 2 * vec_len_f32_sve;
i += 2 * vec_len;
}
if (i + vec_len_f32_sve <= n) {
if (i + vec_len <= n) {
svfloat32_t a0 = svld1_f32(pg, a + i);
sum0 = svadd_f32_z(pg, a0, sum0);
i += vec_len_f32_sve;
i += vec_len;
}
// Process tail elements parallely as well (if any)
if (i != n) {
Expand Down Expand Up @@ -623,42 +624,43 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float*
svfloat32_t sum1 = svdup_n_f32(0.0f);
svfloat32_t sum2 = svdup_n_f32(0.0f);
svfloat32_t sum3 = svdup_n_f32(0.0f);
auto vec_len = vec_len_f32_sve();

auto _a = reinterpret_cast<float32_t*>(a);
auto _b = reinterpret_cast<float32_t*>(b);

for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) {
for (; i + 4 * vec_len <= n; i += 4 * vec_len) {
svfloat32_t a0 = svld1_f32(pg, _a + i);
svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve);
svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len_f32_sve * 2);
svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len_f32_sve * 3);
svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len);
svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len * 2);
svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len * 3);

svfloat32_t b0 = svld1_f32(pg, _b + i);
svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve);
svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len_f32_sve * 2);
svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len_f32_sve * 3);
svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len);
svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len * 2);
svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len * 3);

sum0 = svmla_f32_z(pg, sum0, a0, b0);
sum1 = svmla_f32_z(pg, sum1, a1, b1);
sum2 = svmla_f32_z(pg, sum2, a2, b2);
sum3 = svmla_f32_z(pg, sum3, a3, b3);
}
if (i + 2 * vec_len_f32_sve <= n) {
if (i + 2 * vec_len <= n) {
svfloat32_t a0 = svld1_f32(pg, _a + i);
svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve);
svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len);

svfloat32_t b0 = svld1_f32(pg, _b + i);
svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve);
svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len);

sum0 = svmla_f32_z(pg, sum0, a0, b0);
sum1 = svmla_f32_z(pg, sum1, a1, b1);
i += 2 * vec_len_f32_sve;
i += 2 * vec_len;
}
if (i + vec_len_f32_sve <= n) {
if (i + vec_len <= n) {
svfloat32_t a0 = svld1_f32(pg, _a + i);
svfloat32_t b0 = svld1_f32(pg, _b + i);
sum0 = svmla_f32_z(pg, sum0, a0, b0);
i += vec_len_f32_sve;
i += vec_len;
}
// Process the tail elements parallely as well (if any)
if (i != n) {
Expand Down Expand Up @@ -746,39 +748,40 @@ static ov::float16 dot_product_fp16(ov::float16* a,
svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);
auto vec_len = vec_len_f16_sve();

for (; i + 4 * vec_len_f16_sve <= n; i += 4 * vec_len_f16_sve) {
for (; i + 4 * vec_len <= n; i += 4 * vec_len) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve);
svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len_f16_sve * 2);
svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len_f16_sve * 3);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len);
svfloat16_t a2 = svld1_f16(pg, _a + i + vec_len * 2);
svfloat16_t a3 = svld1_f16(pg, _a + i + vec_len * 3);

svfloat16_t b0 = svld1_f16(pg, _b + i);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve);
svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len_f16_sve * 2);
svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len_f16_sve * 3);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len);
svfloat16_t b2 = svld1_f16(pg, _b + i + vec_len * 2);
svfloat16_t b3 = svld1_f16(pg, _b + i + vec_len * 3);

sum0 = svmla_f16_z(pg, sum0, a0, b0);
sum1 = svmla_f16_z(pg, sum1, a1, b1);
sum2 = svmla_f16_z(pg, sum2, a2, b2);
sum3 = svmla_f16_z(pg, sum3, a3, b3);
}
if (i + 2 * vec_len_f16_sve <= n) {
if (i + 2 * vec_len <= n) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len_f16_sve);
svfloat16_t a1 = svld1_f16(pg, _a + i + vec_len);

svfloat16_t b0 = svld1_f16(pg, _b + i);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len_f16_sve);
svfloat16_t b1 = svld1_f16(pg, _b + i + vec_len);

sum0 = svmla_f16_z(pg, sum0, a0, b0);
sum1 = svmla_f16_z(pg, sum1, a1, b1);
i += 2 * vec_len_f16_sve;
i += 2 * vec_len;
}
if (i + vec_len_f16_sve <= n) {
if (i + vec_len <= n) {
svfloat16_t a0 = svld1_f16(pg, _a + i);
svfloat16_t b0 = svld1_f16(pg, _b + i);
sum0 = svmla_f16_z(pg, sum0, a0, b0);
i += vec_len_f16_sve;
i += vec_len;
}
// Process the tail elements parallely as well (if any)
if (i != n) {
Expand Down Expand Up @@ -1029,11 +1032,11 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str
#elif defined(OPENVINO_ARCH_ARM64)
# if defined(HAVE_SVE)
auto _dst = reinterpret_cast<float32_t*>(dst);
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

while (i < S) {
if (S - i < vec_len_f32_sve) {
if (S - i < vec_len_f32_sve()) {
inc = S - i;
pg = svwhilelt_b32(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -1079,7 +1082,7 @@ static void attn_reduce(ov::float16* dst, ov::float16* temp, size_t M, size_t S,
# if defined(HAVE_SVE)
svbool_t pg = svptrue_b16();

for (; i + vec_len_f16_sve <= S; i += vec_len_f16_sve) {
for (; i + vec_len_f16_sve() <= S; i += vec_len_f16_sve()) {
auto* src = temp + i;
auto result_vec_fp16 = svdup_n_f16(0.0f);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,10 @@ inline void scale_add2_reduce_max(ov::float16* a,
svbool_t pg_f16 = svptrue_b16();
svbool_t pg_u8 = svptrue_b8();
svbool_t pg_u16 = svptrue_b16();
size_t inc = vec_len_f16_sve;
size_t inc = vec_len_f16_sve();

while (i < size) {
if (size - i < vec_len_f16_sve) {
if (size - i < vec_len_f16_sve()) {
inc = size - i;
pg_f16 = svwhilelt_b16(0, static_cast<int>(inc));
pg_u8 = svwhilelt_b8(0, static_cast<int>(inc));
Expand Down Expand Up @@ -717,12 +717,11 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&
svfloat32_t v_a;
svfloat32_t v_max = svdup_n_f32(max);
svfloat32_t v_sum = svdup_n_f32(0.0f);
size_t vec_len_f32_sve = svcntw();
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

while (i < size) {
if (size - i < vec_len_f32_sve) {
if (size - i < vec_len_f32_sve()) {
inc = size - i;
pg = svwhilelt_b32(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -767,12 +766,12 @@ inline void exp_reduce_sum_f32(ov::float16* a, const ov::float16 max, const size
svbool_t pg_f32 = svptrue_b32();
svbool_t pg_f16 = svptrue_b16();
svfloat16_t zero = svdup_n_f16(0.0);
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();

while (i < size) {
if (size - i < vec_len_f16_sve)
if (size - i < vec_len_f16_sve())
pg_f16 = svwhilelt_b16(0, static_cast<int>(size - i));
if (size - i < vec_len_f32_sve) {
if (size - i < vec_len_f32_sve()) {
pg_f32 = svwhilelt_b32(0, static_cast<int>(size - i));
inc = size - i;
}
Expand Down Expand Up @@ -838,10 +837,10 @@ inline void exp_reduce_sum(ov::float16* a, const ov::float16 max, const size_t s
svfloat16_t v_max = svdup_n_f16(max);
svfloat16_t v_sum = svdup_n_f16(0.0f);
svbool_t pg = svptrue_b16();
size_t inc = vec_len_f16_sve;
size_t inc = vec_len_f16_sve();

while (i < size) {
if (size - i < vec_len_f16_sve) {
if (size - i < vec_len_f16_sve()) {
inc = size - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -919,11 +918,11 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_
#elif defined(OPENVINO_ARCH_ARM64)
# if defined(HAVE_SVE)
svfloat32_t v_scale = svdup_n_f32(val);
size_t inc = vec_len_f32_sve;
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();

while (i < size) {
if (size - i < vec_len_f32_sve) {
if (size - i < vec_len_f32_sve()) {
inc = size - i;
pg = svwhilelt_b32(0, static_cast<int>(inc));
}
Expand Down Expand Up @@ -1020,11 +1019,11 @@ inline void multiply_scalar(ov::float16* a, ov::float16* a_dst, const ov::float1
size_t i = 0;
# if defined(HAVE_SVE)
svfloat16_t v_scale = svdup_n_f16(val);
size_t inc = vec_len_f16_sve;
size_t inc = vec_len_f16_sve();
svbool_t pg = svptrue_b16();

while (i < size) {
if (size - i < vec_len_f16_sve) {
if (size - i < vec_len_f16_sve()) {
inc = size - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
Expand Down

0 comments on commit c0b807a

Please sign in to comment.