Skip to content

Commit

Permalink
Add parameter for base_frequency to CreateInvTimeScale().
Browse files Browse the repository at this point in the history
Extract a few local variables to make code easier to read (hopefully).

PiperOrigin-RevId: 716204947
  • Loading branch information
danielkeysers authored and copybara-github committed Jan 23, 2025
1 parent a133b3d commit 0798491
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
22 changes: 12 additions & 10 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ struct Activations {
size_t seq_len;
size_t cache_pos_size = 0;

static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
PostQKType post_qk) {
static RowVectorBatch<float> CreateInvTimescale(
size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) {
const size_t rope_dim =
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents =
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
const double freq_exponents =
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably.
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents);
inv_timescale.Batch(0)[dim] =
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
}
return inv_timescale;
}
Expand All @@ -94,19 +95,20 @@ struct Activations {
const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;

x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
Extents2D(batch_size, heads * layer_config.QStride()));
if (vocab_size > 0) {
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
}

pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));

bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
Expand All @@ -122,7 +124,7 @@ struct Activations {
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
}

inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
inv_timescale = CreateInvTimescale(qkv_dim, post_qk);

env = std::make_unique<MatMulEnv>(pools);
}
Expand Down
40 changes: 21 additions & 19 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,17 @@ class GemmaAttention {
template <typename U>
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
const float mul, U* qk_out) {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const();
// PostQKType::Rope
(void)layer;
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
hwy::CopyBytes(qk, qk_out, layer_config_.qkv_dim * sizeof(*qk));
Rope(qk_out, layer_config_.qkv_dim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, layer_config_.qkv_dim);
hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, qkv_dim);
} else {
RopeAndMulBy(mul, qk, layer_config_.qkv_dim, inv_timescale, pos, qk_out);
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
}
}

Expand Down Expand Up @@ -334,13 +336,14 @@ class GemmaAttention {
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
const size_t qkv_dim = layer_config_.qkv_dim;
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, layer_config_.qkv_dim);
const float score = Dot(q, k, qkv_dim);
head_att[pos] = score;
}
} else {
Expand All @@ -349,7 +352,7 @@ class GemmaAttention {
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, layer_config_.qkv_dim);
const float score = Dot(q, k, qkv_dim);
head_att[pos % activations_.seq_len] = score;
}
}
Expand All @@ -364,26 +367,27 @@ class GemmaAttention {
const hwy::Divisor& div_seq_len,
const KVCache& kv_cache,
float* HWY_RESTRICT att_out) const {
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
const size_t qkv_dim = layer_config_.qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));

if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
MulByConstAndAdd(head_att[pos], v, att_out, layer_config_.qkv_dim);
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
layer_config_.qkv_dim);
qkv_dim);
}
}
}
Expand All @@ -403,8 +407,8 @@ class GemmaAttention {
const size_t interleaved_idx = task / layer_config_.heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;
Expand Down Expand Up @@ -435,15 +439,14 @@ class GemmaAttention {

float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * layer_config_.qkv_dim;
head * qkv_dim;
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
div_seq_len_, kv_cache, att_out);
});
}

// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim
// (`layer_config_.qkv_dim`) into output (`layer_out`).
// head_dim (`qkv_dim`) into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.SumHeads");
// att_weights and att_out are concatenated heads, each of length
Expand Down Expand Up @@ -630,13 +633,12 @@ class VitAttention {
}

// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim
// (`layer_config_.qkv_dim`) into output (`att_sums`).
// head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
// att_weights and att_out are concatenated heads, each of length
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
Expand Down

0 comments on commit 0798491

Please sign in to comment.