@@ -158,7 +158,7 @@ static auto rng = std::default_random_engine();
158
158
159
159
void subproblem_only (Iter_Vars* vars, float lambda, float * rowsum,
160
160
float * old_rowsum, Active_Set* active_set,
161
- int_fast64_t depth, char use_intercept)
161
+ int_fast64_t depth, char use_intercept, struct continuous_info * cont_inf )
162
162
{
163
163
float ** last_rowsum = vars->last_rowsum ;
164
164
Thread_Cache* thread_caches = vars->thread_caches ;
@@ -205,8 +205,8 @@ void subproblem_only(Iter_Vars* vars, float lambda, float* rowsum,
205
205
if (entry.present ) {
206
206
if (current_beta_set->contains (k) && fabs (current_beta_set->at (k)) != 0.0 ) {
207
207
update_beta_cyclic (
208
- entry. col , Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept ,
209
- thread_caches[omp_get_thread_num ()].col_i );
208
+ & entry, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept ,
209
+ thread_caches[omp_get_thread_num ()].col_i , cont_inf );
210
210
}
211
211
}
212
212
}
@@ -265,6 +265,7 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
265
265
for (int_fast64_t i = 0 ; i < n; i++) {
266
266
error += rowsum[i] * rowsum[i];
267
267
}
268
+ error = sqrt (error);
268
269
269
270
// run several iterations of will_update to make sure we catch any new
270
271
// columns
@@ -405,8 +406,8 @@ int_fast64_t run_lambda_iters_pruned(Iter_Vars* vars, float lambda, float* rowsu
405
406
was_zero = FALSE ;
406
407
total_beta_updates++;
407
408
Changes changes = update_beta_cyclic (
408
- entry. col , Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept ,
409
- thread_caches[omp_get_thread_num ()].col_i );
409
+ & entry, Y, rowsum, n, p, lambda, current_beta_set, k, vars->intercept ,
410
+ thread_caches[omp_get_thread_num ()].col_i , cont_inf );
410
411
if (changes.actual_diff == 0.0 ) {
411
412
total_unchanged++;
412
413
} else {
@@ -495,8 +496,9 @@ Lasso_Result simple_coordinate_descent_lasso(
495
496
float hed, enum LOG_LEVEL log_level,
496
497
const char ** job_args, int_fast64_t job_args_num,
497
498
int_fast64_t mnz_beta, const char * log_filename, int_fast64_t depth,
498
- const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, const bool continuous_X, struct continuous_info * cont_inf)
499
+ const bool estimate_unbiased, const bool use_intercept, const bool check_duplicates, struct continuous_info * cont_inf)
499
500
{
501
+ const bool continuous_X = cont_inf->use_cont ;
500
502
int_fast64_t max_nz_beta = mnz_beta;
501
503
if (verbose)
502
504
printf (" n: %ld, p: %ld\n " , n, p);
@@ -857,7 +859,7 @@ Lasso_Result simple_coordinate_descent_lasso(
857
859
};
858
860
// run_lambda_iters_pruned(&iter_vars_pruned, 0.0, rowsum,
859
861
subproblem_only (&iter_vars_pruned, 0.0 , rowsum,
860
- old_rowsum, &active_set, depth, use_intercept);
862
+ old_rowsum, &active_set, depth, use_intercept, cont_inf );
861
863
if (verbose)
862
864
printf (" un-regularized error: %f\n " , calculate_error (Y, rowsum, n));
863
865
unbiased_intercept = iter_vars_pruned.intercept ;
@@ -965,18 +967,22 @@ Changes update_beta_cyclic_old(
965
967
966
968
return changes;
967
969
}
968
- Changes update_beta_cyclic (S8bCol col , float * Y, float * rowsum, int_fast64_t n, int_fast64_t p,
970
+ Changes update_beta_cyclic (AS_Entry* as_entry , float * Y, float * rowsum, int_fast64_t n, int_fast64_t p,
969
971
float lambda,
970
972
robin_hood::unordered_flat_map<int_fast64_t , float >* beta,
971
- int_fast64_t k, float intercept, int_fast64_t * column_entry_cache)
973
+ int_fast64_t k, float intercept, int_fast64_t * column_entry_cache, struct continuous_info * ci )
972
974
{
975
+ S8bCol col = as_entry->col ;
973
976
float sumk = col.nz ;
977
+ if (ci->use_cont )
978
+ sumk = 0.0 ;
974
979
float bk = 0.0 ;
975
980
if (beta->contains (k)) {
976
981
bk = beta->at (k);
977
982
}
978
- float sumn = col. nz * bk ;
983
+ float sumn = 0.0 ;
979
984
int_fast64_t * column_entries = column_entry_cache;
985
+ std::vector<float > col_real_vals = as_entry->real_vals ;
980
986
981
987
int_fast64_t col_entry_pos = 0 ;
982
988
int_fast64_t entry = -1 ;
@@ -988,12 +994,20 @@ Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n,
988
994
if (diff != 0 ) {
989
995
entry += diff;
990
996
column_entries[col_entry_pos] = entry;
991
- sumn -= rowsum[entry];
997
+ float rs = rowsum[entry];
998
+ if (ci->use_cont ) {
999
+ float cv = col_real_vals[col_entry_pos];
1000
+ rs *= cv;
1001
+ sumk += cv*cv;
1002
+ // sumn += cv*bk;
1003
+ }
1004
+ sumn -= rs;
992
1005
col_entry_pos++;
993
1006
}
994
1007
values >>= item_width[word.selector ];
995
1008
}
996
1009
}
1010
+ sumn += sumk*bk;
997
1011
998
1012
float new_value = soft_threshold (sumn, lambda * total_sqrt_error) / sumk; // square root lasso
999
1013
float Bk_diff = new_value - bk;
@@ -1011,8 +1025,11 @@ Changes update_beta_cyclic(S8bCol col, float* Y, float* rowsum, int_fast64_t n,
1011
1025
if (Bk_diff != 0 ) {
1012
1026
for (int_fast64_t e = 0 ; e < col.nz ; e++) {
1013
1027
int_fast64_t i = column_entries[e];
1028
+ float offset = Bk_diff;
1029
+ if (ci->use_cont )
1030
+ offset *= col_real_vals[e];
1014
1031
#pragma omp atomic
1015
- rowsum[i] += Bk_diff ;
1032
+ rowsum[i] += offset ;
1016
1033
}
1017
1034
} else {
1018
1035
zero_updates++;
0 commit comments