6
6
from torch .testing ._internal .common_utils import run_tests
7
7
from torch_tensorrt import Input
8
8
9
+ grid_sampler_aten_ops = {
10
+ "torch.ops.aten.grid_sampler" : torch .ops .aten .grid_sampler ,
11
+ "torch.ops.aten.grid_sampler_2d" : torch .ops .aten .grid_sampler_2d ,
12
+ "torch.ops.aten.grid_sampler.default" : torch .ops .aten .grid_sampler .default ,
13
+ "torch.ops.aten.grid_sampler_2d.default" : torch .ops .aten .grid_sampler_2d .default ,
14
+ }
15
+
9
16
grid_sampler_ops = [
10
17
(
11
18
"input_grid_interpolation_nearest_sample_fill" ,
12
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
19
+ "torch.ops.aten.grid_sampler" ,
20
+ (lambda x , grid , op : op (x , grid , 0 , 0 , True )),
13
21
[1 , 1 , 5 , 5 ],
14
22
[1 , 5 , 2 , 2 ],
15
23
),
16
24
(
17
25
"input_grid_interpolation_nearest_sample_clamp" ,
18
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
26
+ "torch.ops.aten.grid_sampler" ,
27
+ (lambda x , grid , op : op (x , grid , 0 , 1 , True )),
19
28
[1 , 1 , 5 , 5 ],
20
29
[1 , 5 , 2 , 2 ],
21
30
),
22
31
(
23
32
"input_grid_interpolation_nearest_sample_reflect" ,
24
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
33
+ "torch.ops.aten.grid_sampler" ,
34
+ (lambda x , grid , op : op (x , grid , 0 , 2 , True )),
25
35
[1 , 1 , 5 , 5 ],
26
36
[1 , 5 , 2 , 2 ],
27
37
),
28
38
(
29
39
"input_grid_interpolation_linear_sample_fill" ,
30
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
40
+ "torch.ops.aten.grid_sampler" ,
41
+ (lambda x , grid , op : op (x , grid , 1 , 0 , True )),
31
42
[1 , 1 , 5 , 5 ],
32
43
[1 , 5 , 2 , 2 ],
33
44
),
34
45
(
35
46
"input_grid_interpolation_linear_sample_clamp" ,
36
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
47
+ "torch.ops.aten.grid_sampler" ,
48
+ (lambda x , grid , op : op (x , grid , 1 , 1 , True )),
37
49
[1 , 1 , 5 , 5 ],
38
50
[1 , 5 , 2 , 2 ],
39
51
),
40
52
(
41
53
"input_grid_interpolation_linear_sample_reflect" ,
42
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
54
+ "torch.ops.aten.grid_sampler" ,
55
+ (lambda x , grid , op : op (x , grid , 1 , 2 , True )),
43
56
[1 , 1 , 5 , 5 ],
44
57
[1 , 5 , 2 , 2 ],
45
58
),
46
59
(
47
60
"input_grid_interpolation_cubic_sample_fill" ,
48
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
61
+ "torch.ops.aten.grid_sampler" ,
62
+ (lambda x , grid , op : op (x , grid , 2 , 0 , True )),
49
63
[1 , 1 , 5 , 5 ],
50
64
[1 , 5 , 2 , 2 ],
51
65
),
52
66
(
53
67
"input_grid_interpolation_cubic_sample_clamp" ,
54
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
68
+ "torch.ops.aten.grid_sampler" ,
69
+ (lambda x , grid , op : op (x , grid , 2 , 1 , True )),
55
70
[1 , 1 , 5 , 5 ],
56
71
[1 , 5 , 2 , 2 ],
57
72
),
58
73
(
59
74
"input_grid_interpolation_cubic_sample_reflect" ,
60
- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
61
- [1 , 1 , 5 , 5 ],
62
- [1 , 5 , 2 , 2 ],
63
- ),
64
- (
65
- "input_grid_interpolation_nearest_sample_fill_2d" ,
66
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
67
- [1 , 1 , 5 , 5 ],
68
- [1 , 5 , 2 , 2 ],
69
- ),
70
- (
71
- "input_grid_interpolation_nearest_sample_clamp_2d" ,
72
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
73
- [1 , 1 , 5 , 5 ],
74
- [1 , 5 , 2 , 2 ],
75
- ),
76
- (
77
- "input_grid_interpolation_nearest_sample_reflect_2d" ,
78
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
79
- [1 , 1 , 5 , 5 ],
80
- [1 , 5 , 2 , 2 ],
81
- ),
82
- (
83
- "input_grid_interpolation_linear_sample_fill_2d" ,
84
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
85
- [1 , 1 , 5 , 5 ],
86
- [1 , 5 , 2 , 2 ],
87
- ),
88
- (
89
- "input_grid_interpolation_linear_sample_clamp_2d" ,
90
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
91
- [1 , 1 , 5 , 5 ],
92
- [1 , 5 , 2 , 2 ],
93
- ),
94
- (
95
- "input_grid_interpolation_linear_sample_reflect_2d" ,
96
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
97
- [1 , 1 , 5 , 5 ],
98
- [1 , 5 , 2 , 2 ],
99
- ),
100
- (
101
- "input_grid_interpolation_cubic_sample_fill_2d" ,
102
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
103
- [1 , 1 , 5 , 5 ],
104
- [1 , 5 , 2 , 2 ],
105
- ),
106
- (
107
- "input_grid_interpolation_cubic_sample_clamp_2d" ,
108
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
109
- [1 , 1 , 5 , 5 ],
110
- [1 , 5 , 2 , 2 ],
111
- ),
112
- (
113
- "input_grid_interpolation_cubic_sample_reflect_2d" ,
114
- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
75
+ "torch.ops.aten.grid_sampler" ,
76
+ (lambda x , grid , op : op (x , grid , 2 , 2 , True )),
115
77
[1 , 1 , 5 , 5 ],
116
78
[1 , 5 , 2 , 2 ],
117
79
),
@@ -126,19 +88,98 @@ class TestGridConverter(DispatchTestCase):
126
88
grid_sampler_op [1 ],
127
89
grid_sampler_op [2 ],
128
90
grid_sampler_op [3 ],
91
+ grid_sampler_op [4 ],
92
+ )
93
+ for grid_sampler_op in grid_sampler_ops
94
+ ]
95
+ )
96
+ def test_grid (self , _ , op_name , op , input_shape , dim_shape ):
97
+ class TestModule (nn .Module ):
98
+ def __init__ (self , grid_sampler_op ):
99
+ super ().__init__ ()
100
+ self .grid_sampler_op = grid_sampler_op
101
+
102
+ def forward (self , x ):
103
+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
104
+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
105
+
106
+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
107
+ grid_model = TestModule (op )
108
+ self .run_test (grid_model , inputs )
109
+
110
+ @parameterized .expand (
111
+ [
112
+ (
113
+ grid_sampler_op [0 ],
114
+ grid_sampler_op [1 ] + "_2d" ,
115
+ grid_sampler_op [2 ],
116
+ grid_sampler_op [3 ],
117
+ grid_sampler_op [4 ],
118
+ )
119
+ for grid_sampler_op in grid_sampler_ops
120
+ ]
121
+ )
122
+ def test_grid_2d (self , _ , op_name , op , input_shape , dim_shape ):
123
+ class TestModule (nn .Module ):
124
+ def __init__ (self , grid_sampler_op ):
125
+ super ().__init__ ()
126
+ self .grid_sampler_op = grid_sampler_op
127
+
128
+ def forward (self , x ):
129
+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
130
+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
131
+
132
+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
133
+ grid_model = TestModule (op )
134
+ self .run_test (grid_model , inputs )
135
+
136
+ @parameterized .expand (
137
+ [
138
+ (
139
+ grid_sampler_op [0 ],
140
+ grid_sampler_op [1 ] + ".default" ,
141
+ grid_sampler_op [2 ],
142
+ grid_sampler_op [3 ],
143
+ grid_sampler_op [4 ],
144
+ )
145
+ for grid_sampler_op in grid_sampler_ops
146
+ ]
147
+ )
148
+ def test_grid_default (self , _ , op_name , op , input_shape , dim_shape ):
149
+ class TestModule (nn .Module ):
150
+ def __init__ (self , grid_sampler_op ):
151
+ super ().__init__ ()
152
+ self .grid_sampler_op = grid_sampler_op
153
+
154
+ def forward (self , x ):
155
+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
156
+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
157
+
158
+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
159
+ grid_model = TestModule (op )
160
+ self .run_test (grid_model , inputs )
161
+
162
+ @parameterized .expand (
163
+ [
164
+ (
165
+ grid_sampler_op [0 ],
166
+ grid_sampler_op [1 ] + "_2d.default" ,
167
+ grid_sampler_op [2 ],
168
+ grid_sampler_op [3 ],
169
+ grid_sampler_op [4 ],
129
170
)
130
171
for grid_sampler_op in grid_sampler_ops
131
172
]
132
173
)
133
- def test_grid (self , _ , op , input_shape , dim_shape ):
174
+ def test_grid_2d_default (self , _ , op_name , op , input_shape , dim_shape ):
134
175
class TestModule (nn .Module ):
135
176
def __init__ (self , grid_sampler_op ):
136
177
super ().__init__ ()
137
178
self .grid_sampler_op = grid_sampler_op
138
179
139
180
def forward (self , x ):
140
181
grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
141
- return self .grid_sampler_op (x , grid )
182
+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [ op_name ] )
142
183
143
184
inputs = [torch .randn (input_shape , dtype = torch .float32 )]
144
185
grid_model = TestModule (op )
0 commit comments