@@ -255,6 +255,86 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
255
255
return result;
256
256
}
257
257
258
+ std::tuple<Tensor, Tensor, Tensor, int64_t >
259
+ grid_sample_backward_helper_in (
260
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
261
+ const Tensor& input, optional<int64_t > input_bdim,
262
+ const Tensor& grid, optional<int64_t > grid_bdim) {
263
+
264
+ auto batch_size = get_bdim_size3 (
265
+ grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
266
+
267
+ auto grad_output_ = moveBatchDimToFront (grad_output, grad_output_bdim);
268
+ grad_output_ = ensure_has_bdim (grad_output_, grad_output_bdim.has_value (), batch_size);
269
+ grad_output_ = reshape_dim_into (0 , 0 , grad_output_);
270
+
271
+ auto input_ = moveBatchDimToFront (input, input_bdim);
272
+ input_ = ensure_has_bdim (input_, input_bdim.has_value (), batch_size);
273
+ input_ = reshape_dim_into (0 , 0 , input_);
274
+
275
+ auto grid_ = moveBatchDimToFront (grid, grid_bdim);
276
+ grid_ = ensure_has_bdim (grid_, grid_bdim.has_value (), batch_size);
277
+ grid_ = reshape_dim_into (0 , 0 , grid_);
278
+
279
+ return std::make_tuple (grad_output_, input_, grid_, batch_size);
280
+ }
281
+
282
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
283
+ grid_sample_backward_helper_out (
284
+ const std::tuple<Tensor, Tensor> & bw_out,
285
+ optional<int64_t > grad_input_out_bdim,
286
+ optional<int64_t > grad_grid_out_bdim,
287
+ int64_t bdim_size) {
288
+ auto grad_input = std::get<0 >(bw_out);
289
+ auto grad_grid = std::get<1 >(bw_out);
290
+ grad_input = reshape_dim_outof (*grad_input_out_bdim, bdim_size, grad_input);
291
+ grad_grid = reshape_dim_outof (*grad_grid_out_bdim, bdim_size, grad_grid);
292
+ auto result = std::make_tuple (grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
293
+ return result;
294
+ }
295
+
296
+
297
+ template <typename F, F Func, typename ... ExtraArgs>
298
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
299
+ grid_sample_backward_batch_rule (
300
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
301
+ const Tensor& input, optional<int64_t > input_bdim,
302
+ const Tensor& grid, optional<int64_t > grid_bdim,
303
+ ExtraArgs... extra_args) {
304
+
305
+ auto new_bw_input = grid_sample_backward_helper_in (
306
+ grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
307
+
308
+ auto new_grad_output = std::get<0 >(new_bw_input);
309
+ auto new_input = std::get<1 >(new_bw_input);
310
+ auto new_grid = std::get<2 >(new_bw_input);
311
+ int64_t batch_size = std::get<3 >(new_bw_input);
312
+
313
+ auto bw_out = Func (new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
314
+
315
+ return grid_sample_backward_helper_out (bw_out, 0 , 0 , batch_size);
316
+ }
317
+
318
+ template <typename F, F Func>
319
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
320
+ cudnn_grid_sample_backward_batch_rule (
321
+ const Tensor& input, optional<int64_t > input_bdim,
322
+ const Tensor& grid, optional<int64_t > grid_bdim,
323
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim) {
324
+
325
+ auto new_bw_input = grid_sample_backward_helper_in (
326
+ grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
327
+
328
+ auto new_grad_output = std::get<0 >(new_bw_input);
329
+ auto new_input = std::get<1 >(new_bw_input);
330
+ auto new_grid = std::get<2 >(new_bw_input);
331
+ int64_t bdim_size = std::get<3 >(new_bw_input);
332
+
333
+ auto bw_out = Func (new_input, new_grid, new_grad_output);
334
+
335
+ return grid_sample_backward_helper_out (bw_out, 0 , 0 , bdim_size);
336
+ }
337
+
258
338
std::tuple<Tensor, optional<int64_t >> cross_batch_rule (
259
339
const Tensor& self, optional<int64_t > self_bdim,
260
340
const Tensor& other, optional<int64_t > other_bdim,
@@ -370,12 +450,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
370
450
}
371
451
};
372
452
453
+ template <typename A, A a, typename C>
454
+ struct GridSampleBackwardBatchRuleHelper ;
455
+
456
+ template <typename F, F Func, typename T1, typename T2, typename T3, typename ... T>
457
+ struct GridSampleBackwardBatchRuleHelper <F, Func, typelist<T1, T2, T3, T...>> {
458
+ static std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >> apply (
459
+ const Tensor& grad_output, optional<int64_t > grad_output_batch_dim,
460
+ const Tensor& input, optional<int64_t > input_batch_dim,
461
+ const Tensor& grid, optional<int64_t > grid_batch_dim,
462
+ T... extra_args) {
463
+ return grid_sample_backward_batch_rule<F, Func, T...>(
464
+ grad_output, grad_output_batch_dim,
465
+ input, input_batch_dim,
466
+ grid, grid_batch_dim,
467
+ std::forward<T>(extra_args)...);
468
+ }
469
+ };
470
+
471
+ template <typename F, F Func>
472
+ struct CudnnGridSampleBackwardBatchRuleHelper {
473
+ static std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >> apply (
474
+ const Tensor& input, optional<int64_t > input_batch_dim,
475
+ const Tensor& grid, optional<int64_t > grid_batch_dim,
476
+ const Tensor& grad_output, optional<int64_t > grad_output_batch_dim) {
477
+ return cudnn_grid_sample_backward_batch_rule<F, Func>(
478
+ input, input_batch_dim,
479
+ grid, grid_batch_dim,
480
+ grad_output, grad_output_batch_dim
481
+ );
482
+ }
483
+ };
484
+
373
485
#define GRID_SAMPLE_BATCH_RULE (fn ) SINGLE_ARG(\
374
486
GridSampleBatchRuleHelper<\
375
487
decltype (&ATEN_FN (fn)),\
376
488
&ATEN_FN(fn),\
377
489
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
378
490
491
+ #define GRID_SAMPLE_BW_BATCH_RULE (fn ) SINGLE_ARG(\
492
+ GridSampleBackwardBatchRuleHelper<\
493
+ decltype (&ATEN_FN (fn)),\
494
+ &ATEN_FN(fn),\
495
+ c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
496
+
497
+ #define CUDNN_GRID_SAMPLE_BW_BATCH_RULE (fn )\
498
+ CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN (fn)), &ATEN_FN(fn)>::apply
499
+
379
500
#define UPSAMPLE_BACKWARD (op, overload ) VMAP_SUPPORT(#op" ." #overload, SINGLE_ARG(\
380
501
UpsampleBackwardBatchRuleHelper<\
381
502
decltype (&ATEN_FN2 (op, overload)),\
@@ -386,6 +507,7 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
386
507
EXISTING_BDIM2 (op, vec); \
387
508
EXISTING_BDIM (op);
388
509
510
+
389
511
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
390
512
VMAP_SUPPORT (" convolution" , convolution_batch_rule);
391
513
// m.impl("conv_transpose2d", convNd_transpose_decomp);
@@ -400,7 +522,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400
522
EXISTING_BDIM (im2col_backward);
401
523
402
524
VMAP_SUPPORT (" grid_sampler_2d" , GRID_SAMPLE_BATCH_RULE (grid_sampler));
525
+ VMAP_SUPPORT (" grid_sampler_2d_backward" , GRID_SAMPLE_BW_BATCH_RULE (grid_sampler_2d_backward));
526
+
403
527
VMAP_SUPPORT (" grid_sampler_3d" , GRID_SAMPLE_BATCH_RULE (grid_sampler));
528
+ VMAP_SUPPORT (" grid_sampler_3d_backward" , GRID_SAMPLE_BW_BATCH_RULE (grid_sampler_3d_backward));
529
+ VMAP_SUPPORT (" cudnn_grid_sampler_backward" , CUDNN_GRID_SAMPLE_BW_BATCH_RULE (cudnn_grid_sampler_backward));
530
+
404
531
VMAP_SUPPORT (" cudnn_grid_sampler" , GRID_SAMPLE_BATCH_RULE (cudnn_grid_sampler));
405
532
VMAP_SUPPORT (" cross" , cross_batch_rule);
406
533
0 commit comments