From 58682de855cd45c1c2ca8914015e3ce2f9195845 Mon Sep 17 00:00:00 2001 From: Mao Yunfei Date: Wed, 31 Jul 2024 18:38:18 +0800 Subject: [PATCH] adaptive_avgpool2d_backward: Propagate memory format correctly (#668) In inductor usage, generated Triton kernel assumes specific memory layout after compile. The fallback operator in the graph need propagate memory format correctly. --- src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp index b09d1c8c0..44ca61805 100644 --- a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp +++ b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp @@ -69,6 +69,7 @@ Tensor sum_backward( Tensor mean_backward( const Tensor& grad, + const Tensor& input, c10::SymIntArrayRef shape, OptionalIntArrayRef opt_dim, c10::SymInt numel, @@ -76,7 +77,15 @@ Tensor mean_backward( bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty(); auto n = is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value()); - return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n); + + Tensor grad_input = + sum_backward(grad, shape, opt_dim, keepdim) / std::move(n); + + if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { + grad_input = grad_input.contiguous(input.suggest_memory_format()); + } + + return grad_input; } } // namespace @@ -98,6 +107,7 @@ Tensor XPUNativeFunctions::_adaptive_avg_pool2d_backward( if (grad_output.size(-1) == 1 && grad_output.size(-2) == 1) { return mean_backward( grad_output, + input, input.sym_sizes().vec(), {-1, -2}, input.sym_numel(),