diff --git a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt index 93b35a697c..69b1ed2784 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. # model sharding with custom op -set(CUSTOM_OP_SRCS_FILE +set(CUSTOM_OP_SRCS_FILE "${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp" ) add_library(custom_ops ${CUSTOM_OP_SRCS_FILE}) @@ -45,7 +45,7 @@ list( # build qnn llama3.2 1b runner add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs}) target_include_directories( - qnn_llama3_2_runner PUBLIC ${_common_include_directories} + qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump ) target_link_libraries( @@ -58,6 +58,8 @@ target_link_libraries( gflags re2::re2 custom_ops + etdump + ${FLATCCRT_LIB} ) target_compile_options( qnn_llama3_2_runner PUBLIC ${_common_compile_options} diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp index 2af882580e..64e6106c93 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp @@ -51,6 +51,11 @@ DEFINE_int32( DEFINE_double(logits_scale, 0.0, "Logits scale"); DEFINE_int32(logits_offset, 0, "Logits offset"); +DEFINE_bool( + gen_etdump, + false, + "false: Disable ET dump/ True: Enable ET dump (default: false)"); + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -61,7 +66,8 @@ int main(int argc, char** argv) { FLAGS_logits_scale, FLAGS_logits_offset, FLAGS_temperature, - FLAGS_eval_mode); + FLAGS_eval_mode, + FLAGS_gen_etdump); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index 02a53861b8..37f9e0c427 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -43,7 +43,8 @@ Runner::Runner( const float logits_scale, const int32_t logits_offset, const float temperature, - const int eval_mode) + const int eval_mode, + const bool gen_etdump) : n_bos_(1), n_eos_(1), tokenizer_path_(tokenizer_path), @@ -58,6 +59,28 @@ Runner::Runner( } ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); ET_LOG(Info, "eval mode=%d", eval_mode); + if (gen_etdump) { + gen_etdump_ = true; + switch (eval_mode) { + case EvalMode::kPrefill: + prefill_dump_ = std::make_unique(); + break; + case EvalMode::kKVCached: + decode_dump_ = std::make_unique(); + break; + case EvalMode::kHybrid: + prefill_dump_ = std::make_unique(); + decode_dump_ = std::make_unique(); + break; + default: + ET_CHECK_MSG(false, "Unsupported eval mode"); + break; + } + std::string etdump_dir = + models_path[0].substr(0, models_path[0].find_last_of("/\\") + 1); + prefill_etdump_path_ = etdump_dir + "prefill_etdump.etdp"; + decode_etdump_path_ = etdump_dir + "decode_etdump.etdp"; + } } bool Runner::is_loaded() const { @@ -95,9 +118,17 @@ Error Runner::load() { for (std::shared_ptr& module : modules_) { if (!prefill_forward_name_.empty()) { + if (gen_etdump_) { + ET_CHECK_OK_OR_RETURN_ERROR( + module->load_method(prefill_forward_name_, prefill_dump_.get())); + } ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_)); } if (!kv_forward_name_.empty()) { + if (gen_etdump_) { + ET_CHECK_OK_OR_RETURN_ERROR( + module->load_method(kv_forward_name_, decode_dump_.get())); + } ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_)); } } @@ -424,6 +455,8 @@ Error Runner::generate( stats_.num_prompt_tokens = num_prompt_tokens; stats_.num_generated_tokens = pos - num_prompt_tokens; + if (gen_etdump_) + gen_etdump_data(); printReport(stats_); if (stats_callback) { stats_callback(stats_); @@ -432,6 +465,22 @@ Error Runner::generate( return Error::Ok; } +void Runner::gen_etdump_data() { + // dump the prefill and decode etdump data + if (prefill_dump_.get() != nullptr) { + torch::executor::etdump_result result = prefill_dump_->get_etdump_data(); + FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+"); + fwrite(result.buf, 1, result.size, ptr); + fclose(ptr); + } + if (decode_dump_.get() != nullptr) { + torch::executor::etdump_result result = decode_dump_->get_etdump_data(); + FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+"); + fwrite(result.buf, 1, result.size, ptr); + fclose(ptr); + } +} + namespace { void printReport(const Runner::Stats& stats) { printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h index 75ad640219..ebe4be0cc8 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -32,7 +33,8 @@ class Runner { const float logits_scale, const int32_t logits_offset, const float temperature, - const int eval_mode); + const int eval_mode, + const bool gen_etdump); struct Stats { // Scaling factor for timestamps - in this case, we use ms. @@ -71,6 +73,7 @@ class Runner { void stop(); std::vector> get_methods_meta(std::string& method_name); + void gen_etdump_data(); private: template @@ -98,6 +101,11 @@ class Runner { float temperature_; std::unique_ptr tokenizer_; std::unique_ptr sampler_; + std::unique_ptr prefill_dump_; + std::unique_ptr decode_dump_; + bool gen_etdump_ = false; + std::string prefill_etdump_path_; + std::string decode_etdump_path_; Stats stats_; std::unique_ptr io_mem_; EvalMode eval_mode_; diff --git a/examples/qualcomm/oss_scripts/llama3_2/targets.bzl b/examples/qualcomm/oss_scripts/llama3_2/targets.bzl index 64adc7eca9..5931165dfd 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/targets.bzl +++ b/examples/qualcomm/oss_scripts/llama3_2/targets.bzl @@ -28,6 +28,7 @@ def define_common_targets(): "//executorch/extension/llm/tokenizer:bpe_tokenizer", "//executorch/extension/evalue_util:print_evalue", "//executorch/backends/qualcomm/runtime:runtime", + "//executorch/devtools/etdump:etdump_flatcc", ], external_deps = [ "gflags",