Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709234221
  • Loading branch information
apoorvreddy authored and copybara-github committed Jan 23, 2025
1 parent a133b3d commit 5f17eb8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
76 changes: 39 additions & 37 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
// Placeholder for internal test4, do not remove
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/basics.h"
Expand Down Expand Up @@ -1310,42 +1309,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
0.0f);
}

const size_t vocab_size = model.Config().vocab_size;
const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
// Decode generates one token per query and increments queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, weights, activations,
div_seq_len, kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.

bool all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start);

const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
}
if (all_queries_eos) break;
} // foreach token to generate

timing_info.NotifyGenerateDone(gen_start);
{
const size_t vocab_size = model.Config().vocab_size;
const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, weights, activations,
div_seq_len, kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.

bool all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start);

const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
}
if (all_queries_eos) break;
} // foreach token to generate
timing_info.NotifyGenerateDone(gen_start);
}
}

template <typename T>
Expand Down
17 changes: 17 additions & 0 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn;
Path image_file;

bool use_heuristic_drafter;
int heuristic_drafter_draft_length;
int heuristic_drafter_history_context_length;
int heuristic_drafter_robust_context_length;
int heuristic_drafter_max_match_length;

// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_generated_tokens > gcpp::kSeqLen) {
Expand Down Expand Up @@ -285,6 +291,17 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
runtime_config.use_heuristic_drafter = use_heuristic_drafter;
if (use_heuristic_drafter) {
runtime_config.heuristic_drafter_draft_length =
heuristic_drafter_draft_length;
runtime_config.heuristic_drafter_history_context_length =
heuristic_drafter_history_context_length;
runtime_config.heuristic_drafter_robust_context_length =
heuristic_drafter_robust_context_length;
runtime_config.heuristic_drafter_max_match_length =
heuristic_drafter_max_match_length;
}
}
};

Expand Down

0 comments on commit 5f17eb8

Please sign in to comment.