diff --git a/gemma/configs.h b/gemma/configs.h index 5bfa5187..10de09ec 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -97,6 +97,7 @@ struct ConfigGemma7B { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; // SSM config. static constexpr int kConv1dWidth = 0; @@ -128,6 +129,7 @@ struct ConfigGemma2B { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; // SSM config. static constexpr int kConv1dWidth = 0; @@ -187,6 +189,7 @@ struct ConfigGriffin2B { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; // SSM config. static constexpr int kConv1dWidth = 4; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index e602aacc..41e7534c 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -87,6 +87,7 @@ struct Layer { static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr bool kPostNormScale = TConfig::kPostNormScale; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = @@ -121,6 +122,8 @@ struct Layer { ArrayT linear_w; ArrayT pre_attention_norm_scale; ArrayT pre_ffw_norm_scale; + ArrayT post_attention_norm_scale; + ArrayT post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases; @@ -269,6 +272,10 @@ hwy::AlignedFreeUniquePtr LoadWeights( SCALE_WEIGHTS(linear_w); READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale); + if (TConfig::kPostNormScale) { + READ_WEIGHTS(post_attention_norm_scale); + READ_WEIGHTS(post_ffw_norm_scale); + } if (TConfig::kFFBiases) { READ_WEIGHTS(ffw_gating_biases); READ_WEIGHTS(ffw_output_biases); @@ -311,6 +318,7 @@ struct CompressedLayer { static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize; static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth; static constexpr bool kFFBiases = TLayer::kFFBiases; + static constexpr bool kPostNormScale = TConfig::kPostNormScale; static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim; static constexpr size_t kGriffinDim = TLayer::kGriffinDim; @@ -346,6 +354,9 @@ struct CompressedLayer { // We don't yet have an RMSNorm that accepts all WeightT. ArrayT pre_attention_norm_scale; ArrayT pre_ffw_norm_scale; + ArrayT + post_attention_norm_scale; + ArrayT post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases; @@ -949,6 +960,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, pool.Run(0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + if (TConfig::kPostNormScale) { + RMSNormInplace(layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data() + token_idx * kModelDim, + kModelDim); + } AddFrom(activations.att_post2.data() + token_idx * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); RMSNorm(activations.x.data() + token_idx * kModelDim, @@ -958,6 +974,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, }); FFW(activations, num_tokens, layer_weights, pool); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + if (TConfig::kPostNormScale) { + RMSNormInplace(layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data() + token_idx * kModelDim, + kModelDim); + } AddFrom(activations.ffw_out.data() + token_idx * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); } @@ -1005,10 +1026,18 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, pool); } + if (TConfig::kPostNormScale) { + RMSNormInplace(layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); + } AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool); + if (TConfig::kPostNormScale) { + RMSNormInplace(layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); + } AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); if (layers_output != nullptr) { std::string block_name = "blocks." + std::to_string(layer); @@ -1336,6 +1365,10 @@ void ForEachTensor(const Weights* weights, CALL_FUNC("gr_a", griffin.a); } CALL_FUNC("pre_att_ns", pre_attention_norm_scale); + if (TConfig::kPostNormScale) { + CALL_FUNC("post_att_ns", post_attention_norm_scale); + CALL_FUNC("post_ff_ns", post_ffw_norm_scale); + } if (TConfig::kFFBiases) { CALL_FUNC("ffw_gat_b", ffw_gating_biases);