|
28 | 28 | #include "gemma/configs.h"
|
29 | 29 | #include "gemma/gemma.h"
|
30 | 30 | #include "gemma/weights.h"
|
31 |
| -// Placeholder for internal test4, do not remove |
32 | 31 | #include "paligemma/image.h"
|
33 | 32 | #include "util/allocator.h"
|
34 | 33 | #include "util/basics.h"
|
@@ -1310,42 +1309,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
1310 | 1309 | 0.0f);
|
1311 | 1310 | }
|
1312 | 1311 |
|
1313 |
| - const size_t vocab_size = model.Config().vocab_size; |
1314 |
| - const double gen_start = hwy::platform::Now(); |
1315 |
| - for (size_t gen = 0; gen < max_generated_tokens; ++gen) { |
1316 |
| - // Decode generates one token per query and increments queries_mutable_pos. |
1317 |
| - Transformer(QueriesToken(gen_tokens.data(), num_queries), |
1318 |
| - queries_mutable_pos, queries_prefix_end, weights, activations, |
1319 |
| - div_seq_len, kv_caches, runtime_config.layers_output, |
1320 |
| - runtime_config.activations_observer); |
1321 |
| - // queries_pos are incremented by Transformer. |
1322 |
| - |
1323 |
| - bool all_queries_eos = true; |
1324 |
| - { |
1325 |
| - PROFILER_ZONE("Gen.EmbeddingMatmul"); |
1326 |
| - // Compute logits from last layer activations. |
1327 |
| - MatMul(ConstMatFromBatch(num_queries, activations.x), |
1328 |
| - ConstMatFromWeights(weights.embedder_input_embedding), |
1329 |
| - /*add=*/nullptr, *activations.env, |
1330 |
| - RowPtrFromBatch(activations.logits)); |
1331 |
| - } |
1332 |
| - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); |
1333 |
| - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { |
1334 |
| - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); |
1335 |
| - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); |
1336 |
| - const TokenAndProb tp = sample_token(logits, vocab_size); |
1337 |
| - timing_info.NotifyGenerated(prefill_start, gen_start); |
1338 |
| - |
1339 |
| - const bool is_eos = |
1340 |
| - token_streamer(query_idx_start + query_idx, |
1341 |
| - queries_mutable_pos[query_idx], tp.token, tp.prob); |
1342 |
| - all_queries_eos &= is_eos; |
1343 |
| - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; |
1344 |
| - } |
1345 |
| - if (all_queries_eos) break; |
1346 |
| - } // foreach token to generate |
1347 |
| - |
1348 |
| - timing_info.NotifyGenerateDone(gen_start); |
| 1312 | + { |
| 1313 | + const size_t vocab_size = model.Config().vocab_size; |
| 1314 | + const double gen_start = hwy::platform::Now(); |
| 1315 | + for (size_t gen = 0; gen < max_generated_tokens; ++gen) { |
| 1316 | + // Decode generates one token per query and increments |
| 1317 | + // queries_mutable_pos. |
| 1318 | + Transformer(QueriesToken(gen_tokens.data(), num_queries), |
| 1319 | + queries_mutable_pos, queries_prefix_end, weights, activations, |
| 1320 | + div_seq_len, kv_caches, runtime_config.layers_output, |
| 1321 | + runtime_config.activations_observer); |
| 1322 | + // queries_pos are incremented by Transformer. |
| 1323 | + |
| 1324 | + bool all_queries_eos = true; |
| 1325 | + { |
| 1326 | + PROFILER_ZONE("Gen.EmbeddingMatmul"); |
| 1327 | + // Compute logits from last layer activations. |
| 1328 | + MatMul(ConstMatFromBatch(num_queries, activations.x), |
| 1329 | + ConstMatFromWeights(weights.embedder_input_embedding), |
| 1330 | + /*add=*/nullptr, *activations.env, |
| 1331 | + RowPtrFromBatch(activations.logits)); |
| 1332 | + } |
| 1333 | + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); |
| 1334 | + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { |
| 1335 | + float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); |
| 1336 | + MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, |
| 1337 | + vocab_size); |
| 1338 | + const TokenAndProb tp = sample_token(logits, vocab_size); |
| 1339 | + timing_info.NotifyGenerated(prefill_start, gen_start); |
| 1340 | + |
| 1341 | + const bool is_eos = |
| 1342 | + token_streamer(query_idx_start + query_idx, |
| 1343 | + queries_mutable_pos[query_idx], tp.token, tp.prob); |
| 1344 | + all_queries_eos &= is_eos; |
| 1345 | + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; |
| 1346 | + } |
| 1347 | + if (all_queries_eos) break; |
| 1348 | + } // foreach token to generate |
| 1349 | + timing_info.NotifyGenerateDone(gen_start); |
| 1350 | + } |
1349 | 1351 | }
|
1350 | 1352 |
|
1351 | 1353 | template <typename T>
|
|
0 commit comments