Skip to content

Commit

Permalink
Fix: Turin kernels for spdot (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
real-eren authored Feb 19, 2025
1 parent d3ee357 commit 5044fef
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions include/simsimd/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //

// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
if (a_length < 64 && b_length < 64) {
simsimd_intersect_u16_serial(a, b, a_length, b_length, results);
simsimd_spdot_weights_u16_serial(a, b, a_weights, b_weights, a_length, b_length, results);
return;
}

Expand Down Expand Up @@ -751,9 +751,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
a += a_step, a_weights += a_step;
b += b_step, b_weights += b_step;
}

simsimd_intersect_u16_serial(a, b, a_end - a, b_end - b, results);
*results += intersection_size;
simsimd_spdot_weights_u16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, results);
results[0] += intersection_size;
results[1] += _mm512_reduce_add_ps(_mm512_insertf32x8(_mm512_setzero_ps(), product_vec.ymmps, 0));
}

SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
Expand All @@ -764,7 +764,7 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //

// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
if (a_length < 64 && b_length < 64) {
simsimd_intersect_u16_serial(a, b, a_length, b_length, results);
simsimd_spdot_counts_u16_serial(a, b, a_weights, b_weights, a_length, b_length, results);
return;
}

Expand Down Expand Up @@ -837,8 +837,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
b += b_step, b_weights += b_step;
}

simsimd_intersect_u16_serial(a, b, a_end - a, b_end - b, results);
*results += intersection_size;
simsimd_spdot_counts_u16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, results);
results[0] += intersection_size;
results[1] += _mm512_reduce_add_epi32(_mm512_inserti64x4(_mm512_setzero_si512(), product_vec.ymm, 0));
}

#pragma clang attribute pop
Expand Down

0 comments on commit 5044fef

Please sign in to comment.