diff --git a/src/codegen/codegen_coreneuron_cpp_visitor.cpp b/src/codegen/codegen_coreneuron_cpp_visitor.cpp index b477a58912..4304d5dec4 100644 --- a/src/codegen/codegen_coreneuron_cpp_visitor.cpp +++ b/src/codegen/codegen_coreneuron_cpp_visitor.cpp @@ -21,7 +21,6 @@ #include "config/config.h" #include "lexer/token_mapping.hpp" #include "parser/c11_driver.hpp" -#include "solver/solver.hpp" #include "utils/logger.hpp" #include "utils/string_utils.hpp" #include "visitors/defuse_analyze_visitor.hpp" @@ -1005,7 +1004,7 @@ void CodegenCoreneuronCppVisitor::print_coreneuron_includes() { #include )CODE"); if (info.eigen_newton_solver_exist) { - printer->add_multi_line(nmodl::solvers::newton_hpp); + printer->add_line("#include \"solver/newton.hpp\""); } if (info.eigen_linear_solver_exist) { if (std::accumulate(info.state_vars.begin(), @@ -1014,7 +1013,7 @@ void CodegenCoreneuronCppVisitor::print_coreneuron_includes() { [](int l, const SymbolType& variable) { return l += variable->get_length(); }) > 4) { - printer->add_multi_line(nmodl::solvers::crout_hpp); + printer->add_line("#include \"solver/crout.hpp\""); } else { printer->add_line("#include "); printer->add_line("#include "); diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 3ee0053610..d09d4adcb4 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -20,7 +20,6 @@ #include "codegen/codegen_utils.hpp" #include "codegen_naming.hpp" #include "config/config.h" -#include "solver/solver.hpp" #include "utils/string_utils.hpp" #include "visitors/rename_visitor.hpp" #include "visitors/var_usage_visitor.hpp" @@ -754,7 +753,7 @@ void CodegenNeuronCppVisitor::print_standard_includes() { #include )CODE"); if (info.eigen_newton_solver_exist) { - printer->add_multi_line(nmodl::solvers::newton_hpp); + printer->add_line("#include \"solver/newton.hpp\""); } } diff --git a/src/main.cpp b/src/main.cpp index be322c0e99..b9111f29c8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,11 +5,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include +#include #include #include #include -#include #include "ast/program.hpp" #include "codegen/codegen_acc_visitor.hpp" @@ -20,6 +21,7 @@ #include "config/config.h" #include "parser/nmodl_driver.hpp" #include "pybind/pyembed.hpp" +#include "solver/solver.hpp" #include "utils/common_utils.hpp" #include "utils/logger.hpp" #include "visitors/after_cvode_to_cnexp_visitor.hpp" @@ -60,6 +62,26 @@ using namespace codegen; using namespace visitor; using nmodl::parser::NmodlDriver; +fs::path get_solver_path(const fs::path& directory, const std::string& solver) { + auto path = directory / "solver" / solver; + path += ".hpp"; + return path; +} + +void write_shared_headers(const std::string& directory, + const std::vector& solvers = nmodl::solver::get_names()) { + fs::path output(directory); + + for (const auto& solver: solvers) { + const auto& path = get_solver_path(directory, solver); + fs::create_directories(path.parent_path()); + + std::ofstream fout(path); + fout << nmodl::solver::get_hpp(solver); + logger->info("Generated {}", path.string()); + } +} + // NOLINTNEXTLINE(readability-function-cognitive-complexity) int main(int argc, const char* argv[]) { CLI::App app{fmt::format("NMODL : Source-to-Source Code Generation Framework [{}]", @@ -180,6 +202,17 @@ int main(int argc, const char* argv[]) { app.add_option("-o,--output", output_dir, "Directory for backend code output") ->capture_default_str() ->ignore_case(); + + app.add_option_function>( + "--write-shared-headers", + [&](const std::vector& solvers) { + write_shared_headers(output_dir, solvers); + exit(0); + }, + "Create solver headers in directory and exit") + ->expected(1, 2) + ->check(CLI::IsMember(nmodl::solver::get_names())); + app.add_option("--scratch", scratch_dir, "Directory for intermediate code output") ->capture_default_str() ->ignore_case(); diff --git a/src/solver/CMakeLists.txt b/src/solver/CMakeLists.txt index cbb40784f2..3923d0b996 100644 --- a/src/solver/CMakeLists.txt +++ b/src/solver/CMakeLists.txt @@ -1,15 +1,11 @@ # Read headers, remove everything up to and including "#pragma once", then remove header includes. -file(READ ${CMAKE_CURRENT_SOURCE_DIR}/crout/crout.hpp NMODL_CROUT_HPP_RAW) -string(REGEX REPLACE ".*#pragma once[ \t\r\n]*" "" NMODL_CROUT_HPP "${NMODL_CROUT_HPP_RAW}") -file(READ ${CMAKE_CURRENT_SOURCE_DIR}/newton/newton.hpp NMODL_NEWTON_HPP_RAW) -string(REGEX REPLACE ".*#pragma once[ \t\r\n]*" "" NMODL_NEWTON_HPP_TMP "${NMODL_NEWTON_HPP_RAW}") -string(REGEX REPLACE "#include [ \t\r\n]*" "" NMODL_NEWTON_HPP - "${NMODL_NEWTON_HPP_TMP}") +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/crout.hpp NMODL_CROUT_HPP) +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/newton.hpp NMODL_NEWTON_HPP) set_property( DIRECTORY APPEND - PROPERTY CMAKE_CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/crout/crout.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/newton/newton.hpp) + PROPERTY CMAKE_CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/crout.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/newton.hpp) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/solver.hpp.inc ${CMAKE_CURRENT_BINARY_DIR}/solver.hpp) add_custom_target(nmodl_copy_solver_files ALL DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/solver.hpp") diff --git a/src/solver/crout/crout.hpp b/src/solver/crout.hpp similarity index 100% rename from src/solver/crout/crout.hpp rename to src/solver/crout.hpp diff --git a/src/solver/newton/newton.hpp b/src/solver/newton.hpp similarity index 99% rename from src/solver/newton/newton.hpp rename to src/solver/newton.hpp index 973b4cd332..20b14e6492 100644 --- a/src/solver/newton/newton.hpp +++ b/src/solver/newton.hpp @@ -15,7 +15,7 @@ * \brief Implementation of Newton method for solving system of non-linear equations */ -#include +#include "crout.hpp" #include #include diff --git a/src/solver/solver.hpp.inc b/src/solver/solver.hpp.inc index a73c8ee40d..dc0f6757d2 100644 --- a/src/solver/solver.hpp.inc +++ b/src/solver/solver.hpp.inc @@ -1,4 +1,5 @@ #include +#include // This file is generated from `crout/crout.hpp` and `newton/newton.hpp`. // @@ -8,13 +9,19 @@ // However, because we want to be able to test the headers separately we can't // move them here. -namespace nmodl::solvers { -const std::string crout_hpp = R"jiowi( -@NMODL_CROUT_HPP@ -)jiowi"; -const std::string newton_hpp = R"jiowi( -@NMODL_CROUT_HPP@ -@NMODL_NEWTON_HPP@ -)jiowi"; +namespace nmodl::solver { +const std::vector get_names() { + return {"crout", "newton"}; +} + +const std::string get_hpp(const std::string& solver) { + if (solver == "crout") { + return R"jiowi(@NMODL_CROUT_HPP@)jiowi"; + } else if (solver == "newton") { + return R"jiowi(@NMODL_NEWTON_HPP@)jiowi"; + } else { + throw std::runtime_error("unknown solver '" + solver + "'"); + } +}; } diff --git a/test/unit/crout/crout.cpp b/test/unit/crout/crout.cpp index cb90e9162e..bb9398ed93 100644 --- a/test/unit/crout/crout.cpp +++ b/test/unit/crout/crout.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "crout/crout.hpp" +#include "crout.hpp" #include diff --git a/test/unit/newton/newton.cpp b/test/unit/newton/newton.cpp index 33f4086980..2560a150f2 100644 --- a/test/unit/newton/newton.cpp +++ b/test/unit/newton/newton.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "newton/newton.hpp" +#include "newton.hpp" #include #include