Skip to content

Commit 04c3de8

Browse files
author
Kieran Elmes
committed
wip: allow continuous X
1 parent 64dc869 commit 04c3de8

15 files changed

+143
-50
lines changed

R/pint.R

+4
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -
7575
stop("Y does not have the same number of rows as X, or the format is wrong")
7676
}
7777

78+
# combination currently not implemented
79+
if (continuous_X) {
80+
check_duplicates <- FALSE
81+
}
7882
continuous_X <- FALSE # not implemented yet.
7983

8084
log_level_enum = 0;

src/Pint.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,36 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
321321

322322
float halt_error_diff = asReal(halt_error_diff_);
323323

324+
std::vector<float>* col_real_vals = new std::vector<float>[p];
325+
float* col_max_vals = new float[p];
324326
int_fast64_t** X = (int_fast64_t**)malloc(p * sizeof *X);
325327
for (int_fast64_t i = 0; i < p; i++)
326328
X[i] = (int_fast64_t*)malloc(n * sizeof *X[i]);
327329

328330
for (int_fast64_t i = 0; i < p; i++) {
331+
float col_max_val = 0.0;
329332
for (int_fast64_t j = 0; j < n; j++) {
330-
X[i][j] = (int)(x[j + i * n]);
333+
float x_val = x[j + i * n];
334+
if (fabs(x_val) > 0.0) {
335+
col_real_vals[i].push_back(x_val);
336+
X[i][j] = 1;
337+
if (fabs(x_val) > fabs(col_max_val))
338+
col_max_val = x_val;
339+
} else {
340+
X[i][j] = 0;
341+
}
331342
}
343+
col_max_vals[i] = col_max_val;
332344
}
345+
struct continuous_info ci;
346+
ci.col_max_vals = col_max_vals;
347+
ci.col_real_vals = col_real_vals;
348+
ci.use_cont = continuous_X;
349+
333350
float* Y = (float*)malloc(n * sizeof(float));
334351
for (int_fast64_t i = 0; i < n; i++) {
335352
Y[i] = (float)y[i];
336353
}
337-
338354
XMatrix xmatrix;
339355
xmatrix.actual_cols = n;
340356
xmatrix.X = X;
@@ -343,7 +359,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
343359
Lasso_Result lasso_result = simple_coordinate_descent_lasso(
344360
xmatrix, Y, n, p, max_interaction_distance, asReal(lambda_min_),
345361
asReal(lambda_max_), max_lambdas, verbose, halt_error_diff, log_level, NULL, 0,
346-
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, continuous_X);
362+
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, continuous_X, ci);
347363
float final_lambda = lasso_result.final_lambda;
348364
float regularized_intercept = lasso_result.regularized_intercept;
349365
float unbiased_intercept = lasso_result.unbiased_intercept;

src/liblasso.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,9 @@ int_fast64_t** X2_from_X(int_fast64_t** X, int_fast64_t n, int_fast64_t p)
153153
}
154154
}
155155
return X2;
156+
}
157+
158+
void free_continuous_info(struct continuous_info ci) {
159+
delete[] ci.col_real_vals;
160+
delete[] ci.col_max_vals;
156161
}

src/liblasso.h

+6
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ struct OpenCL_Setup {
9999
// int_fast64_t b : 64;
100100
// } int_128;
101101

102+
struct continuous_info {
103+
bool use_cont;
104+
std::vector<float>* col_real_vals;
105+
float* col_max_vals;
106+
};
107+
102108
typedef struct {
103109
robin_hood::unordered_flat_map<XXH64_hash_t, robin_hood::unordered_flat_map<XXH64_hash_t, robin_hood::unordered_flat_set<int_fast64_t>>> cols_for_hash;
104110
// robin_hood::unordered_flat_map<int64_t, std::vector<int64_t>> defining_co;

src/pruning.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ float pessimistic_estimate(float alpha, float* last_rowsum, float* rowsum,
4747
// the worst case effect is \leq last_max * alpha + pessimistic_estimate()
4848
float l2_combined_estimate(X_uncompressed X, float lambda, int_fast64_t k,
4949
float last_max, float* last_rowsum,
50-
float* rowsum)
50+
float* rowsum, struct continuous_info* ci)
5151
{
5252
float alpha = 0.0;
5353
// read through the compressed column
@@ -70,6 +70,8 @@ float l2_combined_estimate(X_uncompressed X, float lambda, int_fast64_t k,
7070
alpha = 0.0;
7171

7272
float remainder = pessimistic_estimate(alpha, last_rowsum, rowsum, col, X.host_col_nz[k]);
73+
if (ci->use_cont)
74+
remainder *= fabs(ci->col_max_vals[k]);
7375

7476
float total_estimate = fabs(last_max * alpha) + remainder;
7577
return total_estimate;
@@ -87,10 +89,10 @@ float l2_combined_estimate(X_uncompressed X, float lambda, int_fast64_t k,
8789
*/
8890
// TODO: should beta[k] be in here?
8991
bool wont_update_effect(X_uncompressed X, float lambda, int_fast64_t k, float last_max,
90-
float* last_rowsum, float* rowsum, int_fast64_t* column_cache)
92+
float* last_rowsum, float* rowsum, int_fast64_t* column_cache, struct continuous_info* ci)
9193
{
9294
// int_fast64_t* cache = malloc(X.n * sizeof *column_cache);
93-
float upper_bound = l2_combined_estimate(X, lambda, k, last_max, last_rowsum, rowsum);
95+
float upper_bound = l2_combined_estimate(X, lambda, k, last_max, last_rowsum, rowsum, ci);
9496
return upper_bound <= lambda * total_sqrt_error;
9597
}
9698

@@ -115,7 +117,7 @@ float as_pessimistic_estimate(float alpha, float* last_rowsum, float* rowsum,
115117
return estimate;
116118
}
117119

118-
float as_combined_estimate(float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* cache)
120+
float as_combined_estimate(float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* cache, float col_max, bool use_cont)
119121
{
120122
float alpha = 0.0;
121123
// read through the compressed column
@@ -149,12 +151,14 @@ float as_combined_estimate(float lambda, float last_max, float* last_rowsum, flo
149151
alpha = 0.0;
150152

151153
float remainder = as_pessimistic_estimate(alpha, last_rowsum, rowsum, cache, col_entry_pos);
154+
if (use_cont)
155+
remainder *= fabs(col_max);
152156

153157
float total_estimate = fabs(last_max * alpha) + remainder;
154158
return total_estimate;
155159
}
156-
bool as_wont_update(X_uncompressed Xu, float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* column_cache)
160+
bool as_wont_update(X_uncompressed Xu, float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* column_cache, float col_max, bool use_cont)
157161
{
158-
float upper_bound = as_combined_estimate(lambda, last_max, last_rowsum, rowsum, col, column_cache);
162+
float upper_bound = as_combined_estimate(lambda, last_max, last_rowsum, rowsum, col, column_cache, col_max, use_cont);
159163
return upper_bound <= lambda * total_sqrt_error;
160164
}

src/pruning.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
bool wont_update_effect(X_uncompressed X, float lambda, int_fast64_t k, float last_max,
2-
float* last_rowsum, float* rowsum, int_fast64_t* column_cache);
3-
bool as_wont_update(X_uncompressed Xu, float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* column_cache);
2+
float* last_rowsum, float* rowsum, int_fast64_t* column_cache, struct continuous_info* ci);
3+
bool as_wont_update(X_uncompressed Xu, float lambda, float last_max, float* last_rowsum, float* rowsum, S8bCol col, int_fast64_t* column_cache, float col_max, bool use_cont);
44
bool as_pessimistic_est(float lambda, float* rowsum, S8bCol col);

src/regression.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ void subproblem_only(Iter_Vars* vars, float lambda, float* rowsum,
243243

244244
int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsum,
245245
float* old_rowsum, Active_Set* active_set,
246-
int_fast64_t depth, const bool use_intercept, IndiCols* indi, const bool check_duplicates)
246+
int_fast64_t depth, const bool use_intercept, IndiCols* indi, const bool check_duplicates,
247+
struct continuous_info* cont_inf)
247248
{
248249
XMatrixSparse Xc = vars->Xc;
249250
X_uncompressed Xu = vars->Xu;
@@ -297,7 +298,7 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
297298
for (int_fast64_t j = 0; j < p; j++) {
298299
bool prev_wont_update = wont_update[j];
299300
wont_update[j] = wont_update_effect(Xu, lambda, j, last_max[j], last_rowsum[j], rowsum,
300-
thread_caches[omp_get_thread_num()].col_j);
301+
thread_caches[omp_get_thread_num()].col_j, cont_inf);
301302
if (!wont_update[j] && !(*vars->seen_before)[j]) {
302303
// if (!wont_update[j] && !prev_wont_update) {
303304
// if (true) {
@@ -352,7 +353,7 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
352353
clock_gettime(CLOCK_MONOTONIC_RAW, &start_time);
353354
auto working_set_results = update_working_set(vars->Xu, Xc, rowsum, wont_update, p, n, lambda,
354355
updateable_items, count_may_update, active_set,
355-
thread_caches, last_max, depth, indi, &new_cols, max_interaction_distance, check_duplicates);
356+
thread_caches, last_max, depth, indi, &new_cols, max_interaction_distance, check_duplicates, cont_inf);
356357
bool increased_set = working_set_results.first;
357358
auto vals_to_remove = working_set_results.second;
358359
for (auto val : vals_to_remove) {
@@ -494,7 +495,7 @@ Lasso_Result simple_coordinate_descent_lasso(
494495
float hed, enum LOG_LEVEL log_level,
495496
const char** job_args, int_fast64_t job_args_num,
496497
int_fast64_t mnz_beta, const char* log_filename, int_fast64_t depth,
497-
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X)
498+
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X, struct continuous_info* cont_inf)
498499
{
499500
int_fast64_t max_nz_beta = mnz_beta;
500501
if (verbose)
@@ -724,7 +725,7 @@ Lasso_Result simple_coordinate_descent_lasso(
724725
int_fast64_t last_iter_count = 0;
725726

726727
nz_beta += run_lambda_iters_pruned(&iter_vars_pruned, lambda, rowsum,
727-
old_rowsum, &active_set, depth, use_intercept, &indi, check_duplicates);
728+
old_rowsum, &active_set, depth, use_intercept, &indi, check_duplicates, cont_inf);
728729

729730
{
730731
int_fast64_t nonzero = beta_sets.beta1.size() + beta_sets.beta2.size() + beta_sets.beta3.size();

src/regression.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Lasso_Result simple_coordinate_descent_lasso(
3030
float lambda_min, float lambda_max, int_fast64_t max_iter, const bool VERBOSE,
3131
float halt_beta_diff,
3232
enum LOG_LEVEL log_level, const char** job_args, int_fast64_t job_args_num,
33-
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X);
33+
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X, struct continuous_info* cont_inf);
3434
float update_intercept_cyclic(float intercept, int_fast64_t** X, float* Y,
3535
robin_hood::unordered_flat_map<int_fast64_t, float>* beta, int_fast64_t n, int_fast64_t p);
3636
// Changes update_beta_cyclic(XMatrixSparse xmatrix_sparse, float *Y,

src/sparse_matrix.cpp

+27-3
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,23 @@ void free_row_set(struct row_set rs)
3232
}
3333
free(rs.row_lengths);
3434
free(rs.rows);
35+
if (rs.row_real_vals != NULL)
36+
delete[] rs.row_real_vals;
3537
}
3638

3739
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu,
3840
bool* remove,
39-
Thread_Cache* thread_caches)
41+
Thread_Cache* thread_caches, struct continuous_info* ci)
4042
{
4143
int_fast64_t p = Xc.p;
4244
int_fast64_t n = Xc.n;
4345
struct row_set rs;
4446
rs.num_rows = n;
4547
int_fast64_t** new_rows = (int_fast64_t**)calloc(n, sizeof *new_rows);
4648
int_fast64_t* row_lengths = (int_fast64_t*)calloc(n, sizeof *row_lengths);
49+
std::vector<float>* row_real_vals = NULL;
50+
if (ci->use_cont)
51+
row_real_vals = new std::vector<float>[n];
4752

4853
// #pragma omp parallel for
4954
for (int_fast64_t row = 0; row < n; row++) {
@@ -64,9 +69,20 @@ struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu,
6469
new_rows[row] = (int_fast64_t*)malloc(row_pos * sizeof *new_rows);
6570
memcpy(new_rows[row], row_cache, row_pos * sizeof *new_rows);
6671
}
72+
if (ci->use_cont) {
73+
for (int col = 0; col < p; col++) {
74+
int_fast64_t* col_vals = &Xu.host_X[Xu.host_col_offsets[col]];
75+
for (int ri = 0; ri < Xu.host_col_nz[col]; ri++) {
76+
int_fast64_t row = col_vals[ri];
77+
float col_real_val = ci->col_real_vals[col][ri];
78+
row_real_vals[row].push_back(col_real_val);
79+
}
80+
}
81+
}
6782

6883
rs.rows = new_rows;
6984
rs.row_lengths = row_lengths;
85+
rs.row_real_vals = row_real_vals;
7086
return rs;
7187
}
7288

@@ -161,7 +177,8 @@ void free_indicols(IndiCols indi) {
161177

162178
std::vector<int_fast64_t> update_main_indistinguishable_cols(
163179
X_uncompressed Xu, bool* wont_update, struct row_set relevant_row_set,
164-
IndiCols* indi, robin_hood::unordered_flat_set<int_fast64_t>* new_cols)
180+
IndiCols* indi, robin_hood::unordered_flat_set<int_fast64_t>* new_cols,
181+
struct continuous_info* ci)
165182
{
166183
int_fast64_t total_cols_checked = 0;
167184
// robin_hood::unordered_flat_map<int64_t, std::vector<int64_t>>
@@ -171,7 +188,14 @@ std::vector<int_fast64_t> update_main_indistinguishable_cols(
171188
total_cols_checked++;
172189
int_fast64_t main_col_len = Xu.host_col_nz[main];
173190
int_fast64_t* column_entries = &Xu.host_X[Xu.host_col_offsets[main]];
174-
XXH128_hash_t main_hash = XXH3_128bits(column_entries, main_col_len * sizeof(int_fast64_t));
191+
XXH3_state_t* mh_state = XXH3_createState();
192+
XXH3_128bits_reset(mh_state);
193+
XXH3_128bits_update(mh_state, column_entries, main_col_len * sizeof(int_fast64_t));
194+
if (ci->use_cont)
195+
XXH3_128bits_update(mh_state, &ci->col_real_vals[main][0], ci->col_real_vals[main].size() * sizeof(ci->col_real_vals[0]));
196+
XXH128_hash_t main_hash = XXH3_128bits_digest(mh_state);
197+
XXH3_freeState(mh_state);
198+
175199

176200
if (indi->main_col_hashes[main_hash.high64].contains(main_hash.low64))
177201
// indi->skip_main_col_ids.insert(main);

src/sparse_matrix.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct row_set {
2020
int_fast64_t** rows;
2121
int_fast64_t* row_lengths;
2222
int_fast64_t num_rows;
23+
std::vector<float>* row_real_vals;
2324
// S8bCol* s8b_rows;
2425
};
2526

@@ -57,11 +58,11 @@ XMatrixSparse sparse_X2_from_X(int_fast64_t** X, int_fast64_t n, int_fast64_t p,
5758
int_fast64_t max_interaction_distance, int_fast64_t shuffle);
5859
XMatrixSparse sparsify_X(int_fast64_t** X, int_fast64_t n, int_fast64_t p);
5960

60-
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu, bool* remove, Thread_Cache* thread_caches);
61+
struct row_set row_list_without_columns(XMatrixSparse Xc, X_uncompressed Xu, bool* remove, Thread_Cache* thread_caches, struct continuous_info* ci);
6162
void free_row_set(struct row_set rs);
6263
X_uncompressed construct_host_X(XMatrixSparse* Xc);
6364
void free_host_X(X_uncompressed *Xu);
64-
std::vector<int_fast64_t> update_main_indistinguishable_cols(X_uncompressed Xu, bool* wont_update, struct row_set relevant_row_set, IndiCols* last_result, robin_hood::unordered_flat_set<int_fast64_t>* new_cols);
65+
std::vector<int_fast64_t> update_main_indistinguishable_cols(X_uncompressed Xu, bool* wont_update, struct row_set relevant_row_set, IndiCols* last_result, robin_hood::unordered_flat_set<int_fast64_t>* new_cols, struct continuous_info* ci);
6566
std::vector<int_fast64_t> get_col_by_id(X_uncompressed Xu, int_fast64_t id);
6667
IndiCols get_empty_indicols(int_fast64_t p);
6768
void free_indicols(IndiCols indi);

0 commit comments

Comments
 (0)