Skip to content

Commit

Permalink
refactor executor
Browse files Browse the repository at this point in the history
  • Loading branch information
usstq committed Dec 30, 2024
1 parent ef6a6c1 commit 2f714a3
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 324 deletions.
275 changes: 140 additions & 135 deletions src/plugins/intel_cpu/src/nodes/act_sparse_fc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,148 +30,20 @@ namespace ov {
namespace intel_cpu {
namespace node {

#if defined(OPENVINO_ARCH_X86_64)

struct ActSparseFC::Executor : public ActSparseFC::ExecutorBase {
ActSparseFC* m_node;
std::shared_ptr<ActSparseFcKernel> m_kernel;
DnnlScratchPadPtr m_scrachPad;
MemoryPtr m_weight;
MemoryPtr m_zp;
MemoryPtr m_scales;
ActSparseFCNode::Config& m_config;

Executor(ActSparseFC* pnode, DnnlScratchPadPtr scrachPad)
: m_node(pnode),
m_scrachPad(scrachPad),
m_config(m_node->m_config) {
m_kernel = std::make_shared<ActSparseFcKernel>(
m_config.is_quantized ? (m_config.is_int4 ? WeightCompressionType::INT4 : WeightCompressionType::INT8)
: (WeightCompressionType::FP16),
m_config.with_zero_point,
m_config.ic_q_group_size);
// reorder weights
const auto& context = m_node->context;
const auto& engine = m_node->getEngine();

auto create_weight = [&]() {
auto raw_weight_mem = m_node->getSrcMemoryAtPort(1);
MemoryPtr weight_mem;
if (m_config.is_int4) {
// weight : [OC, IC/group_size, group_size] => [IC, OC/2, 2]
// each row is further reordered in unit of 16 x i4 in [0,8,1,9,2,a,3,b,4,c,5,d,6,e,7,f] order
weight_mem = std::make_shared<Memory>(engine, raw_weight_mem->getDescPtr());

const auto& dims = raw_weight_mem->getShape().getStaticDims();
OPENVINO_ASSERT(dims.size() == 3);
OPENVINO_ASSERT(dims[0] == static_cast<size_t>(m_config.oc));
OPENVINO_ASSERT(dims[1] == static_cast<size_t>(m_config.ic / m_config.ic_q_group_size));
OPENVINO_ASSERT(dims[2] == static_cast<size_t>(m_config.ic_q_group_size));

auto* src = raw_weight_mem->getDataAs<uint8_t>();
auto* dst = weight_mem->getDataAs<uint8_t>();
m_kernel->repack_weights_i4(src, dst, m_config.ic, m_config.oc);
} else {
// raw [OC, IC] layout
// target [IC, OC] layout
ArbitraryOrderDescCreator descCreator({1, 0});
auto dst_mem_desc =
descCreator.createSharedDesc(raw_weight_mem->getPrecision(), raw_weight_mem->getShape());

weight_mem = std::make_shared<Memory>(engine, dst_mem_desc);
node::Reorder::reorderData(*raw_weight_mem, *weight_mem, context->getParamsCache());
}
return weight_mem;
};

auto create_zp_i4 = [&]() {
// [OC, IC/group_size, 1] => [IC/group_size, OC]
auto raw_zp_mem = m_node->getSrcMemoryAtPort(3);
auto zp_mem = std::make_shared<Memory>(engine, raw_zp_mem->getDescPtr());

auto* src = raw_zp_mem->getDataAs<uint8_t>();
auto* dst = zp_mem->getDataAs<uint8_t>();

m_kernel->repack_weights_i4(src, dst, m_config.ic / m_config.ic_q_group_size, m_config.oc);
return zp_mem;
};

auto create_scales_i4 = [&]() {
// [OC, IC/group_size, 1] => [IC/group_size, OC]
auto raw_scales_mem = m_node->getSrcMemoryAtPort(2);
ArbitraryOrderDescCreator descCreator({2, 1, 0});
auto dst_mem_desc =
descCreator.createSharedDesc(raw_scales_mem->getPrecision(), raw_scales_mem->getShape());

auto scales_mem = std::make_shared<Memory>(engine, dst_mem_desc);
node::Reorder::reorderData(*raw_scales_mem, *scales_mem, context->getParamsCache());
return scales_mem;
};

if (!m_config.is_int4) {
// int8 is perOC, no need for reorder
if (m_config.is_quantized)
m_scales = m_node->getSrcMemoryAtPort(2);
if (m_config.with_zero_point)
m_zp = m_node->getSrcMemoryAtPort(3);
}

auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) {
const auto string_hash = m_node->getOriginalLayers() + std::to_string(m_config.is_int4);
m_weight = *weightCache->findOrCreate(string_hash + "_weight", create_weight);
if (m_config.is_int4) {
if (m_config.with_zero_point)
m_zp = *weightCache->findOrCreate(string_hash + "_zp_i4", create_zp_i4);
if (m_config.is_quantized)
m_scales = *weightCache->findOrCreate(string_hash + "_scales_i4", create_scales_i4);
}
} else {
m_weight = create_weight();
if (m_config.is_int4) {
if (m_config.with_zero_point)
m_zp = create_zp_i4();
if (m_config.is_quantized)
m_scales = create_scales_i4();
}
}
}

void execute() override {
const auto* input = m_node->getSrcDataAtPortAs<float>(0);
void ActSparseFC::execute(dnnl::stream strm) {
MAYBE_UNUSED(strm);
if (m_executor) {
const auto* input = getSrcDataAtPortAs<float>(0);
const auto* weight = m_weight->getDataAs<uint8_t>();
const auto* zp = m_config.with_zero_point ? m_zp->getDataAs<uint8_t>() : nullptr;
const auto* scales = m_config.is_quantized ? m_scales->getDataAs<float>() : nullptr;
auto* output = m_node->getDstDataAtPortAs<float>(0);
auto* output = getDstDataAtPortAs<float>(0);

const auto& ishape = m_node->getSrcMemoryAtPort(0)->getStaticDims();
const auto& ishape = getSrcMemoryAtPort(0)->getStaticDims();
int M = shape_size(ishape) / ishape[ishape.size() - 1];

(*m_kernel)(input, output, M, m_config.ic, m_config.oc, m_config.threshold, 0, weight, scales, zp);
(*m_executor)(input, output, M, m_config.ic, m_config.oc, m_config.threshold, 0, weight, scales, zp);
}
};
#else
struct ActSparseFC::Executor : public ActSparseFC::ExecutorBase {
ActSparseFC* m_pnode;
Executor(ActSparseFC* pnode, DnnlScratchPadPtr scrachPad) : m_pnode(pnode) {}
void execute() override {}
};
#endif

void ActSparseFC::createPrimitive() {
auto rtPrecision = getInputPrecisions()[0];
#ifdef OPENVINO_ARCH_X86_64
m_executor = std::make_shared<Executor>(this, context->getScratchPad());
#endif

if (!m_executor) {
OPENVINO_THROW("ActSparseFC Executor creation fails with precision " + rtPrecision.to_string());
}
}

void ActSparseFC::execute(dnnl::stream strm) {
MAYBE_UNUSED(strm);
m_executor->execute();
}

ActSparseFC::ActSparseFC(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
Expand All @@ -185,6 +57,139 @@ ActSparseFC::ActSparseFC(const std::shared_ptr<ov::Node>& op, const GraphContext
m_config = node->get_config();
}

struct ActSparseFCKey {
bool is_quantized;
bool is_int4;
bool with_zero_point;
int ic_q_group_size;

size_t hash() const {
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, is_quantized);
seed = hash_combine(seed, is_int4);
seed = hash_combine(seed, with_zero_point);
seed = hash_combine(seed, ic_q_group_size);
return seed;
}

bool operator==(const ActSparseFCKey& rhs) const {
return is_quantized == rhs.is_quantized && is_int4 == rhs.is_int4 && with_zero_point == rhs.with_zero_point && ic_q_group_size == rhs.ic_q_group_size;
}
};

void ActSparseFC::createPrimitive() {
ActSparseFCKey key;
key.is_quantized = m_config.is_quantized;
key.is_int4 = m_config.is_int4;
key.with_zero_point = m_config.with_zero_point;
key.ic_q_group_size = m_config.ic_q_group_size;

auto buildExecutor = [](const ActSparseFCKey& key) -> std::shared_ptr<ActSparseFcKernel> {
#if defined(OPENVINO_ARCH_X86_64)
return std::make_shared<ActSparseFcKernel>(key.is_quantized,
key.is_int4,
key.with_zero_point,
key.ic_q_group_size);
#else
return nullptr;
#endif
};

m_executor = nullptr;
auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor);
m_executor = result.first;

if (!m_executor)
OPENVINO_THROW("Failed to create executor for node ", getName(), ".");

// reorder weights
const auto& engine = getEngine();

auto create_weight = [&]() {
auto raw_weight_mem = getSrcMemoryAtPort(1);
MemoryPtr weight_mem;
if (m_config.is_int4) {
// weight : [OC, IC/group_size, group_size] => [IC, OC/2, 2]
// each row is further reordered in unit of 16 x i4 in [0,8,1,9,2,a,3,b,4,c,5,d,6,e,7,f] order
weight_mem = std::make_shared<Memory>(engine, raw_weight_mem->getDescPtr());

const auto& dims = raw_weight_mem->getShape().getStaticDims();
OPENVINO_ASSERT(dims.size() == 3);
OPENVINO_ASSERT(dims[0] == static_cast<size_t>(m_config.oc));
OPENVINO_ASSERT(dims[1] == static_cast<size_t>(m_config.ic / m_config.ic_q_group_size));
OPENVINO_ASSERT(dims[2] == static_cast<size_t>(m_config.ic_q_group_size));

auto* src = raw_weight_mem->getDataAs<uint8_t>();
auto* dst = weight_mem->getDataAs<uint8_t>();
m_executor->repack_weights_i4(src, dst, m_config.ic, m_config.oc);
} else {
// raw [OC, IC] layout
// target [IC, OC] layout
ArbitraryOrderDescCreator descCreator({1, 0});
auto dst_mem_desc =
descCreator.createSharedDesc(raw_weight_mem->getPrecision(), raw_weight_mem->getShape());

weight_mem = std::make_shared<Memory>(engine, dst_mem_desc);
node::Reorder::reorderData(*raw_weight_mem, *weight_mem, context->getParamsCache());
}
return weight_mem;
};

auto create_zp_i4 = [&]() {
// [OC, IC/group_size, 1] => [IC/group_size, OC]
auto raw_zp_mem = getSrcMemoryAtPort(3);
auto zp_mem = std::make_shared<Memory>(engine, raw_zp_mem->getDescPtr());

auto* src = raw_zp_mem->getDataAs<uint8_t>();
auto* dst = zp_mem->getDataAs<uint8_t>();

m_executor->repack_weights_i4(src, dst, m_config.ic / m_config.ic_q_group_size, m_config.oc);
return zp_mem;
};

auto create_scales_i4 = [&]() {
// [OC, IC/group_size, 1] => [IC/group_size, OC]
auto raw_scales_mem = getSrcMemoryAtPort(2);
ArbitraryOrderDescCreator descCreator({2, 1, 0});
auto dst_mem_desc =
descCreator.createSharedDesc(raw_scales_mem->getPrecision(), raw_scales_mem->getShape());

auto scales_mem = std::make_shared<Memory>(engine, dst_mem_desc);
node::Reorder::reorderData(*raw_scales_mem, *scales_mem, context->getParamsCache());
return scales_mem;
};

if (!m_config.is_int4) {
// int8 is perOC, no need for reorder
if (m_config.is_quantized)
m_scales = getSrcMemoryAtPort(2);
if (m_config.with_zero_point)
m_zp = getSrcMemoryAtPort(3);
}

auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) {
const auto string_hash = getOriginalLayers() + std::to_string(m_config.is_int4);
m_weight = *weightCache->findOrCreate(string_hash + "_weight", create_weight);
if (m_config.is_int4) {
if (m_config.with_zero_point)
m_zp = *weightCache->findOrCreate(string_hash + "_zp_i4", create_zp_i4);
if (m_config.is_quantized)
m_scales = *weightCache->findOrCreate(string_hash + "_scales_i4", create_scales_i4);
}
} else {
m_weight = create_weight();
if (m_config.is_int4) {
if (m_config.with_zero_point)
m_zp = create_zp_i4();
if (m_config.is_quantized)
m_scales = create_scales_i4();
}
}
}

void ActSparseFC::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
Expand Down
42 changes: 34 additions & 8 deletions src/plugins/intel_cpu/src/nodes/act_sparse_fc.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,34 @@
#include "transformations/cpu_opset/x64/op/act_sparse_fc.hpp"

#if defined(OPENVINO_ARCH_X86_64)
# include "kernels/x64/mlp_kernel.hpp"
# include "kernels/x64/act_sparse_fc_kernel.hpp"
#else
namespace ov {
namespace intel_cpu {
class ActSparseFcKernel {
public:
// compile time parameters
ActSparseFcKernel(bool is_quantized, bool is_int4, bool with_zero_points, int ic_group_size);

void operator()(const float* input,
float* output,
int M,
int IC,
int OC,
float threshold,
float zero_point,
const void* W,
const float* scales,
const uint8_t* zp) {
OPENVINO_THROW("Unsupported platform.");
}

void repack_weights_i4(uint8_t* src, uint8_t* dst, int IC, int OC) {
OPENVINO_THROW("Unsupported platform.");
}
};
} // namespace intel_cpu
} // namespace ov
#endif

namespace ov {
Expand All @@ -24,7 +51,7 @@ class ActSparseFC : public Node {
return getType() == Type::ActSparseFC;
}
bool needPrepareParams() const override {
return false;
return false; // this is a shape-agnostic kernel
}
void createPrimitive() override;
void executeDynamicImpl(dnnl::stream strm) override {
Expand All @@ -35,12 +62,11 @@ class ActSparseFC : public Node {
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;

private:
struct ExecutorBase {
virtual void execute() = 0;
virtual ~ExecutorBase() = default;
};
std::shared_ptr<ExecutorBase> m_executor;
struct Executor;
std::shared_ptr<ov::intel_cpu::ActSparseFcKernel> m_executor;

MemoryPtr m_weight;
MemoryPtr m_zp;
MemoryPtr m_scales;

ActSparseFCNode::Config m_config;
};
Expand Down
Loading

0 comments on commit 2f714a3

Please sign in to comment.