diff --git a/src/api/libcellml/analysermodel.h b/src/api/libcellml/analysermodel.h index f5d745d15..e1c61427b 100644 --- a/src/api/libcellml/analysermodel.h +++ b/src/api/libcellml/analysermodel.h @@ -29,6 +29,7 @@ namespace libcellml { class LIBCELLML_EXPORT AnalyserModel { friend class Analyser; + friend class Generator; public: /** diff --git a/src/api/libcellml/generator.h b/src/api/libcellml/generator.h index 3888cc973..597d0fbee 100644 --- a/src/api/libcellml/generator.h +++ b/src/api/libcellml/generator.h @@ -67,6 +67,74 @@ class LIBCELLML_EXPORT Generator */ void setProfile(const GeneratorProfilePtr &profile); + /** + * @brief Track all the variables in the given @p model. + * + * Track all the variables in the given @p model. This will add all the variables in the model to the list of + * tracked variables. + * + * @param model The pointer to the @ref AnalyserModel which all the variables are to be tracked. + * + * @return @c true if all the variables in the model were tracked, @c false otherwise. + */ + bool trackAllVariables(const AnalyserModelPtr &model); + + /** + * @brief Untrack all the variables in the given @p model. + * + * Untrack all the variables in the given @p model. This will remove all the variables in the model from the list of + * tracked variables. + * + * @param model The pointer to the @ref AnalyserModel which all the variables are to be untracked. + * + * @return @c true if all the variables in the model were untracked, @c false otherwise. + */ + bool untrackAllVariables(const AnalyserModelPtr &model); + + /** + * @brief Track the given @p variable. + * + * Track the given @p variable. This will add the variable to the list of tracked variables. + * + * @param variable The pointer to the @ref Variable to track. + * + * @return @c true if the variable was tracked, @c false otherwise. + */ + bool trackVariable(const VariablePtr &variable); + + /** + * @brief Untrack the given @p variable. + * + * Untrack the given @p variable. This will remove the variable from the list of tracked variables. + * + * @param variable The pointer to the @ref Variable to untrack. + * + * @return @c true if the variable was untracked, @c false otherwise. + */ + bool untrackVariable(const VariablePtr &variable); + + /** + * @brief Get the number of tracked variables in the given @p model. + * + * Get the number of tracked variables in the given @p model. + * + * @param model The pointer to the @ref AnalyserModel for which to get the number of tracked variables. + * + * @return The number of tracked variables in the model. + */ + size_t trackedVariableCount(const AnalyserModelPtr &model); + + /** + * @brief Get the number of untracked variables in the given @p model. + * + * Get the number of untracked variables in the given @p model. + * + * @param model The pointer to the @ref AnalyserModel for which to get the number of untracked variables. + * + * @return The number of untracked variables in the model. + */ + size_t untrackedVariableCount(const AnalyserModelPtr &model); + /** * @brief Get the interface code for the @ref AnalyserModel. * diff --git a/src/generator.cpp b/src/generator.cpp index a72b6ecf3..daabf472a 100644 --- a/src/generator.cpp +++ b/src/generator.cpp @@ -27,6 +27,7 @@ limitations under the License. #include "libcellml/units.h" #include "libcellml/version.h" +#include "analysermodel_p.h" #include "commonutils.h" #include "generator_p.h" #include "generatorprofilesha1values.h" @@ -42,6 +43,85 @@ void Generator::GeneratorImpl::reset() mCode = {}; } +bool Generator::GeneratorImpl::doTrackVariable(const ModelPtr &model, const VariablePtr &variable, bool tracked) +{ + mTrackedVariables[model][variable] = tracked; + + return true; +} + +bool Generator::GeneratorImpl::trackVariable(const VariablePtr &variable) +{ + if (variable == nullptr) { + return false; + } + + return doTrackVariable(owningModel(variable), variable, true); +} + +bool Generator::GeneratorImpl::untrackVariable(const VariablePtr &variable) +{ + if (variable == nullptr) { + return false; + } + + return doTrackVariable(owningModel(variable), variable, false); +} + +bool Generator::GeneratorImpl::doTrackAllVariables(const AnalyserModelPtr &model, bool tracked) +{ + if (model == nullptr) { + return false; + } + + for (const auto &variable : variables(model, true)) { + doTrackVariable(model->mPimpl->mModel, variable->variable(), tracked); + } + + return true; +} + +bool Generator::GeneratorImpl::trackAllVariables(const AnalyserModelPtr &model) +{ + return doTrackAllVariables(model, true); +} + +bool Generator::GeneratorImpl::untrackAllVariables(const AnalyserModelPtr &model) +{ + return doTrackAllVariables(model, false); +} + +size_t Generator::GeneratorImpl::doTrackedVariableCount(const AnalyserModelPtr &model, bool tracked) +{ + if (model == nullptr) { + return 0; + } + + size_t res = 0; + + for (const auto &variable : variables(model, true)) { + if (mTrackedVariables[model->mPimpl->mModel].find(variable->variable()) == mTrackedVariables[model->mPimpl->mModel].end()) { + mTrackedVariables[model->mPimpl->mModel][variable->variable()] = true; + } + + if (mTrackedVariables[model->mPimpl->mModel][variable->variable()] == tracked) { + ++res; + } + } + + return res; +} + +size_t Generator::GeneratorImpl::trackedVariableCount(const AnalyserModelPtr &model) +{ + return doTrackedVariableCount(model, true); +} + +size_t Generator::GeneratorImpl::untrackedVariableCount(const AnalyserModelPtr &model) +{ + return doTrackedVariableCount(model, false); +} + bool Generator::GeneratorImpl::modelHasOdes(const AnalyserModelPtr &model) const { switch (model->type()) { @@ -2022,6 +2102,36 @@ void Generator::setProfile(const GeneratorProfilePtr &profile) mPimpl->mProfile = profile; } +bool Generator::trackVariable(const VariablePtr &variable) +{ + return mPimpl->trackVariable(variable); +} + +bool Generator::untrackVariable(const VariablePtr &variable) +{ + return mPimpl->untrackVariable(variable); +} + +bool Generator::trackAllVariables(const AnalyserModelPtr &model) +{ + return mPimpl->trackAllVariables(model); +} + +bool Generator::untrackAllVariables(const AnalyserModelPtr &model) +{ + return mPimpl->untrackAllVariables(model); +} + +size_t Generator::trackedVariableCount(const AnalyserModelPtr &model) +{ + return mPimpl->trackedVariableCount(model); +} + +size_t Generator::untrackedVariableCount(const AnalyserModelPtr &model) +{ + return mPimpl->untrackedVariableCount(model); +} + std::string Generator::interfaceCode(const AnalyserModelPtr &model) const { if ((model == nullptr) diff --git a/src/generator_p.h b/src/generator_p.h index 91f405cb3..af3ae5d4a 100644 --- a/src/generator_p.h +++ b/src/generator_p.h @@ -37,8 +37,25 @@ struct Generator::GeneratorImpl GeneratorProfilePtr mProfile = GeneratorProfile::create(); + std::map> mTrackedVariables; + void reset(); + bool doTrackVariable(const ModelPtr &model, const VariablePtr &variable, bool tracked); + + bool trackVariable(const VariablePtr &variable); + bool untrackVariable(const VariablePtr &variable); + + bool doTrackAllVariables(const AnalyserModelPtr &model, bool tracked); + + bool trackAllVariables(const AnalyserModelPtr &model); + bool untrackAllVariables(const AnalyserModelPtr &model); + + size_t doTrackedVariableCount(const AnalyserModelPtr &model, bool tracked); + + size_t trackedVariableCount(const AnalyserModelPtr &model); + size_t untrackedVariableCount(const AnalyserModelPtr &model); + bool modelHasOdes(const AnalyserModelPtr &model) const; bool modelHasNlas(const AnalyserModelPtr &model) const; diff --git a/src/utilities.cpp b/src/utilities.cpp index 996fc653f..9875d2c4f 100644 --- a/src/utilities.cpp +++ b/src/utilities.cpp @@ -1317,21 +1317,25 @@ XmlNodePtr mathmlChildNode(const XmlNodePtr &node, size_t index) return res; } -std::vector variables(const AnalyserModelPtr &model) +std::vector variables(const AnalyserModelPtr &model, bool onlyUntrackableVariables) { std::vector res; - if (model->voi() != nullptr) { - res.push_back(model->voi()); + if (!onlyUntrackableVariables) { + if (model->voi() != nullptr) { + res.push_back(model->voi()); + } + + auto states = model->states(); + + res.insert(res.end(), states.begin(), states.end()); } - auto states = model->states(); auto constants = model->constants(); auto computedConstants = model->computedConstants(); auto algebraic = model->algebraic(); auto externals = model->externals(); - res.insert(res.end(), states.begin(), states.end()); res.insert(res.end(), constants.begin(), constants.end()); res.insert(res.end(), computedConstants.begin(), computedConstants.end()); res.insert(res.end(), algebraic.begin(), algebraic.end()); diff --git a/src/utilities.h b/src/utilities.h index 87beda8af..32524c231 100644 --- a/src/utilities.h +++ b/src/utilities.h @@ -876,10 +876,11 @@ XmlNodePtr mathmlChildNode(const XmlNodePtr &node, size_t index); * Return the variables in the given model. * * @param model The model for which we want the variables. + * @param onlyUntrackableVariables If @c true, only return untrackable variables. * * @return The variables in the given model. */ -std::vector variables(const AnalyserModelPtr &model); +std::vector variables(const AnalyserModelPtr &model, bool onlyUntrackableVariables = false); /** * @brief Return the variables in the given equation. diff --git a/tests/generator/generatortrackedvariables.cpp b/tests/generator/generatortrackedvariables.cpp new file mode 100644 index 000000000..34aeeac34 --- /dev/null +++ b/tests/generator/generatortrackedvariables.cpp @@ -0,0 +1,168 @@ +/* +Copyright libCellML Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "test_utils.h" + +#include "gtest/gtest.h" + +#include + +TEST(GeneratorTrackedVariables, noModelOrVariable) +{ + auto generator = libcellml::Generator::create(); + + EXPECT_EQ(size_t(0), generator->trackedVariableCount(nullptr)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(nullptr)); + + EXPECT_FALSE(generator->trackVariable(nullptr)); + EXPECT_FALSE(generator->untrackVariable(nullptr)); + + EXPECT_FALSE(generator->trackAllVariables(nullptr)); + EXPECT_FALSE(generator->untrackAllVariables(nullptr)); +} + +TEST(GeneratorTrackedVariables, trackUntrackAllVariables) +{ + auto parser = libcellml::Parser::create(); + auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml")); + auto analyser = libcellml::Analyser::create(); + + analyser->analyseModel(model); + + auto analyserModel = analyser->model(); + auto generator = libcellml::Generator::create(); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->untrackAllVariables(analyserModel); + + EXPECT_EQ(size_t(0), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(18), generator->untrackedVariableCount(analyserModel)); + + generator->trackAllVariables(analyserModel); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); +} + +TEST(GeneratorTrackedVariables, trackUntrackStateVariable) +{ + auto parser = libcellml::Parser::create(); + auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml")); + auto analyser = libcellml::Analyser::create(); + + analyser->analyseModel(model); + + auto analyserModel = analyser->model(); + auto generator = libcellml::Generator::create(); + + auto variable = model->component("membrane")->variable("V"); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->untrackVariable(variable); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->trackVariable(variable); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); +} + +TEST(GeneratorTrackedVariables, trackUntrackConstant) +{ + auto parser = libcellml::Parser::create(); + auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml")); + auto analyser = libcellml::Analyser::create(); + + analyser->analyseModel(model); + + auto analyserModel = analyser->model(); + auto generator = libcellml::Generator::create(); + + auto variable = model->component("membrane")->variable("Cm"); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->untrackVariable(variable); + + EXPECT_EQ(size_t(17), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(1), generator->untrackedVariableCount(analyserModel)); + + generator->trackVariable(variable); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); +} + +TEST(GeneratorTrackedVariables, trackUntrackComputedConstant) +{ + auto parser = libcellml::Parser::create(); + auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml")); + auto analyser = libcellml::Analyser::create(); + + analyser->analyseModel(model); + + auto analyserModel = analyser->model(); + auto generator = libcellml::Generator::create(); + + auto variable = model->component("leakage_current")->variable("E_L"); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->untrackVariable(variable); + + EXPECT_EQ(size_t(17), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(1), generator->untrackedVariableCount(analyserModel)); + + generator->trackVariable(variable); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); +} + +TEST(GeneratorTrackedVariables, trackUntrackAlgebraicVariable) +{ + auto parser = libcellml::Parser::create(); + auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml")); + auto analyser = libcellml::Analyser::create(); + + analyser->analyseModel(model); + + auto analyserModel = analyser->model(); + auto generator = libcellml::Generator::create(); + + auto variable = model->component("membrane")->variable("i_Stim"); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); + + generator->untrackVariable(variable); + + EXPECT_EQ(size_t(17), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(1), generator->untrackedVariableCount(analyserModel)); + + generator->trackVariable(variable); + + EXPECT_EQ(size_t(18), generator->trackedVariableCount(analyserModel)); + EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel)); +} diff --git a/tests/generator/tests.cmake b/tests/generator/tests.cmake index c5fee6f22..552285d87 100644 --- a/tests/generator/tests.cmake +++ b/tests/generator/tests.cmake @@ -6,4 +6,5 @@ list(APPEND LIBCELLML_TESTS ${CURRENT_TEST}) set(${CURRENT_TEST}_SRCS ${CMAKE_CURRENT_LIST_DIR}/generator.cpp ${CMAKE_CURRENT_LIST_DIR}/generatorprofile.cpp + ${CMAKE_CURRENT_LIST_DIR}/generatortrackedvariables.cpp )