diff --git a/functorch/csrc/BatchRulesModules.cpp b/functorch/csrc/BatchRulesModules.cpp index 127c10ea0..39a756552 100644 --- a/functorch/csrc/BatchRulesModules.cpp +++ b/functorch/csrc/BatchRulesModules.cpp @@ -404,6 +404,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); VMAP_SUPPORT("cross", cross_batch_rule); + EXISTING_BDIM(pixel_shuffle); + EXISTING_BDIM(pixel_unshuffle); + VARIADIC_BDIMS(constant_pad_nd); EXISTING_BDIM(reflection_pad1d); EXISTING_BDIM(reflection_pad2d); diff --git a/test/test_vmap.py b/test/test_vmap.py index 12a23a841..ac68c4ccb 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3205,8 +3205,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('nn.functional.huber_loss'), xfail('nn.functional.instance_norm'), xfail('nn.functional.poisson_nll_loss'), - xfail('nn.functional.pixel_shuffle'), - xfail('nn.functional.pixel_unshuffle'), })) def test_op_has_batch_rule(self, device, dtype, op): def test():