Skip to content

Commit

Permalink
fix review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
usstq committed Dec 28, 2024
1 parent efc7c57 commit ef6a6c1
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 136 deletions.
11 changes: 0 additions & 11 deletions src/plugins/intel_cpu/src/nodes/act_sparse_fc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ struct ActSparseFC::Executor : public ActSparseFC::ExecutorBase {
MemoryPtr m_scales;
ActSparseFCNode::Config& m_config;

void show(const char* name, uint8_t* src, int stride, int rows, int cols) {
printf("===== %s \n", name);
for (int r = 0; r < rows; r++, src += stride) {
for (int c = 0; c < cols; c++) {
printf("%02X,", src[c]);
}
printf("\n");
}
}

Executor(ActSparseFC* pnode, DnnlScratchPadPtr scrachPad)
: m_node(pnode),
m_scrachPad(scrachPad),
Expand All @@ -64,7 +54,6 @@ struct ActSparseFC::Executor : public ActSparseFC::ExecutorBase {
const auto& context = m_node->context;
const auto& engine = m_node->getEngine();

std::cout << m_node->getName() << std::endl;
auto create_weight = [&]() {
auto raw_weight_mem = m_node->getSrcMemoryAtPort(1);
MemoryPtr weight_mem;
Expand Down
112 changes: 56 additions & 56 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/act_sparse_fc_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <cstring>

#include "openvino/core/except.hpp"

#if defined(OPENVINO_ARCH_X86_64)

# include "openvino/core/parallel.hpp"
Expand Down Expand Up @@ -34,14 +36,14 @@ static std::shared_ptr<SIMDJit> jit_compile_gemmRegBlk(int rows, int cols, int p
};

// load all arguments into register
auto A_ptr = jit->get_sreg(0);
auto A_stride = jit->get_sreg(1);
auto B_ptr = jit->get_sreg(2);
auto B_stride = jit->get_sreg(3);
auto dst_ptr = jit->get_sreg(4);
auto dst_stride = jit->get_sreg(5);
auto K = jit->get_sreg(6);
auto accumulate = jit->get_sreg(7);
auto A_ptr = jit->get_arg(0);
auto A_stride = jit->get_arg(1);
auto B_ptr = jit->get_arg(2);
auto B_stride = jit->get_arg(3);
auto dst_ptr = jit->get_arg(4);
auto dst_stride = jit->get_arg(5);
auto K = jit->get_arg(6);
auto accumulate = jit->get_arg(7);

auto stemp = jit->get_sreg();

Expand Down Expand Up @@ -95,7 +97,7 @@ static std::shared_ptr<SIMDJit> jit_compile_gemmRegBlk(int rows, int cols, int p
jit->simd_broadcast_ss(vmmA(r), jit->ptr[A_ptr3 + 2 * A_stride]);
break;
default:
throw std::runtime_error("number of reg-blocking rows is not supported");
OPENVINO_ASSERT(false, "number of reg-blocking rows is not supported");
}
};

Expand Down Expand Up @@ -181,20 +183,20 @@ static void gemm6x2_Mx2(const float* pA,
static std::shared_ptr<SIMDJit> jit_compile_accumulate_weight_i4(bool with_zero_point) {
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();
auto dst = jit->get_sreg(0); // float*
auto p_w0 = jit->get_sreg(1); // int4*
auto p_w1 = jit->get_sreg(2); // int4*
auto p_w2 = jit->get_sreg(3); // int4*
auto p_w3 = jit->get_sreg(4); // int4*
auto dense_x = jit->get_sreg(5); // float*
auto OC = jit->get_sreg(6); // float*
auto scales = jit->get_sreg(7); // float*
auto zero_points = jit->get_sreg(8); // float*
auto dst = jit->get_arg(0); // float*
auto p_w0 = jit->get_arg(1); // int4*
auto p_w1 = jit->get_arg(2); // int4*
auto p_w2 = jit->get_arg(3); // int4*
auto p_w3 = jit->get_arg(4); // int4*
auto dense_x = jit->get_arg(5); // float*
auto OC = jit->get_arg(6); // float*
auto scales = jit->get_arg(7); // float*
auto zero_points = jit->get_arg(8); // float*

auto oc = jit->get_sreg();

auto vx = [&](int i) {
ASSERT(i < 4);
OPENVINO_ASSERT(i < 4);
return jit->Vmm(i);
};

Expand Down Expand Up @@ -278,14 +280,14 @@ static std::shared_ptr<SIMDJit> jit_compile_accumulate_weight(WeightCompressionT
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();
// load all arguments into register
auto dst = jit->get_sreg(0); // float*
auto OC = jit->get_sreg(1);
auto gate_ids = jit->get_sreg(2); // int32_t *
auto gate_cnt = jit->get_sreg(3); // int
auto pw0 = jit->get_sreg(4); // ov::float16* / uint8_t*
auto dense_x = jit->get_sreg(5); //
auto scales = jit->get_sreg(6); // float*
auto zero_points = jit->get_sreg(7); // float*
auto dst = jit->get_arg(0); // float*
auto OC = jit->get_arg(1);
auto gate_ids = jit->get_arg(2); // int32_t *
auto gate_cnt = jit->get_arg(3); // int
auto pw0 = jit->get_arg(4); // ov::float16* / uint8_t*
auto dense_x = jit->get_arg(5); //
auto scales = jit->get_arg(6); // float*
auto zero_points = jit->get_arg(7); // float*

auto g = jit->get_sreg();
auto i = jit->get_sreg();
Expand Down Expand Up @@ -383,11 +385,11 @@ static std::shared_ptr<SIMDJit> jit_compile_reduce_outputs() {
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();
// load all arguments into register
auto dst0 = jit->get_sreg(0); // float*
auto src0 = jit->get_sreg(1); // float*
auto num_copies = jit->get_sreg(2); // int
auto OC = jit->get_sreg(3); // int
auto stride = jit->get_sreg(4); // int
auto dst0 = jit->get_arg(0); // float*
auto src0 = jit->get_arg(1); // float*
auto num_copies = jit->get_arg(2); // int
auto OC = jit->get_arg(3); // int
auto stride = jit->get_arg(4); // int

auto i = jit->get_sreg();
auto k = jit->get_sreg();
Expand Down Expand Up @@ -434,14 +436,14 @@ static std::shared_ptr<SIMDJit> jit_compile_repack_3xsimdw_1xsimdw(bool with_zp)
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();
// load all arguments into register
auto src = jit->get_sreg(0); // uint8_t*
auto strideW = jit->get_sreg(1); // int
auto scales = jit->get_sreg(2); // float*
auto zero_points = jit->get_sreg(3); // float*
auto K = jit->get_sreg(4); // int
auto N = jit->get_sreg(5); // int
auto repacked_B_nx3 = jit->get_sreg(6); // float*
auto repacked_B_nx1 = jit->get_sreg(7); // float*
auto src = jit->get_arg(0); // uint8_t*
auto strideW = jit->get_arg(1); // int
auto scales = jit->get_arg(2); // float*
auto zero_points = jit->get_arg(3); // float*
auto K = jit->get_arg(4); // int
auto N = jit->get_arg(5); // int
auto repacked_B_nx3 = jit->get_arg(6); // float*
auto repacked_B_nx1 = jit->get_arg(7); // float*

auto k = jit->get_sreg();
auto n0 = jit->get_sreg();
Expand Down Expand Up @@ -531,12 +533,12 @@ static std::shared_ptr<SIMDJit> jit_compile_repack_2xsimdw(WeightCompressionType
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();
// load all arguments into register
auto src = jit->get_sreg(0); // pointer to ov::float16/u8/i8/i4
auto src_stride = jit->get_sreg(1); // in unit of f16 or bytes (int8/int4)
auto dst = jit->get_sreg(2); // float*
auto bK = jit->get_sreg(3);
auto scales = jit->get_sreg(4); // scales
auto zero_point = jit->get_sreg(5); // zero-point
auto src = jit->get_arg(0); // pointer to ov::float16/u8/i8/i4
auto src_stride = jit->get_arg(1); // in unit of f16 or bytes (int8/int4)
auto dst = jit->get_arg(2); // float*
auto bK = jit->get_arg(3);
auto scales = jit->get_arg(4); // scales
auto zero_point = jit->get_arg(5); // zero-point

auto k = jit->get_sreg();

Expand Down Expand Up @@ -627,7 +629,7 @@ T* ActSparseFcKernel::scratch_alloc(size_t cnt) {
# else
thread_local uint8_t scratch[1024 * 1024 * 2];
# endif
ASSERT(cnt * sizeof(T) < sizeof(scratch));
OPENVINO_ASSERT(cnt * sizeof(T) < sizeof(scratch));
// DEBUG_LOG(reinterpret_cast<void*>(scratch));
return reinterpret_cast<T*>(scratch);
}
Expand Down Expand Up @@ -668,9 +670,9 @@ static std::shared_ptr<SIMDJit> get_decompress_zp_u8() {
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();

auto zp_input_u8 = jit->get_sreg(0);
auto zp_output_f32 = jit->get_sreg(1);
auto cnt = jit->get_sreg(2);
auto zp_input_u8 = jit->get_arg(0);
auto zp_output_f32 = jit->get_arg(1);
auto cnt = jit->get_arg(2);

auto n = jit->get_sreg();

Expand All @@ -689,9 +691,9 @@ static std::shared_ptr<SIMDJit> get_decompress_zp_u4() {
auto jit = std::make_shared<SIMDJit>(__func__);
auto simd_width = SIMDJit::vmm_width<float>();

auto zp_input_u4 = jit->get_sreg(0);
auto zp_output_f32 = jit->get_sreg(1);
auto cnt = jit->get_sreg(2);
auto zp_input_u4 = jit->get_arg(0);
auto zp_output_f32 = jit->get_arg(1);
auto cnt = jit->get_arg(2);

auto n = jit->get_sreg();

Expand Down Expand Up @@ -1011,9 +1013,7 @@ void ActSparseFcKernel::operator()(const float* input,
const float* scales,
const uint8_t* zp) {
const auto SIMDW = SIMDJit::vmm_width<float>();
if (OC % (2 * SIMDW)) {
throw std::runtime_error(std::string("ActSparseFcKernel: OC is not multiple of ") + std::to_string(2 * SIMDW));
}
OPENVINO_ASSERT((OC % (2 * SIMDW)) == 0, "ActSparseFcKernel: OC is not multiple of ", 2 * SIMDW);

if (M > 1) {
const auto SIMDW = SIMDJit::vmm_width<float>();
Expand Down
Loading

0 comments on commit ef6a6c1

Please sign in to comment.