diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 7f7a122a4b..d90a60718e 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -136,6 +136,9 @@ struct ModelState { int new_state_c_idx; int new_state_h_idx; int mfccs_idx; + + std::vector acoustic_exec_plan; + std::vector mfcc_exec_plan; #endif ModelState(); @@ -446,6 +449,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logi memcpy(interpreter->typed_tensor(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size); memcpy(interpreter->typed_tensor(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size); + interpreter->SetExecutionPlan(acoustic_exec_plan); TfLiteStatus status = interpreter->Invoke(); if (status != kTfLiteOk) { std::cerr << "Error running session: " << status << "\n"; @@ -501,6 +505,7 @@ ModelState::compute_mfcc(const vector& samples, vector& mfcc_outpu input_samples[i] = samples[i]; } + interpreter->SetExecutionPlan(mfcc_exec_plan); TfLiteStatus status = interpreter->Invoke(); if (status != kTfLiteOk) { std::cerr << "Error running session: " << status << "\n"; @@ -592,6 +597,47 @@ tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name) { return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)]; } + +void push_back_if_not_present(std::deque& list, int value) { + if (std::find(list.begin(), list.end(), value) == list.end()) { + list.push_back(value); + } +} + +// Backwards BFS on the node DAG. At each iteration we get the next tensor id +// from the frontier list, then for each node which has that tensor id as an +// output, add it to the parent list, and add its input tensors to the frontier +// list. Because we start from the final tensor and work backwards to the inputs, +// the parents list is constructed in reverse, adding elements to its front. +std::vector +tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id) +{ + std::deque parents; + std::deque frontier; + frontier.push_back(tensor_id); + while (!frontier.empty()) { + int next_tensor_id = frontier.front(); + frontier.pop_front(); + // Find all nodes that have next_tensor_id as an output + for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) { + TfLiteNode node = interpreter->node_and_registration(node_id)->first; + // Search node outputs for the tensor we're looking for + for (int i = 0; i < node.outputs->size; ++i) { + if (node.outputs->data[i] == next_tensor_id) { + // This node is part of the parent tree, add it to the parent list and + // add its input tensors to the frontier list + parents.push_front(node_id); + for (int j = 0; j < node.inputs->size; ++j) { + push_back_if_not_present(frontier, node.inputs->data[j]); + } + } + } + } + } + + return std::vector(parents.begin(), parents.end()); +} + #endif int @@ -746,6 +792,23 @@ DS_CreateModel(const char* aModelPath, model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h"); model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs"); + // When we call Interpreter::Invoke, the whole graph is executed by default, + // which means every time compute_mfcc is called the entire acoustic model is + // also executed. To workaround that problem, we walk up the dependency DAG + // from the mfccs output tensor to find all the relevant nodes required for + // feature computation, building an execution plan that runs just those nodes. + auto mfcc_plan = tflite_find_parent_node_ids(model->interpreter.get(), model->mfccs_idx); + auto orig_plan = model->interpreter->execution_plan(); + + // Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan + auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) { + return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end(); + }); + orig_plan.erase(erase_begin, orig_plan.end()); + + model->acoustic_exec_plan = std::move(orig_plan); + model->mfcc_exec_plan = std::move(mfcc_plan); + TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims; model->n_steps = dims_input_node->data[1];