Skip to content

Commit cce29b7

Browse files
committed
implement select_scatter using slice_scatter
1 parent c4ff602 commit cce29b7

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,8 @@ def select_scatter_decomposition(
171171
dim: int,
172172
index: int,
173173
) -> torch.Tensor:
174-
unbind_tensors = torch.unbind(input_tensor, dim)
175-
unbind_tensors_list = list(unbind_tensors)
176-
unbind_tensors_list[index] = src_tensor
177-
return torch.stack(tuple(unbind_tensors_list), dim)
174+
src_tensor = torch.unsqueeze(src_tensor, dim)
175+
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)
178176

179177

180178
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -430,13 +430,11 @@ def forward(self, x, src, dim, index):
430430
return y
431431

432432
# Operations expected to be removed in the traced graph after decompositions
433-
expected_ops = {
434-
torch.ops.aten.slice.Tensor,
435-
torch.ops.aten.squeeze.dim,
436-
torch.ops.aten.cat.default,
437-
torch.ops.aten.reshape.default,
433+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
434+
unexpected_ops = {
435+
torch.ops.aten.select_scatter.default,
436+
torch.ops.aten.slice_scatter.default,
438437
}
439-
unexpected_ops = {torch.ops.aten.select_scatter.default}
440438

441439
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]
442440

@@ -469,6 +467,7 @@ def forward(self, x, src, dim, index):
469467
"torch_compile",
470468
inputs,
471469
min_block_size=1,
470+
truncate_long_and_double=True,
472471
pass_through_build_failures=True,
473472
)
474473
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -494,13 +493,11 @@ def forward(self, x, src, dim, index):
494493
return y
495494

496495
# Operations expected to be removed in the traced graph after decompositions
497-
expected_ops = {
498-
torch.ops.aten.slice.Tensor,
499-
torch.ops.aten.squeeze.dim,
500-
torch.ops.aten.unsqueeze.default,
501-
torch.ops.aten.cat.default,
496+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
497+
unexpected_ops = {
498+
torch.ops.aten.select_scatter.default,
499+
torch.ops.aten.slice_scatter.default,
502500
}
503-
unexpected_ops = {torch.ops.aten.select_scatter.default}
504501

505502
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]
506503

@@ -533,6 +530,7 @@ def forward(self, x, src, dim, index):
533530
"torch_compile",
534531
inputs,
535532
min_block_size=1,
533+
truncate_long_and_double=True,
536534
pass_through_build_failures=True,
537535
)
538536
optimized_model_results = optimized_model(*inputs).detach().cpu()

0 commit comments

Comments
 (0)