diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp index 6d08a7f55a6..20379fa299b 100644 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ b/src/cpu/aarch64/jit_uni_reorder.cpp @@ -161,14 +161,20 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { static bool applicable(const prb_t &p) { using namespace data_type; + bool bf16_ok + = (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16) + && !interim_f32_needed(p, false) && p.beta == 0.f) + || (p.itype != bf16 && p.otype != bf16) + || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16() + && p.beta == 0.f); + bool ok = true && p.ndims > 0 - && utils::one_of(p.itype, f32, s32, data_type::s8, u8) + && utils::one_of(p.itype, f32, bf16, s32, data_type::s8, u8) && utils::one_of(p.otype, f32, bf16, s32, data_type::s8, u8) && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) - && IMPLICATION( - p.otype == bf16, p.itype == f32 && mayiuse_bf16()); + && bf16_ok; return ok; } @@ -702,7 +708,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { const int load_tail_step = !can_load_xmm && can_store_xmm ? ur_step : load_step; - const bool interim_f32 = interim_f32_needed(); + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); const bool need_saturation = (utils::one_of(prb_.otype, u8, data_type::s8, s32) @@ -1285,17 +1291,17 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } } - bool interim_f32_needed() { + static bool interim_f32_needed(const prb_t &prb, bool compensation_needed) { using namespace data_type; - - return utils::one_of(f32, prb_.itype, prb_.otype) - || prb_.src_scale_type != scale_type_t::NONE - || prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f - || ((prb_.req_src_zp || prb_.req_dst_zp) - ? !(prb_.itype == s32 && prb_.otype == s32) + bool ret = utils::one_of(f32, prb.itype, prb.otype) + || prb.src_scale_type != scale_type_t::NONE + || prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f + || ((prb.req_src_zp || prb.req_dst_zp) + ? !(prb.itype == s32 && prb.otype == s32) : false) - || (prb_.itype != f32 && compensation_needed_) - || prb_.scale_adjust != 1.f; + || (prb.itype != f32 && compensation_needed) + || prb.scale_adjust != 1.f; + return ret; } void process_unroll_generic( @@ -1313,7 +1319,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { int curr = 0; // will switch between 0 and 1 - const bool interim_f32 = interim_f32_needed(); + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); if (prb_.req_src_zp) { add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0);