Skip to content

Commit 2daa8b9

Browse files
authored
[Transformation] LoraSubgraph fusion (#27068)
### Details: - *Introduced `LoraSubgraph` operation, which is used for LoRA subgraphs fusion for further optimizations (for the details, please refer to the description in the header)* - *Introduced `LoraSubgraphFusion` pass* - *The changes are covered by transformation tests* ### Tickets: - *CVS-153035* - *CVS-155112*
1 parent 8d6f517 commit 2daa8b9

File tree

5 files changed

+457
-0
lines changed

5 files changed

+457
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/op/op.hpp"
8+
#include "openvino/op/util/sub_graph_base.hpp"
9+
#include "transformations_visibility.hpp"
10+
11+
namespace ov {
12+
namespace op {
13+
namespace internal {
14+
/**
15+
* @interface LoraSubgraph
16+
* @brief LoraSubgraph operation, which is used for LoRA subgraphs fusion.
17+
* It always has only 1 output, and the following inputs, whose order is fixed:
18+
* 1. main_flow_input: input from original model.
19+
* 2. LoRA_input: input to which the Low-Rank adaptation is applied.
20+
* The adapted input is combined with `main_flow_input`.
21+
* 3. LoRA_matrices: 3 Low-Rank adaptation matrices applied to `LoRA_input`.
22+
* The fused subgraph can be optimized in runtime based on LoRA semantic.
23+
* For instance, `main_flow_input` can be fast-forwarded to output in case of empty `LoRA_matrices`.
24+
*/
25+
class TRANSFORMATIONS_API LoraSubgraph : public ov::op::util::SubGraphOp {
26+
public:
27+
OPENVINO_OP("LoraSubgraph", "ie_internal_opset", ov::op::util::SubGraphOp);
28+
29+
LoraSubgraph() = default;
30+
LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body);
31+
32+
void validate_and_infer_types() override;
33+
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
34+
};
35+
36+
} // namespace internal
37+
} // namespace op
38+
} // namespace ov
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <vector>
9+
10+
#include "openvino/pass/matcher_pass.hpp"
11+
#include "transformations_visibility.hpp"
12+
13+
namespace ov {
14+
namespace pass {
15+
16+
class TRANSFORMATIONS_API LoraSubgraphFusion;
17+
18+
} // namespace pass
19+
} // namespace ov
20+
21+
class ov::pass::LoraSubgraphFusion : public ov::pass::MatcherPass {
22+
public:
23+
OPENVINO_RTTI("LoraSubgraphFusion", "0");
24+
LoraSubgraphFusion();
25+
};
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "ov_ops/lora_subgraph.hpp"
6+
7+
#include "itt.hpp"
8+
9+
namespace ov {
10+
namespace op {
11+
namespace internal {
12+
13+
LoraSubgraph::LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body) : SubGraphOp(args) {
14+
SubGraphOp::set_function(body);
15+
for (size_t i = 0; i < body->get_parameters().size(); ++i)
16+
m_input_descriptions[0].push_back(std::make_shared<InvariantInputDescription>(i, i));
17+
for (size_t i = 0; i < body->get_output_size(); ++i)
18+
m_output_descriptions[0].push_back(std::make_shared<BodyOutputDescription>(i, i));
19+
constructor_validate_and_infer_types();
20+
}
21+
22+
std::shared_ptr<Node> LoraSubgraph::clone_with_new_inputs(const OutputVector& new_args) const {
23+
INTERNAL_OP_SCOPE(internal_LoraSubgraph_clone_with_new_inputs);
24+
check_new_args_count(this, new_args);
25+
return std::make_shared<LoraSubgraph>(new_args, get_function()->clone());
26+
}
27+
28+
void LoraSubgraph::validate_and_infer_types() {
29+
INTERNAL_OP_SCOPE(internal_LoraSubgraph_validate_and_infer_types);
30+
OPENVINO_ASSERT(get_input_size() == 5, "LoraSubgraph must have 5 inputs whereas it has ", get_input_size());
31+
OPENVINO_ASSERT(get_output_size() == 1, "LoraSubgraph must have 1 output whereas it has ", get_output_size());
32+
const auto& body = get_function();
33+
OPENVINO_ASSERT(body, "LoraSubgraph must have initialized body");
34+
validate_and_infer_type_body(body, m_input_descriptions[0]);
35+
for (size_t i = 0; i < get_output_size(); ++i)
36+
set_output_type(i, body->get_output_element_type(i), body->get_output_partial_shape(i));
37+
}
38+
39+
} // namespace internal
40+
} // namespace op
41+
} // namespace ov
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/lora_subgraph_fusion.hpp"
6+
7+
#include <memory>
8+
#include <vector>
9+
10+
#include "itt.hpp"
11+
#include "openvino/op/add.hpp"
12+
#include "openvino/op/convolution.hpp"
13+
#include "openvino/op/matmul.hpp"
14+
#include "openvino/op/multiply.hpp"
15+
#include "openvino/op/parameter.hpp"
16+
#include "openvino/op/transpose.hpp"
17+
#include "openvino/op/util/read_value_base.hpp"
18+
#include "openvino/pass/pattern/op/optional.hpp"
19+
#include "openvino/pass/pattern/op/wrap_type.hpp"
20+
#include "ov_ops/lora_subgraph.hpp"
21+
#include "transformations/utils/utils.hpp"
22+
23+
ov::pass::LoraSubgraphFusion::LoraSubgraphFusion() {
24+
MATCHER_SCOPE(LoraSubgraphFusion);
25+
using namespace pass::pattern;
26+
auto lora_input_m = any_input();
27+
auto transpose_const1_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
28+
auto transpose1_m = optional<ov::op::v1::Transpose>({lora_input_m, transpose_const1_m}, consumers_count(1));
29+
auto read_value1_m = wrap_type<ov::op::util::ReadValueBase>();
30+
auto matmul1_m = wrap_type<ov::op::v0::MatMul>({transpose1_m, read_value1_m}, consumers_count(1));
31+
auto read_value2_m = wrap_type<ov::op::util::ReadValueBase>();
32+
auto multiply_m = wrap_type<ov::op::v1::Multiply>({matmul1_m, read_value2_m}, consumers_count(1));
33+
auto read_value3_m = wrap_type<ov::op::util::ReadValueBase>();
34+
auto matmul2_m = wrap_type<ov::op::v0::MatMul>({multiply_m, read_value3_m}, consumers_count(1));
35+
auto transpose_const2_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
36+
auto transpose2_m = optional<ov::op::v1::Transpose>({matmul2_m, transpose_const2_m}, consumers_count(1));
37+
auto main_flow_m = wrap_type<ov::op::v0::MatMul, ov::op::v1::Convolution>({lora_input_m, any_input()});
38+
auto add_m = wrap_type<ov::op::v1::Add>({transpose2_m, main_flow_m});
39+
40+
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) {
41+
const auto& pattern_map = m.get_pattern_value_map();
42+
const auto& lora_input = pattern_map.at(lora_input_m);
43+
const auto& matmul1 = pattern_map.at(matmul1_m);
44+
const auto& read_value1 = pattern_map.at(read_value1_m);
45+
const auto& multiply = pattern_map.at(multiply_m);
46+
const auto& read_value2 = pattern_map.at(read_value2_m);
47+
const auto& matmul2 = pattern_map.at(matmul2_m);
48+
const auto& read_value3 = pattern_map.at(read_value3_m);
49+
const auto& main_flow = pattern_map.at(main_flow_m);
50+
const auto& add = pattern_map.at(add_m);
51+
52+
const auto add_node = add.get_node_shared_ptr();
53+
if (transformation_callback(add_node)) {
54+
return false;
55+
}
56+
57+
auto find_connected_input = [](ov::Node* child, ov::Node* parent) {
58+
for (size_t i = 0; i < child->get_input_size(); ++i) {
59+
auto input = child->input(i);
60+
if (input.get_source_output().get_node() == parent)
61+
return input;
62+
}
63+
OPENVINO_THROW("Ops are not connected");
64+
};
65+
66+
// Note: internal_inputs/external_connections order corresponds to LoraSubgraph semantic
67+
const std::vector<ov::Input<ov::Node>> internal_inputs{
68+
// For commutative eltwise ops, input idx may be any, so it must be computed
69+
find_connected_input(add.get_node(), main_flow.get_node()),
70+
pattern_map.count(transpose1_m) ? pattern_map.at(transpose1_m).get_node()->input(0)
71+
: matmul1.get_node()->input(0),
72+
matmul1.get_node()->input(1),
73+
find_connected_input(multiply.get_node(), read_value2.get_node()),
74+
matmul2.get_node()->input(1),
75+
};
76+
const ov::OutputVector external_connections{
77+
main_flow,
78+
lora_input,
79+
read_value1,
80+
read_value2,
81+
read_value3,
82+
};
83+
84+
ov::ParameterVector subgraph_parameters;
85+
subgraph_parameters.reserve(internal_inputs.size());
86+
for (auto& in : internal_inputs) {
87+
auto new_parameter = std::make_shared<ov::op::v0::Parameter>(in.get_element_type(), in.get_partial_shape());
88+
subgraph_parameters.push_back(new_parameter);
89+
in.replace_source_output(new_parameter);
90+
}
91+
// Note: lora consumers should be taken before lora_subgraph creation,
92+
// because only original consumers should be replaced with lora's output
93+
const auto& lora_consumers = add.get_target_inputs();
94+
const auto lora_subgraph = std::make_shared<ov::Model>(ov::OutputVector{add}, subgraph_parameters);
95+
const auto lora_node = std::make_shared<ov::op::internal::LoraSubgraph>(external_connections, lora_subgraph);
96+
ov::copy_runtime_info(m.get_matched_nodes(), lora_node);
97+
lora_node->set_friendly_name(add_node->get_friendly_name());
98+
99+
for (const auto& consumer : lora_consumers)
100+
consumer.replace_source_output(lora_node->output(0));
101+
if (!add.get_names().empty())
102+
lora_node->output(0).set_names(add.get_names());
103+
return true;
104+
};
105+
106+
auto m = std::make_shared<Matcher>(add_m, matcher_name);
107+
this->register_matcher(m, callback);
108+
}

0 commit comments

Comments
 (0)