Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mkulakow/llm model context length #2870

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,7 @@ cc_test(
"//conditions:default" : [
# LLM logic uses Python for processing Jinja templates
"test/llmnode_test.cpp",
"test/max_model_length_test.cpp",
"test/llmtemplate_test.cpp",
"test/text_streamer_test.cpp",],
}) + select({
Expand Down
9 changes: 7 additions & 2 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,13 @@ void OpenAIChatCompletionsHandler::incrementProcessedTokens(int numTokens) {
usage.completionTokens += numTokens;
}

ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig() const {
return request.createGenerationConfig();
ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig(std::optional<uint> maxModelLength) const {
ov::genai::GenerationConfig config = request.createGenerationConfig();
if (maxModelLength.has_value()) {
config.max_length = maxModelLength.value();
SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed max model length {}", maxModelLength.value());
}
return config;
}

absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
Expand Down
3 changes: 2 additions & 1 deletion src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class OpenAIChatCompletionsHandler {
std::chrono::time_point<std::chrono::system_clock> created;
ov::genai::Tokenizer tokenizer;
size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline
std::optional<uint> maxModelLength;

absl::Status parseCompletionsPart();
absl::Status parseChatCompletionsPart();
Expand All @@ -180,7 +181,7 @@ class OpenAIChatCompletionsHandler {

void incrementProcessedTokens(int numTokens = 1);

ov::genai::GenerationConfig createGenerationConfig() const;
ov::genai::GenerationConfig createGenerationConfig(std::optional<uint> maxModelLength) const;

absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit);

Expand Down
3 changes: 1 addition & 2 deletions src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,10 @@ class HttpLLMCalculator : public CalculatorBase {
ov::Tensor finalPromptIds = nodeResources->cbPipe->get_tokenizer().encode(finalPrompt, ov::genai::add_special_tokens(encodeAddSpecialTokens)).input_ids;
this->apiHandler->setPromptTokensUsage(finalPromptIds.get_size());
SPDLOG_LOGGER_TRACE(llm_calculator_logger, "{}", getPromptTokensString(finalPromptIds));

this->generationHandle = nodeResources->cbPipe->add_request(
currentRequestId++, /*to be removed from API?*/
finalPromptIds,
this->apiHandler->createGenerationConfig());
this->apiHandler->createGenerationConfig(this->nodeResources->maxModelLength));

// TODO: Revert when drogon adds disconnection callbacks: https://github.com/drogonframework/drogon/pull/2204
// this->client->registerDisconnectionCallback([genHandle = this->generationHandle]() {
Expand Down
30 changes: 30 additions & 0 deletions src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
#include "mediapipe/framework/calculator_graph.h"
#pragma GCC diagnostic pop

#include <fstream>

#include <rapidjson/error/en.h>
#include <rapidjson/istreamwrapper.h>

#include "../mediapipe_internal/mediapipe_utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "src/llm/llm_executor.hpp"
Expand Down Expand Up @@ -119,6 +124,30 @@ void LLMNodeResources::loadTextProcessor(LLMNodeResources& nodeResources, const
}
}

std::optional<uint> LLMNodeResources::parseMaxModelLength(std::string& modelsPath) {
std::string configPath = modelsPath + "/config.json";
std::optional<uint> maxModelLength;
if (std::filesystem::exists(configPath.c_str())) {
std::ifstream ifs(configPath);
if (!ifs.is_open()) {
return maxModelLength;
}
rapidjson::Document modelConfig;
rapidjson::IStreamWrapper isw(ifs);
rapidjson::ParseResult parseResult = modelConfig.ParseStream(isw);
if (parseResult.Code()) {
return maxModelLength;
}
std::vector<std::string> maxLengthFields = {"max_position_embeddings", "n_positions", "seq_len", "seq_length", "n_ctx", "sliding_window"};
for (auto field : maxLengthFields) {
if (modelConfig.HasMember(field.c_str()) && modelConfig[field.c_str()].IsUint()) {
maxModelLength = modelConfig[field.c_str()].GetUint();
}
}
}
return maxModelLength;
}

Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string graphPath) {
mediapipe::LLMCalculatorOptions nodeOptions;
graphNodeConfig.node_options(0).UnpackTo(&nodeOptions);
Expand All @@ -144,6 +173,7 @@ Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResour
SPDLOG_LOGGER_ERROR(modelmanager_logger, "LLM node models_path: {} is not a directory. ", basePath);
return StatusCode::LLM_NODE_DIRECTORY_DOES_NOT_EXIST;
}
nodeResources.maxModelLength = parseMaxModelLength(basePath);

nodeResources.schedulerConfig = {
.max_num_batched_tokens = nodeOptions.max_num_batched_tokens(),
Expand Down
2 changes: 2 additions & 0 deletions src/llm/llmnoderesources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ struct LLMNodeResources {
TextProcessor textProcessor;
int maxTokensLimit;
int bestOfLimit;
std::optional<uint> maxModelLength;

static Status initializeLLMNodeResources(LLMNodeResources& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNode, std::string graphPath);
static void loadTextProcessor(LLMNodeResources& nodeResources, const std::string& chatTemplateDirectory);
static std::optional<uint> parseMaxModelLength(std::string& modelsPath);

LLMNodeResources(const LLMNodeResources&) = delete;
LLMNodeResources& operator=(LLMNodeResources&) = delete;
Expand Down
142 changes: 142 additions & 0 deletions src/test/max_model_length_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//*****************************************************************************
// Copyright 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <rapidjson/document.h>

#include "../llm/llmnoderesources.hpp"
#include "test_utils.hpp"

using namespace ovms;

class MaxModelLengthTest : public TestWithTempDir {
protected:
std::string configFilePath;
rapidjson::Document doc;
ov::genai::Tokenizer dummyTokenizer;

void SetUp() {
TestWithTempDir::SetUp();
configFilePath = directoryPath + "/config.json";
}
};

TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_VALID) {
std::string modelConfigContent = R"({"max_position_embeddings" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_MaxPositionEmbeddings_INVALID) {
std::string modelConfigContent = R"({"max_position_embeddings" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_nPositions_VALID) {
std::string modelConfigContent = R"({"n_positions" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_nPositions_INVALID) {
std::string modelConfigContent = R"({"n_positions" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLen_VALID) {
std::string modelConfigContent = R"({"seq_len" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLen_INVALID) {
std::string modelConfigContent = R"({"seq_len" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLength_VALID) {
std::string modelConfigContent = R"({"seq_length" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_seqLength_INVALID) {
std::string modelConfigContent = R"({"seq_length" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_nCtx_VALID) {
std::string modelConfigContent = R"({"n_ctx" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_nCtx_INVALID) {
std::string modelConfigContent = R"({"n_ctx" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_VALID) {
std::string modelConfigContent = R"({"sliding_window" : 5})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 5);
}

TEST_F(MaxModelLengthTest, maxModelLength_slidingWindow_INVALID) {
std::string modelConfigContent = R"({"sliding_window" : "INVALID"})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_emptyConfig) {
std::string modelConfigContent = R"({})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
EXPECT_FALSE(maxModelLength.has_value());
}

TEST_F(MaxModelLengthTest, maxModelLength_parsingOrder) {
std::string modelConfigContent = R"({"max_position_embeddings" : 5, "seq_length" : 6, "n_positions" : 7, "sliding_window" : 8, "seq_len" : 9, "n_ctx" : 10})";
createConfigFileWithContent(modelConfigContent, configFilePath);
auto maxModelLength = LLMNodeResources::parseMaxModelLength(directoryPath);
ASSERT_TRUE(maxModelLength.has_value());
EXPECT_EQ(maxModelLength.value(), 8);
}

// TODO: Add e2e test