Skip to content

Commit 81a2715

Browse files
committed
Test case for select_scatter
1 parent 7040582 commit 81a2715

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

tests/py/dynamo/lowering/test_decompositions.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423-
def test_lowering_select_scatter_module(self):
423+
def test_lowering_select_scatter_dimZero_module(self):
424424
class selectScatter(torch.nn.Module):
425425
def __init__(self, *args, **kwargs) -> None:
426426
super().__init__(*args, **kwargs)
@@ -484,5 +484,67 @@ def forward(self, x, src, dim, index):
484484
)
485485

486486

487+
def test_lowering_select_scatter_dimOne_module(self):
488+
class selectScatter(torch.nn.Module):
489+
def __init__(self, *args, **kwargs) -> None:
490+
super().__init__(*args, **kwargs)
491+
492+
def forward(self, x, src, dim, index):
493+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
494+
return y
495+
496+
# Operations expected to be removed in the traced graph after decompositions
497+
expected_ops = {
498+
torch.ops.aten.slice.Tensor,
499+
torch.ops.aten.squeeze.dim,
500+
torch.ops.aten.cat.default,
501+
}
502+
unexpected_ops = {torch.ops.aten.select_scatter.default}
503+
504+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]
505+
506+
fx_graph = torch.fx.symbolic_trace(selectScatter())
507+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
508+
fx_graph,
509+
inputs,
510+
expected_ops=expected_ops,
511+
unexpected_ops=unexpected_ops,
512+
min_block_size=1,
513+
)
514+
515+
self.assertEquals(
516+
len(unexpected_ops_seen),
517+
0,
518+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
519+
)
520+
521+
self.assertEquals(
522+
len(expected_ops_unseen),
523+
0,
524+
f"The following expected ops were not encountered: {expected_ops_unseen}",
525+
)
526+
527+
torch._dynamo.reset()
528+
529+
# Validate that the results between Torch and Torch-TRT are similar
530+
optimized_model = torch_tensorrt.compile(
531+
fx_graph,
532+
"torch_compile",
533+
inputs,
534+
min_block_size=1,
535+
pass_through_build_failures=True,
536+
)
537+
optimized_model_results = optimized_model(*inputs).detach().cpu()
538+
torch_model_results = fx_graph(*inputs).detach().cpu()
539+
540+
max_diff = float(
541+
torch.max(torch.abs(optimized_model_results - torch_model_results))
542+
)
543+
self.assertAlmostEqual(
544+
max_diff,
545+
0,
546+
DECIMALS_OF_AGREEMENT,
547+
f"Select_scatter TRT outputs don't match with the original model.",
548+
)
487549
if __name__ == "__main__":
488550
run_tests()

0 commit comments

Comments
 (0)