Skip to content

Commit 4a924f1

Browse files
Merge pull request #527 from ufownl:feature/gemma2_secondary_eos
PiperOrigin-RevId: 740327973
2 parents 6300c12 + d42deaa commit 4a924f1

File tree

3 files changed

+9
-23
lines changed

3 files changed

+9
-23
lines changed

gemma/configs.cc

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ static ModelConfig ConfigBaseGemmaV2() {
3535
ModelConfig config = ConfigNoSSM();
3636
config.att_cap = 50.0f;
3737
config.final_cap = 30.0f;
38+
config.eos_id = 1;
39+
config.secondary_eos_id = 107;
3840
return config;
3941
}
4042

gemma/gemma-inl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14271427
// Sanity check: prompts should not be empty, nor start with EOS.
14281428
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
14291429
const PromptTokens& prompt = queries_prompt[query_idx];
1430-
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
1430+
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
14311431
}
14321432

14331433
const size_t num_queries = queries_prompt.size();
@@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
16151615
} // namespace gcpp
16161616
HWY_AFTER_NAMESPACE();
16171617

1618-
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
1618+
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_

gemma/run.cc

+5-21
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
8585
size_t abs_pos = 0; // across turns
8686
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
8787
size_t prompt_size = 0;
88-
bool end_of_turn_seen = false;
8988

9089
std::mt19937 gen;
9190
InitGenerator(args, gen);
@@ -118,12 +117,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
118117
// callback function invoked for each generated token.
119118
auto stream_token = [&](int token, float) {
120119
++abs_pos;
121-
if (model.GetModelConfig().IsEOS(token)) {
122-
if (app.verbosity >= 2) {
123-
std::cout << "\n[ End ]\n";
124-
}
125-
return true;
126-
}
127120
const bool in_prompt = tokens_generated_this_turn < prompt_size;
128121
const bool first_response_token = tokens_generated_this_turn == prompt_size;
129122
++tokens_generated_this_turn;
@@ -132,6 +125,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
132125
std::cerr << "." << std::flush;
133126
}
134127
return true;
128+
} else if (model.GetModelConfig().IsEOS(token)) {
129+
if (app.verbosity >= 2) {
130+
std::cout << "\n[ End ]\n";
131+
}
132+
return true;
135133
}
136134
std::string token_text;
137135
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
@@ -141,13 +139,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
141139
std::cout << "\n\n";
142140
}
143141
}
144-
if (token_text == "<end_of_turn>") {
145-
// We don't want to show the <end_of_turn> token to the user.
146-
// We also need to remember that we've seen it, so that we can rewind
147-
// abs_pos appropriately. We expect EOS as the next token.
148-
end_of_turn_seen = true;
149-
return true;
150-
}
151142
std::cout << token_text << std::flush;
152143
return true;
153144
};
@@ -233,13 +224,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
233224
HWY_ASSERT(abs_pos > 0);
234225
abs_pos--;
235226
}
236-
if (end_of_turn_seen && abs_pos > 0) {
237-
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
238-
// more, because we will prepend it again to the prompt in
239-
// WrapAndTokenize.
240-
abs_pos--;
241-
}
242-
end_of_turn_seen = false;
243227
}
244228
}
245229

0 commit comments

Comments
 (0)