Skip to content

Commit

Permalink
first pass at #200
Browse files Browse the repository at this point in the history
  • Loading branch information
stevencarlislewalker committed May 3, 2024
1 parent 6965bf4 commit fdfef67
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 7 deletions.
1 change: 1 addition & 0 deletions R/enum.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ valid_func_sigs = c(
, "fwrap,fail: reulermultinom(size, rate, delta_t)"
, "fwrap,null: round(x)"
, "fwrap,fail: pgamma(q, shape, scale)"
, "fwrap,fail: safe_power(x,y)"
)
process_enum = function(x) {
RE = "(null|fail|binop|fwrap|bwrap|pwrap)[ ]*,[ ]*(null|fail|binop|fwrap|bwrap|pwrap)[ ]*:[ ]*\\`?([^`]*)\\`?\\((.*)(\\,.*)*\\)"
Expand Down
36 changes: 36 additions & 0 deletions misc/dev/dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ enum macpan2_func
, MP2_EULER_MULTINOM_SIM = 48 // fwrap,fail: reulermultinom(size, rate, delta_t)
, MP2_ROUND = 49 // fwrap,null: round(x)
, MP2_PGAMMA = 50 // fwrap,fail: pgamma(q, shape, scale)
, MP2_SAFEPOWER = 51 // fwrap,fail: safe_power(x,y)
};

enum macpan2_meth
Expand Down Expand Up @@ -1286,6 +1287,41 @@ class ExprEvaluator
#endif
return pow(args[0].array(), args[1].array()).matrix();
// return args[0].pow(args[1].coeff(0,0));

case MP2_SAFEPOWER: // SAFE_POWER, equivalent to (ifelse(x==0, 0, x^y))
if (n != 2) {
SetError(err_code, "safe_power requires exactly two arguments", row, table_x[row] + 1, args.all_rows(), args.all_cols(), args.all_type_ints());
return m;
}
args = args.recycle_for_bin_op();
err_code = args.get_error_code();
switch (err_code) {
case 201:
SetError(err_code, "The two operands do not have the same number of columns", row, table_x[row] + 1, args.all_rows(), args.all_cols(), args.all_type_ints());
return m;
case 202:
SetError(err_code, "The two operands do not have the same number of rows", row, table_x[row] + 1, args.all_rows(), args.all_cols(), args.all_type_ints());
return m;
case 203:
SetError(err_code, "The two operands do not have the same number of columns or rows", row, table_x[row] + 1, args.all_rows(), args.all_cols(), args.all_type_ints());
return m;
}
m1 = args[0];
m2 = args[1];
rows = m1.rows();
cols = m1.cols();
m = matrix<Type>::Zero(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
m.coeffRef(i, j) = CppAD::CondExpEq(
m1.coeff(i, j),
Type(0),
Type(0),
pow(m1.coeff(i, j), m2.coeff(i, j))
);
}
}
return m;

// #' ## Unary Elementwise Math
// #'
Expand Down
16 changes: 9 additions & 7 deletions src/macpan2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ enum macpan2_func
, MP2_EULER_MULTINOM_SIM = 48 // fwrap,fail: reulermultinom(size, rate, delta_t)
, MP2_ROUND = 49 // fwrap,null: round(x)
, MP2_PGAMMA = 50 // fwrap,fail: pgamma(q, shape, scale)
// `^`(x, y), but with 0^0 defined as 0
, MP2_SAFEPOWER = 51 // fwrap,fail: safe_power(x,y)
};

Expand Down Expand Up @@ -1289,9 +1288,12 @@ class ExprEvaluator
#endif
return pow(args[0].array(), args[1].array()).matrix();
// return args[0].pow(args[1].coeff(0,0));

case MP2_SAFEPOWER: // SAFE_POWER, equivalent to (ifelse(x==0, 0, x^y))


case MP2_SAFEPOWER: // SAFE_POWER, equivalent to (ifelse(x==0, 0, x^y))
if (n != 2) {
SetError(err_code, "safe_power requires exactly two arguments", row, table_x[row] + 1, args.all_rows(), args.all_cols(), args.all_type_ints());
return m;
}
args = args.recycle_for_bin_op();
err_code = args.get_error_code();
switch (err_code) {
Expand All @@ -1310,11 +1312,11 @@ class ExprEvaluator
rows = m1.rows();
cols = m1.cols();
m = matrix<Type>::Zero(rows, cols);
for (int i; i < rows; i++) {
for (int j; j < cols; j++) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
m.coeffRef(i, j) = CppAD::CondExpEq(
m1.coeff(i, j),
m2.coeff(i, j),
Type(0),
Type(0),
pow(m1.coeff(i, j), m2.coeff(i, j))
);
Expand Down
29 changes: 29 additions & 0 deletions tests/testthat/test-binop.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
library(macpan2); library(testthat); library(dplyr); library(tidyr); library(ggplot2)

test_that("elementwise binary operator executable specs match spec doc", {
## https://canmod.net/misc/elementwise_binary_operators
times = BinaryOperator(`*`)
Expand Down Expand Up @@ -83,3 +85,30 @@ test_that("elementwise binary operator executable specs match spec doc", {
test_that("equivalent unary and binary minus operators give the same answers", {
expect_equal(engine_eval(~-4), engine_eval(~0-4))
})

test_that("safe_power meets the requirements of #200", {
expect_equal(
matrix(0.1 ^ -5),
engine_eval(~safe_power(0.1, -5))
)
expect_equal(
matrix(0 ^ 5),
engine_eval(~safe_power(0, 5))
)
expect_equal(
matrix(c(1, 1, 0, 1, 1)), # != matrix((-2:2)^0),
engine_eval(~safe_power(-2:2, 0))
)
expect_equal(
matrix(rep(0, 5)), # != matrix(0^(-2:2))
engine_eval(~safe_power(0, -2:2))
)
expect_error(
engine_eval(~safe_power(1)),
regexp = "safe_power requires exactly two arguments"
)
expect_error(
engine_eval(~safe_power(1:3, 2:3)),
regexp = "The two operands do not have the same number of rows"
)
})

0 comments on commit fdfef67

Please sign in to comment.