Skip to content

Commit 5f17eb8

Browse files
apoorvreddycopybara-github
authored andcommitted
internal change
PiperOrigin-RevId: 709234221
1 parent a133b3d commit 5f17eb8

File tree

2 files changed

+56
-37
lines changed

2 files changed

+56
-37
lines changed

gemma/gemma-inl.h

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "gemma/configs.h"
2929
#include "gemma/gemma.h"
3030
#include "gemma/weights.h"
31-
// Placeholder for internal test4, do not remove
3231
#include "paligemma/image.h"
3332
#include "util/allocator.h"
3433
#include "util/basics.h"
@@ -1310,42 +1309,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
13101309
0.0f);
13111310
}
13121311

1313-
const size_t vocab_size = model.Config().vocab_size;
1314-
const double gen_start = hwy::platform::Now();
1315-
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
1316-
// Decode generates one token per query and increments queries_mutable_pos.
1317-
Transformer(QueriesToken(gen_tokens.data(), num_queries),
1318-
queries_mutable_pos, queries_prefix_end, weights, activations,
1319-
div_seq_len, kv_caches, runtime_config.layers_output,
1320-
runtime_config.activations_observer);
1321-
// queries_pos are incremented by Transformer.
1322-
1323-
bool all_queries_eos = true;
1324-
{
1325-
PROFILER_ZONE("Gen.EmbeddingMatmul");
1326-
// Compute logits from last layer activations.
1327-
MatMul(ConstMatFromBatch(num_queries, activations.x),
1328-
ConstMatFromWeights(weights.embedder_input_embedding),
1329-
/*add=*/nullptr, *activations.env,
1330-
RowPtrFromBatch(activations.logits));
1331-
}
1332-
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
1333-
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
1334-
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
1335-
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
1336-
const TokenAndProb tp = sample_token(logits, vocab_size);
1337-
timing_info.NotifyGenerated(prefill_start, gen_start);
1338-
1339-
const bool is_eos =
1340-
token_streamer(query_idx_start + query_idx,
1341-
queries_mutable_pos[query_idx], tp.token, tp.prob);
1342-
all_queries_eos &= is_eos;
1343-
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
1344-
}
1345-
if (all_queries_eos) break;
1346-
} // foreach token to generate
1347-
1348-
timing_info.NotifyGenerateDone(gen_start);
1312+
{
1313+
const size_t vocab_size = model.Config().vocab_size;
1314+
const double gen_start = hwy::platform::Now();
1315+
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
1316+
// Decode generates one token per query and increments
1317+
// queries_mutable_pos.
1318+
Transformer(QueriesToken(gen_tokens.data(), num_queries),
1319+
queries_mutable_pos, queries_prefix_end, weights, activations,
1320+
div_seq_len, kv_caches, runtime_config.layers_output,
1321+
runtime_config.activations_observer);
1322+
// queries_pos are incremented by Transformer.
1323+
1324+
bool all_queries_eos = true;
1325+
{
1326+
PROFILER_ZONE("Gen.EmbeddingMatmul");
1327+
// Compute logits from last layer activations.
1328+
MatMul(ConstMatFromBatch(num_queries, activations.x),
1329+
ConstMatFromWeights(weights.embedder_input_embedding),
1330+
/*add=*/nullptr, *activations.env,
1331+
RowPtrFromBatch(activations.logits));
1332+
}
1333+
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
1334+
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
1335+
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
1336+
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
1337+
vocab_size);
1338+
const TokenAndProb tp = sample_token(logits, vocab_size);
1339+
timing_info.NotifyGenerated(prefill_start, gen_start);
1340+
1341+
const bool is_eos =
1342+
token_streamer(query_idx_start + query_idx,
1343+
queries_mutable_pos[query_idx], tp.token, tp.prob);
1344+
all_queries_eos &= is_eos;
1345+
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
1346+
}
1347+
if (all_queries_eos) break;
1348+
} // foreach token to generate
1349+
timing_info.NotifyGenerateDone(gen_start);
1350+
}
13491351
}
13501352

13511353
template <typename T>

util/app.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
247247
bool multiturn;
248248
Path image_file;
249249

250+
bool use_heuristic_drafter;
251+
int heuristic_drafter_draft_length;
252+
int heuristic_drafter_history_context_length;
253+
int heuristic_drafter_robust_context_length;
254+
int heuristic_drafter_max_match_length;
255+
250256
// Returns error string or nullptr if OK.
251257
const char* Validate() const {
252258
if (max_generated_tokens > gcpp::kSeqLen) {
@@ -285,6 +291,17 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
285291
runtime_config.decode_qbatch_size = decode_qbatch_size;
286292
runtime_config.temperature = temperature;
287293
runtime_config.top_k = top_k;
294+
runtime_config.use_heuristic_drafter = use_heuristic_drafter;
295+
if (use_heuristic_drafter) {
296+
runtime_config.heuristic_drafter_draft_length =
297+
heuristic_drafter_draft_length;
298+
runtime_config.heuristic_drafter_history_context_length =
299+
heuristic_drafter_history_context_length;
300+
runtime_config.heuristic_drafter_robust_context_length =
301+
heuristic_drafter_robust_context_length;
302+
runtime_config.heuristic_drafter_max_match_length =
303+
heuristic_drafter_max_match_length;
304+
}
288305
}
289306
};
290307

0 commit comments

Comments
 (0)