diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp index cd02714113..ae0cfd1e27 100644 --- a/backends/cadence/fusion_g3/operators/op_mean.cpp +++ b/backends/cadence/fusion_g3/operators/op_mean.cpp @@ -60,7 +60,7 @@ int prepare_data( return num_axis_dims; } -Tensor& mean_dim_out( +Tensor& mean_out( KernelRuntimeContext& ctx, const Tensor& in, optional> dim_list, @@ -169,29 +169,32 @@ Tensor& mean_dim_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { - ET_SWITCH_FLOATH_TYPES( - out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { - CTYPE_OUT* out_data = out.mutable_data_ptr(); - const size_t num = - torch::executor::get_reduced_dim_product(in, dim_list); - for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { - CTYPE_OUT sum = 0; - if (in.numel() > 0) { - sum = torch::executor:: - map_reduce_over_dim_list( - [](CTYPE_IN v) { return static_cast(v); }, - [](CTYPE_OUT outv, CTYPE_OUT acc) { - return acc + outv; - }, - in, - dim_list, - out_ix); - } - out_data[out_ix] = sum / static_cast(num); - } - }); - }); + ET_SWITCH_REALHBBF16_TYPES( + in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { + ET_SWITCH_FLOATHBF16_TYPES( + out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const size_t num = + torch::executor::get_reduced_dim_product(in, dim_list); + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + CTYPE_OUT sum = 0; + if (in.numel() > 0) { + sum = torch::executor:: + map_reduce_over_dim_list( + [](CTYPE_IN v) { + return static_cast(v); + }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { + return acc + outv; + }, + in, + dim_list, + out_ix); + } + out_data[out_ix] = sum / static_cast(num); + } + }); + }); } return out;