@@ -420,6 +420,195 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
+ def test_lowering_select_scatter_dimZero_module (self ):
424
+ class selectScatter (torch .nn .Module ):
425
+ def __init__ (self , * args , ** kwargs ) -> None :
426
+ super ().__init__ (* args , ** kwargs )
427
+
428
+ def forward (self , x , src , dim , index ):
429
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
430
+ return y
431
+
432
+ # Operations expected to be removed in the traced graph after decompositions
433
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
434
+ unexpected_ops = {
435
+ torch .ops .aten .select_scatter .default ,
436
+ torch .ops .aten .slice_scatter .default ,
437
+ }
438
+
439
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 0 , 0 ]
440
+
441
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
442
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
443
+ fx_graph ,
444
+ inputs ,
445
+ expected_ops = expected_ops ,
446
+ unexpected_ops = unexpected_ops ,
447
+ min_block_size = 1 ,
448
+ )
449
+
450
+ self .assertEquals (
451
+ len (unexpected_ops_seen ),
452
+ 0 ,
453
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
454
+ )
455
+
456
+ self .assertEquals (
457
+ len (expected_ops_unseen ),
458
+ 0 ,
459
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
460
+ )
461
+
462
+ torch ._dynamo .reset ()
463
+
464
+ # Validate that the results between Torch and Torch-TRT are similar
465
+ optimized_model = torch_tensorrt .compile (
466
+ fx_graph ,
467
+ "torch_compile" ,
468
+ inputs ,
469
+ min_block_size = 1 ,
470
+ truncate_long_and_double = True ,
471
+ pass_through_build_failures = True ,
472
+ )
473
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
474
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
475
+
476
+ max_diff = float (
477
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
478
+ )
479
+ self .assertAlmostEqual (
480
+ max_diff ,
481
+ 0 ,
482
+ DECIMALS_OF_AGREEMENT ,
483
+ f"Select_scatter TRT outputs don't match with the original model." ,
484
+ )
485
+
486
+ def test_lowering_select_scatter_dimOne_module (self ):
487
+ class selectScatter (torch .nn .Module ):
488
+ def __init__ (self , * args , ** kwargs ) -> None :
489
+ super ().__init__ (* args , ** kwargs )
490
+
491
+ def forward (self , x , src , dim , index ):
492
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
493
+ return y
494
+
495
+ # Operations expected to be removed in the traced graph after decompositions
496
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
497
+ unexpected_ops = {
498
+ torch .ops .aten .select_scatter .default ,
499
+ torch .ops .aten .slice_scatter .default ,
500
+ }
501
+
502
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
503
+
504
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
505
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
506
+ fx_graph ,
507
+ inputs ,
508
+ expected_ops = expected_ops ,
509
+ unexpected_ops = unexpected_ops ,
510
+ min_block_size = 1 ,
511
+ )
512
+
513
+ self .assertEquals (
514
+ len (unexpected_ops_seen ),
515
+ 0 ,
516
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
517
+ )
518
+
519
+ self .assertEquals (
520
+ len (expected_ops_unseen ),
521
+ 0 ,
522
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
523
+ )
524
+
525
+ torch ._dynamo .reset ()
526
+
527
+ # Validate that the results between Torch and Torch-TRT are similar
528
+ optimized_model = torch_tensorrt .compile (
529
+ fx_graph ,
530
+ "torch_compile" ,
531
+ inputs ,
532
+ min_block_size = 1 ,
533
+ truncate_long_and_double = True ,
534
+ pass_through_build_failures = True ,
535
+ )
536
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
537
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
538
+
539
+ max_diff = float (
540
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
541
+ )
542
+ self .assertAlmostEqual (
543
+ max_diff ,
544
+ 0 ,
545
+ DECIMALS_OF_AGREEMENT ,
546
+ f"Select_scatter TRT outputs don't match with the original model." ,
547
+ )
548
+
549
+ def test_lowering_select_scatter_multidimension_module (self ):
550
+ class selectScatter (torch .nn .Module ):
551
+ def __init__ (self , * args , ** kwargs ) -> None :
552
+ super ().__init__ (* args , ** kwargs )
553
+
554
+ def forward (self , x , src , dim , index ):
555
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
556
+ return y
557
+
558
+ # Operations expected to be removed in the traced graph after decompositions
559
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
560
+ unexpected_ops = {
561
+ torch .ops .aten .select_scatter .default ,
562
+ torch .ops .aten .slice_scatter .default ,
563
+ }
564
+
565
+ inputs = [torch .zeros (2 , 3 , 4 ).cuda (), torch .ones (2 , 4 ).cuda (), 1 , 0 ]
566
+
567
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
568
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
569
+ fx_graph ,
570
+ inputs ,
571
+ expected_ops = expected_ops ,
572
+ unexpected_ops = unexpected_ops ,
573
+ min_block_size = 1 ,
574
+ )
575
+
576
+ self .assertEquals (
577
+ len (unexpected_ops_seen ),
578
+ 0 ,
579
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
580
+ )
581
+
582
+ self .assertEquals (
583
+ len (expected_ops_unseen ),
584
+ 0 ,
585
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
586
+ )
587
+
588
+ torch ._dynamo .reset ()
589
+
590
+ # Validate that the results between Torch and Torch-TRT are similar
591
+ optimized_model = torch_tensorrt .compile (
592
+ fx_graph ,
593
+ "torch_compile" ,
594
+ inputs ,
595
+ min_block_size = 1 ,
596
+ truncate_long_and_double = True ,
597
+ pass_through_build_failures = True ,
598
+ )
599
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
600
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
601
+
602
+ max_diff = float (
603
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
604
+ )
605
+ self .assertAlmostEqual (
606
+ max_diff ,
607
+ 0 ,
608
+ DECIMALS_OF_AGREEMENT ,
609
+ f"Select_scatter TRT outputs don't match with the original model." ,
610
+ )
611
+
423
612
424
613
if __name__ == "__main__" :
425
614
run_tests ()
0 commit comments