diff --git a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp index b09d1c8c0..44ca61805 100644 --- a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp +++ b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp @@ -69,6 +69,7 @@ Tensor sum_backward( Tensor mean_backward( const Tensor& grad, + const Tensor& input, c10::SymIntArrayRef shape, OptionalIntArrayRef opt_dim, c10::SymInt numel, @@ -76,7 +77,15 @@ Tensor mean_backward( bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty(); auto n = is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value()); - return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n); + + Tensor grad_input = + sum_backward(grad, shape, opt_dim, keepdim) / std::move(n); + + if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { + grad_input = grad_input.contiguous(input.suggest_memory_format()); + } + + return grad_input; } } // namespace @@ -98,6 +107,7 @@ Tensor XPUNativeFunctions::_adaptive_avg_pool2d_backward( if (grad_output.size(-1) == 1 && grad_output.size(-2) == 1) { return mean_backward( grad_output, + input, input.sym_sizes().vec(), {-1, -2}, input.sym_numel(),