Skip to content

Commit

Permalink
Support Half/BFloat16 in ones (#7851)
Browse files Browse the repository at this point in the history
Partial fix for #7748.
  • Loading branch information
swolchok authored Jan 23, 2025
1 parent d5ef631 commit 74d4fb6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_ones.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Tensor& ones_out(KernelRuntimeContext& ctx, IntArrayRef size, Tensor& out) {
ctx, resize_tensor(out, size) == Error::Ok, InvalidArgument, out);

ScalarType out_type = out.scalar_type();
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&] {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, __func__, CTYPE, [&] {
auto out_data = out.mutable_data_ptr<CTYPE>();
for (size_t i = 0; i < out.numel(); i++) {
out_data[i] = static_cast<CTYPE>(1);
Expand Down
2 changes: 1 addition & 1 deletion kernels/test/op_ones_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ class OpOnesOutTest : public OperatorTest {
test_ones_out<ScalarType::DTYPE>({2, 3, 4}); \
}

ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TEST)
ET_FORALL_REALHBBF16_TYPES(GENERATE_TEST)

0 comments on commit 74d4fb6

Please sign in to comment.