Skip to content

Commit

Permalink
Fixed according to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AsyaPronina committed Dec 24, 2024
1 parent 07f2b43 commit b00d987
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ ov::genai::LLMPipeline::LLMPipeline(
auto& scheduler_config = properties.at(ov::genai::scheduler_config.name()).as<SchedulerConfig>();
m_pimpl = std::make_unique<ContinuousBatchingAdapter>(models_path, tokenizer, scheduler_config, device, config_without_scheduler_config);
} else if ("NPU" == device) {
m_pimpl = StaticLLMPipelineFactory::create(models_path, tokenizer, device, properties);
m_pimpl = static_llm::LLMPipelineFactory::create(models_path, tokenizer, device, properties);
} else {
m_pimpl = std::make_unique<StatefulLLMPipeline>(models_path, tokenizer, device, properties);
}
Expand All @@ -590,7 +590,7 @@ ov::genai::LLMPipeline::LLMPipeline(
auto& scheduler_config = config.at(ov::genai::scheduler_config.name()).as<SchedulerConfig>();
m_pimpl = std::make_unique<ContinuousBatchingAdapter>(models_path, scheduler_config, device, config_without_scheduler_config);
} else if ("NPU" == device) {
m_pimpl = StaticLLMPipelineFactory::create(models_path, device, config);
m_pimpl = static_llm::LLMPipelineFactory::create(models_path, device, config);
} else {
m_pimpl = std::make_unique<StatefulLLMPipeline>(models_path, device, config);
}
Expand Down
46 changes: 25 additions & 21 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,9 @@ void copy_columns_by_row_chunks(const ov::Tensor& src, ov::Tensor& dst) {

namespace ov {
namespace genai {
namespace static_llm {

SMStaticLLMPipeline::SMStaticLLMPipeline(
StatefulLLMPipeline::StatefulLLMPipeline(
const std::filesystem::path& models_path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
Expand Down Expand Up @@ -666,7 +667,7 @@ SMStaticLLMPipeline::SMStaticLLMPipeline(
m_request = compiled.create_infer_request();
}

DecodedResults SMStaticLLMPipeline::generate(
DecodedResults StatefulLLMPipeline::generate(
StringInputs inputs,
OptionalGenerationConfig generation_config,
StreamerVariant streamer
Expand Down Expand Up @@ -722,7 +723,7 @@ DecodedResults SMStaticLLMPipeline::generate(
return decoded_results;
}

EncodedResults SMStaticLLMPipeline::generate(
EncodedResults StatefulLLMPipeline::generate(
const EncodedInputs& inputs,
OptionalGenerationConfig generation_config,
StreamerVariant streamer
Expand Down Expand Up @@ -822,39 +823,41 @@ EncodedResults SMStaticLLMPipeline::generate(
}
}

void SMStaticLLMPipeline::start_chat(const std::string& system_message) {
void StatefulLLMPipeline::start_chat(const std::string& system_message) {
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void SMStaticLLMPipeline::finish_chat() {
void StatefulLLMPipeline::finish_chat() {
m_is_chat_conversation = false;
m_history.clear();
};

std::unique_ptr<LLMPipelineImplBase>
StaticLLMPipelineFactory::create(const std::filesystem::path& models_path,
LLMPipelineFactory::create(const std::filesystem::path& models_path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
const ov::AnyMap& config) {
auto properties = config;
const auto use_sm_pipeline = pop_or_default(properties, "USE_SM_PIPELINE", false);
if (use_sm_pipeline) {
return std::make_unique<SMStaticLLMPipeline>(models_path, tokenizer, device, properties);
const auto pipeline_mode = pop_or_default(properties, "NPU_PIPELINE", std::string("STATELESS"));
OPENVINO_ASSERT(pipeline_mode == "STATELESS" || pipeline_mode == "STATEFUL",
"Only STATELESS and STATEFULL NPU_PIPELINE modes are supported!");
if (pipeline_mode == "STATEFUL") {
return std::make_unique<ov::genai::static_llm::StatefulLLMPipeline>(models_path, tokenizer, device, properties);
}
return std::make_unique<StaticLLMPipeline>(models_path, tokenizer, device, properties);
return std::make_unique<ov::genai::static_llm::StatelessLLMPipeline>(models_path, tokenizer, device, properties);
}

std::unique_ptr<LLMPipelineImplBase>
StaticLLMPipelineFactory::create(const std::filesystem::path& models_path,
LLMPipelineFactory::create(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& config) {
return create(models_path, Tokenizer(models_path), device, config);
}

StaticLLMPipeline::StaticLLMPipeline(
StatelessLLMPipeline::StatelessLLMPipeline(
const std::filesystem::path& models_path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
Expand Down Expand Up @@ -889,14 +892,14 @@ StaticLLMPipeline::StaticLLMPipeline(
}
};

StaticLLMPipeline::StaticLLMPipeline(
StatelessLLMPipeline::StatelessLLMPipeline(
const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties
) : StaticLLMPipeline(models_path, Tokenizer(models_path), device, properties) {
) : StatelessLLMPipeline(models_path, Tokenizer(models_path), device, properties) {
}

void StaticLLMPipeline::setupAndCompileModels(
void StatelessLLMPipeline::setupAndCompileModels(
const std::filesystem::path& models_path,
const std::string& device,
ov::AnyMap& properties) {
Expand Down Expand Up @@ -970,7 +973,7 @@ void StaticLLMPipeline::setupAndCompileModels(
).create_infer_request();
}

void StaticLLMPipeline::setupAndImportModels(
void StatelessLLMPipeline::setupAndImportModels(
const std::filesystem::path& models_path,
const std::string& device,
ov::AnyMap& properties) {
Expand Down Expand Up @@ -1044,27 +1047,27 @@ void StaticLLMPipeline::setupAndImportModels(
m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, 2u };
}

void StaticLLMPipeline::start_chat(const std::string& system_message) {
void StatelessLLMPipeline::start_chat(const std::string& system_message) {
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void StaticLLMPipeline::finish_chat() {
void StatelessLLMPipeline::finish_chat() {
m_is_chat_conversation = false;
m_history.clear();
};

void StaticLLMPipeline::prepare_for_new_conversation() {
void StatelessLLMPipeline::prepare_for_new_conversation() {
fill_tensor<int64_t>(m_prefill_request.get_tensor("input_ids"), m_tokenizer.get_pad_token_id());
fill_tensor<int64_t>(m_prefill_request.get_tensor("position_ids"), 0u);
fill_tensor<int64_t>(m_prefill_request.get_tensor("attention_mask"), 0u);
fill_tensor<int64_t>(m_kvcache_request.get_tensor("attention_mask"), 0u);
m_kvcache_desc.num_stored_tokens = 0u;
}

DecodedResults StaticLLMPipeline::generate(
DecodedResults StatelessLLMPipeline::generate(
StringInputs inputs,
OptionalGenerationConfig generation_config,
StreamerVariant streamer
Expand Down Expand Up @@ -1119,7 +1122,7 @@ DecodedResults StaticLLMPipeline::generate(
return decoded_results;
}

EncodedResults StaticLLMPipeline::generate(
EncodedResults StatelessLLMPipeline::generate(
const EncodedInputs& inputs,
OptionalGenerationConfig generation_config,
StreamerVariant streamer
Expand Down Expand Up @@ -1294,5 +1297,6 @@ EncodedResults StaticLLMPipeline::generate(
return results;
}

} // namespace static_llm
} // namespace genai
} // namespace ov
14 changes: 8 additions & 6 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

namespace ov {
namespace genai {
namespace static_llm {

struct StaticLLMPipelineFactory {
struct LLMPipelineFactory {
static std::unique_ptr<LLMPipelineImplBase> create(const std::filesystem::path& path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
Expand All @@ -21,9 +22,9 @@ struct StaticLLMPipelineFactory {
const ov::AnyMap& config);
};

class SMStaticLLMPipeline : public LLMPipelineImplBase {
class StatefulLLMPipeline : public LLMPipelineImplBase {
public:
SMStaticLLMPipeline(
StatefulLLMPipeline(
const std::filesystem::path& path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
Expand Down Expand Up @@ -51,16 +52,16 @@ class SMStaticLLMPipeline : public LLMPipelineImplBase {
ChatHistory m_history;
};

class StaticLLMPipeline final : public LLMPipelineImplBase {
class StatelessLLMPipeline final : public LLMPipelineImplBase {
public:
StaticLLMPipeline(
StatelessLLMPipeline(
const std::filesystem::path& path,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
const ov::AnyMap& config
);

StaticLLMPipeline(
StatelessLLMPipeline(
const std::filesystem::path& path,
const std::string& device,
const ov::AnyMap& config
Expand Down Expand Up @@ -114,5 +115,6 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
ChatHistory m_history;
};

} // namespace static_llm
} // namespace genai
} // namespace ov

0 comments on commit b00d987

Please sign in to comment.