Skip to content

Commit

Permalink
Add AVX2 support to breakpoint i2i and m2m
Browse files Browse the repository at this point in the history
Additionally, the code has been restructured for
better readability, whith each vectorized version
having its own function.
  • Loading branch information
quim0 committed Jan 23, 2025
1 parent 0a9db9f commit cab0f55
Showing 1 changed file with 323 additions and 11 deletions.
334 changes: 323 additions & 11 deletions wavefront/wavefront_bialign.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ int wavefront_bialign_base(
/*
* Bidirectional check breakpoints
*/
void wavefront_bialign_breakpoint_indel2indel(
#if __AVX2__
#if __AVX512CD__ && __AVX512VL__
void wavefront_bialign_breakpoint_indel2indel_avx512(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
Expand All @@ -199,7 +201,6 @@ void wavefront_bialign_breakpoint_indel2indel(
wavefront_t* const dwf_1,
const affine2p_matrix_type component,
wf_bialign_breakpoint_t* const breakpoint) {
#if __AVX2__ && __AVX512CD__ && __AVX512VL__
// AVX512 implementation of the bialign_breakpoint_indel2indel
// Parameters
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
Expand Down Expand Up @@ -328,7 +329,182 @@ void wavefront_bialign_breakpoint_indel2indel(
}
k_0 = initial_k0;
}
#else // Scalar implementation of the bialign_breakpoint_indel2indel
}
#else // No AVX512, but AVX2 available
void wavefront_bialign_breakpoint_indel2indel_avx2(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
const int score_1,
wavefront_t* const dwf_0,
wavefront_t* const dwf_1,
const affine2p_matrix_type component,
wf_bialign_breakpoint_t* const breakpoint) {
// Parameters
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
const int text_length = sequences->text_length;
const int pattern_length = sequences->pattern_length;
const int gap_open =
(component==affine2p_matrix_I1 || component==affine2p_matrix_D1) ?
wf_aligner->penalties.gap_opening1 : wf_aligner->penalties.gap_opening2;

if (score_0 + score_1 - gap_open >= breakpoint->score) return;

// Check wavefronts overlapping
const int lo_0 = dwf_0->lo;
const int hi_0 = dwf_0->hi;
const int lo_1 = WAVEFRONT_K_INVERSE(dwf_1->hi,pattern_length,text_length);
const int hi_1 = WAVEFRONT_K_INVERSE(dwf_1->lo,pattern_length,text_length);

if (hi_1 < lo_0 || hi_0 < lo_1) return;

// Compute overlapping interval
const int min_hi = MIN(hi_0,hi_1);
const int max_lo = MAX(lo_0,lo_1);
const int elems_per_register = 8;
const int num_diagonals = min_hi - max_lo + 1;
const int loop_peeling_iters = num_diagonals % elems_per_register;
int k_0;
for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
const wf_offset_t doffset_0 = dwf_0->offsets[k_0];
const wf_offset_t doffset_1 = dwf_1->offsets[k_1];
const int dh_0 = WAVEFRONT_H(k_0,doffset_0);
const int dh_1 = WAVEFRONT_H(k_1,doffset_1);
// Check breakpoint d2d
if (dh_0 + dh_1 >= text_length) {
if (breakpoint_forward) {
// Check out-of-bounds coordinates
const int v = WAVEFRONT_V(k_0,dh_0);
const int h = WAVEFRONT_H(k_0,dh_0);
if (v > pattern_length || h > text_length) continue;
// Set breakpoint
breakpoint->score_forward = score_0;
breakpoint->score_reverse = score_1;
breakpoint->k_forward = k_0;
breakpoint->k_reverse = k_1;
breakpoint->offset_forward = dh_0;
breakpoint->offset_reverse = dh_1;
breakpoint->score = score_0 + score_1 - gap_open;
breakpoint->component = component;
return;
} else {
// Check out-of-bounds coordinates
const int v = WAVEFRONT_V(k_1,dh_1);
const int h = WAVEFRONT_H(k_1,dh_1);
if (v > pattern_length || h > text_length) continue;
// Set breakpoint
breakpoint->score_forward = score_1;
breakpoint->score_reverse = score_0;
breakpoint->k_forward = k_1;
breakpoint->k_reverse = k_0;
breakpoint->offset_forward = dh_1;
breakpoint->offset_reverse = dh_0;
breakpoint->score = score_0 + score_1 - gap_open;
breakpoint->component = component;
return;
}
}
}
// Finish the remaining iterations in a vectorized manner
const __m256i tlens = _mm256_set1_epi32(text_length-1);//enable change >= to >
const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7);
for (;k_0<=min_hi;k_0+=elems_per_register) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
__m256i doffsets_0 = _mm256_lddqu_si256((__m256i*)&dwf_0->offsets[k_0]);
__m256i doffsets_1 = _mm256_lddqu_si256((__m256i*)&dwf_1->offsets[k_1-elems_per_register+1]);
// doffsets_1 are in reverse order, so we need to reverse them
doffsets_1 = _mm256_permutevar8x32_epi32(doffsets_1, rev);
__m256i dh_0_1 = _mm256_add_epi32(doffsets_0, doffsets_1);
__m256i mask = _mm256_cmpgt_epi32(dh_0_1, tlens);
int bp_found_mask = _mm256_movemask_epi8(mask);
if (__builtin_expect(bp_found_mask == 0, 1)) continue;
// A breakpoint has been found! Check in which exact diagonal it is
// This can be done directly from the mask and vector registers, for now,
// it is implemented like the scalar implementation. This only happens
// when a BP is found, so it should not be a bottleneck.
int initial_k0 = k_0;
for (;k_0<initial_k0+elems_per_register;k_0++) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
const wf_offset_t doffset_0 = dwf_0->offsets[k_0];
const wf_offset_t doffset_1 = dwf_1->offsets[k_1];
const int dh_0 = WAVEFRONT_H(k_0,doffset_0);
const int dh_1 = WAVEFRONT_H(k_1,doffset_1);
// Check breakpoint d2d
if (dh_0 + dh_1 >= text_length) {
if (breakpoint_forward) {
// Check out-of-bounds coordinates
const int v = WAVEFRONT_V(k_0,dh_0);
const int h = WAVEFRONT_H(k_0,dh_0);
if (v > pattern_length || h > text_length) continue;
// Set breakpoint
breakpoint->score_forward = score_0;
breakpoint->score_reverse = score_1;
breakpoint->k_forward = k_0;
breakpoint->k_reverse = k_1;
breakpoint->offset_forward = dh_0;
breakpoint->offset_reverse = dh_1;
breakpoint->score = score_0 + score_1 - gap_open;
breakpoint->component = component;
return;
} else {
// Check out-of-bounds coordinates
const int v = WAVEFRONT_V(k_1,dh_1);
const int h = WAVEFRONT_H(k_1,dh_1);
if (v > pattern_length || h > text_length) continue;
// Set breakpoint
breakpoint->score_forward = score_1;
breakpoint->score_reverse = score_0;
breakpoint->k_forward = k_1;
breakpoint->k_reverse = k_0;
breakpoint->offset_forward = dh_1;
breakpoint->offset_reverse = dh_0;
breakpoint->score = score_0 + score_1 - gap_open;
breakpoint->component = component;
return;
}
}
}
k_0 = initial_k0;
}
}
#endif // AVX512
#endif // AVX2
void wavefront_bialign_breakpoint_indel2indel(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
const int score_1,
wavefront_t* const dwf_0,
wavefront_t* const dwf_1,
const affine2p_matrix_type component,
wf_bialign_breakpoint_t* const breakpoint) {
#if __AVX2__
#if __AVX512CD__ && __AVX512VL__
wavefront_bialign_breakpoint_indel2indel_avx512(
wf_aligner,
breakpoint_forward,
score_0,
score_1,
dwf_0,
dwf_1,
component,
breakpoint);
#else
wavefront_bialign_breakpoint_indel2indel_avx2(
wf_aligner,
breakpoint_forward,
score_0,
score_1,
dwf_0,
dwf_1,
component,
breakpoint);
#endif // AVX512
#else // Scalar implementation
// Parameters
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
const int text_length = sequences->text_length;
Expand Down Expand Up @@ -391,21 +567,19 @@ void wavefront_bialign_breakpoint_indel2indel(
// No need to keep searching
return;
}

}
#endif // AVX512
#endif // AVX2
}
void wavefront_bialign_breakpoint_m2m(
#if __AVX2__
#if __AVX512CD__ && __AVX512VL__
void wavefront_bialign_breakpoint_m2m_avx512(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
const int score_1,
wavefront_t* const mwf_0,
wavefront_t* const mwf_1,
wf_bialign_breakpoint_t* const breakpoint) {
// Parameters
if (score_0 + score_1 >= breakpoint->score) return;
#if __AVX2__ && __AVX512CD__ && __AVX512VL__
// AVX512 implementation of the bialign_breakpoint_indel2indel
// Parameters
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
Expand Down Expand Up @@ -512,7 +686,145 @@ void wavefront_bialign_breakpoint_m2m(
}
}
}
#else // Scalar implementation of the bialign_breakpoint_indel2indel
}
#else // No AVX512, but AVX2 available
void wavefront_bialign_breakpoint_m2m_avx2(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
const int score_1,
wavefront_t* const mwf_0,
wavefront_t* const mwf_1,
wf_bialign_breakpoint_t* const breakpoint) {
if ( score_0 + score_1 >= breakpoint->score) return;
// Parameters
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
const int text_length = sequences->text_length;
const int pattern_length = sequences->pattern_length;
// Check wavefronts overlapping
const int lo_0 = mwf_0->lo;
const int hi_0 = mwf_0->hi;
const int lo_1 = WAVEFRONT_K_INVERSE(mwf_1->hi,pattern_length,text_length);
const int hi_1 = WAVEFRONT_K_INVERSE(mwf_1->lo,pattern_length,text_length);
if (hi_1 < lo_0 || hi_0 < lo_1) return;
// Compute overlapping interval
const int min_hi = MIN(hi_0,hi_1);
const int max_lo = MAX(lo_0,lo_1);
const int elems_per_register = 8;
const int num_diagonals = min_hi - max_lo + 1;
const int loop_peeling_iters = num_diagonals % elems_per_register;
int k_0;
for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
const wf_offset_t moffset_0 = mwf_0->offsets[k_0];
const wf_offset_t moffset_1 = mwf_1->offsets[k_1];
const int mh_0 = WAVEFRONT_H(k_0,moffset_0);
const int mh_1 = WAVEFRONT_H(k_1,moffset_1);
// Check breakpoint m2m
if (mh_0 + mh_1 >= text_length) {
if (breakpoint_forward) {
breakpoint->score_forward = score_0;
breakpoint->score_reverse = score_1;
breakpoint->k_forward = k_0;
breakpoint->k_reverse = k_1;
breakpoint->offset_forward = moffset_0;
breakpoint->offset_reverse = moffset_1;
} else {
breakpoint->score_forward = score_1;
breakpoint->score_reverse = score_0;
breakpoint->k_forward = k_1;
breakpoint->k_reverse = k_0;
breakpoint->offset_forward = moffset_1;
breakpoint->offset_reverse = moffset_0;
}
breakpoint->score = score_0 + score_1;
breakpoint->component = affine2p_matrix_M;
return;
}
}
const __m256i tlens = _mm256_set1_epi32(text_length-1); //enable change >= to >
const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7);
for (;k_0<=min_hi;k_0+=elems_per_register) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
__m256i moffsets_0 = _mm256_lddqu_si256((__m256i*)&mwf_0->offsets[k_0]);
__m256i moffsets_1 = _mm256_lddqu_si256((__m256i*)&mwf_1->offsets[k_1-elems_per_register+1]);
// doffsets_1 are in reverse order, so we need to reverse them
moffsets_1 = _mm256_permutevar8x32_epi32(moffsets_1, rev);
__m256i mh_0_1 = _mm256_add_epi32(moffsets_0, moffsets_1);
__m256i mask = _mm256_cmpgt_epi32(mh_0_1, tlens);
int bp_found_mask = _mm256_movemask_epi8(mask);
if (__builtin_expect(bp_found_mask == 0, 1)) continue;
// A breakpoint has been found! Check in which exact diagonal it is
// This can be done directly from the mask and vector registers, for now,
// it is implemented like the scalar implementation. This only happens
// when a BP is found, so it should not be a bottleneck.
int initial_k0 = k_0;
for (;k_0<initial_k0+elems_per_register;k_0++) {
const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length);
// Fetch offsets
const wf_offset_t moffset_0 = mwf_0->offsets[k_0];
const wf_offset_t moffset_1 = mwf_1->offsets[k_1];
const int mh_0 = WAVEFRONT_H(k_0,moffset_0);
const int mh_1 = WAVEFRONT_H(k_1,moffset_1);
// Check breakpoint m2m
if (mh_0 + mh_1 >= text_length) {
if (breakpoint_forward) {
breakpoint->score_forward = score_0;
breakpoint->score_reverse = score_1;
breakpoint->k_forward = k_0;
breakpoint->k_reverse = k_1;
breakpoint->offset_forward = moffset_0;
breakpoint->offset_reverse = moffset_1;
} else {
breakpoint->score_forward = score_1;
breakpoint->score_reverse = score_0;
breakpoint->k_forward = k_1;
breakpoint->k_reverse = k_0;
breakpoint->offset_forward = moffset_1;
breakpoint->offset_reverse = moffset_0;
}
breakpoint->score = score_0 + score_1;
breakpoint->component = affine2p_matrix_M;
return;
}
}
}
}
#endif // AVX2
#endif // AVX512
void wavefront_bialign_breakpoint_m2m(
wavefront_aligner_t* const wf_aligner,
const bool breakpoint_forward,
const int score_0,
const int score_1,
wavefront_t* const mwf_0,
wavefront_t* const mwf_1,
wf_bialign_breakpoint_t* const breakpoint) {
// Parameters
if (score_0 + score_1 >= breakpoint->score) return;
#if __AVX2__
#if __AVX512CD__ && __AVX512VL__
wavefront_bialign_breakpoint_m2m_avx512(
wf_aligner,
breakpoint_forward,
score_0,
score_1,
mwf_0,
mwf_1,
breakpoint);
#else
wavefront_bialign_breakpoint_m2m_avx2(
wf_aligner,
breakpoint_forward,
score_0,
score_1,
mwf_0,
mwf_1,
breakpoint);
#endif // AVX2
#else // Scalar implementation
wavefront_sequences_t* const sequences = &wf_aligner->sequences;
const int text_length = sequences->text_length;
const int pattern_length = sequences->pattern_length;
Expand Down Expand Up @@ -557,7 +869,7 @@ void wavefront_bialign_breakpoint_m2m(
return;
}
}
#endif // AVX512
#endif // AVX2
}
/*
* Bidirectional find overlaps
Expand Down

0 comments on commit cab0f55

Please sign in to comment.