From 6e4a8ccb98d873ee36e2cfd6e7386daf37263e32 Mon Sep 17 00:00:00 2001 From: Nick Christofides <118103879+NicChr@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:19:37 +0100 Subject: [PATCH] Minor bug fixes. --- src/int64.cpp | 4 +- src/scalars.cpp | 127 ++++++++++++++++++++++++++++++++++++++++++------ src/which.cpp | 2 +- 3 files changed, 114 insertions(+), 19 deletions(-) diff --git a/src/int64.cpp b/src/int64.cpp index 85e4319..f0807f6 100644 --- a/src/int64.cpp +++ b/src/int64.cpp @@ -5,7 +5,6 @@ bool is_int64(SEXP x){ } // Convert 64-bit integer vec to 32-bit integer vec -// We almost never want to convert back to 32-bit int [[cpp11::register]] SEXP cpp_int64_to_int(SEXP x){ @@ -30,6 +29,7 @@ SEXP cpp_int64_to_int(SEXP x){ } // Convert 64-bit integer vec to double vec + [[cpp11::register]] SEXP cpp_int64_to_double(SEXP x){ if (!is_int64(x)){ @@ -93,7 +93,7 @@ bool cpp_all_integerable(SEXP x, int shift = 0){ return out; } -// Convert 64-bit integer to int if possible, otherwise double +// Convert 64-bit integer to 32-bit int if possible, otherwise double [[cpp11::register]] SEXP cpp_int64_to_numeric(SEXP x){ diff --git a/src/scalars.cpp b/src/scalars.cpp index 0480f06..da94483 100644 --- a/src/scalars.cpp +++ b/src/scalars.cpp @@ -1,5 +1,57 @@ #include "cheapr_cpp.h" +// Relational operators +// define CHEAPR_OP_SWITCH +// switch(op){ +// case 1: { +// c_op = equals; +// break; +// } +// case 2: { +// c_op = gt; +// break; +// } +// case 3: { +// c_op = lt; +// break; +// } +// case 4: { +// c_op = gte; +// break; +// } +// case 5: { +// c_op = lte; +// break; +// } +// case 6: { +// c_op = neq; +// break; +// } +// default: { +// Rf_error("Supported relational operations: `==`, `>`, `<`, `>=`, `<=`, `!=`"); +// } +// } + +// #define equals(a, b) ((int) a == b) +// #define gt(a, b) ((int) a > b); +// #define lt(a, b) ((int) a < b) +// #define gte(a, b) ((int) a >= b) +// #define lte(a, b) ((int) a <= b) +// #define neq(a, b) ((int) a != b) + +// template +// int equals(T1 a, T2 b) { return a == b; } +// template +// int gt(T1 a, T2 b) { return a > b; } +// template +// int lt(T1 a, T2 b) { return a < b; } +// template +// int gte(T1 a, T2 b) { return a >= b; } +// template +// int lte(T1 a, T2 b) { return a <= b; } +// template +// int neq(T1 a, T2 b) { return a != b; } + bool is_scalar_na(SEXP x){ if (Rf_xlength(x) != 1){ Rf_error("x must be a scalar value"); @@ -33,7 +85,11 @@ bool is_scalar_na(SEXP x){ SEXP coerce_vector(SEXP source, SEXP target){ if (is_int64(target)){ - return cpp_numeric_to_int64(Rf_coerceVector(source, REALSXP)); + if (is_int64(source)){ + return source; + } else { + return cpp_numeric_to_int64(Rf_coerceVector(source, REALSXP)); + } } else { return Rf_coerceVector(source, TYPEOF(target)); } @@ -53,15 +109,32 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){ R_xlen_t n = Rf_xlength(x); R_xlen_t count = 0; int NP = 0; - bool do_parallel = n >= 100000; - int n_cores = do_parallel ? num_cores() : 1; + int n_cores = n >= CHEAPR_OMP_THRESHOLD ? num_cores() : 1; SEXP val_is_na = Rf_protect(cpp_is_na(value)); ++NP; if (Rf_length(val_is_na) == 1 && LOGICAL(val_is_na)[0]){ + // Can't count NA > NA for example + // if (op != 1){ + // Rf_unprotect(NP); + // return 0; + // } else { Rf_unprotect(NP); return na_count(x, recursive); + // } } -#define VAL_COUNT(_val_) for (R_xlen_t i = 0; i < n; ++i) count += (p_x[i] == _val_); +#define CHEAPR_VAL_COUNT(_val_) \ + for (R_xlen_t i = 0; i < n; ++i){ \ + count += (p_x[i] == _val_); \ + } \ + \ + + // Alternative that works for other equality operators + // _IS_NA_ is a arg that accepts a function like cheapr_is_na_int + // for (R_xlen_t i = 0; i < n; ++i){ + // count += (c_op(p_x[i], _val_) && !_IS_NA_(p_x[i])); + // } + + switch ( TYPEOF(x) ){ case NILSXP: { Rf_unprotect(NP); @@ -73,26 +146,45 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){ Rf_protect(value = Rf_coerceVector(value, INTSXP)); ++NP; int val = Rf_asInteger(value); int *p_x = INTEGER(x); - if (do_parallel){ + // int (*c_op)(int, int); + // CHEAPR_OP_SWITCH; + if (n_cores > 1){ #pragma omp parallel for simd num_threads(n_cores) reduction(+:count) - VAL_COUNT(val) + CHEAPR_VAL_COUNT(val) } else { #pragma omp for simd - VAL_COUNT(val) + CHEAPR_VAL_COUNT(val) } break; } case REALSXP: { if (implicit_na_coercion(value, x)) break; - Rf_protect(value = Rf_coerceVector(value, REALSXP)); ++NP; - double val = Rf_asReal(value); - double *p_x = REAL(x); - if (do_parallel){ + if (is_int64(x)){ + Rf_protect(value = coerce_vector(value, x)); ++NP; + long long int val = INTEGER64_PTR(value)[0]; + long long int *p_x = INTEGER64_PTR(x); + // int (*c_op)(long long int, long long int); + // CHEAPR_OP_SWITCH; + if (n_cores > 1){ #pragma omp parallel for simd num_threads(n_cores) reduction(+:count) - VAL_COUNT(val) + CHEAPR_VAL_COUNT(val) + } else { +#pragma omp for simd + CHEAPR_VAL_COUNT(val) + } } else { + Rf_protect(value = Rf_coerceVector(value, REALSXP)); ++NP; + double val = Rf_asReal(value); + double *p_x = REAL(x); + // int (*c_op)(double, double); + // CHEAPR_OP_SWITCH; + if (n_cores > 1){ +#pragma omp parallel for simd num_threads(n_cores) reduction(+:count) + CHEAPR_VAL_COUNT(val) + } else { #pragma omp for simd - VAL_COUNT(val) + CHEAPR_VAL_COUNT(val) + } } break; } @@ -101,12 +193,15 @@ R_xlen_t scalar_count(SEXP x, SEXP value, bool recursive){ Rf_protect(value = Rf_coerceVector(value, STRSXP)); ++NP; SEXP val = Rf_protect(Rf_asChar(value)); ++NP; const SEXP *p_x = STRING_PTR_RO(x); - if (do_parallel){ + // int (*c_op)(SEXP, SEXP); + // CHEAPR_OP_SWITCH; + + if (n_cores > 1){ #pragma omp parallel for simd num_threads(n_cores) reduction(+:count) - VAL_COUNT(val); + CHEAPR_VAL_COUNT(val); } else { #pragma omp for simd - VAL_COUNT(val); + CHEAPR_VAL_COUNT(val); } break; } diff --git a/src/which.cpp b/src/which.cpp index 256ff51..0479e94 100644 --- a/src/which.cpp +++ b/src/which.cpp @@ -228,7 +228,7 @@ SEXP cpp_which_na(SEXP x){ } case REALSXP: { R_xlen_t count = na_count(x, true); - if (Rf_inherits(x, "integer64")){ + if (is_int64(x)){ long long *p_x = INTEGER64_PTR(x); if (is_short){ int out_size = count;