Skip to content

Commit

Permalink
AMREX_ENUM and ParmParse support for enum class (#4069)
Browse files Browse the repository at this point in the history
This adds AMREX_ENUM that can be used to define enum class with
reflection. The new feature allows us to support enum class in
ParmParse.
  • Loading branch information
WeiqunZhang committed Aug 14, 2024
1 parent 6890ce0 commit 8535675
Show file tree
Hide file tree
Showing 15 changed files with 500 additions and 51 deletions.
50 changes: 50 additions & 0 deletions Docs/sphinx_documentation/source/Basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,56 @@ will become
foo.a = 2
foo.b = 2

Enum Class
----------

.. versionadded:: 24.09
Enum class support in :cpp:`ParmParse`.

AMReX provides a macro :cpp:`AMREX_ENUM` for defining :cpp:`enum class` that
supports reflection. For example,

.. highlight:: c++

::

AMREX_ENUM(MyColor, red, green, blue);

void f ()
{
MyColor color = amrex::getEnum<MyColor>("red"); // MyColor::red
std::string name = amrex::getEnumNameString(MyColor::blue); // "blue"
std::vector<std::string> names = amrex::getEnumNameStrings<MyColor>();
// names = {"red", "green", "blue"};
std::string class_name = amrex::getEnumClassName<MyColor>(); // "MyColor"
}

This allows us to read :cpp:`ParmParse` parameters into enum class objects.

.. highlight:: python

::

color1 = red
color2 = BLue

The following code shows how to query the enumerators.

.. highlight:: c++

::

AMREX_ENUM(MyColor, none, red, green, blue);

void f (MyColor& c1, MyColor& c2)
{
ParmParse pp;
pp.query("color1", c1); // c1 becomes MyColor::red
pp.query_enum_case_insensitive("color2", c2); // c2 becomes MyColor::blue
MyColor default_color; // MyColor::none
pp.query("color3", default_color); // Still MyColor::none
}

Overriding Parameters with Command-Line Arguments
-------------------------------------------------

Expand Down
81 changes: 81 additions & 0 deletions Src/Base/AMReX_Enum.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#ifndef AMREX_ENUM_H_
#define AMREX_ENUM_H_

#include <AMReX_String.H>

#include <algorithm>
#include <array>
#include <stdexcept>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <vector>

template <typename T>
using amrex_enum_traits = decltype(amrex_get_enum_traits(std::declval<T>()));

namespace amrex {
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
T getEnum (std::string_view const& s)
{
auto pos = ET::enum_names.find(s);
if (pos == std::string_view::npos) {
std::string error_msg("amrex::getEnum: Unknown enum: ");
error_msg.append(s).append(" in AMREX_ENUM(").append(ET::class_name)
.append(", ").append(ET::enum_names).append(").");
throw std::runtime_error(error_msg);
}
auto count = std::count(ET::enum_names.begin(),
ET::enum_names.begin()+pos, ',');
return static_cast<T>(count);
}

template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
std::string getEnumNameString (T const& v)
{
auto n = static_cast<int>(v);
std::size_t pos = 0;
for (int i = 0; i < n; ++i) {
pos = ET::enum_names.find(',', pos);
if (pos == std::string::npos) {
std::string error_msg("amrex::getEnum: Unknown enum value: ");
error_msg.append(std::to_string(n)).append(" in AMREX_ENUM(")
.append(ET::class_name).append(", ").append(ET::enum_names)
.append(").");
throw std::runtime_error(error_msg);
}
++pos;
}
auto pos2 = ET::enum_names.find(',', pos);
return amrex::trim(std::string(ET::enum_names.substr(pos,pos2-pos)));
}

template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
std::vector<std::string> getEnumNameStrings ()
{
return amrex::split(std::string(ET::enum_names), ", ");
}

template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
std::string getEnumClassName ()
{
return std::string(ET::class_name);
}
}

#define AMREX_ENUM(CLASS, ...) \
enum class CLASS : int { __VA_ARGS__ }; \
struct CLASS##_EnumTraits { \
using enum_class_t = CLASS; \
static constexpr bool value = true; \
static constexpr std::string_view class_name{#CLASS}; \
static constexpr std::string_view enum_names{#__VA_ARGS__}; \
}; \
CLASS##_EnumTraits amrex_get_enum_traits(CLASS)

#endif
133 changes: 133 additions & 0 deletions Src/Base/AMReX_ParmParse.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX_BLassert.H>
#include <AMReX_Enum.H>
#include <AMReX_INT.H>
#include <AMReX_IParser.H>
#include <AMReX_Parser.H>
Expand Down Expand Up @@ -1118,6 +1119,138 @@ public:
}
}

/**
* \brief. Query enum value using given name.
*
* Here T is an enum class defined by AMREX_ENUM. The return value
* indicates if `name` is found. An exception is thrown, if the found
* string associated with the name cannot be converted to an enumerator
* (i.e., the string does not match any names in the definition of T).
*/
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
int query (const char* name, T& ref)
{
std::string s;
int exist = this->query(name, s);
if (exist) {
try {
ref = amrex::getEnum<T>(s);
} catch (...) {
throw;
}
}
return exist;
}

/**
* \brief. Get enum value using given name.
*
* Here T is an enum class defined by AMREX_ENUM. It's a runtime error,
* if `name` is not found. An exception is thrown, if the found string
* associated with the name cannot be converted to an enumerator (i.e.,
* the string does not match any names in the definition of T).
*/
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
void get (const char* name, T& ref)
{
std::string s;
this->get(name, s);
try {
ref = amrex::getEnum<T>(s);
} catch (...) {
throw;
}
}

//! Query an array of enum values using given name.
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
int queryarr (const char* name, std::vector<T>& ref)
{
std::vector<std::string> s;
int exist = this->queryarr(name, s);
if (exist) {
ref.resize(s.size());
for (std::size_t i = 0; i < s.size(); ++i) {
ref[i] = amrex::getEnum<T>(s[i]);
}
}
return exist;
}

//! Get an array of enum values using given name.
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
void getarr (const char* name, std::vector<T>& ref)
{
std::vector<std::string> s;
this->getarr(name, s);
ref.resize(s.size());
for (std::size_t i = 0; i < s.size(); ++i) {
ref[i] = amrex::getEnum<T>(s[i]);
}
}

/**
* \brief. Query enum value using given name.
*
* Here T is an enum class defined by AMREX_ENUM. The return value
* indicates if `name` is found. An exception is thrown, if the found
* string associated with the name cannot be case-insensitively
* converted to an enumerator (i.e., the found string, not `name`, does
* not case-insensitively match any names in the definition of T). If
* there are multiple matches, the first one is used.
*/
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
int query_enum_case_insensitive (const char* name, T& ref)
{
std::string s;
int exist = this->query(name, s);
if (exist) {
s = amrex::toLower(s);
auto const& enum_names = amrex::getEnumNameStrings<T>();
auto found = std::find_if(enum_names.begin(), enum_names.end(),
[&] (std::string const& ename) {
return amrex::toLower(ename) == s;
});
if (found != enum_names.end()) {
ref = static_cast<T>(std::distance(enum_names.begin(), found));
} else {
std::string msg("query_enum_case_insensitive(\"");
msg.append(name).append("\",").append(amrex::getEnumClassName<T>())
.append("&) failed.");
throw std::runtime_error(msg);
}
}
return exist;
}

/**
* \brief. Get enum value using given name.
*
* Here T is an enum class defined by AMREX_ENUM. It's a runtime error,
* if `name` is not found. An exception is thrown, if the found string
* associated with the name cannot be case-insensitively converted to an
* enumerator (i.e., the found string, not `name`, does not
* case-insensitively match any names in the definition of T). If there
* are multiple matches, the first one is used.
*/
template <typename T, typename ET = amrex_enum_traits<T>,
std::enable_if_t<ET::value,int> = 0>
void get_enum_case_insensitive (const char* name, T& ref)
{
int exist = this->query_enum_case_insensitive(name, ref);
if (!exist) {
std::string msg("get_enum_case_insensitive(\"");
msg.append(name).append("\",").append(amrex::getEnumClassName<T>())
.append("&) failed.");
amrex::Abort(msg);
}
}

//! Remove given name from the table.
int remove (const char* name);

Expand Down
30 changes: 30 additions & 0 deletions Src/Base/AMReX_String.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef AMREX_STRING_H_
#define AMREX_STRING_H_
#include <AMReX_Config.H>

#include <string>
#include <vector>

namespace amrex {

//! Converts all characters of the string into lower case based on std::locale
std::string toLower (std::string s);

//! Converts all characters of the string into uppercase based on std::locale
std::string toUpper (std::string s);

//! Trim leading and trailing characters in the optional `space`
//! argument.
std::string trim (std::string s, std::string const& space = " \t");

//! Returns rootNNNN where NNNN == num.
std::string Concatenate (const std::string& root,
int num,
int mindigits = 5);

//! Split a string using given tokens in `sep`.
std::vector<std::string> split (std::string const& s,
std::string const& sep = " \t");
}

#endif
54 changes: 54 additions & 0 deletions Src/Base/AMReX_String.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <AMReX_String.H>
#include <AMReX_BLassert.H>

#include <algorithm>
#include <cctype>
#include <iomanip>
#include <sstream>

namespace amrex {

std::string toLower (std::string s)
{
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c) { return std::tolower(c); });
return s;
}

std::string toUpper (std::string s)
{
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c) { return std::toupper(c); });
return s;
}

std::string trim(std::string s, std::string const& space)
{
const auto sbegin = s.find_first_not_of(space);
if (sbegin == std::string::npos) { return std::string{}; }
const auto send = s.find_last_not_of(space);
s = s.substr(sbegin, send-sbegin+1);
return s;
}

std::string Concatenate (const std::string& root, int num, int mindigits)
{
BL_ASSERT(mindigits >= 0);
std::stringstream result;
result << root << std::setfill('0') << std::setw(mindigits) << num;
return result.str();
}

std::vector<std::string> split (std::string const& s, std::string const& sep)
{
std::vector<std::string> result;
std::size_t pos_begin, pos_end = 0;
while ((pos_begin = s.find_first_not_of(sep,pos_end)) != std::string::npos) {
pos_end = s.find_first_of(sep,pos_begin);
result.push_back(s.substr(pos_begin,pos_end-pos_begin));
if (pos_end == std::string::npos) { break; }
}
return result;
}

}
12 changes: 1 addition & 11 deletions Src/Base/AMReX_Utility.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <AMReX_Random.H>
#include <AMReX_GpuQualifiers.H>
#include <AMReX_FileSystem.H>
#include <AMReX_String.H>

#include <cfloat>
#include <chrono>
Expand Down Expand Up @@ -44,17 +45,6 @@ namespace amrex
const std::vector<std::string>& Tokenize (const std::string& instr,
const std::string& separators);

//! Converts all characters of the string into lower or uppercase based on std::locale
std::string toLower (std::string s);
std::string toUpper (std::string s);

//! Trim leading and trailing white space
std::string trim (std::string s, std::string const& space = " \t");

//! Returns rootNNNN where NNNN == num.
std::string Concatenate (const std::string& root,
int num,
int mindigits = 5);
/**
* \brief Creates the specified directories. path may be either a full pathname
* or a relative pathname. It will create all the directories in the
Expand Down
Loading

0 comments on commit 8535675

Please sign in to comment.