Skip to content

Commit 192d2cf

Browse files
vfdev-5zou3519
authored andcommitted
[functorch] Added grid_sample backward batch rule (pytorch/functorch#284)
* Added grid_sample backward batch rule Description: - Added grid_sample backward batch rule: CPU and CUDA - Updated tests Notes: I had to expand on dim 0 in most of the cases and could not use tricks like in forward pass when batch dim is merged either with channel or H_out due to wrong grid grads in these cases * Code updates according to the review * Updated OutOfPlacePlumbing.cpp to the latest pytorch
1 parent c56f769 commit 192d2cf

File tree

5 files changed

+1787
-1135
lines changed

5 files changed

+1787
-1135
lines changed

functorch/codegen/codegen_outofplacebatching.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def parse_return(return_t):
157157
return tuple([x.strip() for x in m.group(1).split(',')])
158158

159159
def parse_args(args_t):
160-
args = args_t.split(',')
160+
# There is an assumption made that args are separated with comma-space
161+
# and types like std::array<bool,2> do not contain spaces after the comma
162+
args = args_t.split(', ')
161163
result = []
162164
for arg in args:
163165
split_idx = arg.rfind(' ')
@@ -172,8 +174,6 @@ def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', includ
172174
for line in lines:
173175
if 'void' in line:
174176
continue
175-
if 'std::array' in line:
176-
continue
177177
m = re.match(r'(.*) \w+\((.*)\); // {"schema": "aten::(\w+\.?\w*)\(.*', line)
178178
if m is None:
179179
continue

functorch/functorch/csrc/BatchRulesModules.cpp

+127
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,86 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
255255
return result;
256256
}
257257

258+
std::tuple<Tensor, Tensor, Tensor, int64_t>
259+
grid_sample_backward_helper_in(
260+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
261+
const Tensor& input, optional<int64_t> input_bdim,
262+
const Tensor& grid, optional<int64_t> grid_bdim) {
263+
264+
auto batch_size = get_bdim_size3(
265+
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
266+
267+
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
268+
grad_output_ = ensure_has_bdim(grad_output_, grad_output_bdim.has_value(), batch_size);
269+
grad_output_ = reshape_dim_into(0, 0, grad_output_);
270+
271+
auto input_ = moveBatchDimToFront(input, input_bdim);
272+
input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
273+
input_ = reshape_dim_into(0, 0, input_);
274+
275+
auto grid_ = moveBatchDimToFront(grid, grid_bdim);
276+
grid_ = ensure_has_bdim(grid_, grid_bdim.has_value(), batch_size);
277+
grid_ = reshape_dim_into(0, 0, grid_);
278+
279+
return std::make_tuple(grad_output_, input_, grid_, batch_size);
280+
}
281+
282+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
283+
grid_sample_backward_helper_out(
284+
const std::tuple<Tensor, Tensor> & bw_out,
285+
optional<int64_t> grad_input_out_bdim,
286+
optional<int64_t> grad_grid_out_bdim,
287+
int64_t bdim_size) {
288+
auto grad_input = std::get<0>(bw_out);
289+
auto grad_grid = std::get<1>(bw_out);
290+
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
291+
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
292+
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
293+
return result;
294+
}
295+
296+
297+
template<typename F, F Func, typename... ExtraArgs>
298+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
299+
grid_sample_backward_batch_rule(
300+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
301+
const Tensor& input, optional<int64_t> input_bdim,
302+
const Tensor& grid, optional<int64_t> grid_bdim,
303+
ExtraArgs... extra_args) {
304+
305+
auto new_bw_input = grid_sample_backward_helper_in(
306+
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
307+
308+
auto new_grad_output = std::get<0>(new_bw_input);
309+
auto new_input = std::get<1>(new_bw_input);
310+
auto new_grid = std::get<2>(new_bw_input);
311+
int64_t batch_size = std::get<3>(new_bw_input);
312+
313+
auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
314+
315+
return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
316+
}
317+
318+
template<typename F, F Func>
319+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
320+
cudnn_grid_sample_backward_batch_rule(
321+
const Tensor& input, optional<int64_t> input_bdim,
322+
const Tensor& grid, optional<int64_t> grid_bdim,
323+
const Tensor& grad_output, optional<int64_t> grad_output_bdim) {
324+
325+
auto new_bw_input = grid_sample_backward_helper_in(
326+
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
327+
328+
auto new_grad_output = std::get<0>(new_bw_input);
329+
auto new_input = std::get<1>(new_bw_input);
330+
auto new_grid = std::get<2>(new_bw_input);
331+
int64_t bdim_size = std::get<3>(new_bw_input);
332+
333+
auto bw_out = Func(new_input, new_grid, new_grad_output);
334+
335+
return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
336+
}
337+
258338
std::tuple<Tensor, optional<int64_t>> cross_batch_rule(
259339
const Tensor& self, optional<int64_t> self_bdim,
260340
const Tensor& other, optional<int64_t> other_bdim,
@@ -370,12 +450,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
370450
}
371451
};
372452

453+
template <typename A, A a, typename C>
454+
struct GridSampleBackwardBatchRuleHelper;
455+
456+
template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
457+
struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
458+
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
459+
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim,
460+
const Tensor& input, optional<int64_t> input_batch_dim,
461+
const Tensor& grid, optional<int64_t> grid_batch_dim,
462+
T... extra_args) {
463+
return grid_sample_backward_batch_rule<F, Func, T...>(
464+
grad_output, grad_output_batch_dim,
465+
input, input_batch_dim,
466+
grid, grid_batch_dim,
467+
std::forward<T>(extra_args)...);
468+
}
469+
};
470+
471+
template <typename F, F Func>
472+
struct CudnnGridSampleBackwardBatchRuleHelper {
473+
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
474+
const Tensor& input, optional<int64_t> input_batch_dim,
475+
const Tensor& grid, optional<int64_t> grid_batch_dim,
476+
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim) {
477+
return cudnn_grid_sample_backward_batch_rule<F, Func>(
478+
input, input_batch_dim,
479+
grid, grid_batch_dim,
480+
grad_output, grad_output_batch_dim
481+
);
482+
}
483+
};
484+
373485
#define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
374486
GridSampleBatchRuleHelper<\
375487
decltype(&ATEN_FN(fn)),\
376488
&ATEN_FN(fn),\
377489
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
378490

491+
#define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
492+
GridSampleBackwardBatchRuleHelper<\
493+
decltype(&ATEN_FN(fn)),\
494+
&ATEN_FN(fn),\
495+
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
496+
497+
#define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
498+
CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply
499+
379500
#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT(#op"."#overload, SINGLE_ARG(\
380501
UpsampleBackwardBatchRuleHelper<\
381502
decltype(&ATEN_FN2(op, overload)),\
@@ -386,6 +507,7 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
386507
EXISTING_BDIM2(op, vec); \
387508
EXISTING_BDIM(op);
388509

510+
389511
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
390512
VMAP_SUPPORT("convolution", convolution_batch_rule);
391513
// m.impl("conv_transpose2d", convNd_transpose_decomp);
@@ -400,7 +522,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400522
EXISTING_BDIM(im2col_backward);
401523

402524
VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
525+
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));
526+
403527
VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
528+
VMAP_SUPPORT("grid_sampler_3d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
529+
VMAP_SUPPORT("cudnn_grid_sampler_backward", CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));
530+
404531
VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
405532
VMAP_SUPPORT("cross", cross_batch_rule);
406533

functorch/functorch/csrc/BatchRulesScatterOps.cpp

+4-32
Original file line numberDiff line numberDiff line change
@@ -158,34 +158,6 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
158158
return self;
159159
}
160160

161-
int64_t bdim_size(
162-
const Tensor& a, optional<int64_t> a_bdim,
163-
const Tensor& b, optional<int64_t> b_bdim,
164-
const Tensor& c, optional<int64_t> c_bdim) {
165-
if (a_bdim) {
166-
return a.size(*a_bdim);
167-
}
168-
if (b_bdim) {
169-
return b.size(*b_bdim);
170-
}
171-
if (c_bdim) {
172-
return c.size(*c_bdim);
173-
}
174-
TORCH_INTERNAL_ASSERT(false);
175-
}
176-
177-
int64_t bdim_size(
178-
const Tensor& a, optional<int64_t> a_bdim,
179-
const Tensor& b, optional<int64_t> b_bdim) {
180-
if (a_bdim) {
181-
return a.size(*a_bdim);
182-
}
183-
if (b_bdim) {
184-
return b.size(*b_bdim);
185-
}
186-
TORCH_INTERNAL_ASSERT(false);
187-
}
188-
189161
namespace {
190162

191163
template<typename Func, typename ...Args>
@@ -197,7 +169,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
197169
const Scalar& value, Args... args) {
198170
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
199171
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
200-
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
172+
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
201173

202174
auto self_ = moveBatchDimToFront(self, self_bdim);
203175
auto index_ = moveBatchDimToFront(index, index_bdim);
@@ -230,7 +202,7 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
230202
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
231203
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
232204
auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
233-
auto batch_size = bdim_size(self, self_bdim, index, index_bdim, src, src_bdim);
205+
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);
234206

235207
auto self_ = moveBatchDimToFront(self, self_bdim);
236208
auto index_ = moveBatchDimToFront(index, index_bdim);
@@ -314,7 +286,7 @@ std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
314286
bool sparse_grad) {
315287
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
316288
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
317-
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
289+
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
318290

319291
auto self_ = moveBatchDimToFront(self, self_bdim);
320292
auto index_ = moveBatchDimToFront(index, index_bdim);
@@ -343,7 +315,7 @@ std::tuple<Tensor,optional<int64_t>> gather_backward_batch_rule(
343315
int64_t dim,
344316
const Tensor& index, optional<int64_t> index_bdim,
345317
bool sparse_grad) {
346-
auto batch_size = bdim_size(grad, grad_bdim, self, self_bdim, index, index_bdim);
318+
auto batch_size = get_bdim_size3(grad, grad_bdim, self, self_bdim, index, index_bdim);
347319
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
348320
auto self_ = moveBatchDimToFront(self, self_bdim);
349321
auto index_ = moveBatchDimToFront(index, index_bdim);

0 commit comments

Comments
 (0)