-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Transformation] LoraSubgraph fusion (#27068)
### Details: - *Introduced `LoraSubgraph` operation, which is used for LoRA subgraphs fusion for further optimizations (for the details, please refer to the description in the header)* - *Introduced `LoraSubgraphFusion` pass* - *The changes are covered by transformation tests* ### Tickets: - *CVS-153035* - *CVS-155112*
- Loading branch information
Showing
5 changed files
with
457 additions
and
0 deletions.
There are no files selected for viewing
38 changes: 38 additions & 0 deletions
38
src/common/transformations/include/ov_ops/lora_subgraph.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "openvino/op/util/sub_graph_base.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
/** | ||
* @interface LoraSubgraph | ||
* @brief LoraSubgraph operation, which is used for LoRA subgraphs fusion. | ||
* It always has only 1 output, and the following inputs, whose order is fixed: | ||
* 1. main_flow_input: input from original model. | ||
* 2. LoRA_input: input to which the Low-Rank adaptation is applied. | ||
* The adapted input is combined with `main_flow_input`. | ||
* 3. LoRA_matrices: 3 Low-Rank adaptation matrices applied to `LoRA_input`. | ||
* The fused subgraph can be optimized in runtime based on LoRA semantic. | ||
* For instance, `main_flow_input` can be fast-forwarded to output in case of empty `LoRA_matrices`. | ||
*/ | ||
class TRANSFORMATIONS_API LoraSubgraph : public ov::op::util::SubGraphOp { | ||
public: | ||
OPENVINO_OP("LoraSubgraph", "ie_internal_opset", ov::op::util::SubGraphOp); | ||
|
||
LoraSubgraph() = default; | ||
LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body); | ||
|
||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
25 changes: 25 additions & 0 deletions
25
...mon/transformations/include/transformations/common_optimizations/lora_subgraph_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API LoraSubgraphFusion; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::LoraSubgraphFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("LoraSubgraphFusion", "0"); | ||
LoraSubgraphFusion(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/lora_subgraph.hpp" | ||
|
||
#include "itt.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
LoraSubgraph::LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body) : SubGraphOp(args) { | ||
SubGraphOp::set_function(body); | ||
for (size_t i = 0; i < body->get_parameters().size(); ++i) | ||
m_input_descriptions[0].push_back(std::make_shared<InvariantInputDescription>(i, i)); | ||
for (size_t i = 0; i < body->get_output_size(); ++i) | ||
m_output_descriptions[0].push_back(std::make_shared<BodyOutputDescription>(i, i)); | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
std::shared_ptr<Node> LoraSubgraph::clone_with_new_inputs(const OutputVector& new_args) const { | ||
INTERNAL_OP_SCOPE(internal_LoraSubgraph_clone_with_new_inputs); | ||
check_new_args_count(this, new_args); | ||
return std::make_shared<LoraSubgraph>(new_args, get_function()->clone()); | ||
} | ||
|
||
void LoraSubgraph::validate_and_infer_types() { | ||
INTERNAL_OP_SCOPE(internal_LoraSubgraph_validate_and_infer_types); | ||
OPENVINO_ASSERT(get_input_size() == 5, "LoraSubgraph must have 5 inputs whereas it has ", get_input_size()); | ||
OPENVINO_ASSERT(get_output_size() == 1, "LoraSubgraph must have 1 output whereas it has ", get_output_size()); | ||
const auto& body = get_function(); | ||
OPENVINO_ASSERT(body, "LoraSubgraph must have initialized body"); | ||
validate_and_infer_type_body(body, m_input_descriptions[0]); | ||
for (size_t i = 0; i < get_output_size(); ++i) | ||
set_output_type(i, body->get_output_element_type(i), body->get_output_partial_shape(i)); | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
108 changes: 108 additions & 0 deletions
108
src/common/transformations/src/transformations/common_optimizations/lora_subgraph_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/lora_subgraph_fusion.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/convolution.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/util/read_value_base.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "ov_ops/lora_subgraph.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
ov::pass::LoraSubgraphFusion::LoraSubgraphFusion() { | ||
MATCHER_SCOPE(LoraSubgraphFusion); | ||
using namespace pass::pattern; | ||
auto lora_input_m = any_input(); | ||
auto transpose_const1_m = wrap_type<ov::op::v0::Constant>(consumers_count(1)); | ||
auto transpose1_m = optional<ov::op::v1::Transpose>({lora_input_m, transpose_const1_m}, consumers_count(1)); | ||
auto read_value1_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto matmul1_m = wrap_type<ov::op::v0::MatMul>({transpose1_m, read_value1_m}, consumers_count(1)); | ||
auto read_value2_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto multiply_m = wrap_type<ov::op::v1::Multiply>({matmul1_m, read_value2_m}, consumers_count(1)); | ||
auto read_value3_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto matmul2_m = wrap_type<ov::op::v0::MatMul>({multiply_m, read_value3_m}, consumers_count(1)); | ||
auto transpose_const2_m = wrap_type<ov::op::v0::Constant>(consumers_count(1)); | ||
auto transpose2_m = optional<ov::op::v1::Transpose>({matmul2_m, transpose_const2_m}, consumers_count(1)); | ||
auto main_flow_m = wrap_type<ov::op::v0::MatMul, ov::op::v1::Convolution>({lora_input_m, any_input()}); | ||
auto add_m = wrap_type<ov::op::v1::Add>({transpose2_m, main_flow_m}); | ||
|
||
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
const auto& lora_input = pattern_map.at(lora_input_m); | ||
const auto& matmul1 = pattern_map.at(matmul1_m); | ||
const auto& read_value1 = pattern_map.at(read_value1_m); | ||
const auto& multiply = pattern_map.at(multiply_m); | ||
const auto& read_value2 = pattern_map.at(read_value2_m); | ||
const auto& matmul2 = pattern_map.at(matmul2_m); | ||
const auto& read_value3 = pattern_map.at(read_value3_m); | ||
const auto& main_flow = pattern_map.at(main_flow_m); | ||
const auto& add = pattern_map.at(add_m); | ||
|
||
const auto add_node = add.get_node_shared_ptr(); | ||
if (transformation_callback(add_node)) { | ||
return false; | ||
} | ||
|
||
auto find_connected_input = [](ov::Node* child, ov::Node* parent) { | ||
for (size_t i = 0; i < child->get_input_size(); ++i) { | ||
auto input = child->input(i); | ||
if (input.get_source_output().get_node() == parent) | ||
return input; | ||
} | ||
OPENVINO_THROW("Ops are not connected"); | ||
}; | ||
|
||
// Note: internal_inputs/external_connections order corresponds to LoraSubgraph semantic | ||
const std::vector<ov::Input<ov::Node>> internal_inputs{ | ||
// For commutative eltwise ops, input idx may be any, so it must be computed | ||
find_connected_input(add.get_node(), main_flow.get_node()), | ||
pattern_map.count(transpose1_m) ? pattern_map.at(transpose1_m).get_node()->input(0) | ||
: matmul1.get_node()->input(0), | ||
matmul1.get_node()->input(1), | ||
find_connected_input(multiply.get_node(), read_value2.get_node()), | ||
matmul2.get_node()->input(1), | ||
}; | ||
const ov::OutputVector external_connections{ | ||
main_flow, | ||
lora_input, | ||
read_value1, | ||
read_value2, | ||
read_value3, | ||
}; | ||
|
||
ov::ParameterVector subgraph_parameters; | ||
subgraph_parameters.reserve(internal_inputs.size()); | ||
for (auto& in : internal_inputs) { | ||
auto new_parameter = std::make_shared<ov::op::v0::Parameter>(in.get_element_type(), in.get_partial_shape()); | ||
subgraph_parameters.push_back(new_parameter); | ||
in.replace_source_output(new_parameter); | ||
} | ||
// Note: lora consumers should be taken before lora_subgraph creation, | ||
// because only original consumers should be replaced with lora's output | ||
const auto& lora_consumers = add.get_target_inputs(); | ||
const auto lora_subgraph = std::make_shared<ov::Model>(ov::OutputVector{add}, subgraph_parameters); | ||
const auto lora_node = std::make_shared<ov::op::internal::LoraSubgraph>(external_connections, lora_subgraph); | ||
ov::copy_runtime_info(m.get_matched_nodes(), lora_node); | ||
lora_node->set_friendly_name(add_node->get_friendly_name()); | ||
|
||
for (const auto& consumer : lora_consumers) | ||
consumer.replace_source_output(lora_node->output(0)); | ||
if (!add.get_names().empty()) | ||
lora_node->output(0).set_names(add.get_names()); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<Matcher>(add_m, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
Oops, something went wrong.