Skip to content

Commit

Permalink
Minor bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
NicChr committed Oct 9, 2024
1 parent 8b76962 commit 6e4a8cc
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/int64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ bool is_int64(SEXP x){
}

// Convert 64-bit integer vec to 32-bit integer vec
// We almost never want to convert back to 32-bit int

[[cpp11::register]]
SEXP cpp_int64_to_int(SEXP x){
Expand All @@ -30,6 +29,7 @@ SEXP cpp_int64_to_int(SEXP x){
}

// Convert 64-bit integer vec to double vec

[[cpp11::register]]
SEXP cpp_int64_to_double(SEXP x){
if (!is_int64(x)){
Expand Down Expand Up @@ -93,7 +93,7 @@ bool cpp_all_integerable(SEXP x, int shift = 0){
return out;
}

// Convert 64-bit integer to int if possible, otherwise double
// Convert 64-bit integer to 32-bit int if possible, otherwise double

[[cpp11::register]]
SEXP cpp_int64_to_numeric(SEXP x){
Expand Down
127 changes: 111 additions & 16 deletions src/scalars.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,57 @@
#include "cheapr_cpp.h"

// Relational operators
// define CHEAPR_OP_SWITCH
// switch(op){
// case 1: {
// c_op = equals;
// break;
// }
// case 2: {
// c_op = gt;
// break;
// }
// case 3: {
// c_op = lt;
// break;
// }
// case 4: {
// c_op = gte;
// break;
// }
// case 5: {
// c_op = lte;
// break;
// }
// case 6: {
// c_op = neq;
// break;
// }
// default: {
// Rf_error("Supported relational operations: `==`, `>`, `<`, `>=`, `<=`, `!=`");
// }
// }

// #define equals(a, b) ((int) a == b)
// #define gt(a, b) ((int) a > b);
// #define lt(a, b) ((int) a < b)
// #define gte(a, b) ((int) a >= b)
// #define lte(a, b) ((int) a <= b)
// #define neq(a, b) ((int) a != b)

// template <typename T1, typename T2>
// int equals(T1 a, T2 b) { return a == b; }
// template <typename T1, typename T2>
// int gt(T1 a, T2 b) { return a > b; }
// template <typename T1, typename T2>
// int lt(T1 a, T2 b) { return a < b; }
// template <typename T1, typename T2>
// int gte(T1 a, T2 b) { return a >= b; }
// template <typename T1, typename T2>
// int lte(T1 a, T2 b) { return a <= b; }
// template <typename T1, typename T2>
// int neq(T1 a, T2 b) { return a != b; }

bool is_scalar_na(SEXP x){
if (Rf_xlength(x) != 1){
Rf_error("x must be a scalar value");
Expand Down Expand Up @@ -33,7 +85,11 @@ bool is_scalar_na(SEXP x){

SEXP coerce_vector(SEXP source, SEXP target){
if (is_int64(target)){
return cpp_numeric_to_int64(Rf_coerceVector(source, REALSXP));
if (is_int64(source)){
return source;
} else {
return cpp_numeric_to_int64(Rf_coerceVector(source, REALSXP));
}
} else {
return Rf_coerceVector(source, TYPEOF(target));
}
Expand All @@ -53,15 +109,32 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){
R_xlen_t n = Rf_xlength(x);
R_xlen_t count = 0;
int NP = 0;
bool do_parallel = n >= 100000;
int n_cores = do_parallel ? num_cores() : 1;
int n_cores = n >= CHEAPR_OMP_THRESHOLD ? num_cores() : 1;

SEXP val_is_na = Rf_protect(cpp_is_na(value)); ++NP;
if (Rf_length(val_is_na) == 1 && LOGICAL(val_is_na)[0]){
// Can't count NA > NA for example
// if (op != 1){
// Rf_unprotect(NP);
// return 0;
// } else {
Rf_unprotect(NP);
return na_count(x, recursive);
// }
}
#define VAL_COUNT(_val_) for (R_xlen_t i = 0; i < n; ++i) count += (p_x[i] == _val_);
#define CHEAPR_VAL_COUNT(_val_) \
for (R_xlen_t i = 0; i < n; ++i){ \
count += (p_x[i] == _val_); \
} \
\

// Alternative that works for other equality operators
// _IS_NA_ is a arg that accepts a function like cheapr_is_na_int
// for (R_xlen_t i = 0; i < n; ++i){
// count += (c_op(p_x[i], _val_) && !_IS_NA_(p_x[i]));
// }


switch ( TYPEOF(x) ){
case NILSXP: {
Rf_unprotect(NP);
Expand All @@ -73,26 +146,45 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){
Rf_protect(value = Rf_coerceVector(value, INTSXP)); ++NP;
int val = Rf_asInteger(value);
int *p_x = INTEGER(x);
if (do_parallel){
// int (*c_op)(int, int);
// CHEAPR_OP_SWITCH;
if (n_cores > 1){
#pragma omp parallel for simd num_threads(n_cores) reduction(+:count)
VAL_COUNT(val)
CHEAPR_VAL_COUNT(val)
} else {
#pragma omp for simd
VAL_COUNT(val)
CHEAPR_VAL_COUNT(val)
}
break;
}
case REALSXP: {
if (implicit_na_coercion(value, x)) break;
Rf_protect(value = Rf_coerceVector(value, REALSXP)); ++NP;
double val = Rf_asReal(value);
double *p_x = REAL(x);
if (do_parallel){
if (is_int64(x)){
Rf_protect(value = coerce_vector(value, x)); ++NP;
long long int val = INTEGER64_PTR(value)[0];
long long int *p_x = INTEGER64_PTR(x);
// int (*c_op)(long long int, long long int);
// CHEAPR_OP_SWITCH;
if (n_cores > 1){
#pragma omp parallel for simd num_threads(n_cores) reduction(+:count)
VAL_COUNT(val)
CHEAPR_VAL_COUNT(val)
} else {
#pragma omp for simd
CHEAPR_VAL_COUNT(val)
}
} else {
Rf_protect(value = Rf_coerceVector(value, REALSXP)); ++NP;
double val = Rf_asReal(value);
double *p_x = REAL(x);
// int (*c_op)(double, double);
// CHEAPR_OP_SWITCH;
if (n_cores > 1){
#pragma omp parallel for simd num_threads(n_cores) reduction(+:count)
CHEAPR_VAL_COUNT(val)
} else {
#pragma omp for simd
VAL_COUNT(val)
CHEAPR_VAL_COUNT(val)
}
}
break;
}
Expand All @@ -101,12 +193,15 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){
Rf_protect(value = Rf_coerceVector(value, STRSXP)); ++NP;
SEXP val = Rf_protect(Rf_asChar(value)); ++NP;
const SEXP *p_x = STRING_PTR_RO(x);
if (do_parallel){
// int (*c_op)(SEXP, SEXP);
// CHEAPR_OP_SWITCH;

if (n_cores > 1){
#pragma omp parallel for simd num_threads(n_cores) reduction(+:count)
VAL_COUNT(val);
CHEAPR_VAL_COUNT(val);
} else {
#pragma omp for simd
VAL_COUNT(val);
CHEAPR_VAL_COUNT(val);
}
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/which.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ SEXP cpp_which_na(SEXP x){
}
case REALSXP: {
R_xlen_t count = na_count(x, true);
if (Rf_inherits(x, "integer64")){
if (is_int64(x)){
long long *p_x = INTEGER64_PTR(x);
if (is_short){
int out_size = count;
Expand Down

0 comments on commit 6e4a8cc

Please sign in to comment.