Skip to content

Commit 5eb7eee

Browse files
committed
adding test case
1 parent cce29b7 commit 5eb7eee

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

Diff for: tests/py/dynamo/lowering/test_decompositions.py

+62
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,68 @@ def forward(self, x, src, dim, index):
546546
f"Select_scatter TRT outputs don't match with the original model.",
547547
)
548548

549+
def test_lowering_select_scatter_multidimension_module(self):
550+
class selectScatter(torch.nn.Module):
551+
def __init__(self, *args, **kwargs) -> None:
552+
super().__init__(*args, **kwargs)
553+
554+
def forward(self, x, src, dim, index):
555+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
556+
return y
557+
558+
# Operations expected to be removed in the traced graph after decompositions
559+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
560+
unexpected_ops = {
561+
torch.ops.aten.select_scatter.default,
562+
torch.ops.aten.slice_scatter.default,
563+
}
564+
565+
inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]
566+
567+
fx_graph = torch.fx.symbolic_trace(selectScatter())
568+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
569+
fx_graph,
570+
inputs,
571+
expected_ops=expected_ops,
572+
unexpected_ops=unexpected_ops,
573+
min_block_size=1,
574+
)
575+
576+
self.assertEquals(
577+
len(unexpected_ops_seen),
578+
0,
579+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
580+
)
581+
582+
self.assertEquals(
583+
len(expected_ops_unseen),
584+
0,
585+
f"The following expected ops were not encountered: {expected_ops_unseen}",
586+
)
587+
588+
torch._dynamo.reset()
589+
590+
# Validate that the results between Torch and Torch-TRT are similar
591+
optimized_model = torch_tensorrt.compile(
592+
fx_graph,
593+
"torch_compile",
594+
inputs,
595+
min_block_size=1,
596+
truncate_long_and_double=True,
597+
pass_through_build_failures=True,
598+
)
599+
optimized_model_results = optimized_model(*inputs).detach().cpu()
600+
torch_model_results = fx_graph(*inputs).detach().cpu()
601+
602+
max_diff = float(
603+
torch.max(torch.abs(optimized_model_results - torch_model_results))
604+
)
605+
self.assertAlmostEqual(
606+
max_diff,
607+
0,
608+
DECIMALS_OF_AGREEMENT,
609+
f"Select_scatter TRT outputs don't match with the original model.",
610+
)
549611

550612
if __name__ == "__main__":
551613
run_tests()

0 commit comments

Comments
 (0)