Skip to content

Implement GLU using internal views to avoid copying #11295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 58 additions & 92 deletions kernels/portable/cpu/op_glu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
#include <cinttypes>
Expand All @@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType;

namespace {

double exp_overload(double d) {
return exp(d);
}

float exp_overload(float f) {
return expf(f);
}

/**
* In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
*/
// TODO: T146333648, refactor this as a common helper function
template <typename CTYPE_OUT>
void sigmoid_tensor(Tensor& out) {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto i : c10::irange(out.numel())) {
out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
}
}

/**
* Element-wise multiplication of the first half of `in` along the specified
* dimension and `out`, overwriting `out`.
*/
template <typename CTYPE_IN, typename CTYPE_OUT>
void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
size_t dim_length_in = static_cast<size_t>(in.size(dim));
size_t dim_length_out = static_cast<size_t>(out.size(dim));
size_t leading_dims = getLeadingDims(in, dim);
size_t trailing_dims = getTrailingDims(in, dim);

const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(leading_dims)) {
const CTYPE_IN* input_data =
input_data_base + i * dim_length_in * trailing_dims;
CTYPE_OUT* output_data =
output_data_base + i * dim_length_out * trailing_dims;
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
for (const auto k : c10::irange(trailing_dims)) {
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
}
input_data += trailing_dims;
output_data += trailing_dims;
}
}
}

/**
* Slice the tensor in the given dim, from start to end, assume tensor in and
* out have same shape and dtype, the dim is a non-negative number and start,
* end are valid non-negative number
*/
template <typename CTYPE_IN, typename CTYPE_OUT>
void slice_tensor(
const Tensor& in,
int64_t dim,
int64_t start,
int64_t end,
Tensor& out) {
size_t num_values = static_cast<size_t>(end - start);
size_t dim_length_in = static_cast<size_t>(in.size(dim));
size_t dim_length_out = static_cast<size_t>(out.size(dim));
size_t non_negative_start = static_cast<size_t>(start);
size_t leading_dims = getLeadingDims(in, dim);
size_t trailing_dims = getTrailingDims(in, dim);

const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(leading_dims)) {
const CTYPE_IN* input_data = input_data_base +
(i * dim_length_in + non_negative_start) * trailing_dims;
CTYPE_OUT* output_data =
output_data_base + i * dim_length_out * trailing_dims;
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
for (const auto k : c10::irange(trailing_dims)) {
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
}
input_data += trailing_dims;
output_data += trailing_dims;
}
}
}

/**
* Applies the gated linear unit function
*
Expand All @@ -120,11 +34,63 @@ void slice_tensor(
* 2. The output shall be in float types (Float, Double)
*/
template <typename CTYPE_IN, typename CTYPE_OUT>
Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
Tensor& glu_out_tensor(
KernelRuntimeContext& ctx,
const Tensor& self,
int64_t dim,
Tensor& out) {
const auto self_size = self.size(dim);
slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
sigmoid_tensor<CTYPE_OUT>(out);
mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
ET_KERNEL_CHECK(
ctx,
self.dim() <= static_cast<ssize_t>(kTensorDimensionLimit),
InvalidArgument,
out);
std::array<executorch::aten::SizesType, kTensorDimensionLimit> half_sizes;
std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin());
half_sizes[dim] /= 2;
TensorImpl first_half_impl(
self.scalar_type(),
self.dim(),
half_sizes.data(),
self.mutable_data_ptr(),
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
const_cast<executorch::aten::StridesType*>(self.strides().data()),
self.shape_dynamism());
TensorImpl second_half_impl(
self.scalar_type(),
self.dim(),
half_sizes.data(),
reinterpret_cast<char*>(self.mutable_data_ptr()) +
self.strides()[dim] * self_size / 2 * self.element_size(),
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
const_cast<executorch::aten::StridesType*>(self.strides().data()),
self.shape_dynamism());
Tensor first_half(&first_half_impl);
Tensor second_half(&second_half_impl);
ScalarType compute_type =
executorch::runtime::isFloatingType(self.scalar_type())
? self.scalar_type()
: ScalarType::Float;
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "glu.out";
ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
// TODO: rewrite this to be vectorization-capable.
const auto one = static_cast<decltype(val_a)>(1.0);
return val_a * (one / (one + std::exp(-val_b)));
},
ctx,
first_half,
utils::SupportedTensorDtypes::FLOATHBF16,
second_half,
utils::SupportedTensorDtypes::FLOATHBF16,
out,
utils::internal::SupportNoncontiguousTensors());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why didn't you pass the support_non_contiguous_tensors here as a template parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the argument exists because specifying it as a template parameter is messy; see other thread

});
return out;
}
} // namespace
Expand Down Expand Up @@ -158,7 +124,7 @@ Tensor& glu_out(

ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
});
});

Expand Down
22 changes: 16 additions & 6 deletions kernels/portable/cpu/util/broadcast_indexes_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s(
std::equal(lhs_begin, lhs_end, rhs_begin);
}

template <std::size_t kNumInputs>
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
class BroadcastIndexesIterator {
public:
using difference_type = ssize_t;
Expand All @@ -57,16 +57,20 @@ class BroadcastIndexesIterator {
template <typename... Args>
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
: output_dim_or_zero_if_no_broadcasting_(
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
...)
!support_noncontiguous_tensors &&
(sizes_match_ignoring_leading_1s(
args.sizes(),
output.sizes()) &&
...)
? 0
: output.dim()),
output_shape_(output.sizes()) {
static_assert(
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
"arguments!");
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
if (support_noncontiguous_tensors ||
output_dim_or_zero_if_no_broadcasting_ != 0) {
effective_input_broadcast_strides_ = {
effective_input_broadcast_stride(output, args)...};
}
Expand Down Expand Up @@ -249,11 +253,17 @@ class BroadcastIndexesIterator {
* Unlike looping using delinearize_index() and
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
* division and modulo operations on each iteration.
*
* The support_noncontiguous_tensors argument disables an optimization
* that causes the iterators not to respect strides in some
* cases. This optimization is normally safe because ExecuTorch
* tensors are contiguous.
*/
template <std::size_t kNumInputs>
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
class BroadcastIndexesRange {
public:
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
using iterator = internal::
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;

template <typename... Args>
BroadcastIndexesRange(const Tensor& output, const Args&... args)
Expand Down
Loading
Loading