Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
liubo-intel committed Dec 20, 2024
1 parent f4853e3 commit 3674a36
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener

apply_post_ops(true, jep_.oc_size > 1 ? j * sizeof(float) : 0);

store_scalar(ptr[reg_dst + j * jep.dst_prc.size()], xmm_dst, exec_prc, jep.dst_prc);
store_scalar(ptr[reg_dst + j * jep.dst_prc.size()],
xmm_dst,
exec_prc,
jep.dst_prc,
jep.do_output_saturation);
}

for (size_t i = 0; i < jep.inputs_number; i++)
Expand Down Expand Up @@ -549,7 +553,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener

apply_post_ops(true);

store_scalar(ptr[reg_dst], xmm_dst, exec_prc, jep.dst_prc);
store_scalar(ptr[reg_dst], xmm_dst, exec_prc, jep.dst_prc, jep.do_output_saturation);

for (size_t i = 0; i < jep.inputs_number; i++)
if (jep.src_size[i] != 1)
Expand Down Expand Up @@ -1015,7 +1019,8 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
inline void store_scalar(const Xbyak::Address& op,
Xmm xmm_dst,
ov::element::Type src_prc,
ov::element::Type dst_prc) {
ov::element::Type dst_prc,
const bool do_output_saturation) {
if (src_prc == dst_prc) {
switch (src_prc.size()) {
case 4:
Expand Down Expand Up @@ -1050,7 +1055,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
uni_vmovss(op, xmm_dst);
break;
case ov::element::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
if (do_output_saturation)
uni_vpsrld(xmm_dst, xmm_dst, 16);
else
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(xmm_dst.getIdx())},
{static_cast<size_t>(xmm_dst.getIdx())});
uni_vpextrw(op, xmm_dst, 0x0);
break;
case ov::element::f16:
Expand Down Expand Up @@ -1424,7 +1433,7 @@ struct EltwiseKey {
result = result && (inpDims[i] == rhs.inpDims[i]);
}
}
if ((outPrc == ov::element::bf16) && (doOutputSaturation != rhs.doOutputSaturation))
if (doOutputSaturation != rhs.doOutputSaturation)
return false;
}

Expand Down Expand Up @@ -2874,8 +2883,11 @@ void Eltwise::prepareParams() {
"'");
}
}
// do output saturation if inputs has constant, this saturation process will be moved to compilation stage in
// future

// FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we
// choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the
// outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on
// compilation stage (ticket: 159589).
key.doOutputSaturation = false;
for (size_t i = 0; i < getParentEdges().size(); i++) {
if (getParentEdgeAt(i)->getParent()->isConstant()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
using namespace CPUTestUtils;
namespace ov {
namespace test {
/*
This test aims to cover Eltwise node BF16 output precision conversion logic in "saturation" mode. In this test, we
have a select node with condition input of boolean type and then/else inputs of f32 type(as constant node with bf16
overflow data). The select node is followed by a convolution node to ensoure that it is converted to bf16 precision.
*/
using selectParams = std::tuple<InputShape, // Condition shapes
ElementType // Then/Else precision
>;
Expand Down

0 comments on commit 3674a36

Please sign in to comment.