Skip to content

Commit

Permalink
+1D dims support for param/const in DQMatMulCWi matcher
Browse files Browse the repository at this point in the history
-extra optional reshape remove.
  • Loading branch information
esmirno committed Dec 6, 2024
1 parent 7e55f15 commit f4ee57b
Showing 1 changed file with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,21 @@ DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {

auto qcoeff_shape = matched_node_qcoeff->output(0).get_shape();

LOG_DEBUG("DQMatMulCWi matched_qweight->get_element_type(): " << matched_qweight->get_element_type());
LOG_DEBUG("DQMatMulCWi matched_node_qcoeff: " << matched_node_qcoeff->get_friendly_name());
LOG_DEBUG("DQMatMulCWi qcoeff_shape: " << qcoeff_shape);
LOG_DEBUG("DQMatMulCWi matched_matmul->get_transpose_a(): " << matched_matmul->get_transpose_a());
LOG_DEBUG("DQMatMulCWi matched_matmul->get_transpose_b(): " << matched_matmul->get_transpose_b());
LOG_DEBUG("DQMatMulCWi ctx.get().mm_dq_full: " << ctx.get().mm_dq_full);

if ((ov::element::i4 == matched_qweight->get_element_type() ||
ov::element::i8 == matched_qweight->get_element_type()) &&
(ov::op::util::is_parameter(matched_node_qcoeff) || ov::op::util::is_constant(matched_node_qcoeff)) &&
qcoeff_shape[1] == 1 && !matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
(qcoeff_shape.size() == 1 || qcoeff_shape[1] == 1) && !matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
auto matched_node_cvtw = node_to_output.at(qcvtw).get_node_shared_ptr();
auto matched_node_muls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_mmi = node_to_output.at(qmmi).get_node_shared_ptr();
auto& matched_node_qcoeff_out = uat::_(node_to_output).at_or_at_or_at(qcvtc, reshapec, qcoeff);
auto& matched_node_qcoeff_out = uat::_(node_to_output).at_or_at(qcvtc, qcoeff);
auto& matched_node_muls_out = uat::_(node_to_output).at_or_at(qcvtm, qmuls);

if (!ctx.get().mm_dq_full) {
Expand All @@ -189,7 +196,7 @@ DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {
auto mm_readers = matched_matmul->output(0).get_target_inputs();

// Introduce a Reshape to alter Scale factor's shape
auto new_dims = std::vector<std::size_t>{qcoeff_shape[1], qcoeff_shape[0]};
auto new_dims = std::vector<std::size_t>{1, qcoeff_shape[0]};
auto new_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims);
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(matched_node_qcoeff_out, new_const, false);

Expand Down

0 comments on commit f4ee57b

Please sign in to comment.