@@ -546,6 +546,68 @@ def forward(self, x, src, dim, index):
546
546
f"Select_scatter TRT outputs don't match with the original model." ,
547
547
)
548
548
549
+ def test_lowering_select_scatter_multidimension_module (self ):
550
+ class selectScatter (torch .nn .Module ):
551
+ def __init__ (self , * args , ** kwargs ) -> None :
552
+ super ().__init__ (* args , ** kwargs )
553
+
554
+ def forward (self , x , src , dim , index ):
555
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
556
+ return y
557
+
558
+ # Operations expected to be removed in the traced graph after decompositions
559
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
560
+ unexpected_ops = {
561
+ torch .ops .aten .select_scatter .default ,
562
+ torch .ops .aten .slice_scatter .default ,
563
+ }
564
+
565
+ inputs = [torch .zeros (2 , 3 , 4 ).cuda (), torch .ones (2 , 4 ).cuda (), 1 , 0 ]
566
+
567
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
568
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
569
+ fx_graph ,
570
+ inputs ,
571
+ expected_ops = expected_ops ,
572
+ unexpected_ops = unexpected_ops ,
573
+ min_block_size = 1 ,
574
+ )
575
+
576
+ self .assertEquals (
577
+ len (unexpected_ops_seen ),
578
+ 0 ,
579
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
580
+ )
581
+
582
+ self .assertEquals (
583
+ len (expected_ops_unseen ),
584
+ 0 ,
585
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
586
+ )
587
+
588
+ torch ._dynamo .reset ()
589
+
590
+ # Validate that the results between Torch and Torch-TRT are similar
591
+ optimized_model = torch_tensorrt .compile (
592
+ fx_graph ,
593
+ "torch_compile" ,
594
+ inputs ,
595
+ min_block_size = 1 ,
596
+ truncate_long_and_double = True ,
597
+ pass_through_build_failures = True ,
598
+ )
599
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
600
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
601
+
602
+ max_diff = float (
603
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
604
+ )
605
+ self .assertAlmostEqual (
606
+ max_diff ,
607
+ 0 ,
608
+ DECIMALS_OF_AGREEMENT ,
609
+ f"Select_scatter TRT outputs don't match with the original model." ,
610
+ )
549
611
550
612
if __name__ == "__main__" :
551
613
run_tests ()
0 commit comments