diff --git a/kernels/portable/cpu/op_native_batch_norm.cpp b/kernels/portable/cpu/op_native_batch_norm.cpp index 100b1a7fb2..060abebac4 100644 --- a/kernels/portable/cpu/op_native_batch_norm.cpp +++ b/kernels/portable/cpu/op_native_batch_norm.cpp @@ -104,7 +104,7 @@ std::tuple _native_batch_norm_legit_no_training_out( constexpr auto name = "native_batch_norm_legit_no_training.out"; - ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { + ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { const CTYPE* in_data = in.const_data_ptr(); CTYPE* out_data = out.mutable_data_ptr(); @@ -261,7 +261,7 @@ std::tuple _native_batch_norm_legit_no_stats_out( constexpr auto name = "_native_batch_norm_legit.no_stats_out"; - ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { + ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { const CTYPE* in_data = in.const_data_ptr(); CTYPE* out_data = out.mutable_data_ptr(); CTYPE* mean_data = mean_out.mutable_data_ptr(); @@ -282,10 +282,12 @@ std::tuple _native_batch_norm_legit_no_stats_out( } // Compute mean and invstd for each channel + const CTYPE elements_per_channel_ct = + static_cast(elements_per_channel); for (size_t c = 0; c < C; ++c) { - CTYPE mean = mean_data[c] / elements_per_channel; + CTYPE mean = mean_data[c] / elements_per_channel_ct; // Var[x] = E[x^2] - E[x]^2 - CTYPE var = invstd_data[c] / elements_per_channel - mean * mean; + CTYPE var = invstd_data[c] / elements_per_channel_ct - mean * mean; CTYPE invstd = 1.0 / std::sqrt(var + eps); mean_data[c] = mean; invstd_data[c] = invstd; diff --git a/kernels/test/op_native_batch_norm_test.cpp b/kernels/test/op_native_batch_norm_test.cpp index ba593d8dc4..8c9581c66d 100644 --- a/kernels/test/op_native_batch_norm_test.cpp +++ b/kernels/test/op_native_batch_norm_test.cpp @@ -44,6 +44,112 @@ class OpNativeBatchNormLegitNoTrainingOutTest : public OperatorTest { out1, out2); } + + template + void test_2d_dtype() { + torch::executor::testing::TensorFactory tf; + + exec_aten::Tensor input = tf.make( + {4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967, + 9.299789428710938, 3.023690700531006, 5.315116882324219, + 7.185585021972656, 6.911304473876953, 7.61051082611084, + 1.4963287115097046, 0.7381612062454224, 8.588483810424805, + 6.583977699279785, 8.831110000610352, 0.8165055513381958, + 7.087201118469238, 5.572513580322266, 4.446897983551025, + 4.444573402404785, 6.254056930541992, 5.906398296356201, + 9.971039772033691, 3.5423521995544434, 7.452159881591797, + 9.93700122833252, 1.8560808897018433, 1.524025797843933, + 7.3222975730896}); + exec_aten::optional weight = + exec_aten::optional(tf.make( + {7}, + {8.287437438964844, + 8.227645874023438, + 6.65926456451416, + 9.436124801635742, + 4.119281768798828, + 8.593960762023926, + 2.3760855197906494})); + exec_aten::optional bias = + exec_aten::optional(tf.make( + {7}, + {7.824275970458984, + 6.84327507019043, + 8.354326248168945, + 8.773970603942871, + 3.89609694480896, + 3.0753469467163086, + 3.1105971336364746})); + exec_aten::Tensor running_mean = tf.make( + {7}, + {9.700226783752441, + 0.1234668493270874, + 7.527220249176025, + 8.993252754211426, + 0.4736626148223877, + 7.7135701179504395, + 5.12320613861084}); + exec_aten::Tensor running_var = tf.make( + {7}, + {3.585531234741211, + 6.615292549133301, + 0.24084866046905518, + 5.175800323486328, + 0.5886000394821167, + 6.23909854888916, + 1.5029621124267578}); + double momentum = 0.1; + double eps = 0; + exec_aten::Tensor out0 = tf.zeros({4, 7}); + exec_aten::Tensor out1 = tf.zeros({0}); + exec_aten::Tensor out2 = tf.zeros({0}); + exec_aten::Tensor out0_expected = tf.make( + {4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727, + 10.04538631439209, 17.5877628326416, -5.17673921585083, + 7.1078033447265625, -4.381907939910889, 30.793603897094727, + -73.48003387451172, -25.46548080444336, 47.46636962890625, + -0.8111140131950378, 10.29708194732666, -31.056814193725586, + 29.119586944580078, -18.16947364807129, -10.082839965820312, + 25.216796875, -1.9462348222732544, 4.628543376922607, + 9.00953483581543, 17.779958724975586, 7.335818767547607, + 12.688335418701172, 11.318607330322266, -18.22031593322754, + 7.372773170471191}); + exec_aten::Tensor out1_expected = tf.make({0}, {}); + exec_aten::Tensor out2_expected = tf.make({0}, {}); + op_native_batch_norm_legit_no_training_out( + input, + weight, + bias, + running_mean, + running_var, + momentum, + eps, + out0, + out1, + out2); + if (DTYPE == exec_aten::ScalarType::Half || + DTYPE == exec_aten::ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + out0_expected, + 4e-2, + executorch::runtime::testing::internal::kDefaultAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out1, + out1_expected, + 2e-2, + executorch::runtime::testing::internal::kDefaultAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out2, + out2_expected, + 2e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out0, out0_expected); + EXPECT_TENSOR_CLOSE(out1, out1_expected); + EXPECT_TENSOR_CLOSE(out2, out2_expected); + } + } }; class OpNativeBatchNormLegitOutTest : public OperatorTest { @@ -103,92 +209,72 @@ class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest { out1, out2); } -}; -TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) { - torch::executor::testing::TensorFactory tfFloat; + template + void test_2d_dtype() { + torch::executor::testing::TensorFactory tf; - exec_aten::Tensor input = tfFloat.make( - {4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967, - 9.299789428710938, 3.023690700531006, 5.315116882324219, - 7.185585021972656, 6.911304473876953, 7.61051082611084, - 1.4963287115097046, 0.7381612062454224, 8.588483810424805, - 6.583977699279785, 8.831110000610352, 0.8165055513381958, - 7.087201118469238, 5.572513580322266, 4.446897983551025, - 4.444573402404785, 6.254056930541992, 5.906398296356201, - 9.971039772033691, 3.5423521995544434, 7.452159881591797, - 9.93700122833252, 1.8560808897018433, 1.524025797843933, - 7.3222975730896}); - exec_aten::optional weight = - exec_aten::optional(tfFloat.make( - {7}, - {8.287437438964844, - 8.227645874023438, - 6.65926456451416, - 9.436124801635742, - 4.119281768798828, - 8.593960762023926, - 2.3760855197906494})); - exec_aten::optional bias = - exec_aten::optional(tfFloat.make( - {7}, - {7.824275970458984, - 6.84327507019043, - 8.354326248168945, - 8.773970603942871, - 3.89609694480896, - 3.0753469467163086, - 3.1105971336364746})); - exec_aten::Tensor running_mean = tfFloat.make( - {7}, - {9.700226783752441, - 0.1234668493270874, - 7.527220249176025, - 8.993252754211426, - 0.4736626148223877, - 7.7135701179504395, - 5.12320613861084}); - exec_aten::Tensor running_var = tfFloat.make( - {7}, - {3.585531234741211, - 6.615292549133301, - 0.24084866046905518, - 5.175800323486328, - 0.5886000394821167, - 6.23909854888916, - 1.5029621124267578}); - double momentum = 0.1; - double eps = 0; - exec_aten::Tensor out0 = tfFloat.zeros({4, 7}); - exec_aten::Tensor out1 = tfFloat.zeros({0}); - exec_aten::Tensor out2 = tfFloat.zeros({0}); - exec_aten::Tensor out0_expected = tfFloat.make( - {4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727, - 10.04538631439209, 17.5877628326416, -5.17673921585083, - 7.1078033447265625, -4.381907939910889, 30.793603897094727, - -73.48003387451172, -25.46548080444336, 47.46636962890625, - -0.8111140131950378, 10.29708194732666, -31.056814193725586, - 29.119586944580078, -18.16947364807129, -10.082839965820312, - 25.216796875, -1.9462348222732544, 4.628543376922607, - 9.00953483581543, 17.779958724975586, 7.335818767547607, - 12.688335418701172, 11.318607330322266, -18.22031593322754, - 7.372773170471191}); - exec_aten::Tensor out1_expected = tfFloat.make({0}, {}); - exec_aten::Tensor out2_expected = tfFloat.make({0}, {}); - op_native_batch_norm_legit_no_training_out( - input, - weight, - bias, - running_mean, - running_var, - momentum, - eps, - out0, - out1, - out2); - EXPECT_TENSOR_CLOSE(out0, out0_expected); - EXPECT_TENSOR_CLOSE(out1, out1_expected); - EXPECT_TENSOR_CLOSE(out2, out2_expected); + exec_aten::Tensor input = + tf.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + exec_aten::optional weight = + exec_aten::optional(); + exec_aten::optional bias = + exec_aten::optional(); + bool training = true; + double momentum = 1e-3; + double eps = 1e-5; + exec_aten::Tensor out0 = tf.zeros({3, 4}); + exec_aten::Tensor out1 = tf.zeros({4}); + exec_aten::Tensor out2 = tf.zeros({4}); + exec_aten::Tensor out0_expected = tf.make( + {3, 4}, + {-0.98058063, + -1.03422451, + -1.06904495, + -1.09332705, + -0.39223224, + -0.31822300, + -0.26726127, + -0.23017406, + 1.37281299, + 1.35244739, + 1.33630610, + 1.32350123}); + exec_aten::Tensor out1_expected = + tf.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794}); + exec_aten::Tensor out2_expected = + tf.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882}); + op_native_batch_norm_legit_no_stats_out( + input, weight, bias, training, momentum, eps, out0, out1, out2); + if (DTYPE == exec_aten::ScalarType::Half || + DTYPE == exec_aten::ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + out0_expected, + 2e-2, + executorch::runtime::testing::internal::kDefaultAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out1, + out1_expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + EXPECT_TENSOR_CLOSE_WITH_TOL( + out2, + out2_expected, + 2e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out0, out0_expected); + EXPECT_TENSOR_CLOSE(out1, out1_expected); + EXPECT_TENSOR_CLOSE(out2, out2_expected); + } + } +}; + +TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D){ +#define TEST_ENTRY(ctype, dtype) test_2d_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY } TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest3D) { @@ -977,44 +1063,10 @@ TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) { EXPECT_TENSOR_CLOSE(out2, out2_expected); } -TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) { - torch::executor::testing::TensorFactory tfFloat; - - exec_aten::Tensor input = - tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); - exec_aten::optional weight = - exec_aten::optional(); - exec_aten::optional bias = - exec_aten::optional(); - bool training = true; - double momentum = 1e-3; - double eps = 1e-5; - exec_aten::Tensor out0 = tfFloat.zeros({3, 4}); - exec_aten::Tensor out1 = tfFloat.zeros({4}); - exec_aten::Tensor out2 = tfFloat.zeros({4}); - exec_aten::Tensor out0_expected = tfFloat.make( - {3, 4}, - {-0.98058063, - -1.03422451, - -1.06904495, - -1.09332705, - -0.39223224, - -0.31822300, - -0.26726127, - -0.23017406, - 1.37281299, - 1.35244739, - 1.33630610, - 1.32350123}); - exec_aten::Tensor out1_expected = - tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794}); - exec_aten::Tensor out2_expected = - tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882}); - op_native_batch_norm_legit_no_stats_out( - input, weight, bias, training, momentum, eps, out0, out1, out2); - EXPECT_TENSOR_CLOSE(out0, out0_expected); - EXPECT_TENSOR_CLOSE(out1, out1_expected); - EXPECT_TENSOR_CLOSE(out2, out2_expected); +TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D){ +#define TEST_ENTRY(ctype, dtype) test_2d_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY } TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {