Skip to content

Commit da4d71d

Browse files
committed
select_scatter decomp
Changing lowering of select_scatter select_scatter changes select_scatter changes Test case for select_scatter removing assertion adding select_scatter decomp lowering ops in test implement select_scatter using slice_scatter adding test case linting commit fix
1 parent 7f14221 commit da4d71d

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+13
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,19 @@ def var_decomposition(
162162
return variance
163163

164164

165+
@register_torch_trt_decomposition(
166+
torch.ops.aten.select_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
167+
)
168+
def select_scatter_decomposition(
169+
input_tensor: torch.Tensor,
170+
src_tensor: torch.Tensor,
171+
dim: int,
172+
index: int,
173+
) -> torch.Tensor:
174+
src_tensor = torch.unsqueeze(src_tensor, dim)
175+
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)
176+
177+
165178
def get_decompositions(
166179
enable_experimental_decompositions: bool = False,
167180
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

+189
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,195 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423+
def test_lowering_select_scatter_dimZero_module(self):
424+
class selectScatter(torch.nn.Module):
425+
def __init__(self, *args, **kwargs) -> None:
426+
super().__init__(*args, **kwargs)
427+
428+
def forward(self, x, src, dim, index):
429+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
430+
return y
431+
432+
# Operations expected to be removed in the traced graph after decompositions
433+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
434+
unexpected_ops = {
435+
torch.ops.aten.select_scatter.default,
436+
torch.ops.aten.slice_scatter.default,
437+
}
438+
439+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]
440+
441+
fx_graph = torch.fx.symbolic_trace(selectScatter())
442+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
443+
fx_graph,
444+
inputs,
445+
expected_ops=expected_ops,
446+
unexpected_ops=unexpected_ops,
447+
min_block_size=1,
448+
)
449+
450+
self.assertEquals(
451+
len(unexpected_ops_seen),
452+
0,
453+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
454+
)
455+
456+
self.assertEquals(
457+
len(expected_ops_unseen),
458+
0,
459+
f"The following expected ops were not encountered: {expected_ops_unseen}",
460+
)
461+
462+
torch._dynamo.reset()
463+
464+
# Validate that the results between Torch and Torch-TRT are similar
465+
optimized_model = torch_tensorrt.compile(
466+
fx_graph,
467+
"torch_compile",
468+
inputs,
469+
min_block_size=1,
470+
truncate_long_and_double=True,
471+
pass_through_build_failures=True,
472+
)
473+
optimized_model_results = optimized_model(*inputs).detach().cpu()
474+
torch_model_results = fx_graph(*inputs).detach().cpu()
475+
476+
max_diff = float(
477+
torch.max(torch.abs(optimized_model_results - torch_model_results))
478+
)
479+
self.assertAlmostEqual(
480+
max_diff,
481+
0,
482+
DECIMALS_OF_AGREEMENT,
483+
f"Select_scatter TRT outputs don't match with the original model.",
484+
)
485+
486+
def test_lowering_select_scatter_dimOne_module(self):
487+
class selectScatter(torch.nn.Module):
488+
def __init__(self, *args, **kwargs) -> None:
489+
super().__init__(*args, **kwargs)
490+
491+
def forward(self, x, src, dim, index):
492+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
493+
return y
494+
495+
# Operations expected to be removed in the traced graph after decompositions
496+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
497+
unexpected_ops = {
498+
torch.ops.aten.select_scatter.default,
499+
torch.ops.aten.slice_scatter.default,
500+
}
501+
502+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]
503+
504+
fx_graph = torch.fx.symbolic_trace(selectScatter())
505+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
506+
fx_graph,
507+
inputs,
508+
expected_ops=expected_ops,
509+
unexpected_ops=unexpected_ops,
510+
min_block_size=1,
511+
)
512+
513+
self.assertEquals(
514+
len(unexpected_ops_seen),
515+
0,
516+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
517+
)
518+
519+
self.assertEquals(
520+
len(expected_ops_unseen),
521+
0,
522+
f"The following expected ops were not encountered: {expected_ops_unseen}",
523+
)
524+
525+
torch._dynamo.reset()
526+
527+
# Validate that the results between Torch and Torch-TRT are similar
528+
optimized_model = torch_tensorrt.compile(
529+
fx_graph,
530+
"torch_compile",
531+
inputs,
532+
min_block_size=1,
533+
truncate_long_and_double=True,
534+
pass_through_build_failures=True,
535+
)
536+
optimized_model_results = optimized_model(*inputs).detach().cpu()
537+
torch_model_results = fx_graph(*inputs).detach().cpu()
538+
539+
max_diff = float(
540+
torch.max(torch.abs(optimized_model_results - torch_model_results))
541+
)
542+
self.assertAlmostEqual(
543+
max_diff,
544+
0,
545+
DECIMALS_OF_AGREEMENT,
546+
f"Select_scatter TRT outputs don't match with the original model.",
547+
)
548+
549+
def test_lowering_select_scatter_multidimension_module(self):
550+
class selectScatter(torch.nn.Module):
551+
def __init__(self, *args, **kwargs) -> None:
552+
super().__init__(*args, **kwargs)
553+
554+
def forward(self, x, src, dim, index):
555+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
556+
return y
557+
558+
# Operations expected to be removed in the traced graph after decompositions
559+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
560+
unexpected_ops = {
561+
torch.ops.aten.select_scatter.default,
562+
torch.ops.aten.slice_scatter.default,
563+
}
564+
565+
inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]
566+
567+
fx_graph = torch.fx.symbolic_trace(selectScatter())
568+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
569+
fx_graph,
570+
inputs,
571+
expected_ops=expected_ops,
572+
unexpected_ops=unexpected_ops,
573+
min_block_size=1,
574+
)
575+
576+
self.assertEquals(
577+
len(unexpected_ops_seen),
578+
0,
579+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
580+
)
581+
582+
self.assertEquals(
583+
len(expected_ops_unseen),
584+
0,
585+
f"The following expected ops were not encountered: {expected_ops_unseen}",
586+
)
587+
588+
torch._dynamo.reset()
589+
590+
# Validate that the results between Torch and Torch-TRT are similar
591+
optimized_model = torch_tensorrt.compile(
592+
fx_graph,
593+
"torch_compile",
594+
inputs,
595+
min_block_size=1,
596+
truncate_long_and_double=True,
597+
pass_through_build_failures=True,
598+
)
599+
optimized_model_results = optimized_model(*inputs).detach().cpu()
600+
torch_model_results = fx_graph(*inputs).detach().cpu()
601+
602+
max_diff = float(
603+
torch.max(torch.abs(optimized_model_results - torch_model_results))
604+
)
605+
self.assertAlmostEqual(
606+
max_diff,
607+
0,
608+
DECIMALS_OF_AGREEMENT,
609+
f"Select_scatter TRT outputs don't match with the original model.",
610+
)
611+
423612

424613
if __name__ == "__main__":
425614
run_tests()

0 commit comments

Comments
 (0)