Skip to content

Commit

Permalink
test optim modes in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pcarruscag committed Jan 7, 2024
1 parent 722eaf9 commit c66ed4c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 21 deletions.
3 changes: 3 additions & 0 deletions definitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ namespace internal {
/// type double, each tree occupies ~4KB.
static constexpr int max_tree_size = MEL_MAX_TREE_SIZE;

/// Default optimization mode for expression trees.
static constexpr OptimMode default_optim_mode = OptimMode::TREE_SIZE;

/// Efficient representation of the supported operations and expression node
/// types, must match the order of "supported_operations".
enum class OpCode {
Expand Down
8 changes: 4 additions & 4 deletions mel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ void DepthFirstIndex(const int root, const TreeType& tree,

/// Type for an expression tree. The result of parsing expressions and used
/// to evaluate them.
template<class NumberType, OptimMode Mode = OptimMode::TREE_SIZE>
template<class NumberType, OptimMode Mode = internal::default_optim_mode>
struct ExpressionTree {
using type = NumberType;
static constexpr OptimMode mode = Mode;
Expand Down Expand Up @@ -534,7 +534,7 @@ struct ExpressionTree {
/// Preprocess an expression, create an expression tree for it, and extract its
/// symbols in the process. NumberType is the type used for stored constants
/// (i.e. literals).
template<class NumberType, OptimMode Mode = OptimMode::TREE_SIZE,
template<class NumberType, OptimMode Mode = internal::default_optim_mode,
class StringType, class StringListType>
ExpressionTree<NumberType, Mode> Parse(StringType expr,
StringListType& symbols) {
Expand Down Expand Up @@ -729,15 +729,15 @@ ReturnType Eval(const TreeType& tree, const StringListType& symbols,

/// Overload of Eval, evaluates a raw expression (string) assuming it does not
/// contain symbols (provided for convenience).
template<class ReturnType, class StringType>
template<class ReturnType, OptimMode Mode = OptimMode::NONE, class StringType>
ReturnType Eval(const StringType& expr) {
std::vector<str_t> s;
auto f = [&](int) {
assert(false && "Unexpected symbol");
return ReturnType{};
};
return internal::EvaluateExpressionTree<ReturnType>(
Parse<ReturnType>(str_t(expr), s), f);
Parse<ReturnType, Mode>(str_t(expr), s), f);
}

} // namespace mel
4 changes: 3 additions & 1 deletion tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

int main(int, char* []) {

if (mel::tests()) return 1;
if (mel::tests<mel::OptimMode::NONE>()) return 1;
if (mel::tests<mel::OptimMode::TREE_SIZE>()) return 1;
if (mel::tests<mel::OptimMode::STACK_SIZE>()) return 1;

return mel::benchmarks();
}
41 changes: 25 additions & 16 deletions tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@

namespace mel {

template <OptimMode Mode>
inline int tests() {
using namespace internal;
std::cout << "\nTests\n\n";

MEL_CHECK(sizeof(ExpressionTree<double>) == (max_tree_size*2+1)*sizeof(double))
MEL_CHECK(sizeof(ExpressionTree<double, Mode>) == (max_tree_size*2+1)*sizeof(double))

MEL_CHECK(BalancedParentheses(str_t("(((a+b)*c))")))
MEL_CHECK(!BalancedParentheses(str_t("a)*2*(c")))
Expand Down Expand Up @@ -78,23 +79,25 @@ inline int tests() {
MEL_CHECK(r3[0] == "-" && r3[1] == "(a+b)*c" && r3[2] == "d")

std::vector<str_t> symb;
Parse<double>(str_t("((a+b)*c - \"var 1\")"), symb);
Parse<double, Mode>(str_t("((a+b)*c - \"var 1\")"), symb);
MEL_CHECK(symb[0] == "a" && symb[1] == "b" && symb[2] == "c" && symb[3] == "\"var 1\"")

symb.clear();
Parse<double>(str_t("1 - 2)"), symb);
Parse<double, Mode>(str_t("1 - 2)"), symb);
MEL_CHECK(symb.front() == "2)")

symb.clear();
Parse<double>(str_t("(1 - xx"), symb);
Parse<double, Mode>(str_t("(1 - xx"), symb);
MEL_CHECK(symb.front() == "(1-xx")

// symb.clear();
// const auto tree = Parse<double>(str_t("(1 - x) * (x - 1)"), symb);
// MEL_CHECK(tree.size == 5)
if (Mode == OptimMode::TREE_SIZE) {
symb.clear();
const auto tree = Parse<double, Mode>(str_t("(1 - x) * (x - 1)"), symb);
MEL_CHECK(tree.size == 5)
}

#define MEL_CHECK_EXPR(EXPR) { \
auto v = Eval<double>(#EXPR); \
auto v = Eval<double, Mode>(#EXPR); \
std::cout << v << '\n'; \
MEL_CHECK(v == (EXPR)) }

Expand All @@ -116,7 +119,7 @@ inline int tests() {
#define MEL_CHECK_EXPR(EXPR, ...) { \
std::vector<str_t> s; \
const double x[] = {__VA_ARGS__}; \
auto t = Parse<double>(str_t(#EXPR), s); \
auto t = Parse<double, Mode>(str_t(#EXPR), s); \
auto v = Eval<double>(t, [&x](int i) {return x[i];}); \
std::cout << v << '\n'; \
MEL_CHECK(v == (EXPR)) \
Expand Down Expand Up @@ -146,6 +149,7 @@ struct Timer {
};

#define MEL_BENCHMARK(NAME, SIZE, SAMPLES, ...) \
template <OptimMode Mode> \
int benchmark_##NAME(const double tol, const double allowed_ratio) { \
constexpr int samples = SAMPLES; \
constexpr int n = SIZE; \
Expand All @@ -158,17 +162,17 @@ int benchmark_##NAME(const double tol, const double allowed_ratio) { \
\
const str_t expr = #__VA_ARGS__; \
std::vector<str_t> s; \
const auto t = Parse<double>(expr, s); \
const auto t = Parse<double, Mode>(expr, s); \
std::cout << expr << '\n'; \
Print(t, s, std::cout); \
std::cout << '\n'; \
PrintNodes(t, s, std::cout); \
\
auto t0 = Timer(); \
auto* tree = new ExpressionTree<double>; \
auto* tree = new ExpressionTree<double, Mode>; \
for (int k = 0; k < samples; ++k) { \
std::vector<str_t> s; \
*tree = Parse<double>(expr, s); \
*tree = Parse<double, Mode>(expr, s); \
} \
delete tree; \
const auto t_parse = t0.mark() / samples; \
Expand Down Expand Up @@ -220,10 +224,15 @@ MEL_BENCHMARK(4, 8192, 1024,

inline int benchmarks() {
std::cout << "\nBenchmarks\n\n";
if (internal::benchmark_1(0.0, 30)) return 1;
if (internal::benchmark_2(1e-15, 75)) return 1;
if (internal::benchmark_3(1e-16, 2.5)) return 1;
if (internal::benchmark_4(1e-12, 70)) return 1;
if (internal::benchmark_1<OptimMode::TREE_SIZE>(0.0, 30)) return 1;
if (internal::benchmark_2<OptimMode::TREE_SIZE>(1e-15, 75)) return 1;
if (internal::benchmark_3<OptimMode::TREE_SIZE>(1e-16, 2.5)) return 1;
if (internal::benchmark_4<OptimMode::TREE_SIZE>(1e-12, 70)) return 1;

if (internal::benchmark_1<OptimMode::STACK_SIZE>(0.0, 30)) return 1;
if (internal::benchmark_2<OptimMode::STACK_SIZE>(1e-15, 110)) return 1;
if (internal::benchmark_3<OptimMode::STACK_SIZE>(1e-16, 2.5)) return 1;
if (internal::benchmark_4<OptimMode::STACK_SIZE>(1e-12, 100)) return 1;
return 0;
}

Expand Down

0 comments on commit c66ed4c

Please sign in to comment.