Skip to content

Commit

Permalink
Fixes #37, lambda issue: missing HWY_ATTR, and cannot capture SVE in/…
Browse files Browse the repository at this point in the history
…out vectors.

PiperOrigin-RevId: 610260610
  • Loading branch information
jan-wassenberg authored and copybara-github committed Feb 26, 2024
1 parent 1243be7 commit 6a30858
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
hn::Transform(D(), x, size,
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
}

// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
Expand Down Expand Up @@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;

// Find max so we can subtract it below.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V max = vmin;
hn::Foreach(d, x, mask_pos, vmin,
[&max](D d, V v) { max = hn::Max(max, v); });
max = hn::MaxOfLanes(d, max); // broadcast
const size_t N = hn::Lanes(d);

// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
// cannot be lambda-captured.
// TODO(janwas): could be replaced with an hn::Accumulate algo.
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast

// Subtract max (avoid precision loss for large exponents) and exponentiate.
V sum = hn::Zero(d);
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
const V out = hn::Exp(d, hn::Sub(v, max));
// Also avoid hn::Transform because the additional `sum` output vector cannot
// be captured by a lambda.
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
return out;
});
hn::StoreN(out, d, x + idx, remaining);
}

// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
Expand All @@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;

const V inv_cap = hn::Set(d, 1.0f / cap);
const V vcap = hn::Set(d, cap);
const float inv_cap = 1.0f / cap;

hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(hn::Set(d, cap),
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
});
}

Expand Down

0 comments on commit 6a30858

Please sign in to comment.