From 2f6bdaa252e5374121400b4e86159d1afba7a464 Mon Sep 17 00:00:00 2001 From: Nguyen Tran Date: Wed, 2 Oct 2024 11:52:57 -0400 Subject: [PATCH] feat: Add YAML-to-struct conversion utilities and improve logging setup - Added YAML converters for `ModelSettings`, `PopulationDemographic`, `TransmissionSettings`, and `date::year_month_day` structs. - Introduced `ConfigData` struct to group configuration data sections. - Updated `Config` class to load, reload, and notify observers for configuration changes using the new `ConfigData` structure. - Integrated `sol2` and `Lua` for dynamic configuration validation. - Added `Logger` class for initializing loggers using `spdlog`, with a dedicated logger for `ConfigValidator` and network operations. - Updated CMakeLists to include dependencies: Sol2, Lua, spdlog, and date libraries. - Added unit tests for YAML converters, including date and configuration-related structs, using Google Test. - Removed example tests and added relevant configuration tests. - Added prompts for generating GTest test cases and YAML struct converters for `ModelSettings`. --- not_used/ConfigValidator.cpp | 147 +++++++++++++++ not_used/ConfigValidator.h | 46 +++++ not_used/ConfigValidator_test.cpp | 37 ++++ promts/generate_gtest.md | 6 + promts/struct_to_yaml_converter.md | 46 +++++ promts/yaml_to_struct.md | 7 + src/CMakeLists.txt | 15 +- src/Configuration/Config.cpp | 70 +++---- src/Configuration/Config.h | 37 +--- src/Configuration/ConfigData.h | 15 ++ src/Configuration/ModelSettings.h | 17 ++ src/Configuration/PopulationDemographic.h | 17 ++ src/Configuration/TransmissionSettings.h | 10 + src/Configuration/YAMLConverters.h | 176 ++++++++++++++++++ src/Utils/Logger.cpp | 41 ++++ src/Utils/Logger.h | 15 ++ tests/CMakeLists.txt | 1 + .../yaml_date_conversion_test.cpp | 70 +++++++ .../yaml_model_settings_conversion_test.cpp | 92 +++++++++ ...population_demographic_conversion_test.cpp | 129 +++++++++++++ tests/SpdlogEnvironment.cpp | 19 ++ tests/example_test.cpp | 7 - vcpkg.json | 6 +- 23 files changed, 942 insertions(+), 84 deletions(-) create mode 100644 not_used/ConfigValidator.cpp create mode 100644 not_used/ConfigValidator.h create mode 100644 not_used/ConfigValidator_test.cpp create mode 100644 promts/generate_gtest.md create mode 100644 promts/struct_to_yaml_converter.md create mode 100644 promts/yaml_to_struct.md create mode 100644 src/Configuration/ConfigData.h create mode 100644 src/Configuration/ModelSettings.h create mode 100644 src/Configuration/PopulationDemographic.h create mode 100644 src/Configuration/TransmissionSettings.h create mode 100644 src/Configuration/YAMLConverters.h create mode 100644 src/Utils/Logger.cpp create mode 100644 src/Utils/Logger.h create mode 100644 tests/Configuration/yaml_date_conversion_test.cpp create mode 100644 tests/Configuration/yaml_model_settings_conversion_test.cpp create mode 100644 tests/Configuration/yaml_population_demographic_conversion_test.cpp create mode 100644 tests/SpdlogEnvironment.cpp delete mode 100644 tests/example_test.cpp diff --git a/not_used/ConfigValidator.cpp b/not_used/ConfigValidator.cpp new file mode 100644 index 0000000..fad69e4 --- /dev/null +++ b/not_used/ConfigValidator.cpp @@ -0,0 +1,147 @@ +#include "ConfigValidator.h" + +#include + +#include "Config.h" + +bool ConfigValidator::Validate(const ConfigData &config) { + try { + // Add more validations as needed + return true; + } catch (const std::exception &e) { + std::cerr << "Validation failed: " << e.what() << std::endl; + return false; + } +} + +bool ConfigValidator::ValidateAgainstSchema(const ConfigData &config, + const YAML::Node &schema) { + // Implement dynamic validation logic here (e.g., using Lua or schema-based + // validation) + return true; // Placeholder +} + +// Recursive function to convert YAML node to Lua table using Sol2 +sol::table ConfigValidator::PushYamlToLua(sol::state &lua, + const YAML::Node &node) { + sol::table lua_table = lua.create_table(); + + for (auto it = node.begin(); it != node.end(); ++it) { + std::string key; + try { + key = it->first.as(); + } catch (const YAML::BadConversion &e) { + // Handle invalid key conversion + key = "invalid_key"; // Or handle appropriately + } + + const YAML::Node &value = it->second; + + try { + if (value.IsScalar()) { + if (value.Tag() == "tag:yaml.org,2002:int") { + lua_table[key] = value.as(); + } else if (value.Tag() == "tag:yaml.org,2002:float") { + lua_table[key] = value.as(); + } else if (value.Tag() == "tag:yaml.org,2002:bool") { + lua_table[key] = value.as(); + } else if (value.IsNull()) { + lua_table[key] = sol::lua_nil; + } else { + lua_table[key] = value.as(); + } + } else if (value.IsMap()) { + lua_table[key] = PushYamlToLua(lua, value); + } else if (value.IsSequence()) { + sol::table array_table = lua.create_table(); + int index = 1; + for (const auto &element : value) { + if (element.IsScalar()) { + if (element.Tag() == "tag:yaml.org,2002:int") { + array_table[index++] = element.as(); + } else if (element.Tag() == "tag:yaml.org,2002:float") { + array_table[index++] = element.as(); + } else if (element.Tag() == "tag:yaml.org,2002:bool") { + array_table[index++] = element.as(); + } else if (element.IsNull()) { + array_table[index++] = sol::lua_nil; + } else { + array_table[index++] = element.as(); + } + } else { + array_table[index++] = PushYamlToLua(lua, element); + } + } + lua_table[key] = array_table; + } + } catch (const YAML::BadConversion &e) { + // Handle conversion error, possibly logging and setting Lua to nil or a + // default value + lua_table[key] = sol::lua_nil; + } + } + + return lua_table; +} + +// Load YAML configuration into Lua +void ConfigValidator::LoadConfigToLua(sol::state &lua, + const YAML::Node &config) { + sol::table lua_config = PushYamlToLua(lua, config); + lua["config"] = lua_config; + // Debugging: Print Lua table contents + std::cout << "Lua 'config' table contents:\n"; + for (const auto &pair : lua_config) { + std::string key = pair.first.as(); + sol::object value = pair.second; + std::cout << key << " = "; + switch (value.get_type()) { + case sol::type::lua_nil: + std::cout << "nil"; + break; + case sol::type::boolean: + std::cout << (value.as() ? "true" : "false"); + break; + case sol::type::number: + std::cout << value.as(); + break; + case sol::type::string: + std::cout << "\"" << value.as() << "\""; + break; + case sol::type::table: + std::cout << "table"; + break; + default: + std::cout << "other"; + break; + } + std::cout << "\n"; + } +} + +// Example validation function using Lua for dynamic rules +bool ConfigValidator::ValidateAgainstLua(const ConfigData &config, + const YAML::Node &schema, + sol::state &lua) { + LoadConfigToLua(lua, schema); + + // You can define or load Lua conditions here based on the schema + std::string condition = R"( + if physics.gravity > 9.8 and simulation.duration > 1000 then + return false + else + return true + end + )"; + + sol::protected_function_result result = + lua.safe_script(condition, sol::script_pass_on_error); + if (!result.valid()) { + sol::error err = result; + std::cerr << "Lua validation error: " << err.what() << std::endl; + return false; + } + + return result; +} + diff --git a/not_used/ConfigValidator.h b/not_used/ConfigValidator.h new file mode 100644 index 0000000..de008f3 --- /dev/null +++ b/not_used/ConfigValidator.h @@ -0,0 +1,46 @@ +// ConfigValidator.h +#ifndef CONFIGVALIDATOR_H +#define CONFIGVALIDATOR_H + +#include + +#include + +class ConfigData; + +class ConfigValidator { +public: + ConfigValidator() = default; + ~ConfigValidator() = default; + + // Prevent copying and moving + ConfigValidator(const ConfigValidator &) = delete; + ConfigValidator(ConfigValidator &&) = delete; + ConfigValidator &operator=(const ConfigValidator &) = delete; + ConfigValidator &operator=(ConfigValidator &&) = delete; + + // Validate the config data + bool Validate(const ConfigData &config); + + bool ValidateAgainstLua(const ConfigData &config, const YAML::Node &schema, + sol::state &lua); + + // Optionally, validate against YAML schema rules + bool ValidateAgainstSchema(const ConfigData &config, + const YAML::Node &schema); + + // Helper method to load entire config into Lua + void LoadConfigToLua(sol::state &lua, const YAML::Node &config); + + // Helper method to convert YAML to Lua tables + sol::table PushYamlToLua(sol::state &lua, const YAML::Node &node); + +private: + // Helper functions for different validation rules + // void ValidateTimestep(double timestep); + // void ValidateGravity(double gravity); + // Other validation logic for fields +}; + +#endif // CONFIGVALIDATOR_H + diff --git a/not_used/ConfigValidator_test.cpp b/not_used/ConfigValidator_test.cpp new file mode 100644 index 0000000..e82730d --- /dev/null +++ b/not_used/ConfigValidator_test.cpp @@ -0,0 +1,37 @@ +#include "Configuration/ConfigValidator.h" + +#include +#include + +class ConfigValidatorTest : public ::testing::Test { +protected: + ConfigValidator validator; + + void SetUp() override { + // Optional: Initialize any shared resources + } + void TearDown() override { + // Optional: Clean up any shared resources + } +}; + +TEST_F(ConfigValidatorTest, Validate) { + YAML::Node node = YAML::Load("{ settings: { pi: 3.14159}, alpha: 0.5 }"); + // YAML::Emitter out; + // out << node; + // spdlog::info("node: {}", out.c_str()); + + // Create a Lua state and load Sol2's standard libraries + sol::state lua; + lua.open_libraries(sol::lib::base, sol::lib::package); + + validator.LoadConfigToLua(lua, node); + + sol::table config = lua["config"]; + + sol::table settings = config["settings"]; + + EXPECT_DOUBLE_EQ(settings.get("pi"), 3.14159); + + EXPECT_DOUBLE_EQ(config["alpha"], 0.5); +} diff --git a/promts/generate_gtest.md b/promts/generate_gtest.md new file mode 100644 index 0000000..f283ea3 --- /dev/null +++ b/promts/generate_gtest.md @@ -0,0 +1,6 @@ +``` +As an expert in C++ programming, would you mind help me to write a gtest for the following function. +Suggest me a test file name and using Test Fixture class. + + +``` diff --git a/promts/struct_to_yaml_converter.md b/promts/struct_to_yaml_converter.md new file mode 100644 index 0000000..222b17e --- /dev/null +++ b/promts/struct_to_yaml_converter.md @@ -0,0 +1,46 @@ +``` +Here is the example of using yaml-cpp to convert yaml to Vec3 class + +namespace YAML { +template<> +struct convert { + static Node encode(const Vec3& rhs) { + Node node; + node.push_back(rhs.x); + node.push_back(rhs.y); + node.push_back(rhs.z); + return node; + } + + static bool decode(const Node& node, Vec3& rhs) { + if(!node.IsSequence() || node.size() != 3) { + return false; + } + + rhs.x = node[0].as(); + rhs.y = node[1].as(); + rhs.z = node[2].as(); + return true; + } +}; +} +Then you could use Vec3 wherever you could use any other type: + +YAML::Node node = YAML::Load("start: [1, 3, 0]"); +Vec3 v = node["start"].as(); +node["end"] = Vec3(2, -1, 0); + + +As an expert in C++ developer, would you mind make the similar convert function for the following struct: + +struct ModelSettings { + int days_between_stdout_output; // Frequency of stdout output, in days + int initial_seed_number; // Seed for random number generator + bool record_genome_db; // Flag to record genomic data + date::year_month_day starting_date; // Simulation start date (YYYY/MM/DD) + date::year_month_day + start_of_comparison_period; // Start of comparison period (YYYY/MM/DD) + date::year_month_day ending_date; // Simulation end date (YYYY/MM/DD) + int start_collect_data_day; // Day to start collecting data +}; +``` diff --git a/promts/yaml_to_struct.md b/promts/yaml_to_struct.md new file mode 100644 index 0000000..9a3c9a9 --- /dev/null +++ b/promts/yaml_to_struct.md @@ -0,0 +1,7 @@ +``` +As an expret in C++ programming, would you mind help me to convert the yaml which is from the input file to C++ struct. + +For date type, you can use the date library from HowardHinnant. + + +``` diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dbaaae6..ea606bc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,8 +1,15 @@ find_package(fmt CONFIG REQUIRED) find_package(GSL REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) - -include_directories(${PROJECT_SOURCE_DIR}/src) +find_package(sol2 CONFIG REQUIRED) +find_package(Lua REQUIRED) +find_package(spdlog REQUIRED) +find_package(date CONFIG REQUIRED) + +include_directories( + ${PROJECT_SOURCE_DIR}/src + ${LUA_INCLUDE_DIR} +) # craete source files # Add source files for the core library @@ -23,6 +30,10 @@ target_link_libraries(MalaSimCore PUBLIC fmt::fmt GSL::gsl GSL::gslcblas yaml-cpp::yaml-cpp + ${LUA_LIBRARIES} + sol2 + spdlog::spdlog + date::date date::date-tz ) set_property(TARGET MalaSimCore PROPERTY CXX_STANDARD 20) diff --git a/src/Configuration/Config.cpp b/src/Configuration/Config.cpp index 1130823..00a1ae2 100644 --- a/src/Configuration/Config.cpp +++ b/src/Configuration/Config.cpp @@ -1,44 +1,32 @@ + #include "Config.h" -#include - -bool Config::ValidateNode(const YAML::Node &node, const YAML::Node &schema) { - for (auto it = schema.begin(); it != schema.end(); ++it) { - std::string key = it->first.as(); - const YAML::Node &schema_field = it->second; - - // Check if the field is required and present - if (schema_field["required"] && schema_field["required"].as() - && !node[key]) { - std::cerr << "Missing required field: " << key << std::endl; - return false; - } - - // If the field exists, check the type - if (node[key]) { - std::string expected_type = schema_field["type"].as(); - if (expected_type == "double" && !node[key].IsScalar()) { - std::cerr << "Invalid type for field: " << key << " (expected double)" - << std::endl; - return false; - } - if (expected_type == "string" && !node[key].IsScalar()) { - std::cerr << "Invalid type for field: " << key << " (expected string)" - << std::endl; - return false; - } - - // Additional checks like min, max can be added - if (expected_type == "double" && schema_field["min"]) { - double value = node[key].as(); - if (value < schema_field["min"].as()) { - std::cerr << "Value for " << key - << " is less than the minimum allowed: " - << schema_field["min"].as() << std::endl; - return false; - } - } - } - } - return true; +#include + +#include + +#include "YAMLConverters.h" + +void Config::Load(const std::string &filename) { + std::shared_lock lock(mutex_); + config_file_path_ = filename; + YAML::Node config = YAML::LoadFile(filename); + config_data_.model_settings = config["ModelSettings"].as(); + config_data_.transmission_settings = + config["TransmissionSettings"].as(); + config_data_.population_demographic = + config["PopulationDemographic"].as(); + NotifyObservers(); +} + +void Config::Reload() { Load(config_file_path_); } + +void Config::RegisterObserver(ConfigObserver observer) { + std::unique_lock lock(mutex_); + observers_.push_back(observer); } + +void Config::NotifyObservers() { + for (const auto &observer : observers_) { observer(config_data_); } +} + diff --git a/src/Configuration/Config.h b/src/Configuration/Config.h index edfbabd..715fbeb 100644 --- a/src/Configuration/Config.h +++ b/src/Configuration/Config.h @@ -2,40 +2,18 @@ #ifndef CONFIG_H #define CONFIG_H +#include #include #include #include -// #include -// #include -// #include +#include + +#include "ConfigData.h" // Forward declaration class Model; -// Define configuration sections -struct SimulationConfig { - double timestep; - double duration; -}; - -struct PhysicsConfig { - double gravity; - double air_resistance; -}; - -struct OutputConfig { - std::string log_level; - std::string file_path; -}; - -// Aggregated Config Structure -struct ConfigData { - SimulationConfig simulation; - PhysicsConfig physics; - OutputConfig output; -}; - // Observer Callback Type using ConfigObserver = std::function; @@ -48,16 +26,9 @@ class Config { // Load configuration from a YAML file void Load(const std::string &filename); - bool ValidateNode(const YAML::Node &node, const YAML::Node &schema); - // Reload configuration (useful for dynamic updates) void Reload(); - // Getters for configuration sections - SimulationConfig GetSimulationConfig() const; - PhysicsConfig GetPhysicsConfig() const; - OutputConfig GetOutputConfig() const; - // Register an observer for configuration changes void RegisterObserver(ConfigObserver observer); diff --git a/src/Configuration/ConfigData.h b/src/Configuration/ConfigData.h new file mode 100644 index 0000000..7860f08 --- /dev/null +++ b/src/Configuration/ConfigData.h @@ -0,0 +1,15 @@ +#ifndef CONFIG_DATA_H +#define CONFIG_DATA_H + +#include "ModelSettings.h" +#include "PopulationDemographic.h" +#include "TransmissionSettings.h" + +struct ConfigData { + ModelSettings model_settings; + TransmissionSettings transmission_settings; + PopulationDemographic population_demographic; +}; + +#endif // CONFIG_DATA_H + diff --git a/src/Configuration/ModelSettings.h b/src/Configuration/ModelSettings.h new file mode 100644 index 0000000..3b07d81 --- /dev/null +++ b/src/Configuration/ModelSettings.h @@ -0,0 +1,17 @@ +#ifndef MODEL_SETTINGS_H +#define MODEL_SETTINGS_H + +#include + +struct ModelSettings { + int days_between_stdout_output; + int initial_seed_number; + bool record_genome_db; + date::year_month_day starting_date; + date::year_month_day start_of_comparison_period; + date::year_month_day ending_date; + int start_collect_data_day; +}; + +#endif // MODEL_SETTINGS_H + diff --git a/src/Configuration/PopulationDemographic.h b/src/Configuration/PopulationDemographic.h new file mode 100644 index 0000000..0bf4755 --- /dev/null +++ b/src/Configuration/PopulationDemographic.h @@ -0,0 +1,17 @@ +#ifndef POPULATION_DEMOGRAPHIC_H +#define POPULATION_DEMOGRAPHIC_H + +#include + +struct PopulationDemographic { + int number_of_age_classes; + std::vector age_structure; + std::vector initial_age_structure; + double birth_rate; + std::vector death_rate_by_age_class; + std::vector mortality_when_treatment_fail_by_age_class; + double artificial_rescaling_of_population_size; +}; + +#endif // POPULATION_DEMOGRAPHIC_H + diff --git a/src/Configuration/TransmissionSettings.h b/src/Configuration/TransmissionSettings.h new file mode 100644 index 0000000..aea88f4 --- /dev/null +++ b/src/Configuration/TransmissionSettings.h @@ -0,0 +1,10 @@ +#ifndef TRANSMISSION_SETTINGS_H +#define TRANSMISSION_SETTINGS_H + +struct TransmissionSettings { + double transmission_parameter; + double p_infection_from_an_infectious_bite; +}; + +#endif // TRANSMISSION_SETTINGS_H + diff --git a/src/Configuration/YAMLConverters.h b/src/Configuration/YAMLConverters.h new file mode 100644 index 0000000..a20d8f8 --- /dev/null +++ b/src/Configuration/YAMLConverters.h @@ -0,0 +1,176 @@ +#ifndef YAML_CONVERTERS_H +#define YAML_CONVERTERS_H + +#include + +#include "ModelSettings.h" +#include "PopulationDemographic.h" +#include "TransmissionSettings.h" + +namespace YAML { +template <> +struct convert { + static Node encode(const date::year_month_day &rhs) { + std::stringstream ss; + ss << rhs; + return Node(ss.str()); + } + + static bool decode(const Node &node, date::year_month_day &rhs) { + if (!node.IsScalar()) { + throw std::runtime_error("Invalid date format: not a scalar."); + } + + std::stringstream ss(node.as()); + date::year_month_day ymd; + ss >> date::parse("%F", ymd); // %F matches YYYY-MM-DD format + + if (ss.fail()) { + throw std::runtime_error("Invalid date format: failed to parse."); + } + + rhs = ymd; + return true; + } +}; + +template <> +struct convert { + static Node encode(const ModelSettings &rhs) { + Node node; + node["days_between_stdout_output"] = rhs.days_between_stdout_output; + node["initial_seed_number"] = rhs.initial_seed_number; + node["record_genome_db"] = rhs.record_genome_db; + node["starting_date"] = rhs.starting_date; + node["start_of_comparison_period"] = rhs.start_of_comparison_period; + node["ending_date"] = rhs.ending_date; + node["start_collect_data_day"] = rhs.start_collect_data_day; + return node; + } + + static bool decode(const Node &node, ModelSettings &rhs) { + if (!node["days_between_stdout_output"]) { + throw std::runtime_error("Missing 'days_between_stdout_output' field."); + } + if (!node["initial_seed_number"]) { + throw std::runtime_error("Missing 'initial_seed_number' field."); + } + if (!node["record_genome_db"]) { + throw std::runtime_error("Missing 'record_genome_db' field."); + } + if (!node["starting_date"]) { + throw std::runtime_error("Missing 'starting_date' field."); + } + if (!node["start_of_comparison_period"]) { + throw std::runtime_error("Missing 'start_of_comparison_period' field."); + } + if (!node["ending_date"]) { + throw std::runtime_error("Missing 'ending_date' field."); + } + if (!node["start_collect_data_day"]) { + throw std::runtime_error("Missing 'start_collect_data_day' field."); + } + + // TODO: Add more error checking for each field + + rhs.days_between_stdout_output = + node["days_between_stdout_output"].as(); + rhs.initial_seed_number = node["initial_seed_number"].as(); + rhs.record_genome_db = node["record_genome_db"].as(); + rhs.starting_date = node["starting_date"].as(); + rhs.start_of_comparison_period = + node["start_of_comparison_period"].as(); + rhs.ending_date = node["ending_date"].as(); + rhs.start_collect_data_day = node["start_collect_data_day"].as(); + return true; + } +}; + +template <> +struct convert { + static Node encode(const TransmissionSettings &rhs) { + Node node; + node["transmission_parameter"] = rhs.transmission_parameter; + node["p_infection_from_an_infectious_bite"] = + rhs.p_infection_from_an_infectious_bite; + return node; + } + + static bool decode(const Node &node, TransmissionSettings &rhs) { + if (!node["transmission_parameter"]) { + throw std::runtime_error("Missing 'transmission_parameter' field."); + } + if (!node["p_infection_from_an_infectious_bite"]) { + throw std::runtime_error( + "Missing 'p_infection_from_an_infectious_bite' field."); + } + // TODO: Add more error checking for each field + + rhs.transmission_parameter = node["transmission_parameter"].as(); + rhs.p_infection_from_an_infectious_bite = + node["p_infection_from_an_infectious_bite"].as(); + return true; + } +}; + +template <> +struct convert { + static Node encode(const PopulationDemographic &rhs) { + Node node; + node["number_of_age_classes"] = rhs.number_of_age_classes; + node["age_structure"] = rhs.age_structure; + node["initial_age_structure"] = rhs.initial_age_structure; + node["birth_rate"] = rhs.birth_rate; + node["death_rate_by_age_class"] = rhs.death_rate_by_age_class; + node["mortality_when_treatment_fail_by_age_class"] = + rhs.mortality_when_treatment_fail_by_age_class; + node["artificial_rescaling_of_population_size"] = + rhs.artificial_rescaling_of_population_size; + return node; + } + + static bool decode(const Node &node, PopulationDemographic &rhs) { + if (!node["number_of_age_classes"]) { + throw std::runtime_error("Missing 'number_of_age_classes' field."); + } + if (!node["age_structure"]) { + throw std::runtime_error("Missing 'age_structure' field."); + } + if (!node["initial_age_structure"]) { + throw std::runtime_error("Missing 'initial_age_structure' field."); + } + if (!node["birth_rate"]) { + throw std::runtime_error("Missing 'birth_rate' field."); + } + if (!node["death_rate_by_age_class"]) { + throw std::runtime_error("Missing 'death_rate_by_age_class' field."); + } + if (!node["mortality_when_treatment_fail_by_age_class"]) { + throw std::runtime_error( + "Missing 'mortality_when_treatment_fail_by_age_class' field."); + } + if (!node["artificial_rescaling_of_population_size"]) { + throw std::runtime_error( + "Missing 'artificial_rescaling_of_population_size' field."); + } + + rhs.number_of_age_classes = node["number_of_age_classes"].as(); + rhs.age_structure = node["age_structure"].as>(); + rhs.initial_age_structure = + node["initial_age_structure"].as>(); + rhs.birth_rate = node["birth_rate"].as(); + rhs.death_rate_by_age_class = + node["death_rate_by_age_class"].as>(); + rhs.mortality_when_treatment_fail_by_age_class = + node["mortality_when_treatment_fail_by_age_class"] + .as>(); + rhs.artificial_rescaling_of_population_size = + node["artificial_rescaling_of_population_size"].as(); + return true; + } +}; + +} // namespace YAML + +#endif // YAML_CONVERTERS_H + diff --git a/src/Utils/Logger.cpp b/src/Utils/Logger.cpp new file mode 100644 index 0000000..8d916ba --- /dev/null +++ b/src/Utils/Logger.cpp @@ -0,0 +1,41 @@ +// Logger.cpp +#include "Logger.h" + +#include + +#include + +void Logger::Initialize(spdlog::level::level_enum log_level) { + try { + // Default logger + auto default_sink = std::make_shared(); + auto default_logger = + std::make_shared("default_logger", default_sink); + spdlog::set_default_logger(default_logger); + spdlog::set_level(spdlog::level::info); + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S] [%l] %v"); + spdlog::info("Default logger initialized."); + + // ConfigValidator logger + auto config_sink = std::make_shared(); + auto config_logger = + std::make_shared("config_validator", config_sink); + spdlog::register_logger(config_logger); + config_logger->set_level(spdlog::level::info); + config_logger->set_pattern("[%Y-%m-%d %H:%M:%S] [%l] [ConfigValidator] %v"); + config_logger->debug("ConfigValidator logger initialized."); + + // Network logger + auto network_sink = std::make_shared(); + auto network_logger = + std::make_shared("network", network_sink); + spdlog::register_logger(network_logger); + network_logger->set_level(spdlog::level::info); + network_logger->set_pattern("[%Y-%m-%d %H:%M:%S] [%l] [Network] %v"); + network_logger->info("Network logger initialized."); + + } catch (const spdlog::spdlog_ex &ex) { + std::cerr << "Logger initialization failed: " << ex.what() << std::endl; + } +} + diff --git a/src/Utils/Logger.h b/src/Utils/Logger.h new file mode 100644 index 0000000..d1f962c --- /dev/null +++ b/src/Utils/Logger.h @@ -0,0 +1,15 @@ +#ifndef LOGGER_H +#define LOGGER_H + +#include + +class Logger { +public: + // Initializes the loggers with a specified log level + static void Initialize( + spdlog::level::level_enum log_level = spdlog::level::info); + + // Retrieves a logger by name; creates it if it doesn't exist +}; + +#endif // LOGGER_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3625c8..f86e6c9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ file(GLOB_RECURSE MALASIM_TEST_SOURCES "*.cpp" "helpers/*.cpp" "Core/Random/*.cpp" + "Configuration/*.cpp" ) diff --git a/tests/Configuration/yaml_date_conversion_test.cpp b/tests/Configuration/yaml_date_conversion_test.cpp new file mode 100644 index 0000000..e662755 --- /dev/null +++ b/tests/Configuration/yaml_date_conversion_test.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include + +// Include the converter you're testing +#include "Configuration/YAMLConverters.h" + +class YamlDateConversionTest : public ::testing::Test { +protected: + date::year_month_day valid_date; + YAML::Node valid_node; + YAML::Node invalid_format_node; + YAML::Node non_scalar_node; + + void SetUp() override { + valid_date = + date::year_month_day{date::year{2023}, date::month{10}, date::day{2}}; + valid_node = YAML::Node("2023-10-02"); + invalid_format_node = YAML::Node("invalid-date-format"); + non_scalar_node = YAML::Node(YAML::NodeType::Sequence); // Non-scalar node + } + + void TearDown() override { + // Cleanup if necessary + } +}; + +TEST_F(YamlDateConversionTest, EncodeValidDate) { + YAML::Node node = YAML::convert::encode(valid_date); + + // Check if the encoded date is correct + EXPECT_EQ(node.as(), "2023-10-02"); +} + +TEST_F(YamlDateConversionTest, DecodeValidDate) { + date::year_month_day date; + + ASSERT_NO_THROW( + { YAML::convert::decode(valid_node, date); }); + + EXPECT_EQ(date, valid_date); +} + +TEST_F(YamlDateConversionTest, DecodeInvalidDateFormatThrows) { + date::year_month_day date; + + // Expect an exception due to invalid date format + EXPECT_THROW( + { + YAML::convert::decode(invalid_format_node, date); + }, + std::runtime_error); +} + +TEST_F(YamlDateConversionTest, DecodeNonScalarThrows) { + date::year_month_day date; + + // Expect an exception due to non-scalar node type + EXPECT_THROW( + { YAML::convert::decode(non_scalar_node, date); }, + std::runtime_error); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/tests/Configuration/yaml_model_settings_conversion_test.cpp b/tests/Configuration/yaml_model_settings_conversion_test.cpp new file mode 100644 index 0000000..2df30f6 --- /dev/null +++ b/tests/Configuration/yaml_model_settings_conversion_test.cpp @@ -0,0 +1,92 @@ +#include +#include + +#include "Configuration/ModelSettings.h" +#include "Configuration/YAMLConverters.h" + +class ModelSettingsTest : public ::testing::Test { +protected: + ModelSettings default_settings; + + void SetUp() override { + // Initialize default ModelSettings object + default_settings.days_between_stdout_output = 10; + default_settings.initial_seed_number = 123; + default_settings.record_genome_db = true; + default_settings.starting_date = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + default_settings.start_of_comparison_period = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + default_settings.ending_date = + date::year_month_day{date::year{2024}, date::month{10}, date::day{2}}; + default_settings.start_collect_data_day = 1; + } +}; + +// Test encoding functionality +TEST_F(ModelSettingsTest, EncodeModelSettings) { + YAML::Node node = YAML::convert::encode(default_settings); + + EXPECT_EQ(node["days_between_stdout_output"].as(), + default_settings.days_between_stdout_output); + EXPECT_EQ(node["initial_seed_number"].as(), + default_settings.initial_seed_number); + EXPECT_EQ(node["record_genome_db"].as(), + default_settings.record_genome_db); + EXPECT_EQ(node["starting_date"].as(), + default_settings.starting_date); + EXPECT_EQ(node["start_of_comparison_period"].as(), + default_settings.start_of_comparison_period); + EXPECT_EQ(node["ending_date"].as(), + default_settings.ending_date); + EXPECT_EQ(node["start_collect_data_day"].as(), + default_settings.start_collect_data_day); +} + +// Test decoding functionality +TEST_F(ModelSettingsTest, DecodeModelSettings) { + YAML::Node node; + node["days_between_stdout_output"] = 10; + node["initial_seed_number"] = 123; + node["record_genome_db"] = true; + node["starting_date"] = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + node["start_of_comparison_period"] = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + node["ending_date"] = + date::year_month_day{date::year{2024}, date::month{10}, date::day{2}}; + node["start_collect_data_day"] = 1; + + ModelSettings decoded_settings; + EXPECT_NO_THROW(YAML::convert::decode(node, decoded_settings)); + + EXPECT_EQ(decoded_settings.days_between_stdout_output, 10); + EXPECT_EQ(decoded_settings.initial_seed_number, 123); + EXPECT_EQ(decoded_settings.record_genome_db, true); + + auto expected_starting_date = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + EXPECT_EQ(decoded_settings.starting_date, expected_starting_date); + + auto expected_start_of_comparison_period = + date::year_month_day{date::year{2024}, date::month{10}, date::day{1}}; + EXPECT_EQ(decoded_settings.start_of_comparison_period, + expected_start_of_comparison_period); + + auto expected_ending_date = + date::year_month_day{date::year{2024}, date::month{10}, date::day{2}}; + EXPECT_EQ(decoded_settings.ending_date, expected_ending_date); + + EXPECT_EQ(decoded_settings.start_collect_data_day, 1); +} + +// Test missing fields during decoding +TEST_F(ModelSettingsTest, DecodeModelSettingsMissingField) { + YAML::Node node; + node["initial_seed_number"] = 123; // intentionally omit other fields + + ModelSettings decoded_settings; + EXPECT_THROW(YAML::convert::decode(node, decoded_settings), + std::runtime_error); +} + diff --git a/tests/Configuration/yaml_population_demographic_conversion_test.cpp b/tests/Configuration/yaml_population_demographic_conversion_test.cpp new file mode 100644 index 0000000..c3b4779 --- /dev/null +++ b/tests/Configuration/yaml_population_demographic_conversion_test.cpp @@ -0,0 +1,129 @@ +#include +#include + +#include "Configuration/PopulationDemographic.h" +#include "Configuration/YAMLConverters.h" + +// Test Fixture Class +class PopulationDemographicTest : public ::testing::Test { +protected: + PopulationDemographic default_demographic; + + void SetUp() override { + // Initialize default PopulationDemographic object + default_demographic.number_of_age_classes = 5; + default_demographic.age_structure = {100, 150, 200, 150, 100}; + default_demographic.initial_age_structure = {100, 150, 200, 150, 100}; + default_demographic.birth_rate = 0.02; + default_demographic.death_rate_by_age_class = {0.01, 0.015, 0.02, 0.015, + 0.01}; + default_demographic.mortality_when_treatment_fail_by_age_class = { + 0.05, 0.07, 0.1, 0.07, 0.05}; + default_demographic.artificial_rescaling_of_population_size = 1.0; + } +}; + +// Test encoding functionality +TEST_F(PopulationDemographicTest, EncodePopulationDemographic) { + YAML::Node node = + YAML::convert::encode(default_demographic); + + EXPECT_EQ(node["number_of_age_classes"].as(), + default_demographic.number_of_age_classes); + EXPECT_EQ(node["age_structure"].as>(), + default_demographic.age_structure); + EXPECT_EQ(node["initial_age_structure"].as>(), + default_demographic.initial_age_structure); + EXPECT_DOUBLE_EQ(node["birth_rate"].as(), + default_demographic.birth_rate); + EXPECT_EQ(node["death_rate_by_age_class"].as>(), + default_demographic.death_rate_by_age_class); + EXPECT_EQ(node["mortality_when_treatment_fail_by_age_class"] + .as>(), + default_demographic.mortality_when_treatment_fail_by_age_class); + EXPECT_DOUBLE_EQ(node["artificial_rescaling_of_population_size"].as(), + default_demographic.artificial_rescaling_of_population_size); +} + +// Test decoding functionality +TEST_F(PopulationDemographicTest, DecodePopulationDemographic) { + YAML::Node node; + node["number_of_age_classes"] = 5; + node["age_structure"] = std::vector{100, 150, 200, 150, 100}; + node["initial_age_structure"] = std::vector{100, 150, 200, 150, 100}; + node["birth_rate"] = 0.02; + node["death_rate_by_age_class"] = + std::vector{0.01, 0.015, 0.02, 0.015, 0.01}; + node["mortality_when_treatment_fail_by_age_class"] = + std::vector{0.05, 0.07, 0.1, 0.07, 0.05}; + node["artificial_rescaling_of_population_size"] = 1.0; + + PopulationDemographic decoded_demographic; + EXPECT_NO_THROW( + YAML::convert::decode(node, decoded_demographic)); + + EXPECT_EQ(decoded_demographic.number_of_age_classes, 5); + EXPECT_EQ(decoded_demographic.age_structure, + std::vector({100, 150, 200, 150, 100})); + EXPECT_EQ(decoded_demographic.initial_age_structure, + std::vector({100, 150, 200, 150, 100})); + EXPECT_DOUBLE_EQ(decoded_demographic.birth_rate, 0.02); + EXPECT_EQ(decoded_demographic.death_rate_by_age_class, + std::vector({0.01, 0.015, 0.02, 0.015, 0.01})); + EXPECT_EQ(decoded_demographic.mortality_when_treatment_fail_by_age_class, + std::vector({0.05, 0.07, 0.1, 0.07, 0.05})); + EXPECT_DOUBLE_EQ(decoded_demographic.artificial_rescaling_of_population_size, + 1.0); +} + +// Test decoding with missing fields +TEST_F(PopulationDemographicTest, DecodePopulationDemographicMissingField) { + YAML::Node node; + node["number_of_age_classes"] = 5; + // Intentionally omit other fields to trigger exceptions + + PopulationDemographic decoded_demographic; + EXPECT_THROW( + YAML::convert::decode(node, decoded_demographic), + std::runtime_error); +} + +// Test decoding with partial missing fields +TEST_F(PopulationDemographicTest, + DecodePopulationDemographicPartialMissingFields) { + YAML::Node node; + node["number_of_age_classes"] = 5; + node["age_structure"] = std::vector{100, 150, 200, 150, 100}; + // Missing 'initial_age_structure' and other fields + + PopulationDemographic decoded_demographic; + EXPECT_THROW( + YAML::convert::decode(node, decoded_demographic), + std::runtime_error); +} + +// Test encoding and then decoding to ensure consistency +TEST_F(PopulationDemographicTest, EncodeDecodeConsistency) { + YAML::Node node = + YAML::convert::encode(default_demographic); + + PopulationDemographic decoded_demographic; + EXPECT_NO_THROW( + YAML::convert::decode(node, decoded_demographic)); + + EXPECT_EQ(decoded_demographic.number_of_age_classes, + default_demographic.number_of_age_classes); + EXPECT_EQ(decoded_demographic.age_structure, + default_demographic.age_structure); + EXPECT_EQ(decoded_demographic.initial_age_structure, + default_demographic.initial_age_structure); + EXPECT_DOUBLE_EQ(decoded_demographic.birth_rate, + default_demographic.birth_rate); + EXPECT_EQ(decoded_demographic.death_rate_by_age_class, + default_demographic.death_rate_by_age_class); + EXPECT_EQ(decoded_demographic.mortality_when_treatment_fail_by_age_class, + default_demographic.mortality_when_treatment_fail_by_age_class); + EXPECT_DOUBLE_EQ(decoded_demographic.artificial_rescaling_of_population_size, + default_demographic.artificial_rescaling_of_population_size); +} + diff --git a/tests/SpdlogEnvironment.cpp b/tests/SpdlogEnvironment.cpp new file mode 100644 index 0000000..0fc499a --- /dev/null +++ b/tests/SpdlogEnvironment.cpp @@ -0,0 +1,19 @@ + +#include +#include +#include + +#include "Utils/Logger.h" + +// Define a test environment to initialize spdlog +class SpdlogEnvironment : public ::testing::Environment { +public: + void SetUp() override { Logger::Initialize(spdlog::level::info); } + + void TearDown() override { spdlog::shutdown(); } +}; + +// Register the test environment +::testing::Environment* const global_env = + ::testing::AddGlobalTestEnvironment(new SpdlogEnvironment()); + diff --git a/tests/example_test.cpp b/tests/example_test.cpp deleted file mode 100644 index cfd251a..0000000 --- a/tests/example_test.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include "example.h" - -#include - -TEST(AdditionTest, PositiveNumbers) { EXPECT_EQ(add(1, 2), 3); } - -TEST(AdditionTest, NegativeNumbers) { EXPECT_EQ(add(-1, -2), -3); } diff --git a/vcpkg.json b/vcpkg.json index a5be3e2..654ea46 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -3,6 +3,10 @@ "fmt", "gsl", "gtest", - "yaml-cpp" + "lua", + "sol2", + "spdlog", + "yaml-cpp", + "date" ] }