Skip to content

Commit

Permalink
link mean dim kernels (pytorch#8053)
Browse files Browse the repository at this point in the history
Summary:

titled

Differential Revision: D68845587
  • Loading branch information
zonglinpeng authored and facebook-github-bot committed Jan 31, 2025
1 parent e8ee36c commit d6d53af
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions backends/cadence/fusion_g3/operators/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@ namespace impl {
namespace G3 {
namespace native {

template <typename CTYPE_IN, typename CTYPE_OUT>
void mean_out_(
const Tensor& in,
optional<ArrayRef<int64_t>> dim_list,
__ET_UNUSED bool keepdim,
__ET_UNUSED optional<ScalarType> dtype,
Tensor& out) {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = torch::executor::map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
in,
dim_list,
out_ix);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
}

int prepare_data(
const Tensor& in,
Tensor& out,
Expand Down Expand Up @@ -60,7 +83,7 @@ int prepare_data(
return num_axis_dims;
}

Tensor& mean_dim_out(
Tensor& mean_out(
KernelRuntimeContext& ctx,
const Tensor& in,
optional<ArrayRef<int64_t>> dim_list,
Expand Down Expand Up @@ -169,29 +192,8 @@ Tensor& mean_dim_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num =
torch::executor::get_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = torch::executor::
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) {
return acc + outv;
},
in,
dim_list,
out_ix);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
});
});
mean_out_<float, float>(in, dim_list, keepdim, dtype, out);
return out;
}

return out;
Expand Down

0 comments on commit d6d53af

Please sign in to comment.