Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing attention mask as optional input in text_generator_main.cc. #430

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cc_binary(
"@com_google_absl//absl/strings",
"@com_google_sentencepiece//:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
Expand Down
43 changes: 39 additions & 4 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "ai_edge_torch/generative/examples/cpp/utils.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/genai/genai_ops.h"
#include "tensorflow/lite/interpreter.h"
Expand Down Expand Up @@ -182,8 +183,11 @@ tflite::SignatureRunner* GetPrefillRunner(
}
TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())
->input_tensor("input_pos");
// The expected shape for input position is [Seq].
int seq_size = input_pos->dims->data[0];
// The expected shape for input position is [Seq](from ai_edge_torch) or
// [Batch, Seq](from ai_edge_jax).
MINIMAL_CHECK(input_pos->dims->size == 1 || input_pos->dims->size == 2);
int seq_size = input_pos->dims->size == 1 ? input_pos->dims->data[0]
: input_pos->dims->data[1];
if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) {
runner = interpreter->GetSignatureRunner(key->c_str());
delta = seq_size - num_input_tokens;
Expand Down Expand Up @@ -229,6 +233,17 @@ int GreedySampler(const TfLiteTensor* logits) {
return max_index;
}

// Scans through the input tensor names to check if the attention mask is
// passed as an input tensor.
bool AttentionMaskInInput(tflite::SignatureRunner* runner) {
for (int i = 0; i < runner->input_names().size(); ++i) {
if (strcmp(runner->input_names()[i], "attention_mask") == 0) {
return true;
}
}
return false;
}

} // namespace

int main(int argc, char* argv[]) {
Expand Down Expand Up @@ -269,15 +284,25 @@ int main(int argc, char* argv[]) {
GetDecodeRunner(interpreter.get(), kv_cache);
MINIMAL_CHECK(decode_runner != nullptr);

// Check if the attention mask is passed as an input tensor.
bool attention_mask_as_input = AttentionMaskInInput(prefill_runner);
// Get Input Tensors for each of the runners.
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* prefill_input = prefill_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int32
// Shape: [Seq] or [Batch, Seq], Dtype: int32
TfLiteTensor* prefill_input_pos = prefill_runner->input_tensor("input_pos");
// Shape: [Batch, 1, Seq], Dtype: int32
TfLiteTensor* prefill_input_mask =
attention_mask_as_input ? prefill_runner->input_tensor("attention_mask")
: nullptr;
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* decode_input = decode_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int32
TfLiteTensor* decode_input_pos = decode_runner->input_tensor("input_pos");
// Shape: [Batch, 1, Seq], Dtype: int32
TfLiteTensor* decode_input_mask =
attention_mask_as_input ? decode_runner->input_tensor("attention_mask")
: nullptr;
// shape: [Batch, kv_cache_max, num_query_groups, head_dim]
TfLiteTensor* kv_cache_k_0 = decode_runner->input_tensor("kv_cache_k_0");

Expand All @@ -290,9 +315,12 @@ int main(int argc, char* argv[]) {
std::min(static_cast<int>(prompt_tokens.size()), max_seq_size);
std::memset(prefill_input->data.i32, 0, prefill_input->bytes);
std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes);
for (int i = 0; i < prefill_seq_size - 1; ++i) {
if (prefill_input_mask)
std::memset(prefill_input_mask->data.b, 0, prefill_input_mask->bytes);
for (int i = 0; i < prefill_seq_size; ++i) {
prefill_input->data.i32[i] = prompt_tokens[i];
prefill_input_pos->data.i32[i] = i;
if (prefill_input_mask) prefill_input_mask->data.b[i] = true;
}
MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);

Expand All @@ -305,13 +333,20 @@ int main(int argc, char* argv[]) {
std::min(max_decode_steps, kv_cache_max_size - prefill_seq_size);
MINIMAL_CHECK(decode_steps > 0);

if (decode_input_mask) {
std::memset(decode_input_mask->data.b, 0, decode_input_mask->bytes);
std::memcpy(decode_input_mask->data.b, prefill_input_mask->data.b,
sizeof(bool) * prefill_seq_size);
}
std::vector<int> output_tokens;
output_tokens.reserve(decode_steps);
int next_token = prompt_tokens[prefill_seq_size - 1];
int next_position = prefill_seq_size - 1;
for (int i = 0; i < decode_steps; ++i) {
decode_input->data.i32[0] = next_token;
decode_input_pos->data.i32[0] = next_position;
if (decode_input_mask) decode_input_mask->data.b[next_position] = true;

MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk);
next_token = GreedySampler(decode_runner->output_tensor("logits"));
output_tokens.push_back(next_token);
Expand Down
Loading