@@ -166,7 +166,7 @@ char update_working_set_cpu(struct XMatrixSparse Xc,
166
166
X_uncompressed Xu, float * rowsum,
167
167
bool * wont_update, int_fast64_t p, int_fast64_t n, float lambda,
168
168
int_fast64_t * updateable_items, int_fast64_t count_may_update,
169
- float * last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t >* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info * ci)
169
+ float * last_max, int_fast64_t depth, IndiCols* indicols, robin_hood::unordered_flat_set<int_fast64_t >* new_cols, int_fast64_t max_interaction_distance, const bool check_duplicates, struct continuous_info * ci, const bool use_hierarchy )
170
170
{
171
171
int_fast64_t * host_X = Xu.host_X ;
172
172
int_fast64_t * host_col_nz = Xu.host_col_nz ;
@@ -239,22 +239,28 @@ char update_working_set_cpu(struct XMatrixSparse Xc,
239
239
sum_with_col[main] += rowsum_diff; // TODO: slow
240
240
if (depth > 1 ) {
241
241
int_fast64_t ri = 0 ;
242
- int_fast64_t jump_dist = relevant_row_set.row_lengths [row_main]/ 2 ;
242
+ const int_fast64_t row_length = relevant_row_set.row_lengths [row_main];
243
243
const int_fast64_t * row = relevant_row_set.rows [row_main];
244
- int_fast64_t tmp = row[ri];
245
- while (tmp != main) {
246
- if (tmp < main)
247
- ri += jump_dist;
248
- else if (tmp > main)
249
- ri -= jump_dist;
250
- else if (tmp == main)
251
- break ;
252
- jump_dist = std::max ((int_fast64_t )1 ,jump_dist/2 );
253
- tmp = row[ri];
244
+ if (use_hierarchy) {
245
+ // we can't assume row_main is in the reduced row matrix.
246
+ while (ri < row_length && row[ri] < main)
247
+ ri++;
248
+ } else {
249
+ int_fast64_t jump_dist = row_length/2 ;
250
+ int_fast64_t tmp = row[ri];
251
+ while (tmp != main) {
252
+ if (tmp < main)
253
+ ri += jump_dist;
254
+ else if (tmp > main)
255
+ ri -= jump_dist;
256
+ else if (tmp == main)
257
+ break ;
258
+ jump_dist = std::max ((int_fast64_t )1 ,jump_dist/2 );
259
+ tmp = row[ri];
260
+ }
261
+ ri++;
254
262
}
255
- ri++;
256
263
257
- const int_fast64_t row_length = relevant_row_set.row_lengths [row_main];
258
264
for (; ri < row_length; ri++) {
259
265
int_fast64_t inter = relevant_row_set.rows [row_main][ri];
260
266
float inter_val = 1.0 ;
@@ -549,15 +555,15 @@ std::pair<bool, std::vector<int_fast64_t>> update_working_set(X_uncompressed Xu,
549
555
Active_Set* as, Thread_Cache* thread_caches,
550
556
float * last_max,
551
557
int_fast64_t depth, IndiCols *indicols, robin_hood::unordered_flat_set<int_fast64_t >* new_cols,
552
- int_fast64_t max_interaction_distance, bool check_duplicates, struct continuous_info * ci)
558
+ int_fast64_t max_interaction_distance, bool check_duplicates, struct continuous_info * ci, bool use_hierarchy, Beta_Value_Sets* beta_sets )
553
559
{
554
- struct row_set new_row_set = row_list_without_columns (Xc, Xu, wont_update, thread_caches, ci);
560
+ struct row_set new_row_set = row_list_without_columns (Xc, Xu, wont_update, thread_caches, ci, use_hierarchy, beta_sets );
555
561
std::vector<int_fast64_t > vals_to_remove;
556
562
if (check_duplicates)
557
563
vals_to_remove = update_main_indistinguishable_cols (Xu, wont_update, new_row_set, indicols, new_cols, ci);
558
564
char increased_set = update_working_set_cpu (
559
565
Xc, new_row_set, thread_caches, as, Xu, rowsum, wont_update, p, n, lambda,
560
- updateable_items, count_may_update, last_max, depth, indicols, new_cols, max_interaction_distance, check_duplicates, ci);
566
+ updateable_items, count_may_update, last_max, depth, indicols, new_cols, max_interaction_distance, check_duplicates, ci, use_hierarchy );
561
567
562
568
free_row_set (new_row_set);
563
569
return std::make_pair (increased_set, vals_to_remove);
0 commit comments