@@ -420,7 +420,7 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
- def test_lowering_select_scatter_module (self ):
423
+ def test_lowering_select_scatter_dimZero_module (self ):
424
424
class selectScatter (torch .nn .Module ):
425
425
def __init__ (self , * args , ** kwargs ) -> None :
426
426
super ().__init__ (* args , ** kwargs )
@@ -484,5 +484,67 @@ def forward(self, x, src, dim, index):
484
484
)
485
485
486
486
487
+ def test_lowering_select_scatter_dimOne_module (self ):
488
+ class selectScatter (torch .nn .Module ):
489
+ def __init__ (self , * args , ** kwargs ) -> None :
490
+ super ().__init__ (* args , ** kwargs )
491
+
492
+ def forward (self , x , src , dim , index ):
493
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
494
+ return y
495
+
496
+ # Operations expected to be removed in the traced graph after decompositions
497
+ expected_ops = {
498
+ torch .ops .aten .slice .Tensor ,
499
+ torch .ops .aten .squeeze .dim ,
500
+ torch .ops .aten .cat .default ,
501
+ }
502
+ unexpected_ops = {torch .ops .aten .select_scatter .default }
503
+
504
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
505
+
506
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
507
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
508
+ fx_graph ,
509
+ inputs ,
510
+ expected_ops = expected_ops ,
511
+ unexpected_ops = unexpected_ops ,
512
+ min_block_size = 1 ,
513
+ )
514
+
515
+ self .assertEquals (
516
+ len (unexpected_ops_seen ),
517
+ 0 ,
518
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
519
+ )
520
+
521
+ self .assertEquals (
522
+ len (expected_ops_unseen ),
523
+ 0 ,
524
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
525
+ )
526
+
527
+ torch ._dynamo .reset ()
528
+
529
+ # Validate that the results between Torch and Torch-TRT are similar
530
+ optimized_model = torch_tensorrt .compile (
531
+ fx_graph ,
532
+ "torch_compile" ,
533
+ inputs ,
534
+ min_block_size = 1 ,
535
+ pass_through_build_failures = True ,
536
+ )
537
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
538
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
539
+
540
+ max_diff = float (
541
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
542
+ )
543
+ self .assertAlmostEqual (
544
+ max_diff ,
545
+ 0 ,
546
+ DECIMALS_OF_AGREEMENT ,
547
+ f"Select_scatter TRT outputs don't match with the original model." ,
548
+ )
487
549
if __name__ == "__main__" :
488
550
run_tests ()
0 commit comments