Skip to content

Commit

Permalink
added elementwise logical or
Browse files Browse the repository at this point in the history
  • Loading branch information
NishantPrabhuFujitsu committed Dec 13, 2024
1 parent a0c705f commit 1407cfb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1447,26 +1447,33 @@ std::set<std::vector<element::Type>> jit_logical_and_emitter::get_supported_prec

/// LOGICAL_OR ///
jit_logical_or_emitter::jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}

jit_logical_or_emitter::jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_logical_or_emitter::get_inputs_count() const { return 2; }
size_t jit_logical_or_emitter::get_inputs_count() const {
return 2;
}

size_t jit_logical_or_emitter::get_aux_vecs_count() const { return 1; }
size_t jit_logical_or_emitter::get_aux_vecs_count() const {
return 1;
}

size_t jit_logical_or_emitter::get_aux_gprs_count() const { return 1; }
size_t jit_logical_or_emitter::get_aux_gprs_count() const {
return 1;
}

void jit_logical_or_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
void jit_logical_or_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
Expand All @@ -1475,7 +1482,8 @@ void jit_logical_or_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, c
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_logical_or_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
void jit_logical_or_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
Expand All @@ -1493,7 +1501,8 @@ void jit_logical_or_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
}

std::set<std::vector<element::Type>> jit_logical_or_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
std::set<std::vector<element::Type>> jit_logical_or_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,13 @@ class jit_logical_and_emitter : public jit_emitter {

class jit_logical_or_emitter : public jit_emitter {
public:
jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);
jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n);
jit_logical_or_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n);

size_t get_inputs_count() const override;

Expand All @@ -659,13 +659,13 @@ class jit_logical_or_emitter : public jit_emitter {
size_t get_aux_gprs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);
const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;

void register_table_entries() override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseIsInf, ov::intel_cpu::aarch64::jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseLessEqual, ov::intel_cpu::aarch64::jit_less_equal_emitter),
OV_CASE(Algorithm::EltwiseLogicalAnd, ov::intel_cpu::aarch64::jit_logical_and_emitter),
OV_CASE(Algorithm::EltwiseLogicalOr, ov::intel_cpu::aarch64::jit_logical_or_emitter),
OV_CASE(Algorithm::EltwiseLogicalNot, ov::intel_cpu::aarch64::jit_logical_not_emitter),
OV_CASE(Algorithm::EltwiseLogicalXor, ov::intel_cpu::aarch64::jit_logical_xor_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, ov::intel_cpu::aarch64::jit_is_nan_emitter),
Expand Down Expand Up @@ -860,6 +861,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter),
OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter),
OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter),
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter),
OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter),
Expand Down

0 comments on commit 1407cfb

Please sign in to comment.