Skip to content

Commit

Permalink
adaptive_avgpool2d_backward: Propagate memory format correctly (#668)
Browse files Browse the repository at this point in the history
In inductor usage, generated Triton kernel assumes specific memory
layout after compile. The fallback operator in the graph need propagate
memory format correctly.
  • Loading branch information
retonym authored Jul 31, 2024
1 parent ced287f commit 58682de
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,23 @@ Tensor sum_backward(

Tensor mean_backward(
const Tensor& grad,
const Tensor& input,
c10::SymIntArrayRef shape,
OptionalIntArrayRef opt_dim,
c10::SymInt numel,
bool keepdim) {
bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
auto n =
is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value());
return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);

Tensor grad_input =
sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);

if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
grad_input = grad_input.contiguous(input.suggest_memory_format());
}

return grad_input;
}
} // namespace

Expand All @@ -98,6 +107,7 @@ Tensor XPUNativeFunctions::_adaptive_avg_pool2d_backward(
if (grad_output.size(-1) == 1 && grad_output.size(-2) == 1) {
return mean_backward(
grad_output,
input,
input.sym_sizes().vec(),
{-1, -2},
input.sym_numel(),
Expand Down

0 comments on commit 58682de

Please sign in to comment.