Skip to content

Commit 77c4b96

Browse files
apboseperi044
authored andcommitted
Removing grid lowering (#2686)
1 parent dee74c4 commit 77c4b96

File tree

3 files changed

+108
-66
lines changed

3 files changed

+108
-66
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+2
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def aten_ops_fmod(
332332

333333
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
334334
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
335+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
336+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default)
335337
@enforce_tensor_types(
336338
{
337339
0: (TRTTensor,),

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
aten.gelu,
4848
aten.gelu_backward,
4949
aten.glu_backward,
50-
aten.grid_sampler_2d,
5150
aten.hardshrink,
5251
aten.hardshrink_backward,
5352
aten.hardsigmoid,

tests/py/dynamo/conversion/test_grid_aten.py

+106-65
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,74 @@
66
from torch.testing._internal.common_utils import run_tests
77
from torch_tensorrt import Input
88

9+
grid_sampler_aten_ops = {
10+
"torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler,
11+
"torch.ops.aten.grid_sampler_2d": torch.ops.aten.grid_sampler_2d,
12+
"torch.ops.aten.grid_sampler.default": torch.ops.aten.grid_sampler.default,
13+
"torch.ops.aten.grid_sampler_2d.default": torch.ops.aten.grid_sampler_2d.default,
14+
}
15+
916
grid_sampler_ops = [
1017
(
1118
"input_grid_interpolation_nearest_sample_fill",
12-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
19+
"torch.ops.aten.grid_sampler",
20+
(lambda x, grid, op: op(x, grid, 0, 0, True)),
1321
[1, 1, 5, 5],
1422
[1, 5, 2, 2],
1523
),
1624
(
1725
"input_grid_interpolation_nearest_sample_clamp",
18-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
26+
"torch.ops.aten.grid_sampler",
27+
(lambda x, grid, op: op(x, grid, 0, 1, True)),
1928
[1, 1, 5, 5],
2029
[1, 5, 2, 2],
2130
),
2231
(
2332
"input_grid_interpolation_nearest_sample_reflect",
24-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
33+
"torch.ops.aten.grid_sampler",
34+
(lambda x, grid, op: op(x, grid, 0, 2, True)),
2535
[1, 1, 5, 5],
2636
[1, 5, 2, 2],
2737
),
2838
(
2939
"input_grid_interpolation_linear_sample_fill",
30-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
40+
"torch.ops.aten.grid_sampler",
41+
(lambda x, grid, op: op(x, grid, 1, 0, True)),
3142
[1, 1, 5, 5],
3243
[1, 5, 2, 2],
3344
),
3445
(
3546
"input_grid_interpolation_linear_sample_clamp",
36-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
47+
"torch.ops.aten.grid_sampler",
48+
(lambda x, grid, op: op(x, grid, 1, 1, True)),
3749
[1, 1, 5, 5],
3850
[1, 5, 2, 2],
3951
),
4052
(
4153
"input_grid_interpolation_linear_sample_reflect",
42-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
54+
"torch.ops.aten.grid_sampler",
55+
(lambda x, grid, op: op(x, grid, 1, 2, True)),
4356
[1, 1, 5, 5],
4457
[1, 5, 2, 2],
4558
),
4659
(
4760
"input_grid_interpolation_cubic_sample_fill",
48-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
61+
"torch.ops.aten.grid_sampler",
62+
(lambda x, grid, op: op(x, grid, 2, 0, True)),
4963
[1, 1, 5, 5],
5064
[1, 5, 2, 2],
5165
),
5266
(
5367
"input_grid_interpolation_cubic_sample_clamp",
54-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
68+
"torch.ops.aten.grid_sampler",
69+
(lambda x, grid, op: op(x, grid, 2, 1, True)),
5570
[1, 1, 5, 5],
5671
[1, 5, 2, 2],
5772
),
5873
(
5974
"input_grid_interpolation_cubic_sample_reflect",
60-
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
61-
[1, 1, 5, 5],
62-
[1, 5, 2, 2],
63-
),
64-
(
65-
"input_grid_interpolation_nearest_sample_fill_2d",
66-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
67-
[1, 1, 5, 5],
68-
[1, 5, 2, 2],
69-
),
70-
(
71-
"input_grid_interpolation_nearest_sample_clamp_2d",
72-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
73-
[1, 1, 5, 5],
74-
[1, 5, 2, 2],
75-
),
76-
(
77-
"input_grid_interpolation_nearest_sample_reflect_2d",
78-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
79-
[1, 1, 5, 5],
80-
[1, 5, 2, 2],
81-
),
82-
(
83-
"input_grid_interpolation_linear_sample_fill_2d",
84-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
85-
[1, 1, 5, 5],
86-
[1, 5, 2, 2],
87-
),
88-
(
89-
"input_grid_interpolation_linear_sample_clamp_2d",
90-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
91-
[1, 1, 5, 5],
92-
[1, 5, 2, 2],
93-
),
94-
(
95-
"input_grid_interpolation_linear_sample_reflect_2d",
96-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
97-
[1, 1, 5, 5],
98-
[1, 5, 2, 2],
99-
),
100-
(
101-
"input_grid_interpolation_cubic_sample_fill_2d",
102-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
103-
[1, 1, 5, 5],
104-
[1, 5, 2, 2],
105-
),
106-
(
107-
"input_grid_interpolation_cubic_sample_clamp_2d",
108-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
109-
[1, 1, 5, 5],
110-
[1, 5, 2, 2],
111-
),
112-
(
113-
"input_grid_interpolation_cubic_sample_reflect_2d",
114-
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
75+
"torch.ops.aten.grid_sampler",
76+
(lambda x, grid, op: op(x, grid, 2, 2, True)),
11577
[1, 1, 5, 5],
11678
[1, 5, 2, 2],
11779
),
@@ -126,19 +88,98 @@ class TestGridConverter(DispatchTestCase):
12688
grid_sampler_op[1],
12789
grid_sampler_op[2],
12890
grid_sampler_op[3],
91+
grid_sampler_op[4],
92+
)
93+
for grid_sampler_op in grid_sampler_ops
94+
]
95+
)
96+
def test_grid(self, _, op_name, op, input_shape, dim_shape):
97+
class TestModule(nn.Module):
98+
def __init__(self, grid_sampler_op):
99+
super().__init__()
100+
self.grid_sampler_op = grid_sampler_op
101+
102+
def forward(self, x):
103+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
104+
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])
105+
106+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
107+
grid_model = TestModule(op)
108+
self.run_test(grid_model, inputs)
109+
110+
@parameterized.expand(
111+
[
112+
(
113+
grid_sampler_op[0],
114+
grid_sampler_op[1] + "_2d",
115+
grid_sampler_op[2],
116+
grid_sampler_op[3],
117+
grid_sampler_op[4],
118+
)
119+
for grid_sampler_op in grid_sampler_ops
120+
]
121+
)
122+
def test_grid_2d(self, _, op_name, op, input_shape, dim_shape):
123+
class TestModule(nn.Module):
124+
def __init__(self, grid_sampler_op):
125+
super().__init__()
126+
self.grid_sampler_op = grid_sampler_op
127+
128+
def forward(self, x):
129+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
130+
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])
131+
132+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
133+
grid_model = TestModule(op)
134+
self.run_test(grid_model, inputs)
135+
136+
@parameterized.expand(
137+
[
138+
(
139+
grid_sampler_op[0],
140+
grid_sampler_op[1] + ".default",
141+
grid_sampler_op[2],
142+
grid_sampler_op[3],
143+
grid_sampler_op[4],
144+
)
145+
for grid_sampler_op in grid_sampler_ops
146+
]
147+
)
148+
def test_grid_default(self, _, op_name, op, input_shape, dim_shape):
149+
class TestModule(nn.Module):
150+
def __init__(self, grid_sampler_op):
151+
super().__init__()
152+
self.grid_sampler_op = grid_sampler_op
153+
154+
def forward(self, x):
155+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
156+
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])
157+
158+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
159+
grid_model = TestModule(op)
160+
self.run_test(grid_model, inputs)
161+
162+
@parameterized.expand(
163+
[
164+
(
165+
grid_sampler_op[0],
166+
grid_sampler_op[1] + "_2d.default",
167+
grid_sampler_op[2],
168+
grid_sampler_op[3],
169+
grid_sampler_op[4],
129170
)
130171
for grid_sampler_op in grid_sampler_ops
131172
]
132173
)
133-
def test_grid(self, _, op, input_shape, dim_shape):
174+
def test_grid_2d_default(self, _, op_name, op, input_shape, dim_shape):
134175
class TestModule(nn.Module):
135176
def __init__(self, grid_sampler_op):
136177
super().__init__()
137178
self.grid_sampler_op = grid_sampler_op
138179

139180
def forward(self, x):
140181
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
141-
return self.grid_sampler_op(x, grid)
182+
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])
142183

143184
inputs = [torch.randn(input_shape, dtype=torch.float32)]
144185
grid_model = TestModule(op)

0 commit comments

Comments
 (0)