From 08b8841eebc1937a2b7dc422168f449489bbba41 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 27 Sep 2024 08:09:08 +0700 Subject: [PATCH] feat: sqlite database implementation (#1336) * feat: sqlite * chore: unit tests * fix: unit tests * refactor: db * fix: remove mutex * fix: rm file * fix: test * fix: test * refactor: LoadModelList * refactor: more * fix: transaction * fix: format * fix: make alias unique * fix: remove ModelStatus * fix: models --- engine/CMakeLists.txt | 8 +- engine/commands/chat_cmd.cc | 10 +- engine/commands/model_alias_cmd.cc | 4 +- engine/commands/model_get_cmd.cc | 10 +- engine/commands/model_import_cmd.cc | 10 +- engine/commands/model_list_cmd.cc | 38 +-- engine/commands/model_start_cmd.cc | 10 +- engine/commands/model_status_cmd.cc | 10 +- engine/commands/model_upd_cmd.cc | 8 +- engine/commands/model_upd_cmd.h | 4 +- engine/commands/run_cmd.cc | 10 +- engine/controllers/models.cc | 43 +-- engine/database/database.h | 28 ++ engine/database/models.cc | 292 ++++++++++++++++++ engine/database/models.h | 49 +++ engine/services/model_service.cc | 36 ++- engine/test/components/CMakeLists.txt | 4 +- .../test/components/test_modellist_utils.cc | 134 -------- engine/test/components/test_models_db.cc | 167 ++++++++++ engine/utils/modellist_utils.cc | 256 --------------- engine/utils/modellist_utils.h | 48 --- engine/vcpkg.json | 3 +- 22 files changed, 657 insertions(+), 525 deletions(-) create mode 100644 engine/database/database.h create mode 100644 engine/database/models.cc create mode 100644 engine/database/models.h delete mode 100644 engine/test/components/test_modellist_utils.cc create mode 100644 engine/test/components/test_models_db.cc delete mode 100644 engine/utils/modellist_utils.cc delete mode 100644 engine/utils/modellist_utils.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 76cdcf303..6279813a6 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -77,11 +77,11 @@ find_package(unofficial-minizip CONFIG REQUIRED) find_package(LibArchive REQUIRED) find_package(tabulate CONFIG REQUIRED) find_package(CURL REQUIRED) +find_package(SQLiteCpp REQUIRED) add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils/modellist_utils.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) @@ -93,6 +93,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE tabulate::tabulate) target_link_libraries(${TARGET_NAME} PRIVATE CURL::libcurl) target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp Drogon::Drogon OpenSSL::SSL OpenSSL::Crypto yaml-cpp::yaml-cpp ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp) # ############################################################################## @@ -114,7 +115,8 @@ aux_source_directory(models MODEL_SRC) aux_source_directory(cortex-common CORTEX_COMMON) aux_source_directory(config CONFIG_SRC) aux_source_directory(commands COMMANDS_SRC) - +aux_source_directory(database DB_SRC) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC}) diff --git a/engine/commands/chat_cmd.cc b/engine/commands/chat_cmd.cc index 922dc32ed..bb44b476b 100644 --- a/engine/commands/chat_cmd.cc +++ b/engine/commands/chat_cmd.cc @@ -2,11 +2,11 @@ #include "httplib.h" #include "cortex_upd_cmd.h" +#include "database/models.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "trantor/utils/Logger.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { namespace { @@ -39,11 +39,15 @@ struct ChunkParser { void ChatCmd::Exec(const std::string& host, int port, const std::string& model_handle, std::string msg) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto mc = yaml_handler.GetModelConfig(); Exec(host, port, mc, std::move(msg)); } catch (const std::exception& e) { diff --git a/engine/commands/model_alias_cmd.cc b/engine/commands/model_alias_cmd.cc index 2123d06cf..4a4ef98af 100644 --- a/engine/commands/model_alias_cmd.cc +++ b/engine/commands/model_alias_cmd.cc @@ -1,11 +1,11 @@ #include "model_alias_cmd.h" -#include "utils/modellist_utils.h" +#include "database/models.h" namespace commands { void ModelAliasCmd::Exec(const std::string& model_handle, const std::string& model_alias) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; try { if (modellist_handler.UpdateModelAlias(model_handle, model_alias)) { CLI_LOG("Successfully set model alias '" + model_alias + diff --git a/engine/commands/model_get_cmd.cc b/engine/commands/model_get_cmd.cc index 715728c1f..5f6658cba 100644 --- a/engine/commands/model_get_cmd.cc +++ b/engine/commands/model_get_cmd.cc @@ -5,18 +5,22 @@ #include #include "cmd_info.h" #include "config/yaml_config.h" +#include "database/models.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { void ModelGetCmd::Exec(const std::string& model_handle) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto model_config = yaml_handler.GetModelConfig(); std::cout << model_config.ToString() << std::endl; diff --git a/engine/commands/model_import_cmd.cc b/engine/commands/model_import_cmd.cc index 3fb047a9d..c12af5054 100644 --- a/engine/commands/model_import_cmd.cc +++ b/engine/commands/model_import_cmd.cc @@ -3,9 +3,9 @@ #include #include "config/gguf_parser.h" #include "config/yaml_config.h" +#include "database/models.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { @@ -16,15 +16,15 @@ ModelImportCmd::ModelImportCmd(std::string model_handle, std::string model_path) void ModelImportCmd::Exec() { config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; - modellist_utils::ModelListUtils modellist_utils_obj; + cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("imported") / std::filesystem::path(model_handle_ + ".yml")) .string(); - modellist_utils::ModelEntry model_entry{ + cortex::db::ModelEntry model_entry{ model_handle_, "local", "imported", - model_yaml_path, model_handle_, modellist_utils::ModelStatus::READY}; + model_yaml_path, model_handle_}; try { std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -34,7 +34,7 @@ void ModelImportCmd::Exec() { model_config.model = model_handle_; yaml_handler.UpdateModelConfig(model_config); - if (modellist_utils_obj.AddModelEntry(model_entry)) { + if (modellist_utils_obj.AddModelEntry(model_entry).value()) { yaml_handler.WriteYamlFile(model_yaml_path); CLI_LOG("Model is imported successfully!"); } else { diff --git a/engine/commands/model_list_cmd.cc b/engine/commands/model_list_cmd.cc index d9a84ec3b..3fe0b4700 100644 --- a/engine/commands/model_list_cmd.cc +++ b/engine/commands/model_list_cmd.cc @@ -3,15 +3,15 @@ #include #include #include "config/yaml_config.h" +#include "database/models.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { void ModelListCmd::Exec() { auto models_path = file_manager_utils::GetModelsContainerPath(); - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; tabulate::Table table; @@ -20,24 +20,24 @@ void ModelListCmd::Exec() { int count = 0; // Iterate through directory - try { - auto list_entry = modellist_handler.LoadModelList(); - for (const auto& model_entry : list_entry) { - // auto model_entry = modellist_handler.GetModelInfo(model_handle); - try { - count += 1; - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); - auto model_config = yaml_handler.GetModelConfig(); - table.add_row({std::to_string(count), model_entry.model_id, - model_entry.model_alias, model_config.engine, - model_config.version}); - yaml_handler.Reset(); - } catch (const std::exception& e) { - CTL_ERR("Fail to get list model information: " + std::string(e.what())); - } + auto list_entry = modellist_handler.LoadModelList(); + if (list_entry.has_error()) { + CTL_ERR("Fail to get list model information: " << list_entry.error()); + return; + } + for (const auto& model_entry : list_entry.value()) { + // auto model_entry = modellist_handler.GetModelInfo(model_handle); + try { + count += 1; + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto model_config = yaml_handler.GetModelConfig(); + table.add_row({std::to_string(count), model_entry.model_id, + model_entry.model_alias, model_config.engine, + model_config.version}); + yaml_handler.Reset(); + } catch (const std::exception& e) { + CTL_ERR("Fail to get list model information: " + std::string(e.what())); } - } catch (const std::exception& e) { - CTL_ERR("Fail to get list model information: " + std::string(e.what())); } for (int i = 0; i < 5; i++) { diff --git a/engine/commands/model_start_cmd.cc b/engine/commands/model_start_cmd.cc index 1340614d9..2b0c8f2b9 100644 --- a/engine/commands/model_start_cmd.cc +++ b/engine/commands/model_start_cmd.cc @@ -1,5 +1,6 @@ #include "model_start_cmd.h" #include "cortex_upd_cmd.h" +#include "database/models.h" #include "httplib.h" #include "model_status_cmd.h" #include "nlohmann/json.hpp" @@ -7,17 +8,20 @@ #include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { bool ModelStartCmd::Exec(const std::string& host, int port, const std::string& model_handle) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return false; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto mc = yaml_handler.GetModelConfig(); return Exec(host, port, mc); } catch (const std::exception& e) { diff --git a/engine/commands/model_status_cmd.cc b/engine/commands/model_status_cmd.cc index e6ba9bbe0..38ff17bdc 100644 --- a/engine/commands/model_status_cmd.cc +++ b/engine/commands/model_status_cmd.cc @@ -1,18 +1,22 @@ #include "model_status_cmd.h" #include "config/yaml_config.h" +#include "database/models.h" #include "httplib.h" #include "nlohmann/json.hpp" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { bool ModelStatusCmd::IsLoaded(const std::string& host, int port, const std::string& model_handle) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return false; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto mc = yaml_handler.GetModelConfig(); return IsLoaded(host, port, mc); } catch (const std::exception& e) { diff --git a/engine/commands/model_upd_cmd.cc b/engine/commands/model_upd_cmd.cc index 65883def3..ea10d2d95 100644 --- a/engine/commands/model_upd_cmd.cc +++ b/engine/commands/model_upd_cmd.cc @@ -11,7 +11,11 @@ void ModelUpdCmd::Exec( const std::unordered_map& options) { try { auto model_entry = model_list_utils_.GetModelInfo(model_handle_); - yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler_.ModelConfigFromFile(model_entry.value().path_to_model_yaml); model_config_ = yaml_handler_.GetModelConfig(); for (const auto& [key, value] : options) { @@ -21,7 +25,7 @@ void ModelUpdCmd::Exec( } yaml_handler_.UpdateModelConfig(model_config_); - yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml); + yaml_handler_.WriteYamlFile(model_entry.value().path_to_model_yaml); CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); } catch (const std::exception& e) { CLI_LOG("Failed to update model with model ID '" + model_handle_ + diff --git a/engine/commands/model_upd_cmd.h b/engine/commands/model_upd_cmd.h index 51f5a88d3..49c104157 100644 --- a/engine/commands/model_upd_cmd.h +++ b/engine/commands/model_upd_cmd.h @@ -5,8 +5,8 @@ #include #include #include "config/model_config.h" -#include "utils/modellist_utils.h" #include "config/yaml_config.h" +#include "database/models.h" namespace commands { class ModelUpdCmd { public: @@ -17,7 +17,7 @@ class ModelUpdCmd { std::string model_handle_; config::ModelConfig model_config_; config::YamlHandler yaml_handler_; - modellist_utils::ModelListUtils model_list_utils_; + cortex::db::Models model_list_utils_; void UpdateConfig(const std::string& key, const std::string& value); void UpdateVectorField(const std::string& key, const std::string& value); diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 0e69f523a..d1a88733d 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -1,18 +1,18 @@ #include "run_cmd.h" #include "chat_cmd.h" #include "config/yaml_config.h" +#include "database/models.h" #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" namespace commands { void RunCmd::Exec() { std::optional model_id = model_handle_; - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); @@ -31,7 +31,11 @@ void RunCmd::Exec() { try { auto model_entry = modellist_handler.GetModelInfo(*model_id); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto mc = yaml_handler.GetModelConfig(); // Check if engine existed. If not, download it diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 72706b2b1..6ac0c1664 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -2,12 +2,12 @@ #include #include "config/gguf_parser.h" #include "config/yaml_config.h" +#include "database/models.h" #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" #include "utils/http_util.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" #include "utils/string_utils.h" void Models::PullModel(const HttpRequestPtr& req, @@ -65,13 +65,12 @@ void Models::ListModel( // Iterate through directory - try { - modellist_utils::ModelListUtils modellist_handler; - config::YamlHandler yaml_handler; - - auto list_entry = modellist_handler.LoadModelList(); + cortex::db::Models modellist_handler; + config::YamlHandler yaml_handler; - for (const auto& model_entry : list_entry) { + auto list_entry = modellist_handler.LoadModelList(); + if (list_entry) { + for (const auto& model_entry : list_entry.value()) { // auto model_entry = modellist_handler.GetModelInfo(model_handle); try { @@ -91,9 +90,9 @@ void Models::ListModel( auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); callback(resp); - } catch (const std::exception& e) { - std::string message = - "Fail to get list model information: " + std::string(e.what()); + } else { + std::string message = "Fail to get list model information: " + + std::string(list_entry.error()); LOG_ERROR << message; ret["data"] = data; ret["result"] = "Fail to get list model information"; @@ -117,10 +116,14 @@ void Models::GetModel( Json::Value data(Json::arrayValue); try { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto model_config = yaml_handler.GetModelConfig(); Json::Value obj = model_config.ToJson(); @@ -172,14 +175,14 @@ void Models::UpdateModel( auto model_id = (*(req->getJsonObject())).get("modelId", "").asString(); auto json_body = *(req->getJsonObject()); try { - modellist_utils::ModelListUtils model_list_utils; + cortex::db::Models model_list_utils; auto model_entry = model_list_utils.GetModelInfo(model_id); config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); config::ModelConfig model_config = yaml_handler.GetModelConfig(); model_config.FromJson(json_body); yaml_handler.UpdateModelConfig(model_config); - yaml_handler.WriteYamlFile(model_entry.path_to_model_yaml); + yaml_handler.WriteYamlFile(model_entry.value().path_to_model_yaml); std::string message = "Successfully update model ID '" + model_id + "': " + json_body.toStyledString(); LOG_INFO << message; @@ -217,15 +220,15 @@ void Models::ImportModel( auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString(); config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; - modellist_utils::ModelListUtils modellist_utils_obj; + cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("imported") / std::filesystem::path(modelHandle + ".yml")) .string(); - modellist_utils::ModelEntry model_entry{ + cortex::db::ModelEntry model_entry{ modelHandle, "local", "imported", - model_yaml_path, modelHandle, modellist_utils::ModelStatus::READY}; + model_yaml_path, modelHandle}; try { std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -235,7 +238,7 @@ void Models::ImportModel( model_config.name = modelHandle; yaml_handler.UpdateModelConfig(model_config); - if (modellist_utils_obj.AddModelEntry(model_entry)) { + if (modellist_utils_obj.AddModelEntry(model_entry).value()) { yaml_handler.WriteYamlFile(model_yaml_path); std::string success_message = "Model is imported successfully!"; LOG_INFO << success_message; @@ -289,7 +292,7 @@ void Models::SetModelAlias( LOG_DEBUG << "GetModel, Model handle: " << model_handle << ", Model alias: " << model_alias; - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; try { if (modellist_handler.UpdateModelAlias(model_handle, model_alias)) { std::string message = "Successfully set model alias '" + model_alias + diff --git a/engine/database/database.h b/engine/database/database.h new file mode 100644 index 000000000..239337b18 --- /dev/null +++ b/engine/database/database.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include +#include "SQLiteCpp/SQLiteCpp.h" +#include "utils/file_manager_utils.h" + +namespace cortex::db { +const std::string kDefaultDbPath = + file_manager_utils::GetCortexDataPath().string() + "/cortex.db"; +class Database { + public: + Database(Database const&) = delete; + Database& operator=(Database const&) = delete; + ~Database() {} + + static Database& GetInstance() { + static Database db; + return db; + } + + SQLite::Database& db() { return db_; } + + private: + Database() + : db_(kDefaultDbPath, SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE) {} + SQLite::Database db_; +}; +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.cc b/engine/database/models.cc new file mode 100644 index 000000000..cfaf275e7 --- /dev/null +++ b/engine/database/models.cc @@ -0,0 +1,292 @@ +#include "models.h" +#include +#include +#include +#include +#include +#include +#include "database.h" +#include "utils/file_manager_utils.h" +#include "utils/result.hpp" +#include "utils/scope_exit.h" + +namespace cortex::db { + +Models::Models() : db_(cortex::db::Database::GetInstance().db()) { + db_.exec( + "CREATE TABLE IF NOT EXISTS models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT);"); +} + +Models::Models(SQLite::Database& db) : db_(db) { + db_.exec( + "CREATE TABLE IF NOT EXISTS models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT UNIQUE);"); +} + +Models::~Models() {} + +cpp::result, std::string> Models::LoadModelList() + const { + try { + db_.exec("BEGIN TRANSACTION;"); + utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); + return LoadModelListNoLock(); + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +bool Models::IsUnique(const std::vector& entries, + const std::string& model_id, + const std::string& model_alias) const { + return std::none_of( + entries.begin(), entries.end(), [&](const ModelEntry& entry) { + return entry.model_id == model_id || entry.model_alias == model_id || + entry.model_id == model_alias || + entry.model_alias == model_alias; + }); +} + +cpp::result, std::string> Models::LoadModelListNoLock() + const { + try { + std::vector entries; + SQLite::Statement query(db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias FROM models"); + + while (query.executeStep()) { + ModelEntry entry; + entry.model_id = query.getColumn(0).getString(); + entry.author_repo_id = query.getColumn(1).getString(); + entry.branch_name = query.getColumn(2).getString(); + entry.path_to_model_yaml = query.getColumn(3).getString(); + entry.model_alias = query.getColumn(4).getString(); + entries.push_back(entry); + } + return entries; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +std::string Models::GenerateShortenedAlias( + const std::string& model_id, const std::vector& entries) const { + std::vector parts; + std::istringstream iss(model_id); + std::string part; + while (std::getline(iss, part, '/')) { + parts.push_back(part); + } + + if (parts.empty()) { + return model_id; // Return original if no parts + } + + // Extract the filename without extension + std::string filename = parts.back(); + size_t last_dot_pos = filename.find_last_of('.'); + if (last_dot_pos != std::string::npos) { + filename = filename.substr(0, last_dot_pos); + } + + // Convert to lowercase + std::transform(filename.begin(), filename.end(), filename.begin(), + [](unsigned char c) { return std::tolower(c); }); + + // Generate alias candidates + std::vector candidates; + candidates.push_back(filename); + + if (parts.size() >= 2) { + candidates.push_back(parts[parts.size() - 2] + ":" + filename); + } + + if (parts.size() >= 3) { + candidates.push_back(parts[parts.size() - 3] + ":" + + parts[parts.size() - 2] + "/" + filename); + } + + if (parts.size() >= 4) { + candidates.push_back(parts[0] + ":" + parts[1] + "/" + + parts[parts.size() - 2] + "/" + filename); + } + + // Find the first unique candidate + for (const auto& candidate : candidates) { + if (IsUnique(entries, model_id, candidate)) { + return candidate; + } + } + + // If all candidates are taken, append a number to the last candidate + std::string base_candidate = candidates.back(); + int suffix = 1; + std::string unique_candidate = base_candidate; + while (!IsUnique(entries, model_id, unique_candidate)) { + unique_candidate = base_candidate + "-" + std::to_string(suffix++); + } + + return unique_candidate; +} + +cpp::result Models::GetModelInfo( + const std::string& identifier) const { + try { + SQLite::Statement query(db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias FROM models " + "WHERE model_id = ? OR model_alias = ?"); + + query.bind(1, identifier); + query.bind(2, identifier); + if (query.executeStep()) { + ModelEntry entry; + entry.model_id = query.getColumn(0).getString(); + entry.author_repo_id = query.getColumn(1).getString(); + entry.branch_name = query.getColumn(2).getString(); + entry.path_to_model_yaml = query.getColumn(3).getString(); + entry.model_alias = query.getColumn(4).getString(); + return entry; + } else { + return cpp::fail("Model not found: " + identifier); + } + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +void Models::PrintModelInfo(const ModelEntry& entry) const { + LOG_INFO << "Model ID: " << entry.model_id; + LOG_INFO << "Author/Repo ID: " << entry.author_repo_id; + LOG_INFO << "Branch Name: " << entry.branch_name; + LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; + LOG_INFO << "Model Alias: " << entry.model_alias; +} + +cpp::result Models::AddModelEntry(ModelEntry new_entry, + bool use_short_alias) { + try { + db_.exec("BEGIN TRANSACTION;"); + utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); + auto model_list = LoadModelListNoLock(); + if (model_list.has_error()) { + CTL_WRN(model_list.error()); + std::cout << "Test: " << model_list.error(); + return cpp::fail(model_list.error()); + } + if (IsUnique(model_list.value(), new_entry.model_id, + new_entry.model_alias)) { + if (use_short_alias) { + new_entry.model_alias = + GenerateShortenedAlias(new_entry.model_id, model_list.value()); + } + + SQLite::Statement insert( + db_, + "INSERT INTO models (model_id, author_repo_id, " + "branch_name, path_to_model_yaml, model_alias) VALUES (?, ?, " + "?, ?, ?)"); + insert.bind(1, new_entry.model_id); + insert.bind(2, new_entry.author_repo_id); + insert.bind(3, new_entry.branch_name); + insert.bind(4, new_entry.path_to_model_yaml); + insert.bind(5, new_entry.model_alias); + insert.exec(); + + return true; + } + return false; // Entry not added due to non-uniqueness + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result Models::UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry) { + try { + SQLite::Statement upd(db_, + "UPDATE models " + "SET author_repo_id = ?, branch_name = ?, " + "path_to_model_yaml = ? " + "WHERE model_id = ? OR model_alias = ?"); + upd.bind(1, updated_entry.author_repo_id); + upd.bind(2, updated_entry.branch_name); + upd.bind(3, updated_entry.path_to_model_yaml); + upd.bind(4, identifier); + upd.bind(5, identifier); + return upd.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::UpdateModelAlias( + const std::string& model_id, const std::string& new_model_alias) { + try { + db_.exec("BEGIN TRANSACTION;"); + utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); + auto model_list = LoadModelListNoLock(); + if (model_list.has_error()) { + CTL_WRN(model_list.error()); + return cpp::fail(model_list.error()); + } + // Check new_model_alias is unique + if (IsUnique(model_list.value(), new_model_alias, new_model_alias)) { + SQLite::Statement upd(db_, + "UPDATE models " + "SET model_alias = ? " + "WHERE model_id = ? OR model_alias = ?"); + upd.bind(1, new_model_alias); + upd.bind(2, model_id); + upd.bind(3, model_id); + return upd.exec() == 1; + } + return false; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::DeleteModelEntry( + const std::string& identifier) { + try { + SQLite::Statement del( + db_, "DELETE from models WHERE model_id = ? OR model_alias = ?"); + del.bind(1, identifier); + del.bind(2, identifier); + return del.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +bool Models::HasModel(const std::string& identifier) const { + try { + SQLite::Statement query( + db_, + "SELECT COUNT(*) FROM models WHERE model_id = ? OR model_alias = ?"); + query.bind(1, identifier); + query.bind(2, identifier); + if (query.executeStep()) { + return query.getColumn(0).getInt() > 0; + } + return false; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return false; + } +} +} // namespace cortex::db diff --git a/engine/database/models.h b/engine/database/models.h new file mode 100644 index 000000000..184f1c6a6 --- /dev/null +++ b/engine/database/models.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include +#include "SQLiteCpp/SQLiteCpp.h" +#include "utils/result.hpp" + +namespace cortex::db { +struct ModelEntry { + std::string model_id; + std::string author_repo_id; + std::string branch_name; + std::string path_to_model_yaml; + std::string model_alias; +}; + +class Models { + + private: + SQLite::Database& db_; + + bool IsUnique(const std::vector& entries, + const std::string& model_id, + const std::string& model_alias) const; + + cpp::result, std::string> LoadModelListNoLock() const; + + public: + static const std::string kModelListPath; + cpp::result, std::string> LoadModelList() const; + Models(); + Models(SQLite::Database& db); + ~Models(); + std::string GenerateShortenedAlias( + const std::string& model_id, + const std::vector& entries) const; + cpp::result GetModelInfo(const std::string& identifier) const; + void PrintModelInfo(const ModelEntry& entry) const; + cpp::result AddModelEntry(ModelEntry new_entry, + bool use_short_alias = false); + cpp::result UpdateModelEntry(const std::string& identifier, + const ModelEntry& updated_entry); + cpp::result DeleteModelEntry(const std::string& identifier); + cpp::result UpdateModelAlias(const std::string& model_id, + const std::string& model_alias); + bool HasModel(const std::string& identifier) const; +}; +} // namespace cortex::db diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index bb5b02a96..0e25e71ab 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -4,11 +4,11 @@ #include #include "config/gguf_parser.h" #include "config/yaml_config.h" +#include "database/models.h" #include "utils/cli_selection_utils.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" -#include "utils/modellist_utils.h" #include "utils/result.hpp" #include "utils/string_utils.h" @@ -37,14 +37,12 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, CTL_INF("Adding model to modellist with branch: " << branch); auto author_id = author.has_value() ? author.value() : "cortexso"; - modellist_utils::ModelListUtils modellist_utils_obj; - modellist_utils::ModelEntry model_entry{ - .model_id = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = yaml_name.string(), - .model_alias = ggufDownloadItem.id, - .status = modellist_utils::ModelStatus::READY}; + cortex::db::Models modellist_utils_obj; + cortex::db::ModelEntry model_entry{.model_id = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = ggufDownloadItem.id}; modellist_utils_obj.AddModelEntry(model_entry, true); } @@ -233,6 +231,7 @@ cpp::result ModelService::HandleUrl( if (async) { auto result = download_service_.AddAsyncDownloadTask(downloadTask, on_finished); + if (result.has_error()) { CTL_ERR(result.error()); return cpp::fail(result.error()); @@ -277,14 +276,13 @@ cpp::result ModelService::DownloadModelFromCortexso( yaml_handler.ModelConfigFromFile(model_yml_item->localPath.string()); auto mc = yaml_handler.GetModelConfig(); - modellist_utils::ModelListUtils modellist_utils_obj; - modellist_utils::ModelEntry model_entry{ + cortex::db::Models modellist_utils_obj; + cortex::db::ModelEntry model_entry{ .model_id = model_id, .author_repo_id = "cortexso", .branch_name = branch, .path_to_model_yaml = model_yml_item->localPath.string(), - .model_alias = model_id, - .status = modellist_utils::ModelStatus::READY}; + .model_alias = model_id}; modellist_utils_obj.AddModelEntry(model_entry); }; @@ -336,17 +334,21 @@ ModelService::DownloadHuggingFaceGgufModel(const std::string& author, cpp::result ModelService::DeleteModel( const std::string& model_handle) { - modellist_utils::ModelListUtils modellist_handler; + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { auto model_entry = modellist_handler.GetModelInfo(model_handle); - yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return cpp::fail(model_entry.error()); + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); auto mc = yaml_handler.GetModelConfig(); // Remove yaml file - std::filesystem::remove(model_entry.path_to_model_yaml); + std::filesystem::remove(model_entry.value().path_to_model_yaml); // Remove model files if they are not imported locally - if (model_entry.branch_name != "imported") { + if (model_entry.value().branch_name != "imported") { if (mc.files.size() > 0) { if (mc.engine == "cortex.llamacpp") { for (auto& file : mc.files) { diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 13bb9c526..a321b1821 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -5,12 +5,12 @@ enable_testing() add_executable(${PROJECT_NAME} ${SRCS} - ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/yaml_config.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/gguf_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../commands/cortex_upd_cmd.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../commands/server_stop_cmd.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../services/download_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../database/models.cc ) find_package(Drogon CONFIG REQUIRED) @@ -20,6 +20,7 @@ find_package(httplib CONFIG REQUIRED) find_package(unofficial-minizip CONFIG REQUIRED) find_package(LibArchive REQUIRED) find_package(CURL REQUIRED) +find_package(SQLiteCpp REQUIRED) target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp ${CMAKE_THREAD_LIBS_INIT}) @@ -28,6 +29,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib) target_link_libraries(${PROJECT_NAME} PRIVATE unofficial::minizip::minizip) target_link_libraries(${PROJECT_NAME} PRIVATE LibArchive::LibArchive) target_link_libraries(${PROJECT_NAME} PRIVATE CURL::libcurl) +target_link_libraries(${PROJECT_NAME} PRIVATE SQLiteCpp) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../) add_test(NAME ${PROJECT_NAME} diff --git a/engine/test/components/test_modellist_utils.cc b/engine/test/components/test_modellist_utils.cc deleted file mode 100644 index d1dbf91e3..000000000 --- a/engine/test/components/test_modellist_utils.cc +++ /dev/null @@ -1,134 +0,0 @@ -#include -#include -#include "gtest/gtest.h" -#include "utils/modellist_utils.h" -#include "utils/file_manager_utils.h" -class ModelListUtilsTestSuite : public ::testing::Test { - protected: - modellist_utils::ModelListUtils model_list_; - - const modellist_utils::ModelEntry kTestModel{ - "test_model_id", "test_author", - "main", "/path/to/model.yaml", - "test_alias", modellist_utils::ModelStatus::READY}; -}; - void SetUp() { - // Create a temporary directory for tests - file_manager_utils::CreateConfigFileIfNotExist(); - } - - void TearDown() { - // Clean up the temporary directory - std::remove((file_manager_utils::GetModelsContainerPath() / "model.list").string().c_str()); - } -TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) { - EXPECT_TRUE(model_list_.AddModelEntry(kTestModel)); - - auto retrieved_model = model_list_.GetModelInfo("test_model_id"); - EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id); - EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id); -} - -TEST_F(ModelListUtilsTestSuite, TestGetModelInfo) { - model_list_.AddModelEntry(kTestModel); - - auto model_by_id = model_list_.GetModelInfo("test_model_id"); - EXPECT_EQ(model_by_id.model_id, kTestModel.model_id); - - auto model_by_alias = model_list_.GetModelInfo("test_alias"); - EXPECT_EQ(model_by_alias.model_id, kTestModel.model_id); - - EXPECT_THROW(model_list_.GetModelInfo("non_existent_model"), - std::runtime_error); -} - -TEST_F(ModelListUtilsTestSuite, TestUpdateModelEntry) { - model_list_.AddModelEntry(kTestModel); - - modellist_utils::ModelEntry updated_model = kTestModel; - updated_model.status = modellist_utils::ModelStatus::RUNNING; - - EXPECT_TRUE(model_list_.UpdateModelEntry("test_model_id", updated_model)); - - auto retrieved_model = model_list_.GetModelInfo("test_model_id"); - EXPECT_EQ(retrieved_model.status, modellist_utils::ModelStatus::RUNNING); - updated_model.status = modellist_utils::ModelStatus::READY; - model_list_.UpdateModelEntry("test_model_id", updated_model); -} - -TEST_F(ModelListUtilsTestSuite, TestDeleteModelEntry) { - model_list_.AddModelEntry(kTestModel); - - EXPECT_TRUE(model_list_.DeleteModelEntry("test_model_id")); - EXPECT_THROW(model_list_.GetModelInfo("test_model_id"), std::runtime_error); -} - -TEST_F(ModelListUtilsTestSuite, TestGenerateShortenedAlias) { - auto alias = model_list_.GenerateShortenedAlias( - "huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", {}); - EXPECT_EQ(alias, "model_id_xxx"); - - // Test with existing entries to force longer alias - modellist_utils::ModelEntry existing_model = kTestModel; - existing_model.model_alias = "model_id_xxx"; - std::vector existing_entries = {existing_model}; - - alias = model_list_.GenerateShortenedAlias( - "huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", - existing_entries); - EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx"); -} - -TEST_F(ModelListUtilsTestSuite, TestPersistence) { - model_list_.AddModelEntry(kTestModel); - - // Create a new ModelListUtils instance to test if it loads from file - modellist_utils::ModelListUtils new_model_list; - auto retrieved_model = new_model_list.GetModelInfo("test_model_id"); - - EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id); - EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id); - model_list_.DeleteModelEntry("test_model_id"); -} - -TEST_F(ModelListUtilsTestSuite, TestUpdateModelAlias) { - // Add the test model - ASSERT_TRUE(model_list_.AddModelEntry(kTestModel)); - - // Test successful update - EXPECT_TRUE(model_list_.UpdateModelAlias("test_model_id", "new_test_alias")); - auto updated_model = model_list_.GetModelInfo("new_test_alias"); - EXPECT_EQ(updated_model.model_alias, "new_test_alias"); - EXPECT_EQ(updated_model.model_id, "test_model_id"); - - // Test update with non-existent model - EXPECT_FALSE(model_list_.UpdateModelAlias("non_existent_model", "another_alias")); - - // Test update with non-unique alias - modellist_utils::ModelEntry another_model = kTestModel; - another_model.model_id = "another_model_id"; - another_model.model_alias = "another_alias"; - ASSERT_TRUE(model_list_.AddModelEntry(another_model)); - - EXPECT_FALSE(model_list_.UpdateModelAlias("test_model_id", "another_alias")); - - // Test update using model alias instead of model ID - EXPECT_TRUE(model_list_.UpdateModelAlias("new_test_alias", "final_test_alias")); - updated_model = model_list_.GetModelInfo("final_test_alias"); - EXPECT_EQ(updated_model.model_alias, "final_test_alias"); - EXPECT_EQ(updated_model.model_id, "test_model_id"); - - // Clean up - model_list_.DeleteModelEntry("test_model_id"); - model_list_.DeleteModelEntry("another_model_id"); -} - -TEST_F(ModelListUtilsTestSuite, TestHasModel) { - model_list_.AddModelEntry(kTestModel); - - EXPECT_TRUE(model_list_.HasModel("test_model_id")); - EXPECT_TRUE(model_list_.HasModel("test_alias")); - EXPECT_FALSE(model_list_.HasModel("non_existent_model")); - // Clean up - model_list_.DeleteModelEntry("test_model_id"); -} \ No newline at end of file diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc new file mode 100644 index 000000000..ee418d851 --- /dev/null +++ b/engine/test/components/test_models_db.cc @@ -0,0 +1,167 @@ +#include +#include +#include "database/models.h" +#include "gtest/gtest.h" +#include "utils/file_manager_utils.h" + +namespace cortex::db { +namespace { +constexpr const auto kTestDb = "./test.db"; +} +class ModelsTestSuite : public ::testing::Test { + public: + ModelsTestSuite() + : db_(kTestDb, SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE), + model_list_(db_) {} + void SetUp() { + try { + db_.exec("DELETE FROM models"); + } catch (const std::exception& e) {} + } + + protected: + SQLite::Database db_; + cortex::db::Models model_list_; + + const cortex::db::ModelEntry kTestModel{ + "test_model_id", "test_author", "main", + "/path/to/model.yaml", "test_alias"}; +}; + +TEST_F(ModelsTestSuite, TestAddModelEntry) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + auto retrieved_model = model_list_.GetModelInfo(kTestModel.model_id); + EXPECT_TRUE(retrieved_model); + EXPECT_EQ(retrieved_model.value().model_id, kTestModel.model_id); + EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); + + // // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} + +TEST_F(ModelsTestSuite, TestGetModelInfo) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + auto model_by_id = model_list_.GetModelInfo(kTestModel.model_id); + EXPECT_TRUE(model_by_id); + EXPECT_EQ(model_by_id.value().model_id, kTestModel.model_id); + + auto model_by_alias = model_list_.GetModelInfo("test_alias"); + EXPECT_TRUE(model_by_alias); + EXPECT_EQ(model_by_alias.value().model_id, kTestModel.model_id); + + EXPECT_TRUE(model_list_.GetModelInfo("non_existent_model").has_error()); + + // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} + +TEST_F(ModelsTestSuite, TestUpdateModelEntry) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + cortex::db::ModelEntry updated_model = kTestModel; + + EXPECT_TRUE( + model_list_.UpdateModelEntry(kTestModel.model_id, updated_model).value()); + + auto retrieved_model = model_list_.GetModelInfo(kTestModel.model_id); + EXPECT_TRUE(retrieved_model); + EXPECT_TRUE( + model_list_.UpdateModelEntry(kTestModel.model_id, updated_model).value()); + + // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} + +TEST_F(ModelsTestSuite, TestDeleteModelEntry) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); + EXPECT_TRUE(model_list_.GetModelInfo(kTestModel.model_id).has_error()); +} + +TEST_F(ModelsTestSuite, TestGenerateShortenedAlias) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + auto models1 = model_list_.LoadModelList(); + auto alias = model_list_.GenerateShortenedAlias( + "huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", + models1.value()); + EXPECT_EQ(alias, "model_id_xxx"); + EXPECT_TRUE(model_list_.UpdateModelAlias(kTestModel.model_id, alias).value()); + + // Test with existing entries to force longer alias + auto models2 = model_list_.LoadModelList(); + alias = model_list_.GenerateShortenedAlias( + "huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", + models2.value()); + EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx"); + + // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} + +TEST_F(ModelsTestSuite, TestPersistence) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + // Create a new ModelListUtils instance to test if it loads from file + cortex::db::Models new_model_list(db_); + auto retrieved_model = new_model_list.GetModelInfo(kTestModel.model_id); + EXPECT_TRUE(retrieved_model); + EXPECT_EQ(retrieved_model.value().model_id, kTestModel.model_id); + EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} + +TEST_F(ModelsTestSuite, TestUpdateModelAlias) { + constexpr const auto kNewTestAlias = "new_test_alias"; + constexpr const auto kNonExistentModel = "non_existent_model"; + constexpr const auto kAnotherAlias = "another_alias"; + constexpr const auto kFinalTestAlias = "final_test_alias"; + constexpr const auto kAnotherModelId = "another_model_id"; + // Add the test model + ASSERT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + // Test successful update + EXPECT_TRUE( + model_list_.UpdateModelAlias(kTestModel.model_id, kNewTestAlias).value()); + auto updated_model = model_list_.GetModelInfo(kNewTestAlias); + EXPECT_TRUE(updated_model); + EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); + EXPECT_EQ(updated_model.value().model_id, kTestModel.model_id); + + // Test update with non-existent model + EXPECT_FALSE( + model_list_.UpdateModelAlias(kNonExistentModel, kAnotherAlias).value()); + + // Test update with non-unique alias + cortex::db::ModelEntry another_model = kTestModel; + another_model.model_id = kAnotherModelId; + another_model.model_alias = kAnotherAlias; + ASSERT_TRUE(model_list_.AddModelEntry(another_model).value()); + + EXPECT_FALSE( + model_list_.UpdateModelAlias(kTestModel.model_id, kAnotherAlias).value()); + + // Test update using model alias instead of model ID + EXPECT_TRUE(model_list_.UpdateModelAlias(kNewTestAlias, kFinalTestAlias)); + updated_model = model_list_.GetModelInfo(kFinalTestAlias); + EXPECT_TRUE(updated_model); + EXPECT_EQ(updated_model.value().model_alias, kFinalTestAlias); + EXPECT_EQ(updated_model.value().model_id, kTestModel.model_id); + + // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); + EXPECT_TRUE(model_list_.DeleteModelEntry(kAnotherModelId).value()); +} + +TEST_F(ModelsTestSuite, TestHasModel) { + EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); + + EXPECT_TRUE(model_list_.HasModel(kTestModel.model_id)); + EXPECT_TRUE(model_list_.HasModel("test_alias")); + EXPECT_FALSE(model_list_.HasModel("non_existent_model")); + // Clean up + EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model_id).value()); +} +} // namespace cortex::db \ No newline at end of file diff --git a/engine/utils/modellist_utils.cc b/engine/utils/modellist_utils.cc deleted file mode 100644 index d577519f3..000000000 --- a/engine/utils/modellist_utils.cc +++ /dev/null @@ -1,256 +0,0 @@ -#include "modellist_utils.h" -#include -#include -#include -#include -#include -#include -#include "file_manager_utils.h" - -namespace modellist_utils { -const std::string ModelListUtils::kModelListPath = - (file_manager_utils::GetModelsContainerPath() / - std::filesystem::path("model.list")) - .string(); - -std::vector ModelListUtils::LoadModelList() const { - std::vector entries; - std::filesystem::path file_path(kModelListPath); - - // Check if the file exists, if not, create it - if (!std::filesystem::exists(file_path)) { - std::ofstream create_file(kModelListPath); - if (!create_file) { - throw std::runtime_error("Unable to create model.list file: " + - kModelListPath); - } - create_file.close(); - return entries; // Return empty vector for newly created file - } - - std::ifstream file(kModelListPath); - if (!file.is_open()) { - throw std::runtime_error("Unable to open model.list file: " + - kModelListPath); - } - - std::string line; - while (std::getline(file, line)) { - std::istringstream iss(line); - ModelEntry entry; - std::string status_str; - if (!(iss >> entry.model_id >> entry.author_repo_id >> entry.branch_name >> - entry.path_to_model_yaml >> entry.model_alias >> status_str)) { - LOG_WARN << "Invalid entry in model.list: " << line; - } else { - entry.status = - (status_str == "RUNNING") ? ModelStatus::RUNNING : ModelStatus::READY; - entries.push_back(entry); - } - } - return entries; -} - -bool ModelListUtils::IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const { - return std::none_of( - entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return entry.model_id == model_id || entry.model_alias == model_id || - entry.model_id == model_alias || - entry.model_alias == model_alias; - }); -} - -void ModelListUtils::SaveModelList( - const std::vector& entries) const { - std::ofstream file(kModelListPath); - if (!file.is_open()) { - throw std::runtime_error("Unable to open model.list file for writing: " + - kModelListPath); - } - - for (const auto& entry : entries) { - file << entry.model_id << " " << entry.author_repo_id << " " - << entry.branch_name << " " << entry.path_to_model_yaml << " " - << entry.model_alias << " " - << (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY") - << std::endl; - } -} - -std::string ModelListUtils::GenerateShortenedAlias( - const std::string& model_id, const std::vector& entries) const { - std::vector parts; - std::istringstream iss(model_id); - std::string part; - while (std::getline(iss, part, '/')) { - parts.push_back(part); - } - - if (parts.empty()) { - return model_id; // Return original if no parts - } - - // Extract the filename without extension - std::string filename = parts.back(); - size_t last_dot_pos = filename.find_last_of('.'); - if (last_dot_pos != std::string::npos) { - filename = filename.substr(0, last_dot_pos); - } - - // Convert to lowercase - std::transform(filename.begin(), filename.end(), filename.begin(), - [](unsigned char c) { return std::tolower(c); }); - - // Generate alias candidates - std::vector candidates; - candidates.push_back(filename); - - if (parts.size() >= 2) { - candidates.push_back(parts[parts.size() - 2] + ":" + filename); - } - - if (parts.size() >= 3) { - candidates.push_back(parts[parts.size() - 3] + ":" + - parts[parts.size() - 2] + "/" + filename); - } - - if (parts.size() >= 4) { - candidates.push_back(parts[0] + ":" + parts[1] + "/" + - parts[parts.size() - 2] + "/" + filename); - } - - // Find the first unique candidate - for (const auto& candidate : candidates) { - if (IsUnique(entries, model_id, candidate)) { - return candidate; - } - } - - // If all candidates are taken, append a number to the last candidate - std::string base_candidate = candidates.back(); - int suffix = 1; - std::string unique_candidate = base_candidate; - while (!IsUnique(entries, model_id, unique_candidate)) { - unique_candidate = base_candidate + "-" + std::to_string(suffix++); - } - - return unique_candidate; -} - -ModelEntry ModelListUtils::GetModelInfo(const std::string& identifier) const { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - auto it = std::find_if( - entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { - return entry.model_id == identifier || entry.model_alias == identifier; - }); - - if (it != entries.end()) { - return *it; - } else { - throw std::runtime_error("Model not found: " + identifier); - } -} - -void ModelListUtils::PrintModelInfo(const ModelEntry& entry) const { - LOG_INFO << "Model ID: " << entry.model_id; - LOG_INFO << "Author/Repo ID: " << entry.author_repo_id; - LOG_INFO << "Branch Name: " << entry.branch_name; - LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; - LOG_INFO << "Model Alias: " << entry.model_alias; - LOG_INFO << "Status: " - << (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY"); -} - -bool ModelListUtils::AddModelEntry(ModelEntry new_entry, bool use_short_alias) { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - - if (IsUnique(entries, new_entry.model_id, new_entry.model_alias)) { - if (use_short_alias) { - new_entry.model_alias = - GenerateShortenedAlias(new_entry.model_id, entries); - } - new_entry.status = ModelStatus::READY; // Set default status to READY - entries.push_back(std::move(new_entry)); - SaveModelList(entries); - return true; - } - return false; // Entry not added due to non-uniqueness -} - -bool ModelListUtils::UpdateModelEntry(const std::string& identifier, - const ModelEntry& updated_entry) { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - auto it = std::find_if( - entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { - return entry.model_id == identifier || entry.model_alias == identifier; - }); - - if (it != entries.end()) { - *it = updated_entry; - SaveModelList(entries); - return true; - } - return false; // Entry not found -} - -bool ModelListUtils::UpdateModelAlias(const std::string& model_id, - const std::string& new_model_alias) { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - auto it = std::find_if( - entries.begin(), entries.end(), [&model_id](const ModelEntry& entry) { - return entry.model_id == model_id || entry.model_alias == model_id; - }); - bool check_alias_unique = std::none_of( - entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return (entry.model_id == new_model_alias && - entry.model_id != model_id) || - entry.model_alias == new_model_alias; - }); - if (it != entries.end() && check_alias_unique) { - - (*it).model_alias = new_model_alias; - SaveModelList(entries); - return true; - } - return false; // Entry not found -} - -bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - auto it = std::find_if(entries.begin(), entries.end(), - [&identifier](const ModelEntry& entry) { - return (entry.model_id == identifier || - entry.model_alias == identifier) && - entry.status == ModelStatus::READY; - }); - - if (it != entries.end()) { - entries.erase(it); - SaveModelList(entries); - return true; - } - return false; // Entry not found or not in READY state -} - -bool ModelListUtils::HasModel(const std::string& identifier) const { - std::lock_guard lock(mutex_); - auto entries = LoadModelList(); - auto it = std::find_if( - entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { - return entry.model_id == identifier || entry.model_alias == identifier; - }); - - if (it != entries.end()) { - return true; - } else { - return false; - } -} -} // namespace modellist_utils diff --git a/engine/utils/modellist_utils.h b/engine/utils/modellist_utils.h deleted file mode 100644 index 113591f25..000000000 --- a/engine/utils/modellist_utils.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace modellist_utils { - -enum class ModelStatus { READY, RUNNING }; - -struct ModelEntry { - std::string model_id; - std::string author_repo_id; - std::string branch_name; - std::string path_to_model_yaml; - std::string model_alias; - ModelStatus status; -}; - -class ModelListUtils { - - private: - mutable std::mutex mutex_; // For thread safety - - bool IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const; - void SaveModelList(const std::vector& entries) const; - - public: - static const std::string kModelListPath; - std::vector LoadModelList() const; - ModelListUtils() = default; - std::string GenerateShortenedAlias( - const std::string& model_id, - const std::vector& entries) const; - ModelEntry GetModelInfo(const std::string& identifier) const; - void PrintModelInfo(const ModelEntry& entry) const; - bool AddModelEntry(ModelEntry new_entry, bool use_short_alias = false); - bool UpdateModelEntry(const std::string& identifier, - const ModelEntry& updated_entry); - bool DeleteModelEntry(const std::string& identifier); - bool UpdateModelAlias(const std::string& model_id, - const std::string& model_alias); - bool HasModel(const std::string& identifier) const; -}; -} // namespace modellist_utils diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 40abc186e..25b5f3de4 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -15,6 +15,7 @@ "nlohmann-json", "yaml-cpp", "libarchive", - "tabulate" + "tabulate", + "sqlitecpp" ] }