Skip to content

Commit 34a1c14

Browse files
authored
Support Half/BFloat16 in upsample_nearest2d (#7911)
Partial fix for #7748.
1 parent 58bda89 commit 34a1c14

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

kernels/portable/cpu/op_upsample_nearest2d.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Tensor& upsample_nearest2d_vec_out(
7979
const auto kernel_scale_w = area_pixel_compute_scale<double>(
8080
in.sizes()[3], out.sizes()[3], false, scale_w);
8181

82-
ET_SWITCH_REAL_TYPES(
82+
ET_SWITCH_REALHBF16_TYPES(
8383
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
8484
upsample_nearest2d_kernel_impl<CTYPE>(
8585
in, kernel_scale_h, kernel_scale_w, out);

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ set(all_test_sources
230230
"op_unbind_copy_test.cpp"
231231
"op_unsqueeze_copy_test.cpp"
232232
"op_upsample_bilinear2d_test.cpp"
233+
"op_upsample_nearest2d_test.cpp"
233234
"op_var_test.cpp"
234235
"op_view_copy_test.cpp"
235236
"op_where_test.cpp"

kernels/test/op_upsample_nearest2d_test.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class OpUpsampleNearest2dTest : public OperatorTest {
5252
op_upsample_nearest2d_out(
5353
input,
5454
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
55-
true,
5655
{},
5756
out);
5857

@@ -254,9 +253,9 @@ TEST_F(OpUpsampleNearest2dTest, MultiBatchAndChannel) {
254253
}
255254

256255
TEST_F(OpUpsampleNearest2dTest, DType) {
257-
#define TEST_ENTRY(ctype, dtype) \
258-
test_upsample_nearest2d_dtype<ctype, ScalarType::dtype>(); \
259-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
256+
#define TEST_ENTRY(ctype, dtype) \
257+
test_upsample_nearest2d_dtype<ctype, ScalarType::dtype>();
258+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
260259
#undef TEST_ENTRY
261260
}
262261

0 commit comments

Comments
 (0)