Skip to content

Commit

Permalink
bf16_staturation Jit Impl
Browse files Browse the repository at this point in the history
  • Loading branch information
liubo-intel committed Jan 19, 2025
1 parent 4a325ce commit 2139cd4
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 40 deletions.
10 changes: 2 additions & 8 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,15 +2895,9 @@ void Eltwise::prepareParams() {

// FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we
// choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the
// outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on
// compilation stage (ticket: 159589).
// outputs. Since "saturation" conversion during kernel runtime is more time consuming, current solution is
// clamp constants on compilation stage.
key.doOutputSaturation = false;
for (size_t i = 0; i < getParentEdges().size(); i++) {
if (getParentEdgeAt(i)->getParent()->isConstant()) {
key.doOutputSaturation = true;
break;
}
}

auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor);
Expand Down
215 changes: 183 additions & 32 deletions src/plugins/intel_cpu/src/nodes/input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ namespace node {

#if defined(OPENVINO_ARCH_X86_64)
namespace {
struct jit_has_subnormals_base : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base)
struct jit_subnormals_bf16saturation_check_base : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_subnormals_bf16saturation_check_base)

typedef struct {
const float* src;
const size_t count;
bool hasSubnormals;
bool hasTargetValues;
} args_t;

typedef void (*fn_t)(const args_t*);

jit_has_subnormals_base() : jit_generator(jit_name()) {
jit_subnormals_bf16saturation_check_base() : jit_generator(jit_name()) {
jit_ker_ = nullptr;
}

Expand Down Expand Up @@ -110,8 +110,35 @@ struct jit_has_subnormals_base : public jit_generator {
uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
}

void check_bf16_saturations(const Xbyak::Reg64& src,
const Xbyak::Ymm& bf16_max_mask,
const Xbyak::Ymm& bf16_min_mask) {
auto a = ymm1;
auto b = ymm2;
auto c = ymm3;
vmovdqu(a, yword[src]); // load 8 floats
vcmpps(b, a, bf16_max_mask, _CMP_GT_OQ); // b = (a > bf16_max) ? 1 : 0
vcmpps(c, a, bf16_min_mask, _CMP_LT_OQ); // c = (a < bf16_min) ? 1 : 0
vorps(b, b, c); // b = b | c
vptest(b, b); // if (b != 0) CF = 1 else CF = 0
}

void check_bf16_saturations(const Xbyak::Reg64& src,
const Xbyak::Xmm& bf16_max_mask,
const Xbyak::Xmm& bf16_min_mask) {
auto a = xmm1;
auto b = xmm2;
auto c = xmm3;

uni_vmovdqu(a, xword[src]); // load 4 floats
uni_vcmpps(b, a, bf16_max_mask, _CMP_GT_OQ); // b = (a > bf16_max) ? 1 : 0
uni_vcmpps(c, a, bf16_max_mask, _CMP_LT_OQ); // c = (a < bf16_min) ? 1 : 0
uni_vorps(b, b, c); // b = b | c
uni_vtestps(b, b); // if (b != 0) CF = 1 else CF = 0
}

protected:
Label exit, has_subnormals, no_subnormals;
Label exit, has_target_values, no_target_values;

const Reg64& reg_src = rax;
const Reg64& reg_dst = rbx;
Expand All @@ -121,16 +148,35 @@ struct jit_has_subnormals_base : public jit_generator {

static const uint32_t exponent_mask_data[8];
static const uint32_t mantissa_mask_data[8];
static const float bf16_max_mask_data[8];
static const float bf16_min_mask_data[8];
};

const uint32_t jit_has_subnormals_base::exponent_mask_data[8] =
const uint32_t jit_subnormals_bf16saturation_check_base::exponent_mask_data[8] =
{0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000};

const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] =
const uint32_t jit_subnormals_bf16saturation_check_base::mantissa_mask_data[8] =
{0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff};

const float jit_subnormals_bf16saturation_check_base::bf16_max_mask_data[8] = {3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f,
3.38953139e+38f};

const float jit_subnormals_bf16saturation_check_base::bf16_min_mask_data[8] = {-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f,
-3.38953139e+38f};
template <cpu_isa_t isa>
struct jit_has_subnormals : public jit_has_subnormals_base {
struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;

const Vmm rmm4 = Vmm(4);
Expand All @@ -150,7 +196,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {

// Get arguments addresses
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]);
lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]);
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);

// Initialize necessary consts
Expand All @@ -167,7 +213,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {

foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) {
check_subnormals(reg_src, exponent_mask, mantissa_mask, zero);
jnc(has_subnormals);
jnc(has_target_values);
add(reg_src, sizeof(float) * vlen);
})
;
Expand All @@ -186,25 +232,98 @@ struct jit_has_subnormals : public jit_has_subnormals_base {

copy_floats(r8, reg_src, reg_sz);
check_subnormals(r8, exponent_mask, mantissa_mask, zero);
jc(no_subnormals);
jc(no_target_values);
add(rsp, vlen * sizeof(float));

L(has_subnormals);
L(has_target_values);

mov(rax, 1);
mov(byte[reg_dst], al);
jmp(exit);

L(no_subnormals);
L(no_target_values);
add(rsp, vlen * sizeof(float));

L(exit);

postamble();
}
};
template <cpu_isa_t isa>
struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base {
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;

const Vmm rmm4 = Vmm(4);
const Vmm rmm5 = Vmm(5);
const Vmm rmm6 = Vmm(6);
const int length = isa == sse41 ? 4 : 8;

void generate() override final { // NOLINT
size_t const vlen = length;
const int sh_bits = std::ilogb(vlen);

auto zero = rmm4;
auto bf16_max_mask = rmm5;
auto bf16_min_mask = rmm6;

preamble();

// Get arguments addresses
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]);
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);

// Initialize necessary consts
uni_vpxor(zero, zero, zero);
mov(reg_mask_addr, (size_t)bf16_max_mask_data);
uni_vmovdqu(bf16_max_mask, ptr[reg_mask_addr]);
mov(reg_mask_addr, (size_t)bf16_min_mask_data);
uni_vmovdqu(bf16_min_mask, ptr[reg_mask_addr]);

// Main loop
xor_(reg_idx, reg_idx);
mov(r8, reg_sz);
shr(r8, sh_bits);

foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) {
check_bf16_saturations(reg_src, bf16_max_mask, bf16_min_mask);
jnz(has_target_values, T_NEAR);
add(reg_src, sizeof(float) * vlen);
})
;

// Tail
shl(reg_idx, sh_bits);
sub(reg_sz, reg_idx);
test(reg_sz, reg_sz);
jz(exit);

jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
// use space on stack for 4 or 8 floats
sub(rsp, vlen * sizeof(float));
mov(r8, rsp);

uni_vmovdqu(ptr[r8], zero);

copy_floats(r8, reg_src, reg_sz);
check_bf16_saturations(r8, bf16_max_mask, bf16_min_mask);
jz(no_target_values, T_NEAR);
add(rsp, vlen * sizeof(float));

L(has_target_values);

mov(rax, 1);
mov(byte[reg_dst], al);
jmp(exit);

L(no_target_values);
add(rsp, vlen * sizeof(float));

L(exit);

postamble();
}
};
jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() {
if (mayiuse(cpu_isa_t::avx2)) {
static jit_has_subnormals<cpu_isa_t::avx2> generator;
static auto fn = generator.get();
Expand All @@ -216,6 +335,18 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
}
return nullptr;
}
jit_subnormals_bf16saturation_check_base::fn_t jit_has_bf16_overflows_function() {
if (mayiuse(cpu_isa_t::avx2)) {
static jit_has_bf16_overflows<cpu_isa_t::avx2> generator;
static auto fn = generator.get();
return fn;
} else if (mayiuse(cpu_isa_t::sse41)) {
static jit_has_bf16_overflows<cpu_isa_t::sse41> generator;
static auto fn = generator.get();
return fn;
}
return nullptr;
}

} // namespace
#endif
Expand Down Expand Up @@ -271,49 +402,69 @@ void Input::cloneBlobIfRequired() {
if (!size)
return;

const float bf16_max = 3.3895313899137927e38f;
const bool do_bf16_saturation_check =
(context->getConfig().inferencePrecision == ov::element::bf16) ? true : false;

#if defined(OPENVINO_ARCH_X86_64)
if (auto fn = jit_has_subnormals_function()) {
auto fn = jit_has_subnormals_function();
auto fn_bf16_check = jit_has_bf16_overflows_function();
if (fn && fn_bf16_check) {
static const size_t batch_size = 2048;
const size_t iterations_num = size / batch_size + 1;

volatile bool has_subnormals_local = false;
volatile bool has_bf16_overflows_local = false;

parallel_for(iterations_num, [&](int n) {
auto ptr = u32data + n * batch_size;
const jit_has_subnormals_base::args_t args = {reinterpret_cast<float const*>(ptr),
std::min(batch_size, (size_t)(u32data + size - ptr)),
false};
const jit_subnormals_bf16saturation_check_base::args_t args1 = {
reinterpret_cast<float const*>(ptr),
std::min(batch_size, (size_t)(u32data + size - ptr)),
false};

fn(&args);
fn(&args1);

if (args.hasSubnormals)
if (args1.hasTargetValues)
has_subnormals_local = true;
});

has_subnormals = has_subnormals_local;
//TODO: opt with jit
for (size_t i = 0; i < size; ++i) {
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
has_bf16_overflows = true;
return;
}
if (do_bf16_saturation_check) {
parallel_for(iterations_num, [&](int n) {
auto ptr2 = f32data + n * batch_size;
const jit_subnormals_bf16saturation_check_base::args_t args2 = {
reinterpret_cast<float const*>(ptr2),
std::min(batch_size, (size_t)(f32data + size - ptr2)),
false};

fn_bf16_check(&args2);

if (args2.hasTargetValues)
has_bf16_overflows_local = true;
});
}

has_subnormals = has_subnormals_local;
has_bf16_overflows = has_bf16_overflows_local;

return;
}
#endif

uint32_t mantissaMask = 0x007fffff;
uint32_t exponentMask = 0x7f800000;
const float bf16_max = 3.3895313899137927e38f;
for (size_t i = 0; i < size; ++i) {
if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
has_subnormals = true;
}
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
has_bf16_overflows = true;
}
if (has_subnormals && has_bf16_overflows) {
if (do_bf16_saturation_check) {
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
has_bf16_overflows = true;
}
if (has_subnormals && has_bf16_overflows) {
return;
}
} else if (has_subnormals) {
return;
}
}
Expand Down

0 comments on commit 2139cd4

Please sign in to comment.