diff --git a/src/BUILD b/src/BUILD index c68a3f4942..9242757822 100644 --- a/src/BUILD +++ b/src/BUILD @@ -1906,6 +1906,7 @@ cc_test( "//conditions:default" : [ # LLM logic uses Python for processing Jinja templates "test/llmnode_test.cpp", + "test/max_model_length_test.cpp", "test/llmtemplate_test.cpp", "test/text_streamer_test.cpp",], }) + select({ diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 5b92d5ed87..6612e2bd2e 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -388,8 +388,13 @@ void OpenAIChatCompletionsHandler::incrementProcessedTokens(int numTokens) { usage.completionTokens += numTokens; } -ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig() const { - return request.createGenerationConfig(); +ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig(std::optional maxModelLength) const { + ov::genai::GenerationConfig config = request.createGenerationConfig(); + if (maxModelLength.has_value()) { + config.max_length = maxModelLength.value(); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed max model length {}", maxModelLength.value()); + } + return config; } absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) { diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 217b48eee2..3e03431f30 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -156,6 +156,7 @@ class OpenAIChatCompletionsHandler { std::chrono::time_point created; ov::genai::Tokenizer tokenizer; size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline + std::optional maxModelLength; absl::Status parseCompletionsPart(); absl::Status parseChatCompletionsPart(); @@ -180,7 +181,7 @@ class OpenAIChatCompletionsHandler { void incrementProcessedTokens(int numTokens = 1); - ov::genai::GenerationConfig createGenerationConfig() const; + ov::genai::GenerationConfig createGenerationConfig(std::optional maxModelLength) const; absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit); diff --git a/src/llm/http_llm_calculator.cc b/src/llm/http_llm_calculator.cc index d9221641e8..ebbc2b4e48 100644 --- a/src/llm/http_llm_calculator.cc +++ b/src/llm/http_llm_calculator.cc @@ -160,11 +160,10 @@ class HttpLLMCalculator : public CalculatorBase { ov::Tensor finalPromptIds = nodeResources->cbPipe->get_tokenizer().encode(finalPrompt, ov::genai::add_special_tokens(encodeAddSpecialTokens)).input_ids; this->apiHandler->setPromptTokensUsage(finalPromptIds.get_size()); SPDLOG_LOGGER_TRACE(llm_calculator_logger, "{}", getPromptTokensString(finalPromptIds)); - this->generationHandle = nodeResources->cbPipe->add_request( currentRequestId++, /*to be removed from API?*/ finalPromptIds, - this->apiHandler->createGenerationConfig()); + this->apiHandler->createGenerationConfig(this->nodeResources->maxModelLength)); // TODO: Revert when drogon adds disconnection callbacks: https://github.com/drogonframework/drogon/pull/2204 // this->client->registerDisconnectionCallback([genHandle = this->generationHandle]() { diff --git a/src/llm/llmnoderesources.cpp b/src/llm/llmnoderesources.cpp index dbfab92b28..498ce248fb 100644 --- a/src/llm/llmnoderesources.cpp +++ b/src/llm/llmnoderesources.cpp @@ -34,6 +34,11 @@ #include "mediapipe/framework/calculator_graph.h" #pragma GCC diagnostic pop +#include + +#include +#include + #include "../mediapipe_internal/mediapipe_utils.hpp" #include "src/llm/llm_calculator.pb.h" #include "src/llm/llm_executor.hpp" @@ -119,6 +124,30 @@ void LLMNodeResources::loadTextProcessor(LLMNodeResources& nodeResources, const } } +std::optional LLMNodeResources::parseMaxModelLength(std::string& modelsPath) { + std::string configPath = modelsPath + "/config.json"; + std::optional maxModelLength; + if (std::filesystem::exists(configPath.c_str())) { + std::ifstream ifs(configPath); + if (!ifs.is_open()) { + return maxModelLength; + } + rapidjson::Document modelConfig; + rapidjson::IStreamWrapper isw(ifs); + rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw); + if (parseResult.Code()) { + return maxModelLength; + } + std::vector maxLengthFields = {"max_position_embeddings", "n_positions", "seq_len", "seq_length", "n_ctx", "sliding_window"}; + for (auto field : maxLengthFields) { + if (modelConfig.HasMember(field.c_str()) && modelConfig[field.c_str()].IsUint()) { + maxModelLength = modelConfig[field.c_str()].GetUint(); + } + } + } + return maxModelLength; +} + Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string graphPath) { mediapipe::LLMCalculatorOptions nodeOptions; graphNodeConfig.node_options(0).UnpackTo(&nodeOptions); @@ -144,6 +173,7 @@ Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResour SPDLOG_LOGGER_ERROR(modelmanager_logger, "LLM node models_path: {} is not a directory. ", basePath); return StatusCode::LLM_NODE_DIRECTORY_DOES_NOT_EXIST; } + nodeResources.maxModelLength = parseMaxModelLength(basePath); nodeResources.schedulerConfig = { .max_num_batched_tokens = nodeOptions.max_num_batched_tokens(), diff --git a/src/llm/llmnoderesources.hpp b/src/llm/llmnoderesources.hpp index debe5942e1..134a1356b7 100644 --- a/src/llm/llmnoderesources.hpp +++ b/src/llm/llmnoderesources.hpp @@ -112,9 +112,11 @@ struct LLMNodeResources { TextProcessor textProcessor; int maxTokensLimit; int bestOfLimit; + std::optional maxModelLength; static Status initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNode, std::string graphPath); static void loadTextProcessor(LLMNodeResources& nodeResources, const std::string& chatTemplateDirectory); + static std::optional parseMaxModelLength(std::string& modelsPath); LLMNodeResources(const LLMNodeResources&) = delete; LLMNodeResources& operator=(LLMNodeResources&) = delete; diff --git a/src/test/max_model_length_test.cpp b/src/test/max_model_length_test.cpp new file mode 100644 index 0000000000..86f61f5a47 --- /dev/null +++ b/src/test/max_model_length_test.cpp @@ -0,0 +1,142 @@ +//***************************************************************************** +// Copyright 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include + +#include "../llm/llmnoderesources.hpp" +#include "test_utils.hpp" + +using namespace ovms; + +class MaxModelLengthTest : public TestWithTempDir { +protected: + std::string configFilePath; + rapidjson::Document doc; + ov::genai::Tokenizer dummyTokenizer; + + void SetUp() { + TestWithTempDir::SetUp(); + configFilePath = directoryPath + "/config.json"; + } +}; + +TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_VALID) { + std::string modelConfigContent = R"({"max_position_embeddings" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_INVALID) { + std::string modelConfigContent = R"({"max_position_embeddings" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_nPositions_VALID) { + std::string modelConfigContent = R"({"n_positions" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_nPositions_INVALID) { + std::string modelConfigContent = R"({"n_positions" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_seqLen_VALID) { + std::string modelConfigContent = R"({"seq_len" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_seqLen_INVALID) { + std::string modelConfigContent = R"({"seq_len" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_seqLength_VALID) { + std::string modelConfigContent = R"({"seq_length" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_seqLength_INVALID) { + std::string modelConfigContent = R"({"seq_length" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_nCtx_VALID) { + std::string modelConfigContent = R"({"n_ctx" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_nCtx_INVALID) { + std::string modelConfigContent = R"({"n_ctx" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_VALID) { + std::string modelConfigContent = R"({"sliding_window" : 5})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 5); +} + +TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_INVALID) { + std::string modelConfigContent = R"({"sliding_window" : "INVALID"})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_emptyConfig) { + std::string modelConfigContent = R"({})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + EXPECT_FALSE(maxModelLength.has_value()); +} + +TEST_F(MaxModelLengthTest, maxModelLength_parsingOrder) { + std::string modelConfigContent = R"({"max_position_embeddings" : 5, "seq_length" : 6, "n_positions" : 7, "sliding_window" : 8, "seq_len" : 9, "n_ctx" : 10})"; + createConfigFileWithContent(modelConfigContent, configFilePath); + auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath); + ASSERT_TRUE(maxModelLength.has_value()); + EXPECT_EQ(maxModelLength.value(), 8); +} + +// TODO: Add e2e test