From 74ceb28a472d3093b02b14f73bc0c0a1bc02a4fb Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 24 Dec 2024 00:40:00 +0000 Subject: [PATCH] Fixed according to review comments --- src/cpp/src/llm_pipeline.cpp | 8 ++--- src/cpp/src/llm_pipeline_static.cpp | 46 ++++++++++++++++------------- src/cpp/src/llm_pipeline_static.hpp | 14 +++++---- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 2c0f3a2635..a76618d139 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -570,12 +570,12 @@ ov::genai::LLMPipeline::LLMPipeline( auto& scheduler_config = properties.at(ov::genai::scheduler_config.name()).as(); m_pimpl = std::make_unique(models_path, tokenizer, scheduler_config, device, config_without_scheduler_config); } else if ("NPU" == device) { - m_pimpl = StaticLLMPipelineFactory::create(models_path, tokenizer, device, properties); + m_pimpl = static_llm::LLMPipelineFactory::create(models_path, tokenizer, device, properties); } else { m_pimpl = std::make_unique(models_path, tokenizer, device, properties); } auto stop_time = std::chrono::steady_clock::now(); - m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); + m_pimpl->m_load_time_ms = static_cast(std::chrono::duration_cast(stop_time - start_time).count()); } ov::genai::LLMPipeline::LLMPipeline( @@ -590,12 +590,12 @@ ov::genai::LLMPipeline::LLMPipeline( auto& scheduler_config = config.at(ov::genai::scheduler_config.name()).as(); m_pimpl = std::make_unique(models_path, scheduler_config, device, config_without_scheduler_config); } else if ("NPU" == device) { - m_pimpl = StaticLLMPipelineFactory::create(models_path, device, config); + m_pimpl = static_llm::LLMPipelineFactory::create(models_path, device, config); } else { m_pimpl = std::make_unique(models_path, device, config); } auto stop_time = std::chrono::steady_clock::now(); - m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); + m_pimpl->m_load_time_ms = static_cast(std::chrono::duration_cast(stop_time - start_time).count()); } ov::genai::GenerationConfig ov::genai::LLMPipeline::get_generation_config() const { diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index ae6c8ebf79..fd417607fe 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -629,8 +629,9 @@ void copy_columns_by_row_chunks(const ov::Tensor& src, ov::Tensor& dst) { namespace ov { namespace genai { +namespace static_llm { -SMStaticLLMPipeline::SMStaticLLMPipeline( +StatefulLLMPipeline::StatefulLLMPipeline( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, const std::string& device, @@ -666,7 +667,7 @@ SMStaticLLMPipeline::SMStaticLLMPipeline( m_request = compiled.create_infer_request(); } -DecodedResults SMStaticLLMPipeline::generate( +DecodedResults StatefulLLMPipeline::generate( StringInputs inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -722,7 +723,7 @@ DecodedResults SMStaticLLMPipeline::generate( return decoded_results; } -EncodedResults SMStaticLLMPipeline::generate( +EncodedResults StatefulLLMPipeline::generate( const EncodedInputs& inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -822,39 +823,41 @@ EncodedResults SMStaticLLMPipeline::generate( } } -void SMStaticLLMPipeline::start_chat(const std::string& system_message) { +void StatefulLLMPipeline::start_chat(const std::string& system_message) { if (!system_message.empty()) { m_history.push_back({{"role", "system"}, {"content", system_message}}); } m_is_chat_conversation = true; }; -void SMStaticLLMPipeline::finish_chat() { +void StatefulLLMPipeline::finish_chat() { m_is_chat_conversation = false; m_history.clear(); }; std::unique_ptr -StaticLLMPipelineFactory::create(const std::filesystem::path& models_path, +LLMPipelineFactory::create(const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& config) { auto properties = config; - const auto use_sm_pipeline = pop_or_default(properties, "USE_SM_PIPELINE", false); - if (use_sm_pipeline) { - return std::make_unique(models_path, tokenizer, device, properties); + const auto pipeline_mode = pop_or_default(properties, "NPU_PIPELINE", std::string("STATELESS")); + OPENVINO_ASSERT(pipeline_mode == "STATELESS" || pipeline_mode == "STATEFUL", + "Only STATELESS and STATEFULL NPU_PIPELINE modes are supported!"); + if (pipeline_mode == "STATEFUL") { + return std::make_unique(models_path, tokenizer, device, properties); } - return std::make_unique(models_path, tokenizer, device, properties); + return std::make_unique(models_path, tokenizer, device, properties); } std::unique_ptr -StaticLLMPipelineFactory::create(const std::filesystem::path& models_path, +LLMPipelineFactory::create(const std::filesystem::path& models_path, const std::string& device, const ov::AnyMap& config) { return create(models_path, Tokenizer(models_path), device, config); } -StaticLLMPipeline::StaticLLMPipeline( +StatelessLLMPipeline::StatelessLLMPipeline( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, const std::string& device, @@ -889,14 +892,14 @@ StaticLLMPipeline::StaticLLMPipeline( } }; -StaticLLMPipeline::StaticLLMPipeline( +StatelessLLMPipeline::StatelessLLMPipeline( const std::filesystem::path& models_path, const std::string& device, const ov::AnyMap& properties -) : StaticLLMPipeline(models_path, Tokenizer(models_path), device, properties) { +) : StatelessLLMPipeline(models_path, Tokenizer(models_path), device, properties) { } -void StaticLLMPipeline::setupAndCompileModels( +void StatelessLLMPipeline::setupAndCompileModels( const std::filesystem::path& models_path, const std::string& device, ov::AnyMap& properties) { @@ -970,7 +973,7 @@ void StaticLLMPipeline::setupAndCompileModels( ).create_infer_request(); } -void StaticLLMPipeline::setupAndImportModels( +void StatelessLLMPipeline::setupAndImportModels( const std::filesystem::path& models_path, const std::string& device, ov::AnyMap& properties) { @@ -1044,19 +1047,19 @@ void StaticLLMPipeline::setupAndImportModels( m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, 2u }; } -void StaticLLMPipeline::start_chat(const std::string& system_message) { +void StatelessLLMPipeline::start_chat(const std::string& system_message) { if (!system_message.empty()) { m_history.push_back({{"role", "system"}, {"content", system_message}}); } m_is_chat_conversation = true; }; -void StaticLLMPipeline::finish_chat() { +void StatelessLLMPipeline::finish_chat() { m_is_chat_conversation = false; m_history.clear(); }; -void StaticLLMPipeline::prepare_for_new_conversation() { +void StatelessLLMPipeline::prepare_for_new_conversation() { fill_tensor(m_prefill_request.get_tensor("input_ids"), m_tokenizer.get_pad_token_id()); fill_tensor(m_prefill_request.get_tensor("position_ids"), 0u); fill_tensor(m_prefill_request.get_tensor("attention_mask"), 0u); @@ -1064,7 +1067,7 @@ void StaticLLMPipeline::prepare_for_new_conversation() { m_kvcache_desc.num_stored_tokens = 0u; } -DecodedResults StaticLLMPipeline::generate( +DecodedResults StatelessLLMPipeline::generate( StringInputs inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -1119,7 +1122,7 @@ DecodedResults StaticLLMPipeline::generate( return decoded_results; } -EncodedResults StaticLLMPipeline::generate( +EncodedResults StatelessLLMPipeline::generate( const EncodedInputs& inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -1294,5 +1297,6 @@ EncodedResults StaticLLMPipeline::generate( return results; } +} // namespace static_llm } // namespace genai } // namespace ov diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index d1867e4b33..b119398b4f 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -9,8 +9,9 @@ namespace ov { namespace genai { +namespace static_llm { -struct StaticLLMPipelineFactory { +struct LLMPipelineFactory { static std::unique_ptr create(const std::filesystem::path& path, const ov::genai::Tokenizer& tokenizer, const std::string& device, @@ -21,9 +22,9 @@ struct StaticLLMPipelineFactory { const ov::AnyMap& config); }; -class SMStaticLLMPipeline : public LLMPipelineImplBase { +class StatefulLLMPipeline : public LLMPipelineImplBase { public: - SMStaticLLMPipeline( + StatefulLLMPipeline( const std::filesystem::path& path, const ov::genai::Tokenizer& tokenizer, const std::string& device, @@ -51,16 +52,16 @@ class SMStaticLLMPipeline : public LLMPipelineImplBase { ChatHistory m_history; }; -class StaticLLMPipeline final : public LLMPipelineImplBase { +class StatelessLLMPipeline final : public LLMPipelineImplBase { public: - StaticLLMPipeline( + StatelessLLMPipeline( const std::filesystem::path& path, const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& config ); - StaticLLMPipeline( + StatelessLLMPipeline( const std::filesystem::path& path, const std::string& device, const ov::AnyMap& config @@ -114,5 +115,6 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { ChatHistory m_history; }; +} // namespace static_llm } // namespace genai } // namespace ov