Skip to content

Commit 367004f

Browse files
author
Kieran Elmes
committed
more efficient & accurate strong hierarchy
1 parent b76298f commit 367004f

15 files changed

+107
-81
lines changed

R/pint.R

+2-16
Original file line numberDiff line numberDiff line change
@@ -112,23 +112,9 @@ interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -
112112
}
113113
rm(tmp)
114114

115-
if (approximate_hierarchy && depth > 1) {
116-
if (verbose) {
117-
print("Finding main effects")
118-
}
119-
main_only <- .Call(lasso_, X, Ym, lambda_min, lambda_max, halt_error_diff, max_interaction_distance, max_nz_beta*2, max_lambdas, verbose, log_filename, 1, log_level_enum, estimate_unbiased, use_intercept, num_threads, check_duplicates, continuous_X)
120-
main_only_result <- process_result(X, main_only[[1]])
121-
# we'll use these to allow the second run to report all identical columns again.
122-
if (check_duplicates) {
123-
all_equiv <- unique(unlist(main_only_result$main_effects$equivalent))
124-
} else {
125-
all_equiv <- unique(unlist(main_only_result$main_effects$i))
126-
}
127-
128-
X <- X[,all_equiv]
129-
}
130115

131-
result <- .Call(lasso_, X, Ym, lambda_min, lambda_max, halt_error_diff, max_interaction_distance, max_nz_beta, max_lambdas, verbose, log_filename, depth, log_level_enum, estimate_unbiased, use_intercept, num_threads, check_duplicates, continuous_X)
116+
117+
result <- .Call(lasso_, X, Ym, lambda_min, lambda_max, halt_error_diff, max_interaction_distance, max_nz_beta, max_lambdas, verbose, log_filename, depth, log_level_enum, estimate_unbiased, use_intercept, num_threads, check_duplicates, continuous_X, approximate_hierarchy)
132118

133119
rm(Ym)
134120

coverage-badge.svg

+1-1
Loading

install_and_run.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ if (length(args >= 2)) {
1616
# f <- "../data/simulated_8k/n8000_p4000_SNR5_nbi40_nbij800_nlethals200_viol0_91159.rds"
1717
# f <- "../data/simulated_small_data/n1000_p100_SNR5_nbi0_nbij100_nlethals0_viol0_11754.rds"
1818
# f <- "../data/simulated_large_data/n10000_p1000_SNR10_nbi0_nbij1000_nlethals0_viol0_11504.rds"
19-
f <- "../data/simulated_8k/n2000_p1000_SNR5_nbi10_nbij200_nlethals50_viol0_11057.rds"
20-
# f <- "../data/simulated_8k/n8000_p4000_SNR5_nbi40_nbij800_nlethals200_viol0_78715.rds"
19+
# f <- "../data/simulated_8k/n2000_p1000_SNR5_nbi10_nbij200_nlethals50_viol0_11057.rds"
20+
f <- "../data/simulated_8k/n8000_p4000_SNR5_nbi40_nbij800_nlethals200_viol0_78715.rds"
2121
# f <- "../infx_lasso_data/3way_data_to_run/n1000_p100_SNR4_nbi10_nbij252_nbijk1666_nlethals0_70443.rds"
2222
# f <- "./weirdly_slow_case/n1000_p100_SNR10_nbi0_nbij100_nlethals0_viol0_33859.rds"
2323
# f <- "./antibio_data.rds"
@@ -43,7 +43,7 @@ Y <- d$Y
4343
# result <- interaction_lasso(X, Y, depth = 3)
4444
# result <- interaction_lasso(X, Y, depth = 2)
4545
# result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 150, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, strong_hierarchy = TRUE, check_duplicates = TRUE, continuous_X = TRUE)
46-
result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 150, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, approximate_hierarchy = FALSE, check_duplicates = TRUE, continuous_X = FALSE)
46+
result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 250, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, approximate_hierarchy = TRUE, check_duplicates = TRUE, continuous_X = FALSE)
4747
# print(result)
4848

4949
# q()

src/Pint.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
275275
SEXP halt_error_diff_, SEXP max_interaction_distance_, SEXP max_nz_beta_,
276276
SEXP max_lambdas_, SEXP verbose_, SEXP log_filename_, SEXP depth_, SEXP log_level_,
277277
SEXP estimate_unbiased_, SEXP use_intercept_, SEXP use_cores_,
278-
SEXP check_duplicates_, SEXP continuous_X_)
278+
SEXP check_duplicates_, SEXP continuous_X_, SEXP use_hierarchy_)
279279
{
280280
double* x = REAL(X_);
281281
double* y = REAL(Y_);
@@ -293,6 +293,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
293293
int_fast64_t use_cores = asInteger(use_cores_);
294294
bool check_duplicates = asLogical(check_duplicates_);
295295
bool continuous_X = asLogical(continuous_X_);
296+
bool use_hierarchy = asLogical(use_hierarchy_);
296297

297298
VERBOSE = verbose;
298299

@@ -359,7 +360,8 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
359360
Lasso_Result lasso_result = simple_coordinate_descent_lasso(
360361
xmatrix, Y, n, p, max_interaction_distance, asReal(lambda_min_),
361362
asReal(lambda_max_), max_lambdas, verbose, halt_error_diff, log_level, NULL, 0,
362-
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, &ci);
363+
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, &ci,
364+
use_hierarchy);
363365
float final_lambda = lasso_result.final_lambda;
364366
float regularized_intercept = lasso_result.regularized_intercept;
365367
float unbiased_intercept = lasso_result.unbiased_intercept;
@@ -392,7 +394,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
392394
}
393395

394396
static const R_CallMethodDef CallEntries[] = {
395-
{ "lasso_", (DL_FUNC)&lasso_, 17 },
397+
{ "lasso_", (DL_FUNC)&lasso_, 18 },
396398
{ "read_log_", (DL_FUNC)&read_log_, 1 },
397399
{ NULL, NULL, 0 }
398400
};

src/liblasso.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ typedef struct {
122122
// int_fast64_t total_found_hash_count;
123123
} IndiCols;
124124

125+
typedef struct {
126+
robin_hood::unordered_flat_map<int_fast64_t, float> beta1;
127+
robin_hood::unordered_flat_map<int_fast64_t, float> beta2;
128+
robin_hood::unordered_flat_map<int_fast64_t, float> beta3;
129+
int_fast64_t p;
130+
} Beta_Value_Sets;
131+
125132
#include "s8b.h"
126133
#include "sparse_matrix.h"
127134
#include "tuple_val.h"
@@ -149,12 +156,6 @@ struct AS_Entry {
149156
std::vector<float> real_vals;
150157
};
151158

152-
typedef struct {
153-
robin_hood::unordered_flat_map<int_fast64_t, float> beta1;
154-
robin_hood::unordered_flat_map<int_fast64_t, float> beta2;
155-
robin_hood::unordered_flat_map<int_fast64_t, float> beta3;
156-
int_fast64_t p;
157-
} Beta_Value_Sets;
158159

159160
typedef struct {
160161
Beta_Value_Sets regularized_result;

src/regression.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ void subproblem_only(Iter_Vars* vars, float lambda, float* rowsum,
244244
int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsum,
245245
float* old_rowsum, Active_Set* active_set,
246246
int_fast64_t depth, const bool use_intercept, IndiCols* indi, const bool check_duplicates,
247-
struct continuous_info* cont_inf)
247+
struct continuous_info* cont_inf, const bool use_hierarchy)
248248
{
249249
XMatrixSparse Xc = vars->Xc;
250250
X_uncompressed Xu = vars->Xu;
@@ -354,7 +354,8 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
354354
clock_gettime(CLOCK_MONOTONIC_RAW, &start_time);
355355
auto working_set_results = update_working_set(vars->Xu, Xc, rowsum, wont_update, p, n, lambda,
356356
updateable_items, count_may_update, active_set,
357-
thread_caches, last_max, depth, indi, &new_cols, max_interaction_distance, check_duplicates, cont_inf);
357+
thread_caches, last_max, depth, indi, &new_cols, max_interaction_distance, check_duplicates, cont_inf,
358+
use_hierarchy, beta_sets);
358359
bool increased_set = working_set_results.first;
359360
auto vals_to_remove = working_set_results.second;
360361
for (auto val : vals_to_remove) {
@@ -496,7 +497,7 @@ Lasso_Result simple_coordinate_descent_lasso(
496497
float hed, enum LOG_LEVEL log_level,
497498
const char** job_args, int_fast64_t job_args_num,
498499
int_fast64_t mnz_beta, const char* log_filename, int_fast64_t depth,
499-
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf)
500+
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf, const bool use_hierarchy)
500501
{
501502
const bool continuous_X = cont_inf->use_cont;
502503
int_fast64_t max_nz_beta = mnz_beta;
@@ -727,7 +728,7 @@ Lasso_Result simple_coordinate_descent_lasso(
727728
int_fast64_t last_iter_count = 0;
728729

729730
nz_beta += run_lambda_iters_pruned(&iter_vars_pruned, lambda, rowsum,
730-
old_rowsum, &active_set, depth, use_intercept, &indi, check_duplicates, cont_inf);
731+
old_rowsum, &active_set, depth, use_intercept, &indi, check_duplicates, cont_inf, use_hierarchy);
731732

732733
{
733734
int_fast64_t nonzero = beta_sets.beta1.size() + beta_sets.beta2.size() + beta_sets.beta3.size();

src/regression.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Lasso_Result simple_coordinate_descent_lasso(
3030
float lambda_min, float lambda_max, int_fast64_t max_iter, const bool VERBOSE,
3131
float halt_beta_diff,
3232
enum LOG_LEVEL log_level, const char** job_args, int_fast64_t job_args_num,
33-
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf);
33+
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf, const bool use_hierarchy);
3434
float update_intercept_cyclic(float intercept, int_fast64_t** X, float* Y,
3535
robin_hood::unordered_flat_map<int_fast64_t, float>* beta, int_fast64_t n, int_fast64_t p);
3636
// Changes update_beta_cyclic(XMatrixSparse xmatrix_sparse, float *Y,

src/sparse_matrix.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ void free_row_set(struct row_set rs)
3737
}
3838

3939
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu,
40-
bool* remove,
41-
Thread_Cache* thread_caches, struct continuous_info* ci)
40+
bool* remove, Thread_Cache* thread_caches, struct continuous_info* ci,
41+
const bool use_hierarchy, Beta_Value_Sets* beta_sets)
4242
{
4343
int_fast64_t p = Xc.p;
4444
int_fast64_t n = Xc.n;
@@ -59,8 +59,10 @@ struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu,
5959
for (int_fast64_t i = 0; i < Xu.host_row_nz[row]; i++) {
6060
int_fast64_t col = Xu.host_X_row[Xu.host_row_offsets[row] + i];
6161
if (!remove[col]) {
62-
row_cache[row_pos] = col;
63-
row_pos++;
62+
if (!use_hierarchy || beta_sets->beta1.contains(col)) {
63+
row_cache[row_pos] = col;
64+
row_pos++;
65+
}
6466
}
6567
}
6668

src/sparse_matrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ XMatrixSparse sparse_X_from_X(int_fast64_t** X, int_fast64_t n, int_fast64_t p,
5858
int_fast64_t shuffle);
5959
XMatrixSparse sparsify_X(int_fast64_t** X, int_fast64_t n, int_fast64_t p);
6060

61-
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu, bool* remove, Thread_Cache* thread_caches, struct continuous_info* ci);
61+
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu, bool* remove, Thread_Cache* thread_caches, struct continuous_info* ci, const bool use_hierarchy, Beta_Value_Sets* beta_sets);
6262
void free_row_set(struct row_set rs);
6363
X_uncompressed construct_host_X(XMatrixSparse* Xc);
6464
void free_host_X(X_uncompressed *Xu);

src/update_working_set.cpp

+23-17
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ char update_working_set_cpu(struct XMatrixSparse Xc,
166166
X_uncompressed Xu, float* rowsum,
167167
bool* wont_update, int_fast64_t p, int_fast64_t n, float lambda,
168168
int_fast64_t* updateable_items, int_fast64_t count_may_update,
169-
float* last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t>* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info* ci)
169+
float* last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t>* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info* ci, const bool use_hierarchy)
170170
{
171171
int_fast64_t* host_X = Xu.host_X;
172172
int_fast64_t* host_col_nz = Xu.host_col_nz;
@@ -239,22 +239,28 @@ char update_working_set_cpu(struct XMatrixSparse Xc,
239239
sum_with_col[main] += rowsum_diff; //TODO: slow
240240
if (depth > 1) {
241241
int_fast64_t ri = 0;
242-
int_fast64_t jump_dist = relevant_row_set.row_lengths[row_main]/2;
242+
const int_fast64_t row_length = relevant_row_set.row_lengths[row_main];
243243
const int_fast64_t* row = relevant_row_set.rows[row_main];
244-
int_fast64_t tmp = row[ri];
245-
while (tmp != main) {
246-
if (tmp < main)
247-
ri += jump_dist;
248-
else if (tmp > main)
249-
ri -= jump_dist;
250-
else if (tmp == main)
251-
break;
252-
jump_dist = std::max((int_fast64_t)1,jump_dist/2);
253-
tmp = row[ri];
244+
if (use_hierarchy) {
245+
// we can't assume row_main is in the reduced row matrix.
246+
while (ri < row_length && row[ri] < main)
247+
ri++;
248+
} else {
249+
int_fast64_t jump_dist = row_length/2;
250+
int_fast64_t tmp = row[ri];
251+
while (tmp != main) {
252+
if (tmp < main)
253+
ri += jump_dist;
254+
else if (tmp > main)
255+
ri -= jump_dist;
256+
else if (tmp == main)
257+
break;
258+
jump_dist = std::max((int_fast64_t)1,jump_dist/2);
259+
tmp = row[ri];
260+
}
261+
ri++;
254262
}
255-
ri++;
256263

257-
const int_fast64_t row_length = relevant_row_set.row_lengths[row_main];
258264
for (; ri < row_length; ri++) {
259265
int_fast64_t inter = relevant_row_set.rows[row_main][ri];
260266
float inter_val = 1.0;
@@ -549,15 +555,15 @@ std::pair<bool, std::vector<int_fast64_t>> update_working_set(X_uncompressed Xu,
549555
Active_Set* as, Thread_Cache* thread_caches,
550556
float* last_max,
551557
int_fast64_t depth, IndiCols *indicols, robin_hood::unordered_flat_set<int_fast64_t>* new_cols,
552-
int_fast64_t max_interaction_distance, bool check_duplicates, struct continuous_info* ci)
558+
int_fast64_t max_interaction_distance, bool check_duplicates, struct continuous_info* ci, bool use_hierarchy, Beta_Value_Sets* beta_sets)
553559
{
554-
struct row_set new_row_set = row_list_without_columns(Xc, Xu, wont_update, thread_caches, ci);
560+
struct row_set new_row_set = row_list_without_columns(Xc, Xu, wont_update, thread_caches, ci, use_hierarchy, beta_sets);
555561
std::vector<int_fast64_t> vals_to_remove;
556562
if (check_duplicates)
557563
vals_to_remove = update_main_indistinguishable_cols(Xu, wont_update, new_row_set, indicols, new_cols, ci);
558564
char increased_set = update_working_set_cpu(
559565
Xc, new_row_set, thread_caches, as, Xu, rowsum, wont_update, p, n, lambda,
560-
updateable_items, count_may_update, last_max, depth, indicols, new_cols, max_interaction_distance, check_duplicates, ci);
566+
updateable_items, count_may_update, last_max, depth, indicols, new_cols, max_interaction_distance, check_duplicates, ci, use_hierarchy);
561567

562568
free_row_set(new_row_set);
563569
return std::make_pair(increased_set, vals_to_remove);

src/update_working_set.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ std::pair<bool, std::vector<int_fast64_t>> update_working_set(
1111
float* rowsum, bool* wont_update, int_fast64_t p, int_fast64_t n,
1212
float lambda, int_fast64_t* updateable_items, int_fast64_t count_may_update, Active_Set* as,
1313
Thread_Cache* thread_caches,
14-
float* last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t>* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info* ci);
14+
float* last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t>* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info* ci, bool use_hierarchy, Beta_Value_Sets* beta_sets);
1515

1616
void free_inter_cache(int_fast64_t p);
1717
//struct OpenCL_Setup setup_working_set_kernel(

0 commit comments

Comments
 (0)