From 8535675d4aef813693adb47c8703926805f2493a Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Wed, 14 Aug 2024 13:14:30 -0500 Subject: [PATCH] AMREX_ENUM and ParmParse support for enum class (#4069) This adds AMREX_ENUM that can be used to define enum class with reflection. The new feature allows us to support enum class in ParmParse. --- Docs/sphinx_documentation/source/Basics.rst | 50 ++++++++ Src/Base/AMReX_Enum.H | 81 ++++++++++++ Src/Base/AMReX_ParmParse.H | 133 ++++++++++++++++++++ Src/Base/AMReX_String.H | 30 +++++ Src/Base/AMReX_String.cpp | 54 ++++++++ Src/Base/AMReX_Utility.H | 12 +- Src/Base/AMReX_Utility.cpp | 39 ------ Src/Base/CMakeLists.txt | 3 + Src/Base/Make.package | 4 + Tests/CMakeLists.txt | 2 +- Tests/Enum/CMakeLists.txt | 9 ++ Tests/Enum/GNUmakefile | 24 ++++ Tests/Enum/Make.package | 1 + Tests/Enum/inputs | 10 ++ Tests/Enum/main.cpp | 99 +++++++++++++++ 15 files changed, 500 insertions(+), 51 deletions(-) create mode 100644 Src/Base/AMReX_Enum.H create mode 100644 Src/Base/AMReX_String.H create mode 100644 Src/Base/AMReX_String.cpp create mode 100644 Tests/Enum/CMakeLists.txt create mode 100644 Tests/Enum/GNUmakefile create mode 100644 Tests/Enum/Make.package create mode 100644 Tests/Enum/inputs create mode 100644 Tests/Enum/main.cpp diff --git a/Docs/sphinx_documentation/source/Basics.rst b/Docs/sphinx_documentation/source/Basics.rst index 5b9b137e30c..97e68fd1e68 100644 --- a/Docs/sphinx_documentation/source/Basics.rst +++ b/Docs/sphinx_documentation/source/Basics.rst @@ -419,6 +419,56 @@ will become foo.a = 2 foo.b = 2 +Enum Class +---------- + +.. versionadded:: 24.09 + Enum class support in :cpp:`ParmParse`. + +AMReX provides a macro :cpp:`AMREX_ENUM` for defining :cpp:`enum class` that +supports reflection. For example, + +.. highlight:: c++ + +:: + + AMREX_ENUM(MyColor, red, green, blue); + + void f () + { + MyColor color = amrex::getEnum("red"); // MyColor::red + std::string name = amrex::getEnumNameString(MyColor::blue); // "blue" + std::vector names = amrex::getEnumNameStrings(); + // names = {"red", "green", "blue"}; + std::string class_name = amrex::getEnumClassName(); // "MyColor" + } + +This allows us to read :cpp:`ParmParse` parameters into enum class objects. + +.. highlight:: python + +:: + + color1 = red + color2 = BLue + +The following code shows how to query the enumerators. + +.. highlight:: c++ + +:: + + AMREX_ENUM(MyColor, none, red, green, blue); + + void f (MyColor& c1, MyColor& c2) + { + ParmParse pp; + pp.query("color1", c1); // c1 becomes MyColor::red + pp.query_enum_case_insensitive("color2", c2); // c2 becomes MyColor::blue + MyColor default_color; // MyColor::none + pp.query("color3", default_color); // Still MyColor::none + } + Overriding Parameters with Command-Line Arguments ------------------------------------------------- diff --git a/Src/Base/AMReX_Enum.H b/Src/Base/AMReX_Enum.H new file mode 100644 index 00000000000..09583f5b73f --- /dev/null +++ b/Src/Base/AMReX_Enum.H @@ -0,0 +1,81 @@ +#ifndef AMREX_ENUM_H_ +#define AMREX_ENUM_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +template +using amrex_enum_traits = decltype(amrex_get_enum_traits(std::declval())); + +namespace amrex { + template , + std::enable_if_t = 0> + T getEnum (std::string_view const& s) + { + auto pos = ET::enum_names.find(s); + if (pos == std::string_view::npos) { + std::string error_msg("amrex::getEnum: Unknown enum: "); + error_msg.append(s).append(" in AMREX_ENUM(").append(ET::class_name) + .append(", ").append(ET::enum_names).append(")."); + throw std::runtime_error(error_msg); + } + auto count = std::count(ET::enum_names.begin(), + ET::enum_names.begin()+pos, ','); + return static_cast(count); + } + + template , + std::enable_if_t = 0> + std::string getEnumNameString (T const& v) + { + auto n = static_cast(v); + std::size_t pos = 0; + for (int i = 0; i < n; ++i) { + pos = ET::enum_names.find(',', pos); + if (pos == std::string::npos) { + std::string error_msg("amrex::getEnum: Unknown enum value: "); + error_msg.append(std::to_string(n)).append(" in AMREX_ENUM(") + .append(ET::class_name).append(", ").append(ET::enum_names) + .append(")."); + throw std::runtime_error(error_msg); + } + ++pos; + } + auto pos2 = ET::enum_names.find(',', pos); + return amrex::trim(std::string(ET::enum_names.substr(pos,pos2-pos))); + } + + template , + std::enable_if_t = 0> + std::vector getEnumNameStrings () + { + return amrex::split(std::string(ET::enum_names), ", "); + } + + template , + std::enable_if_t = 0> + std::string getEnumClassName () + { + return std::string(ET::class_name); + } +} + +#define AMREX_ENUM(CLASS, ...) \ + enum class CLASS : int { __VA_ARGS__ }; \ + struct CLASS##_EnumTraits { \ + using enum_class_t = CLASS; \ + static constexpr bool value = true; \ + static constexpr std::string_view class_name{#CLASS}; \ + static constexpr std::string_view enum_names{#__VA_ARGS__}; \ + }; \ + CLASS##_EnumTraits amrex_get_enum_traits(CLASS) + +#endif diff --git a/Src/Base/AMReX_ParmParse.H b/Src/Base/AMReX_ParmParse.H index c9e273643ec..a844dcc1aa6 100644 --- a/Src/Base/AMReX_ParmParse.H +++ b/Src/Base/AMReX_ParmParse.H @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -1118,6 +1119,138 @@ public: } } + /** + * \brief. Query enum value using given name. + * + * Here T is an enum class defined by AMREX_ENUM. The return value + * indicates if `name` is found. An exception is thrown, if the found + * string associated with the name cannot be converted to an enumerator + * (i.e., the string does not match any names in the definition of T). + */ + template , + std::enable_if_t = 0> + int query (const char* name, T& ref) + { + std::string s; + int exist = this->query(name, s); + if (exist) { + try { + ref = amrex::getEnum(s); + } catch (...) { + throw; + } + } + return exist; + } + + /** + * \brief. Get enum value using given name. + * + * Here T is an enum class defined by AMREX_ENUM. It's a runtime error, + * if `name` is not found. An exception is thrown, if the found string + * associated with the name cannot be converted to an enumerator (i.e., + * the string does not match any names in the definition of T). + */ + template , + std::enable_if_t = 0> + void get (const char* name, T& ref) + { + std::string s; + this->get(name, s); + try { + ref = amrex::getEnum(s); + } catch (...) { + throw; + } + } + + //! Query an array of enum values using given name. + template , + std::enable_if_t = 0> + int queryarr (const char* name, std::vector& ref) + { + std::vector s; + int exist = this->queryarr(name, s); + if (exist) { + ref.resize(s.size()); + for (std::size_t i = 0; i < s.size(); ++i) { + ref[i] = amrex::getEnum(s[i]); + } + } + return exist; + } + + //! Get an array of enum values using given name. + template , + std::enable_if_t = 0> + void getarr (const char* name, std::vector& ref) + { + std::vector s; + this->getarr(name, s); + ref.resize(s.size()); + for (std::size_t i = 0; i < s.size(); ++i) { + ref[i] = amrex::getEnum(s[i]); + } + } + + /** + * \brief. Query enum value using given name. + * + * Here T is an enum class defined by AMREX_ENUM. The return value + * indicates if `name` is found. An exception is thrown, if the found + * string associated with the name cannot be case-insensitively + * converted to an enumerator (i.e., the found string, not `name`, does + * not case-insensitively match any names in the definition of T). If + * there are multiple matches, the first one is used. + */ + template , + std::enable_if_t = 0> + int query_enum_case_insensitive (const char* name, T& ref) + { + std::string s; + int exist = this->query(name, s); + if (exist) { + s = amrex::toLower(s); + auto const& enum_names = amrex::getEnumNameStrings(); + auto found = std::find_if(enum_names.begin(), enum_names.end(), + [&] (std::string const& ename) { + return amrex::toLower(ename) == s; + }); + if (found != enum_names.end()) { + ref = static_cast(std::distance(enum_names.begin(), found)); + } else { + std::string msg("query_enum_case_insensitive(\""); + msg.append(name).append("\",").append(amrex::getEnumClassName()) + .append("&) failed."); + throw std::runtime_error(msg); + } + } + return exist; + } + + /** + * \brief. Get enum value using given name. + * + * Here T is an enum class defined by AMREX_ENUM. It's a runtime error, + * if `name` is not found. An exception is thrown, if the found string + * associated with the name cannot be case-insensitively converted to an + * enumerator (i.e., the found string, not `name`, does not + * case-insensitively match any names in the definition of T). If there + * are multiple matches, the first one is used. + */ + template , + std::enable_if_t = 0> + void get_enum_case_insensitive (const char* name, T& ref) + { + int exist = this->query_enum_case_insensitive(name, ref); + if (!exist) { + std::string msg("get_enum_case_insensitive(\""); + msg.append(name).append("\",").append(amrex::getEnumClassName()) + .append("&) failed."); + amrex::Abort(msg); + } + } + //! Remove given name from the table. int remove (const char* name); diff --git a/Src/Base/AMReX_String.H b/Src/Base/AMReX_String.H new file mode 100644 index 00000000000..147b7ab1870 --- /dev/null +++ b/Src/Base/AMReX_String.H @@ -0,0 +1,30 @@ +#ifndef AMREX_STRING_H_ +#define AMREX_STRING_H_ +#include + +#include +#include + +namespace amrex { + + //! Converts all characters of the string into lower case based on std::locale + std::string toLower (std::string s); + + //! Converts all characters of the string into uppercase based on std::locale + std::string toUpper (std::string s); + + //! Trim leading and trailing characters in the optional `space` + //! argument. + std::string trim (std::string s, std::string const& space = " \t"); + + //! Returns rootNNNN where NNNN == num. + std::string Concatenate (const std::string& root, + int num, + int mindigits = 5); + + //! Split a string using given tokens in `sep`. + std::vector split (std::string const& s, + std::string const& sep = " \t"); +} + +#endif diff --git a/Src/Base/AMReX_String.cpp b/Src/Base/AMReX_String.cpp new file mode 100644 index 00000000000..24dbce4532f --- /dev/null +++ b/Src/Base/AMReX_String.cpp @@ -0,0 +1,54 @@ +#include +#include + +#include +#include +#include +#include + +namespace amrex { + +std::string toLower (std::string s) +{ + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { return std::tolower(c); }); + return s; +} + +std::string toUpper (std::string s) +{ + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { return std::toupper(c); }); + return s; +} + +std::string trim(std::string s, std::string const& space) +{ + const auto sbegin = s.find_first_not_of(space); + if (sbegin == std::string::npos) { return std::string{}; } + const auto send = s.find_last_not_of(space); + s = s.substr(sbegin, send-sbegin+1); + return s; +} + +std::string Concatenate (const std::string& root, int num, int mindigits) +{ + BL_ASSERT(mindigits >= 0); + std::stringstream result; + result << root << std::setfill('0') << std::setw(mindigits) << num; + return result.str(); +} + +std::vector split (std::string const& s, std::string const& sep) +{ + std::vector result; + std::size_t pos_begin, pos_end = 0; + while ((pos_begin = s.find_first_not_of(sep,pos_end)) != std::string::npos) { + pos_end = s.find_first_of(sep,pos_begin); + result.push_back(s.substr(pos_begin,pos_end-pos_begin)); + if (pos_end == std::string::npos) { break; } + } + return result; +} + +} diff --git a/Src/Base/AMReX_Utility.H b/Src/Base/AMReX_Utility.H index 016b8adb0e2..6bec276dbf2 100644 --- a/Src/Base/AMReX_Utility.H +++ b/Src/Base/AMReX_Utility.H @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -44,17 +45,6 @@ namespace amrex const std::vector& Tokenize (const std::string& instr, const std::string& separators); - //! Converts all characters of the string into lower or uppercase based on std::locale - std::string toLower (std::string s); - std::string toUpper (std::string s); - - //! Trim leading and trailing white space - std::string trim (std::string s, std::string const& space = " \t"); - - //! Returns rootNNNN where NNNN == num. - std::string Concatenate (const std::string& root, - int num, - int mindigits = 5); /** * \brief Creates the specified directories. path may be either a full pathname * or a relative pathname. It will create all the directories in the diff --git a/Src/Base/AMReX_Utility.cpp b/Src/Base/AMReX_Utility.cpp index 1c79dfba92f..aa3d8a2d165 100644 --- a/Src/Base/AMReX_Utility.cpp +++ b/Src/Base/AMReX_Utility.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -113,44 +112,6 @@ amrex::Tokenize (const std::string& instr, return tokens; } -std::string -amrex::toLower (std::string s) -{ - std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c) { return std::tolower(c); }); - return s; -} - -std::string -amrex::toUpper (std::string s) -{ - std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c) { return std::toupper(c); }); - return s; -} - -std::string -amrex::trim(std::string s, std::string const& space) -{ - const auto sbegin = s.find_first_not_of(space); - if (sbegin == std::string::npos) { return std::string{}; } - const auto send = s.find_last_not_of(space); - s = s.substr(sbegin, send-sbegin+1); - return s; -} - -std::string -amrex::Concatenate (const std::string& root, - int num, - int mindigits) -{ - BL_ASSERT(mindigits >= 0); - std::stringstream result; - result << root << std::setfill('0') << std::setw(mindigits) << num; - return result.str(); -} - - bool amrex::UtilCreateDirectory (const std::string& path, mode_t mode, bool verbose) diff --git a/Src/Base/CMakeLists.txt b/Src/Base/CMakeLists.txt index cebd1f9bce1..0436ad032e4 100644 --- a/Src/Base/CMakeLists.txt +++ b/Src/Base/CMakeLists.txt @@ -12,6 +12,7 @@ foreach(D IN LISTS AMReX_SPACEDIM) AMReX_Array.H AMReX_BlockMutex.H AMReX_BlockMutex.cpp + AMReX_Enum.H AMReX_GpuComplex.H AMReX_Vector.H AMReX_TableData.H @@ -30,6 +31,8 @@ foreach(D IN LISTS AMReX_SPACEDIM) AMReX_parmparse_fi.cpp AMReX_ParmParse.H AMReX_Functional.H + AMReX_String.H + AMReX_String.cpp AMReX_Utility.H AMReX_Utility.cpp AMReX_FileSystem.H diff --git a/Src/Base/Make.package b/Src/Base/Make.package index dfbfb4f03a1..b009ebf7d65 100644 --- a/Src/Base/Make.package +++ b/Src/Base/Make.package @@ -2,6 +2,7 @@ AMREX_BASE=EXE C$(AMREX_BASE)_headers += AMReX_ccse-mpi.H AMReX_Algorithm.H AMReX_Any.H AMReX_Array.H +C$(AMREX_BASE)_headers += AMReX_Enum.H C$(AMREX_BASE)_headers += AMReX_Vector.H AMReX_TableData.H AMReX_Tuple.H AMReX_Math.H C$(AMREX_BASE)_headers += AMReX_TypeList.H @@ -22,6 +23,9 @@ C$(AMREX_BASE)_sources += AMReX_PODVector.cpp C$(AMREX_BASE)_headers += AMReX_BlockMutex.H C$(AMREX_BASE)_sources += AMReX_BlockMutex.cpp +C$(AMREX_BASE)_headers += AMReX_String.H +C$(AMREX_BASE)_sources += AMReX_String.cpp + C$(AMREX_BASE)_sources += AMReX_ParmParse.cpp AMReX_parmparse_fi.cpp AMReX_Utility.cpp C$(AMREX_BASE)_headers += AMReX_ParmParse.H AMReX_Utility.H AMReX_BLassert.H AMReX_ArrayLim.H C$(AMREX_BASE)_headers += AMReX_Functional.H AMReX_Reduce.H AMReX_Scan.H AMReX_Partition.H diff --git a/Tests/CMakeLists.txt b/Tests/CMakeLists.txt index 01f187b6642..3f801cde2b0 100644 --- a/Tests/CMakeLists.txt +++ b/Tests/CMakeLists.txt @@ -121,7 +121,7 @@ else() # # List of subdirectories to search for CMakeLists. # - set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CTOParFor DeviceGlobal + set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CTOParFor DeviceGlobal Enum MultiBlock Parser Parser2 Reinit RoundoffDomain) if (AMReX_PARTICLES) diff --git a/Tests/Enum/CMakeLists.txt b/Tests/Enum/CMakeLists.txt new file mode 100644 index 00000000000..9c0e7f321d0 --- /dev/null +++ b/Tests/Enum/CMakeLists.txt @@ -0,0 +1,9 @@ +foreach(D IN LISTS AMReX_SPACEDIM) + set(_sources main.cpp) + set(_input_files inputs) + + setup_test(${D} _sources _input_files) + + unset(_sources) + unset(_input_files) +endforeach() diff --git a/Tests/Enum/GNUmakefile b/Tests/Enum/GNUmakefile new file mode 100644 index 00000000000..d0d895ff522 --- /dev/null +++ b/Tests/Enum/GNUmakefile @@ -0,0 +1,24 @@ +AMREX_HOME := ../.. + +DEBUG = FALSE + +DIM = 3 + +COMP = gcc + +USE_MPI = FALSE +USE_OMP = FALSE +USE_CUDA = FALSE +USE_HIP = FALSE +USE_SYCL = FALSE + +BL_NO_FORT = TRUE + +TINY_PROFILE = FALSE + +include $(AMREX_HOME)/Tools/GNUMake/Make.defs + +include ./Make.package +include $(AMREX_HOME)/Src/Base/Make.package + +include $(AMREX_HOME)/Tools/GNUMake/Make.rules diff --git a/Tests/Enum/Make.package b/Tests/Enum/Make.package new file mode 100644 index 00000000000..6b4b865e8fc --- /dev/null +++ b/Tests/Enum/Make.package @@ -0,0 +1 @@ +CEXE_sources += main.cpp diff --git a/Tests/Enum/inputs b/Tests/Enum/inputs new file mode 100644 index 00000000000..0972043658f --- /dev/null +++ b/Tests/Enum/inputs @@ -0,0 +1,10 @@ + +color1 = red +color2 = green +color3 = blue +color4 = greenxxx +color5 = Blue + +colors = cyan yellow orange + + diff --git a/Tests/Enum/main.cpp b/Tests/Enum/main.cpp new file mode 100644 index 00000000000..6fb25b01a59 --- /dev/null +++ b/Tests/Enum/main.cpp @@ -0,0 +1,99 @@ +#include +#include +#include + +using namespace amrex; + +AMREX_ENUM(MyColor, red, green, blue ); + +namespace my_namespace { + AMREX_ENUM(MyColor, orange, yellow,cyan ); +} + +int main (int argc, char* argv[]) +{ + amrex::Initialize(argc, argv); + { + auto const& names = amrex::getEnumNameStrings(); + auto const& names2 = amrex::getEnumNameStrings(); + amrex::Print() << "colors:"; + for (auto const& name : names) { + amrex::Print() << " " << name; + } + amrex::Print() << "\n"; + amrex::Print() << "colors:"; + for (auto const& name : names2) { + amrex::Print() << " " << name; + } + amrex::Print() << "\n"; + + ParmParse pp; + { + auto color = static_cast(999); + pp.query("color1", color); + amrex::Print() << "color = " << amrex::getEnumNameString(color) << '\n'; + AMREX_ALWAYS_ASSERT(color == MyColor::red); + } + { + auto color = static_cast(999); + pp.get("color2", color); + amrex::Print() << "color = " << amrex::getEnumNameString(color) << '\n'; + AMREX_ALWAYS_ASSERT(color == MyColor::green); + } + { + auto color = static_cast(999); + pp.get("color3", color); + amrex::Print() << "color = " << amrex::getEnumNameString(color) << '\n'; + AMREX_ALWAYS_ASSERT(color == MyColor::blue); + } + { + auto color = static_cast(999); + try { + pp.query("color4", color); + } catch (std::runtime_error const& e) { + amrex::Print() << "As expected, " << e.what() << '\n'; + } + AMREX_ALWAYS_ASSERT(color == static_cast(999)); + try { + pp.get_enum_case_insensitive("color4", color); + } catch (std::runtime_error const& e) { + amrex::Print() << "As expected, " << e.what() << '\n'; + } + AMREX_ALWAYS_ASSERT(color == static_cast(999)); + } + { + auto color = static_cast(999); + try { + pp.query("color5", color); + } catch (std::runtime_error const& e) { + amrex::Print() << "As expected, " << e.what() << '\n'; + } + AMREX_ALWAYS_ASSERT(color == static_cast(999)); + pp.query_enum_case_insensitive("color5", color); + amrex::Print() << "color = " << amrex::getEnumNameString(color) << '\n'; + AMREX_ALWAYS_ASSERT(color == MyColor::blue); + } + { + std::vector color; + pp.getarr("colors", color); + AMREX_ALWAYS_ASSERT(color.size() == 3 && + color[0] == my_namespace::MyColor::cyan && + color[1] == my_namespace::MyColor::yellow && + color[2] == my_namespace::MyColor::orange); + std::vector color2; + pp.queryarr("colors", color2); + AMREX_ALWAYS_ASSERT(color.size() == 3 && + color == color2 && + color[0] == my_namespace::MyColor::cyan && + color[1] == my_namespace::MyColor::yellow && + color[2] == my_namespace::MyColor::orange); + amrex::Print() << "colors:"; + for (auto const& c : color) { + amrex::Print() << " " << amrex::getEnumNameString(c); + } + amrex::Print() << "\n"; + } + } + + amrex::Finalize(); +}