@@ -1219,6 +1219,91 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2Inputs) {
1219
1219
ASSERT_TRUE (new_input_defs[1 ]->Name () == " B" );
1220
1220
}
1221
1221
1222
+ // (A')'B' = AB' where transpose has multiple consumers
1223
+ TEST_F (GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
1224
+ auto model_uri = MODEL_FOLDER " fusion/gemm_transpose_2outputs_from_transpose.onnx" ;
1225
+ std::shared_ptr<Model> p_model;
1226
+ ASSERT_STATUS_OK (Model::Load (model_uri, p_model, nullptr , *logger_));
1227
+ Graph& graph = p_model->MainGraph ();
1228
+ std::map<std::string, int > op_to_count = CountOpsInGraph (graph);
1229
+ ASSERT_EQ (op_to_count[" Transpose" ], 2 );
1230
+ ASSERT_EQ (op_to_count[" Gemm" ], 1 );
1231
+ ASSERT_EQ (op_to_count[" Identity" ], 1 );
1232
+
1233
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5 };
1234
+ auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>(" RuleTransformer1" );
1235
+ rule_transformer_L1->Register (std::make_unique<GemmTransposeFusion>());
1236
+ graph_transformation_mgr.Register (std::move (rule_transformer_L1), TransformerLevel::Level1);
1237
+ ASSERT_STATUS_OK (graph_transformation_mgr.ApplyTransformers (graph, TransformerLevel::Level1, *logger_));
1238
+
1239
+ op_to_count = CountOpsInGraph (graph);
1240
+ ASSERT_EQ (op_to_count[" Transpose" ], 1 );
1241
+ ASSERT_EQ (op_to_count[" Gemm" ], 1 );
1242
+ ASSERT_EQ (op_to_count[" Identity" ], 1 );
1243
+
1244
+ auto gemm_node =
1245
+ std::find_if (
1246
+ graph.Nodes ().cbegin (), graph.Nodes ().cend (),
1247
+ [](const Node& node) { return node.Name () == " Gemm_transformed" ; });
1248
+
1249
+ auto & node = *gemm_node;
1250
+ ASSERT_TRUE (node.OpType () == " Gemm" );
1251
+ ASSERT_TRUE (static_cast <bool >(node.GetAttributes ().at (" transA" ).i ()));
1252
+ ASSERT_TRUE (static_cast <bool >(node.GetAttributes ().at (" transB" ).i ()));
1253
+ auto new_input_defs = node.InputDefs ();
1254
+ ASSERT_TRUE (new_input_defs[0 ]->Name () == " tp0" );
1255
+ ASSERT_TRUE (new_input_defs[1 ]->Name () == " B" );
1256
+ }
1257
+
1258
+ // (A')'B' = AB' and (B')'C = BC where transpose has multiple consumers
1259
+ TEST_F (GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemms) {
1260
+ auto model_uri = MODEL_FOLDER " fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx" ;
1261
+ std::shared_ptr<Model> p_model;
1262
+ ASSERT_STATUS_OK (Model::Load (model_uri, p_model, nullptr , *logger_));
1263
+ Graph& graph = p_model->MainGraph ();
1264
+ std::map<std::string, int > op_to_count = CountOpsInGraph (graph);
1265
+ ASSERT_EQ (op_to_count[" Transpose" ], 2 );
1266
+ ASSERT_EQ (op_to_count[" Gemm" ], 2 );
1267
+ ASSERT_EQ (op_to_count[" Identity" ], 1 );
1268
+
1269
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5 };
1270
+ auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>(" RuleTransformer1" );
1271
+ rule_transformer_L1->Register (std::make_unique<GemmTransposeFusion>());
1272
+ graph_transformation_mgr.Register (std::move (rule_transformer_L1), TransformerLevel::Level1);
1273
+ ASSERT_STATUS_OK (graph_transformation_mgr.ApplyTransformers (graph, TransformerLevel::Level1, *logger_));
1274
+
1275
+ op_to_count = CountOpsInGraph (graph);
1276
+ ASSERT_EQ (op_to_count[" Transpose" ], 1 );
1277
+ ASSERT_EQ (op_to_count[" Gemm" ], 2 );
1278
+ ASSERT_EQ (op_to_count[" Identity" ], 1 );
1279
+
1280
+ auto gemm1_node =
1281
+ std::find_if (
1282
+ graph.Nodes ().cbegin (), graph.Nodes ().cend (),
1283
+ [](const Node& node) { return node.Name () == " Gemm1_transformed" ; });
1284
+
1285
+ auto & node1 = *gemm1_node;
1286
+ ASSERT_TRUE (node1.OpType () == " Gemm" );
1287
+ ASSERT_TRUE (static_cast <bool >(node1.GetAttributes ().at (" transA" ).i ()));
1288
+ ASSERT_TRUE (static_cast <bool >(node1.GetAttributes ().at (" transB" ).i ()));
1289
+ auto new_input_defs1 = node1.InputDefs ();
1290
+ ASSERT_TRUE (new_input_defs1[0 ]->Name () == " tp0" );
1291
+ ASSERT_TRUE (new_input_defs1[1 ]->Name () == " B" );
1292
+
1293
+ auto gemm2_node =
1294
+ std::find_if (
1295
+ graph.Nodes ().cbegin (), graph.Nodes ().cend (),
1296
+ [](const Node& node) { return node.Name () == " Gemm2_transformed" ; });
1297
+
1298
+ auto & node2 = *gemm2_node;
1299
+ ASSERT_TRUE (node2.OpType () == " Gemm" );
1300
+ ASSERT_FALSE (static_cast <bool >(node2.GetAttributes ().at (" transA" ).i ()));
1301
+ ASSERT_FALSE (static_cast <bool >(node2.GetAttributes ().at (" transB" ).i ()));
1302
+ auto new_input_defs2 = node2.InputDefs ();
1303
+ ASSERT_TRUE (new_input_defs2[0 ]->Name () == " B" );
1304
+ ASSERT_TRUE (new_input_defs2[1 ]->Name () == " C" );
1305
+ }
1306
+
1222
1307
// (A'B)' = B'A
1223
1308
TEST_F (GraphTransformationTests, GemmTransposeFusionOutput) {
1224
1309
auto model_uri = MODEL_FOLDER " fusion/gemm_transpose_output_transposed.onnx" ;
0 commit comments