|
| 1 | +// Copyright (C) 2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "transformations/common_optimizations/lora_subgraph_fusion.hpp" |
| 6 | + |
| 7 | +#include <memory> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +#include "itt.hpp" |
| 11 | +#include "openvino/op/add.hpp" |
| 12 | +#include "openvino/op/convolution.hpp" |
| 13 | +#include "openvino/op/matmul.hpp" |
| 14 | +#include "openvino/op/multiply.hpp" |
| 15 | +#include "openvino/op/parameter.hpp" |
| 16 | +#include "openvino/op/transpose.hpp" |
| 17 | +#include "openvino/op/util/read_value_base.hpp" |
| 18 | +#include "openvino/pass/pattern/op/optional.hpp" |
| 19 | +#include "openvino/pass/pattern/op/wrap_type.hpp" |
| 20 | +#include "ov_ops/lora_subgraph.hpp" |
| 21 | +#include "transformations/utils/utils.hpp" |
| 22 | + |
| 23 | +ov::pass::LoraSubgraphFusion::LoraSubgraphFusion() { |
| 24 | + MATCHER_SCOPE(LoraSubgraphFusion); |
| 25 | + using namespace pass::pattern; |
| 26 | + auto lora_input_m = any_input(); |
| 27 | + auto transpose_const1_m = wrap_type<ov::op::v0::Constant>(consumers_count(1)); |
| 28 | + auto transpose1_m = optional<ov::op::v1::Transpose>({lora_input_m, transpose_const1_m}, consumers_count(1)); |
| 29 | + auto read_value1_m = wrap_type<ov::op::util::ReadValueBase>(); |
| 30 | + auto matmul1_m = wrap_type<ov::op::v0::MatMul>({transpose1_m, read_value1_m}, consumers_count(1)); |
| 31 | + auto read_value2_m = wrap_type<ov::op::util::ReadValueBase>(); |
| 32 | + auto multiply_m = wrap_type<ov::op::v1::Multiply>({matmul1_m, read_value2_m}, consumers_count(1)); |
| 33 | + auto read_value3_m = wrap_type<ov::op::util::ReadValueBase>(); |
| 34 | + auto matmul2_m = wrap_type<ov::op::v0::MatMul>({multiply_m, read_value3_m}, consumers_count(1)); |
| 35 | + auto transpose_const2_m = wrap_type<ov::op::v0::Constant>(consumers_count(1)); |
| 36 | + auto transpose2_m = optional<ov::op::v1::Transpose>({matmul2_m, transpose_const2_m}, consumers_count(1)); |
| 37 | + auto main_flow_m = wrap_type<ov::op::v0::MatMul, ov::op::v1::Convolution>({lora_input_m, any_input()}); |
| 38 | + auto add_m = wrap_type<ov::op::v1::Add>({transpose2_m, main_flow_m}); |
| 39 | + |
| 40 | + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { |
| 41 | + const auto& pattern_map = m.get_pattern_value_map(); |
| 42 | + const auto& lora_input = pattern_map.at(lora_input_m); |
| 43 | + const auto& matmul1 = pattern_map.at(matmul1_m); |
| 44 | + const auto& read_value1 = pattern_map.at(read_value1_m); |
| 45 | + const auto& multiply = pattern_map.at(multiply_m); |
| 46 | + const auto& read_value2 = pattern_map.at(read_value2_m); |
| 47 | + const auto& matmul2 = pattern_map.at(matmul2_m); |
| 48 | + const auto& read_value3 = pattern_map.at(read_value3_m); |
| 49 | + const auto& main_flow = pattern_map.at(main_flow_m); |
| 50 | + const auto& add = pattern_map.at(add_m); |
| 51 | + |
| 52 | + const auto add_node = add.get_node_shared_ptr(); |
| 53 | + if (transformation_callback(add_node)) { |
| 54 | + return false; |
| 55 | + } |
| 56 | + |
| 57 | + auto find_connected_input = [](ov::Node* child, ov::Node* parent) { |
| 58 | + for (size_t i = 0; i < child->get_input_size(); ++i) { |
| 59 | + auto input = child->input(i); |
| 60 | + if (input.get_source_output().get_node() == parent) |
| 61 | + return input; |
| 62 | + } |
| 63 | + OPENVINO_THROW("Ops are not connected"); |
| 64 | + }; |
| 65 | + |
| 66 | + // Note: internal_inputs/external_connections order corresponds to LoraSubgraph semantic |
| 67 | + const std::vector<ov::Input<ov::Node>> internal_inputs{ |
| 68 | + // For commutative eltwise ops, input idx may be any, so it must be computed |
| 69 | + find_connected_input(add.get_node(), main_flow.get_node()), |
| 70 | + pattern_map.count(transpose1_m) ? pattern_map.at(transpose1_m).get_node()->input(0) |
| 71 | + : matmul1.get_node()->input(0), |
| 72 | + matmul1.get_node()->input(1), |
| 73 | + find_connected_input(multiply.get_node(), read_value2.get_node()), |
| 74 | + matmul2.get_node()->input(1), |
| 75 | + }; |
| 76 | + const ov::OutputVector external_connections{ |
| 77 | + main_flow, |
| 78 | + lora_input, |
| 79 | + read_value1, |
| 80 | + read_value2, |
| 81 | + read_value3, |
| 82 | + }; |
| 83 | + |
| 84 | + ov::ParameterVector subgraph_parameters; |
| 85 | + subgraph_parameters.reserve(internal_inputs.size()); |
| 86 | + for (auto& in : internal_inputs) { |
| 87 | + auto new_parameter = std::make_shared<ov::op::v0::Parameter>(in.get_element_type(), in.get_partial_shape()); |
| 88 | + subgraph_parameters.push_back(new_parameter); |
| 89 | + in.replace_source_output(new_parameter); |
| 90 | + } |
| 91 | + // Note: lora consumers should be taken before lora_subgraph creation, |
| 92 | + // because only original consumers should be replaced with lora's output |
| 93 | + const auto& lora_consumers = add.get_target_inputs(); |
| 94 | + const auto lora_subgraph = std::make_shared<ov::Model>(ov::OutputVector{add}, subgraph_parameters); |
| 95 | + const auto lora_node = std::make_shared<ov::op::internal::LoraSubgraph>(external_connections, lora_subgraph); |
| 96 | + ov::copy_runtime_info(m.get_matched_nodes(), lora_node); |
| 97 | + lora_node->set_friendly_name(add_node->get_friendly_name()); |
| 98 | + |
| 99 | + for (const auto& consumer : lora_consumers) |
| 100 | + consumer.replace_source_output(lora_node->output(0)); |
| 101 | + if (!add.get_names().empty()) |
| 102 | + lora_node->output(0).set_names(add.get_names()); |
| 103 | + return true; |
| 104 | + }; |
| 105 | + |
| 106 | + auto m = std::make_shared<Matcher>(add_m, matcher_name); |
| 107 | + this->register_matcher(m, callback); |
| 108 | +} |
0 commit comments