Skip to content

Commit

Permalink
link mean dim kernels (pytorch#8053)
Browse files Browse the repository at this point in the history
Summary:

titled

Differential Revision: D68845587
  • Loading branch information
zonglinpeng authored and facebook-github-bot committed Jan 31, 2025
1 parent e8ee36c commit 2b1a907
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions backends/cadence/fusion_g3/operators/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ int prepare_data(
return num_axis_dims;
}

Tensor& mean_dim_out(
Tensor& mean_out(
KernelRuntimeContext& ctx,
const Tensor& in,
optional<ArrayRef<int64_t>> dim_list,
Expand Down Expand Up @@ -169,29 +169,32 @@ Tensor& mean_dim_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num =
torch::executor::get_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = torch::executor::
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) {
return acc + outv;
},
in,
dim_list,
out_ix);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
});
});
ET_SWITCH_REALHBBF16_TYPES(
in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num =
torch::executor::get_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = torch::executor::
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) {
return static_cast<CTYPE_OUT>(v);
},
[](CTYPE_OUT outv, CTYPE_OUT acc) {
return acc + outv;
},
in,
dim_list,
out_ix);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
});
});
}

return out;
Expand Down

0 comments on commit 2b1a907

Please sign in to comment.