Skip to content

Commit 019df69

Browse files
authored
unbreak optimized gelu test (#7987)
We don't run optimized gelu in OSS currently (soon, I hope!), so I missed the need to fix it. Test plan: ran optimized gelu test internally.
1 parent 533c1aa commit 019df69

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

kernels/optimized/cpu/op_gelu.cpp

+6-23
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <cmath>
1515

16+
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1617
#include <executorch/runtime/kernel/kernel_includes.h>
1718
#include <executorch/runtime/platform/assert.h>
1819

@@ -116,30 +117,12 @@ Tensor& opt_gelu_out(
116117
Tensor& out) {
117118
(void)context;
118119
ET_KERNEL_CHECK(
119-
context,
120-
tensors_have_same_shape_and_dtype(input, out),
121-
InvalidArgument,
122-
out);
120+
context, check_gelu_args(input, approximate, out), InvalidArgument, out);
123121

124-
// helper for generating the cases for different data types
125-
#define GELU(ctype, dtype) \
126-
case ScalarType::dtype: \
127-
gelu<ctype>(context, input, approximate, out); \
128-
break;
129-
130-
switch (input.scalar_type()) {
131-
// TODO support Double as well
132-
GELU(float, Float)
133-
default:
134-
ET_KERNEL_CHECK_MSG(
135-
context,
136-
false,
137-
InvalidArgument,
138-
out,
139-
"Unhandled dtype %" PRId8,
140-
static_cast<int8_t>(input.scalar_type()));
141-
}
142-
#undef GELU
122+
ET_SWITCH_FLOATHBF16_TYPES(
123+
input.scalar_type(), context, "gelu.out", CTYPE, [&]() {
124+
gelu<CTYPE>(context, input, approximate, out);
125+
});
143126

144127
return out;
145128
}

kernels/optimized/cpu/targets.bzl

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ _OPTIMIZED_ATEN_OPS = (
3333
"ovr_config//cpu:arm64": [
3434
"fbsource//third-party/sleef:sleef_arm",
3535
],
36-
}),
36+
}) + [
37+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
38+
],
3739
),
3840
op_target(
3941
name = "op_le",

0 commit comments

Comments
 (0)