Skip to content

Commit f8acb69

Browse files
author
Kieran Elmes
committed
fix real-value upper bound for interactions
1 parent 818d3b4 commit f8acb69

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

coverage-badge.svg

+1-1
Loading

src/Pint.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
328328
for (int_fast64_t i = 0; i < p; i++)
329329
X[i] = (int_fast64_t*)malloc(n * sizeof *X[i]);
330330

331+
float overall_max_val = 0.0;
331332
for (int_fast64_t i = 0; i < p; i++) {
332333
float col_max_val = 0.0;
333334
for (int_fast64_t j = 0; j < n; j++) {
@@ -342,11 +343,15 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
342343
}
343344
}
344345
col_max_vals[i] = col_max_val;
346+
if (fabs(col_max_val) > overall_max_val) {
347+
overall_max_val = fabs(col_max_val);
348+
}
345349
}
346350
struct continuous_info ci;
347351
ci.col_max_vals = col_max_vals;
348352
ci.col_real_vals = col_real_vals;
349353
ci.use_cont = continuous_X;
354+
ci.overall_max_val = overall_max_val;
350355

351356
float* Y = (float*)malloc(n * sizeof(float));
352357
for (int_fast64_t i = 0; i < n; i++) {

src/liblasso.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ struct continuous_info {
103103
bool use_cont;
104104
std::vector<float>* col_real_vals;
105105
float* col_max_vals;
106+
float overall_max_val;
107+
int depth;
106108
};
107109
void free_continuous_info(struct continuous_info ci);
108110

@@ -245,4 +247,4 @@ extern int_pair* cached_nums;
245247
extern bool VERBOSE;
246248
extern float total_sqrt_error;
247249

248-
#endif
250+
#endif

src/pruning.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ float l2_combined_estimate(X_uncompressed X, float lambda, int_fast64_t k,
7070
alpha = 0.0;
7171

7272
float remainder = pessimistic_estimate(alpha, last_rowsum, rowsum, col, X.host_col_nz[k]);
73-
if (ci->use_cont)
74-
remainder *= fabs(ci->col_max_vals[k]);
73+
if (ci->use_cont) {
74+
if (ci->depth == 2)
75+
remainder *= fabs(ci->overall_max_val) * fabs(ci->col_max_vals[k]);
76+
else
77+
remainder *= fabs(ci->overall_max_val) * fabs(ci->overall_max_val) * fabs(ci->col_max_vals[k]);
78+
}
7579

7680
float total_estimate = fabs(last_max * alpha) + remainder;
7781
return total_estimate;

tests/func-tests.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,8 @@ void check_small_continuous() {
13921392
ci.col_real_vals = new vector<float>[p];
13931393
ci.col_max_vals = new float[p];
13941394
ci.use_cont = true;
1395+
ci.overall_max_val = 0.0;
1396+
ci.depth = 2;
13951397

13961398
XMatrix xm;
13971399
xm.X = calloc(p, sizeof(int_fast64_t*));
@@ -1408,14 +1410,19 @@ void check_small_continuous() {
14081410
ci.col_real_vals[3] = {0.2, 0.2, 0.4, -3.3, 2.1};
14091411
ci.col_real_vals[4] = {-0.2, -1.2, 3.2, 3.5, 0.1};
14101412

1413+
float overall_max = 0.0;
14111414
for (int j = 0; j < p; j++) {
14121415
float max_val = 0.0;
14131416
for (auto v : ci.col_real_vals[j]) {
14141417
if (fabs(v) > fabs(max_val))
14151418
max_val = v;
14161419
}
14171420
ci.col_max_vals[j] = max_val;
1421+
if (fabs(max_val) > overall_max) {
1422+
overall_max = fabs(max_val);
1423+
}
14181424
}
1425+
ci.overall_max_val = overall_max;
14191426

14201427
std::vector<float> beta = {0.3, 1.1, 0.9, -2.2, 1.5};
14211428
float* Y = calloc(n, sizeof(float));
@@ -1479,6 +1486,8 @@ void check_continous_ones(UpdateFixture* fixture,
14791486
ci.col_max_vals = col_max_vals;
14801487
ci.col_real_vals = col_real_vals;
14811488
ci.use_cont = true;
1489+
ci.overall_max_val = 1.0;
1490+
ci.depth = 3;
14821491
bool check_duplicates = true;
14831492
Lasso_Result lr = simple_coordinate_descent_lasso(fixture->xmatrix, fixture->Y, fixture->n, fixture->p,
14841493
-1, 0.01, 200,

0 commit comments

Comments
 (0)