Skip to content

Commit 53eb79f

Browse files
ytaousEthan Tao
and
Ethan Tao
authored
Gemm/Transpose fusion - additional pattern coverage (#8941)
* gemm transpose fixes * enforce condition * add comments * rm redundant code Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
1 parent eebcc20 commit 53eb79f

File tree

5 files changed

+193
-10
lines changed

5 files changed

+193
-10
lines changed

onnxruntime/core/optimizer/gemm_transpose_fusion.cc

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,37 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
2626

2727
// check if input A is a Transpose
2828
if (A_node_ptr != nullptr && A_node_ptr->OpType() == "Transpose") {
29-
Node& A_node = *graph.GetNode(A_node_ptr->Index());
30-
transA = !transA;
31-
nodes_to_remove.push_back(A_node);
32-
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
29+
// make sure all consumers are gemm nodes to avoid possible double transpose
30+
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*A_node_ptr, "Gemm");
31+
if (gemm_nodes.size() == A_node_ptr->GetOutputEdgesCount()) {
32+
Node& A_node = *graph.GetNode(A_node_ptr->Index());
33+
transA = !transA;
34+
if (A_node.GetOutputEdgesCount() > 1) {
35+
// remove only the edge between the Transpose and Gemm nodes, the Transpose won't be removed
36+
// since it's still connected to other Gemm. When transformation for the last connected Gemm is
37+
// being processed, it would fall into the else {} below to remove the Transpose node
38+
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(A_node, gemm_node.MutableInputDefs()[0]->Name());
39+
graph.RemoveEdge(A_node.Index(), gemm_node.Index(), output_idx, 0);
40+
} else {
41+
nodes_to_remove.push_back(A_node);
42+
}
43+
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
44+
}
3345
}
3446
// check if input B is a Transpose
3547
if (B_node_ptr != nullptr && B_node_ptr->OpType() == "Transpose") {
36-
Node& B_node = *graph.GetNode(B_node_ptr->Index());
37-
transB = !transB;
38-
nodes_to_remove.push_back(B_node);
39-
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
48+
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*B_node_ptr, "Gemm");
49+
if (gemm_nodes.size() == B_node_ptr->GetOutputEdgesCount()) {
50+
Node& B_node = *graph.GetNode(B_node_ptr->Index());
51+
transB = !transB;
52+
if (B_node.GetOutputEdgesCount() > 1) {
53+
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(B_node, gemm_node.MutableInputDefs()[1]->Name());
54+
graph.RemoveEdge(B_node.Index(), gemm_node.Index(), output_idx, 1);
55+
} else {
56+
nodes_to_remove.push_back(B_node);
57+
}
58+
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
59+
}
4060
}
4161

4262
nodes_to_remove.push_back(gemm_node);
@@ -82,11 +102,14 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
82102
// Fusion can be applied if there is a transpose at either of the inputs
83103
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
84104
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
85-
node_it->GetOutputEdgesCount() == 1 &&
86105
!graph.NodeProducesGraphOutput(*node_it) &&
87106
// Make sure the two nodes do not span execution providers.
88107
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
89-
return true;
108+
// acceptable if all consumer(s) are gemm node(s)
109+
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*node_it, "Gemm");
110+
if (gemm_nodes.size() == node_it->GetOutputEdgesCount()) {
111+
return true;
112+
}
90113
}
91114
}
92115

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,91 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2Inputs) {
12191219
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
12201220
}
12211221

1222+
// (A')'B' = AB' where transpose has multiple consumers
1223+
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
1224+
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose.onnx";
1225+
std::shared_ptr<Model> p_model;
1226+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
1227+
Graph& graph = p_model->MainGraph();
1228+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1229+
ASSERT_EQ(op_to_count["Transpose"], 2);
1230+
ASSERT_EQ(op_to_count["Gemm"], 1);
1231+
ASSERT_EQ(op_to_count["Identity"], 1);
1232+
1233+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1234+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
1235+
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
1236+
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
1237+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
1238+
1239+
op_to_count = CountOpsInGraph(graph);
1240+
ASSERT_EQ(op_to_count["Transpose"], 1);
1241+
ASSERT_EQ(op_to_count["Gemm"], 1);
1242+
ASSERT_EQ(op_to_count["Identity"], 1);
1243+
1244+
auto gemm_node =
1245+
std::find_if(
1246+
graph.Nodes().cbegin(), graph.Nodes().cend(),
1247+
[](const Node& node) { return node.Name() == "Gemm_transformed"; });
1248+
1249+
auto& node = *gemm_node;
1250+
ASSERT_TRUE(node.OpType() == "Gemm");
1251+
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transA").i()));
1252+
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transB").i()));
1253+
auto new_input_defs = node.InputDefs();
1254+
ASSERT_TRUE(new_input_defs[0]->Name() == "tp0");
1255+
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
1256+
}
1257+
1258+
// (A')'B' = AB' and (B')'C = BC where transpose has multiple consumers
1259+
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemms) {
1260+
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx";
1261+
std::shared_ptr<Model> p_model;
1262+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
1263+
Graph& graph = p_model->MainGraph();
1264+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1265+
ASSERT_EQ(op_to_count["Transpose"], 2);
1266+
ASSERT_EQ(op_to_count["Gemm"], 2);
1267+
ASSERT_EQ(op_to_count["Identity"], 1);
1268+
1269+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1270+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
1271+
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
1272+
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
1273+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
1274+
1275+
op_to_count = CountOpsInGraph(graph);
1276+
ASSERT_EQ(op_to_count["Transpose"], 1);
1277+
ASSERT_EQ(op_to_count["Gemm"], 2);
1278+
ASSERT_EQ(op_to_count["Identity"], 1);
1279+
1280+
auto gemm1_node =
1281+
std::find_if(
1282+
graph.Nodes().cbegin(), graph.Nodes().cend(),
1283+
[](const Node& node) { return node.Name() == "Gemm1_transformed"; });
1284+
1285+
auto& node1 = *gemm1_node;
1286+
ASSERT_TRUE(node1.OpType() == "Gemm");
1287+
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transA").i()));
1288+
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transB").i()));
1289+
auto new_input_defs1 = node1.InputDefs();
1290+
ASSERT_TRUE(new_input_defs1[0]->Name() == "tp0");
1291+
ASSERT_TRUE(new_input_defs1[1]->Name() == "B");
1292+
1293+
auto gemm2_node =
1294+
std::find_if(
1295+
graph.Nodes().cbegin(), graph.Nodes().cend(),
1296+
[](const Node& node) { return node.Name() == "Gemm2_transformed"; });
1297+
1298+
auto& node2 = *gemm2_node;
1299+
ASSERT_TRUE(node2.OpType() == "Gemm");
1300+
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transA").i()));
1301+
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transB").i()));
1302+
auto new_input_defs2 = node2.InputDefs();
1303+
ASSERT_TRUE(new_input_defs2[0]->Name() == "B");
1304+
ASSERT_TRUE(new_input_defs2[1]->Name() == "C");
1305+
}
1306+
12221307
// (A'B)' = B'A
12231308
TEST_F(GraphTransformationTests, GemmTransposeFusionOutput) {
12241309
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_output_transposed.onnx";
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import onnx
2+
from onnx import helper
3+
from onnx import TensorProto
4+
from onnx import OperatorSetIdProto
5+
6+
onnxdomain = OperatorSetIdProto()
7+
onnxdomain.version = 12
8+
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
9+
onnxdomain.domain = ""
10+
msdomain = OperatorSetIdProto()
11+
msdomain.version = 1
12+
msdomain.domain = "com.microsoft"
13+
opsets = [onnxdomain, msdomain]
14+
15+
16+
def save(model_path, nodes, inputs, outputs, initializers):
17+
graph = helper.make_graph(
18+
nodes,
19+
"TransposeGemmTest",
20+
inputs, outputs, initializers)
21+
22+
model = helper.make_model(
23+
graph, opset_imports=opsets, producer_name="onnxruntime-test")
24+
25+
onnx.save(model, model_path)
26+
27+
# (A')'B' = AB'
28+
def gemm_transpose_2outputs_from_transpose(model_path):
29+
nodes = [
30+
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
31+
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
32+
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1),
33+
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
34+
]
35+
36+
inputs = [
37+
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
38+
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K'])
39+
]
40+
41+
outputs = [
42+
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
43+
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M'])
44+
]
45+
46+
save(model_path, nodes, inputs, outputs, [])
47+
48+
49+
# (A')'B' = AB' and (B')'C = BC
50+
def gemm_transpose_2outputs_from_transpose_to_2gemms(model_path):
51+
nodes = [
52+
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
53+
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
54+
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm1", alpha=3.0, transA=1),
55+
helper.make_node("Gemm", ["tp1", "C"], ["output3"], "Gemm2", alpha=3.0, transA=1),
56+
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
57+
]
58+
59+
inputs = [
60+
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
61+
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']),
62+
helper.make_tensor_value_info("C", TensorProto.FLOAT, ['K', 'L'])
63+
]
64+
65+
outputs = [
66+
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
67+
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']),
68+
helper.make_tensor_value_info("output3", TensorProto.FLOAT, ['N', 'L'])
69+
]
70+
71+
save(model_path, nodes, inputs, outputs, [])
72+
73+
gemm_transpose_2outputs_from_transpose("gemm_transpose_2outputs_from_transpose.onnx")
74+
gemm_transpose_2outputs_from_transpose_to_2gemms("gemm_transpose_2outputs_from_transpose_to_2gemms.onnx")
75+

0 commit comments

Comments
 (0)