Skip to content

Commit 7a9276a

Browse files
author
Kieran Elmes
committed
probably working continuous X
1 parent 04c3de8 commit 7a9276a

15 files changed

+262
-97
lines changed

R/pint.R

+14-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
val_to_list_name <- function(val, X) {
1313
range <- ncol(X)
1414
names <- colnames(X)
15+
print(sprintf("val: %s", val))
16+
print(sprintf("range: %s", range))
1517
if (val < range) {
1618
return(names[val + 1])
1719
} else if (val < range * range) {
@@ -29,8 +31,8 @@ val_to_list_name <- function(val, X) {
2931
process_result <- function(X, result) {
3032
i <- colnames(X)[result[[1]]]
3133
strength <- result[[2]]
32-
equiv_list = c()
33-
if (length(result[[3]]) > 0) {
34+
equiv_list <- c()
35+
if (length(result[[3]][0]) > 0) {
3436
equiv_list <- lapply(result[[3]], val_to_list_name, X)
3537
names(equiv_list) <- i
3638
}
@@ -39,8 +41,8 @@ process_result <- function(X, result) {
3941
i <- colnames(X)[result[[4]]]
4042
j <- colnames(X)[result[[5]]]
4143
strength <- result[[6]]
42-
equiv_list = c()
43-
if (length(result[[7]]) > 0) {
44+
equiv_list <- c()
45+
if (length(result[[7]][0]) > 0) {
4446
equiv_list <- lapply(result[[7]], val_to_list_name, X)
4547
names(equiv_list) <- paste0(i, ",", j)
4648
}
@@ -49,8 +51,8 @@ process_result <- function(X, result) {
4951
i <- colnames(X)[result[[8]]]
5052
j <- colnames(X)[result[[9]]]
5153
k <- colnames(X)[result[[10]]]
52-
equiv_list = c()
53-
if (length(result[[12]]) > 0) {
54+
equiv_list <- c()
55+
if (length(result[[12]][0]) > 0) {
5456
equiv_list <- lapply(result[[12]], val_to_list_name, X)
5557
names(equiv_list) <- paste0(i, ",", j, ",", k)
5658
}
@@ -69,7 +71,7 @@ read_log <- function(log_filename="regression.log") {
6971
return(process_result(result))
7072
}
7173

72-
interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -1, halt_error_diff=1.01, max_interaction_distance=-1, max_nz_beta=-1, max_lambdas=200, verbose=FALSE, log_filename="regression.log", depth=2, log_level="none", estimate_unbiased=FALSE, use_intercept=TRUE, num_threads=-1, strong_hierarchy=FALSE, check_duplicates=FALSE) {
74+
interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -1, halt_error_diff=1.01, max_interaction_distance=-1, max_nz_beta=-1, max_lambdas=200, verbose=FALSE, log_filename="regression.log", depth=2, log_level="none", estimate_unbiased=FALSE, use_intercept=TRUE, num_threads=-1, strong_hierarchy=FALSE, check_duplicates=FALSE, continuous_X=FALSE) {
7375
Ym = as.matrix(Y)
7476
if (!dim(Ym)[1] == n) {
7577
stop("Y does not have the same number of rows as X, or the format is wrong")
@@ -79,7 +81,6 @@ interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -
7981
if (continuous_X) {
8082
check_duplicates <- FALSE
8183
}
82-
continuous_X <- FALSE # not implemented yet.
8384

8485
log_level_enum = 0;
8586
if (log_level == "lambda") {
@@ -112,7 +113,11 @@ interaction_lasso <- function(X, Y, n = dim(X)[1], p = dim(X)[2], lambda_min = -
112113
main_only = .Call(lasso_, X, Ym, lambda_min, lambda_max, halt_error_diff, max_interaction_distance, max_nz_beta*2, max_lambdas, verbose, log_filename, 1, log_level_enum, estimate_unbiased, use_intercept, num_threads, check_duplicates, continuous_X)
113114
main_only_result <- process_result(X, main_only[[1]])
114115
# we'll use these to allow the second run to report all identical columns again.
115-
all_equiv <- unique(unlist(main_only_result$main_effects$equivalent))
116+
if (check_duplicates) {
117+
all_equiv <- unique(unlist(main_only_result$main_effects$equivalent))
118+
} else {
119+
all_equiv <- unique(unlist(main_only_result$main_effects$i))
120+
}
116121

117122
X <- X[,all_equiv]
118123
}

coverage-badge.svg

+1-1
Loading

install_and_run.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ Y <- d$Y
4242
## result <- interaction_lasso(X, Y, lambda_min = -1, max_interaction_distance = -1, use_adaptive_calibration = FALSE, max_nz_beta = 200, depth = 3)
4343
# result <- interaction_lasso(X, Y, depth = 3)
4444
# result <- interaction_lasso(X, Y, depth = 2)
45-
result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 150, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, strong_hierarchy = TRUE, check_duplicates = TRUE)
45+
# result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 150, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, strong_hierarchy = TRUE, check_duplicates = TRUE, continuous_X = TRUE)
46+
result <- interaction_lasso(X, Y, depth = 2, max_nz_beta = 150, estimate_unbiased = TRUE, num_threads = 4, verbose=TRUE, strong_hierarchy = FALSE, check_duplicates = TRUE, continuous_X = TRUE)
4647
# print(result)
4748

4849
# q()

meson.build

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_project_arguments([
1010
'-ffast-math',
1111
'-fno-stack-protector',
1212
'-fpermissive',
13+
'-Wno-unused',
1314
#'-Ofast',
1415
# '-g',
1516
# '-Os',

src/Pint.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ SEXP lasso_(SEXP X_, SEXP Y_, SEXP lambda_min_, SEXP lambda_max_,
359359
Lasso_Result lasso_result = simple_coordinate_descent_lasso(
360360
xmatrix, Y, n, p, max_interaction_distance, asReal(lambda_min_),
361361
asReal(lambda_max_), max_lambdas, verbose, halt_error_diff, log_level, NULL, 0,
362-
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, continuous_X, ci);
362+
max_nz_beta, log_filename, depth, estimate_unbiased, use_intercept, check_duplicates, &ci);
363363
float final_lambda = lasso_result.final_lambda;
364364
float regularized_intercept = lasso_result.regularized_intercept;
365365
float unbiased_intercept = lasso_result.unbiased_intercept;

src/liblasso.h

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct continuous_info {
104104
std::vector<float>* col_real_vals;
105105
float* col_max_vals;
106106
};
107+
void free_continuous_info(struct continuous_info ci);
107108

108109
typedef struct {
109110
robin_hood::unordered_flat_map<XXH64_hash_t, robin_hood::unordered_flat_map<XXH64_hash_t, robin_hood::unordered_flat_set<int_fast64_t>>> cols_for_hash;
@@ -145,6 +146,7 @@ struct AS_Entry {
145146
S8bCol col;
146147
float *last_rowsum;
147148
float last_max;
149+
std::vector<float> real_vals;
148150
};
149151

150152
typedef struct {

src/queue.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ void* queue_pop_head(Queue* q)
6161
void queue_free(Queue* q)
6262
{
6363
Queue_Item* current_item = q->first_item;
64-
Queue_Item* next_item = (Queue_Item*)current_item->next;
6564

6665
// free the queue contents
6766
while (current_item != NULL) {
6867
free(current_item->contents);
69-
next_item = (Queue_Item*)current_item->next;
68+
Queue_Item* next_item = (Queue_Item*)current_item->next;
7069
free(current_item);
7170
current_item = next_item;
7271
}

src/regression.cpp

+29-12
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ static auto rng = std::default_random_engine();
158158

159159
void subproblem_only(Iter_Vars* vars, float lambda, float* rowsum,
160160
float* old_rowsum, Active_Set* active_set,
161-
int_fast64_t depth, char use_intercept)
161+
int_fast64_t depth, char use_intercept, struct continuous_info* cont_inf)
162162
{
163163
float** last_rowsum = vars->last_rowsum;
164164
Thread_Cache* thread_caches = vars->thread_caches;
@@ -205,8 +205,8 @@ void subproblem_only(Iter_Vars* vars, float lambda, float* rowsum,
205205
if (entry.present) {
206206
if (current_beta_set->contains(k) && fabs(current_beta_set->at(k)) != 0.0) {
207207
update_beta_cyclic(
208-
entry.col, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept,
209-
thread_caches[omp_get_thread_num()].col_i);
208+
&entry, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept,
209+
thread_caches[omp_get_thread_num()].col_i, cont_inf);
210210
}
211211
}
212212
}
@@ -265,6 +265,7 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
265265
for (int_fast64_t i = 0; i < n; i++) {
266266
error += rowsum[i] * rowsum[i];
267267
}
268+
error = sqrt(error);
268269

269270
// run several iterations of will_update to make sure we catch any new
270271
// columns
@@ -405,8 +406,8 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
405406
was_zero = FALSE;
406407
total_beta_updates++;
407408
Changes changes = update_beta_cyclic(
408-
entry.col, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept,
409-
thread_caches[omp_get_thread_num()].col_i);
409+
&entry, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept,
410+
thread_caches[omp_get_thread_num()].col_i, cont_inf);
410411
if (changes.actual_diff == 0.0) {
411412
total_unchanged++;
412413
} else {
@@ -495,8 +496,9 @@ Lasso_Result simple_coordinate_descent_lasso(
495496
float hed, enum LOG_LEVEL log_level,
496497
const char** job_args, int_fast64_t job_args_num,
497498
int_fast64_t mnz_beta, const char* log_filename, int_fast64_t depth,
498-
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X, struct continuous_info* cont_inf)
499+
const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf)
499500
{
501+
const bool continuous_X = cont_inf->use_cont;
500502
int_fast64_t max_nz_beta = mnz_beta;
501503
if (verbose)
502504
printf("n: %ld, p: %ld\n", n, p);
@@ -857,7 +859,7 @@ Lasso_Result simple_coordinate_descent_lasso(
857859
};
858860
// run_lambda_iters_pruned(&iter_vars_pruned, 0.0, rowsum,
859861
subproblem_only(&iter_vars_pruned, 0.0, rowsum,
860-
old_rowsum, &active_set, depth, use_intercept);
862+
old_rowsum, &active_set, depth, use_intercept, cont_inf);
861863
if (verbose)
862864
printf("un-regularized error: %f\n", calculate_error(Y, rowsum, n));
863865
unbiased_intercept = iter_vars_pruned.intercept;
@@ -965,18 +967,22 @@ Changes update_beta_cyclic_old(
965967

966968
return changes;
967969
}
968-
Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n, int_fast64_t p,
970+
Changes update_beta_cyclic(AS_Entry* as_entry, float* Y, float* rowsum, int_fast64_t n, int_fast64_t p,
969971
float lambda,
970972
robin_hood::unordered_flat_map<int_fast64_t, float>* beta,
971-
int_fast64_t k, float intercept, int_fast64_t* column_entry_cache)
973+
int_fast64_t k, float intercept, int_fast64_t* column_entry_cache, struct continuous_info* ci)
972974
{
975+
S8bCol col = as_entry->col;
973976
float sumk = col.nz;
977+
if (ci->use_cont)
978+
sumk = 0.0;
974979
float bk = 0.0;
975980
if (beta->contains(k)) {
976981
bk = beta->at(k);
977982
}
978-
float sumn = col.nz * bk;
983+
float sumn = 0.0;
979984
int_fast64_t* column_entries = column_entry_cache;
985+
std::vector<float> col_real_vals = as_entry->real_vals;
980986

981987
int_fast64_t col_entry_pos = 0;
982988
int_fast64_t entry = -1;
@@ -988,12 +994,20 @@ Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n,
988994
if (diff != 0) {
989995
entry += diff;
990996
column_entries[col_entry_pos] = entry;
991-
sumn -= rowsum[entry];
997+
float rs = rowsum[entry];
998+
if (ci->use_cont) {
999+
float cv = col_real_vals[col_entry_pos];
1000+
rs *= cv;
1001+
sumk += cv*cv;
1002+
// sumn += cv*bk;
1003+
}
1004+
sumn -= rs;
9921005
col_entry_pos++;
9931006
}
9941007
values >>= item_width[word.selector];
9951008
}
9961009
}
1010+
sumn += sumk*bk;
9971011

9981012
float new_value = soft_threshold(sumn, lambda * total_sqrt_error) / sumk; // square root lasso
9991013
float Bk_diff = new_value - bk;
@@ -1011,8 +1025,11 @@ Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n,
10111025
if (Bk_diff != 0) {
10121026
for (int_fast64_t e = 0; e < col.nz; e++) {
10131027
int_fast64_t i = column_entries[e];
1028+
float offset = Bk_diff;
1029+
if (ci->use_cont)
1030+
offset *= col_real_vals[e];
10141031
#pragma omp atomic
1015-
rowsum[i] += Bk_diff;
1032+
rowsum[i] += offset;
10161033
}
10171034
} else {
10181035
zero_updates++;

src/regression.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ Lasso_Result simple_coordinate_descent_lasso(
3030
float lambda_min, float lambda_max, int_fast64_t max_iter, const bool VERBOSE,
3131
float halt_beta_diff,
3232
enum LOG_LEVEL log_level, const char** job_args, int_fast64_t job_args_num,
33-
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X, struct continuous_info* cont_inf);
33+
int_fast64_t max_nz_beta, const char* log_filename, int_fast64_t depth, const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info* cont_inf);
3434
float update_intercept_cyclic(float intercept, int_fast64_t** X, float* Y,
3535
robin_hood::unordered_flat_map<int_fast64_t, float>* beta, int_fast64_t n, int_fast64_t p);
3636
// Changes update_beta_cyclic(XMatrixSparse xmatrix_sparse, float *Y,
37-
Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n, int_fast64_t p,
37+
Changes update_beta_cyclic(AS_Entry* entry, float* Y, float* rowsum, int_fast64_t n, int_fast64_t p,
3838
float lambda, robin_hood::unordered_flat_map<int_fast64_t, float>* beta, int_fast64_t k,
39-
float intercept, int_fast64_t* column_cache);
39+
float intercept, int_fast64_t* column_cache, struct continuous_info* ci);
4040
Changes update_beta_cyclic_old(XMatrixSparse xmatrix_sparse, float* Y,
4141
float* rowsum, int_fast64_t n, int_fast64_t p, float lambda,
4242
robin_hood::unordered_flat_map<int_fast64_t, float>* beta, int_fast64_t k, float intercept,

src/sparse_matrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ std::vector<int_fast64_t> update_main_indistinguishable_cols(
192192
XXH3_128bits_reset(mh_state);
193193
XXH3_128bits_update(mh_state, column_entries, main_col_len * sizeof(int_fast64_t));
194194
if (ci->use_cont)
195-
XXH3_128bits_update(mh_state, &ci->col_real_vals[main][0], ci->col_real_vals[main].size() * sizeof(ci->col_real_vals[0]));
195+
XXH3_128bits_update(mh_state, &ci->col_real_vals[main][0], ci->col_real_vals[main].size() * sizeof(ci->col_real_vals[main][0]));
196196
XXH128_hash_t main_hash = XXH3_128bits_digest(mh_state);
197197
XXH3_freeState(mh_state);
198198

0 commit comments

Comments
 (0)