From cab0f554ac04debec58b26169ba6e02acd4ff57c Mon Sep 17 00:00:00 2001 From: Quim Date: Thu, 23 Jan 2025 10:40:54 +0000 Subject: [PATCH] Add AVX2 support to breakpoint i2i and m2m Additionally, the code has been restructured for better readability, whith each vectorized version having its own function. --- wavefront/wavefront_bialign.c | 334 ++++++++++++++++++++++++++++++++-- 1 file changed, 323 insertions(+), 11 deletions(-) diff --git a/wavefront/wavefront_bialign.c b/wavefront/wavefront_bialign.c index f5d7365..d6fd683 100644 --- a/wavefront/wavefront_bialign.c +++ b/wavefront/wavefront_bialign.c @@ -190,7 +190,9 @@ int wavefront_bialign_base( /* * Bidirectional check breakpoints */ -void wavefront_bialign_breakpoint_indel2indel( +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ +void wavefront_bialign_breakpoint_indel2indel_avx512( wavefront_aligner_t* const wf_aligner, const bool breakpoint_forward, const int score_0, @@ -199,7 +201,6 @@ void wavefront_bialign_breakpoint_indel2indel( wavefront_t* const dwf_1, const affine2p_matrix_type component, wf_bialign_breakpoint_t* const breakpoint) { -#if __AVX2__ && __AVX512CD__ && __AVX512VL__ // AVX512 implementation of the bialign_breakpoint_indel2indel // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; @@ -328,7 +329,182 @@ void wavefront_bialign_breakpoint_indel2indel( } k_0 = initial_k0; } -#else // Scalar implementation of the bialign_breakpoint_indel2indel +} +#else // No AVX512, but AVX2 available +void wavefront_bialign_breakpoint_indel2indel_avx2( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const dwf_0, + wavefront_t* const dwf_1, + const affine2p_matrix_type component, + wf_bialign_breakpoint_t* const breakpoint) { + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + const int gap_open = + (component==affine2p_matrix_I1 || component==affine2p_matrix_D1) ? + wf_aligner->penalties.gap_opening1 : wf_aligner->penalties.gap_opening2; + + if (score_0 + score_1 - gap_open >= breakpoint->score) return; + + // Check wavefronts overlapping + const int lo_0 = dwf_0->lo; + const int hi_0 = dwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(dwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(dwf_1->lo,pattern_length,text_length); + + if (hi_1 < lo_0 || hi_0 < lo_1) return; + + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + const int elems_per_register = 8; + const int num_diagonals = min_hi - max_lo + 1; + const int loop_peeling_iters = num_diagonals % elems_per_register; + int k_0; + for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + const wf_offset_t doffset_0 = dwf_0->offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length) { + if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } + } + } + // Finish the remaining iterations in a vectorized manner + const __m256i tlens = _mm256_set1_epi32(text_length-1);//enable change >= to > + const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m256i doffsets_0 = _mm256_lddqu_si256((__m256i*)&dwf_0->offsets[k_0]); + __m256i doffsets_1 = _mm256_lddqu_si256((__m256i*)&dwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + doffsets_1 = _mm256_permutevar8x32_epi32(doffsets_1, rev); + __m256i dh_0_1 = _mm256_add_epi32(doffsets_0, doffsets_1); + __m256i mask = _mm256_cmpgt_epi32(dh_0_1, tlens); + int bp_found_mask = _mm256_movemask_epi8(mask); + if (__builtin_expect(bp_found_mask == 0, 1)) continue; + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t doffset_1 = dwf_1->offsets[k_1]; + const int dh_0 = WAVEFRONT_H(k_0,doffset_0); + const int dh_1 = WAVEFRONT_H(k_1,doffset_1); + // Check breakpoint d2d + if (dh_0 + dh_1 >= text_length) { + if (breakpoint_forward) { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_0,dh_0); + const int h = WAVEFRONT_H(k_0,dh_0); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = dh_0; + breakpoint->offset_reverse = dh_1; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } else { + // Check out-of-bounds coordinates + const int v = WAVEFRONT_V(k_1,dh_1); + const int h = WAVEFRONT_H(k_1,dh_1); + if (v > pattern_length || h > text_length) continue; + // Set breakpoint + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = dh_1; + breakpoint->offset_reverse = dh_0; + breakpoint->score = score_0 + score_1 - gap_open; + breakpoint->component = component; + return; + } + } + } + k_0 = initial_k0; + } +} +#endif // AVX512 +#endif // AVX2 +void wavefront_bialign_breakpoint_indel2indel( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const dwf_0, + wavefront_t* const dwf_1, + const affine2p_matrix_type component, + wf_bialign_breakpoint_t* const breakpoint) { +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ + wavefront_bialign_breakpoint_indel2indel_avx512( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + dwf_0, + dwf_1, + component, + breakpoint); +#else + wavefront_bialign_breakpoint_indel2indel_avx2( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + dwf_0, + dwf_1, + component, + breakpoint); +#endif // AVX512 +#else // Scalar implementation // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; @@ -391,11 +567,12 @@ void wavefront_bialign_breakpoint_indel2indel( // No need to keep searching return; } - } -#endif // AVX512 +#endif // AVX2 } -void wavefront_bialign_breakpoint_m2m( +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ +void wavefront_bialign_breakpoint_m2m_avx512( wavefront_aligner_t* const wf_aligner, const bool breakpoint_forward, const int score_0, @@ -403,9 +580,6 @@ void wavefront_bialign_breakpoint_m2m( wavefront_t* const mwf_0, wavefront_t* const mwf_1, wf_bialign_breakpoint_t* const breakpoint) { - // Parameters - if (score_0 + score_1 >= breakpoint->score) return; -#if __AVX2__ && __AVX512CD__ && __AVX512VL__ // AVX512 implementation of the bialign_breakpoint_indel2indel // Parameters wavefront_sequences_t* const sequences = &wf_aligner->sequences; @@ -512,7 +686,145 @@ void wavefront_bialign_breakpoint_m2m( } } } -#else // Scalar implementation of the bialign_breakpoint_indel2indel +} +#else // No AVX512, but AVX2 available +void wavefront_bialign_breakpoint_m2m_avx2( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const mwf_0, + wavefront_t* const mwf_1, + wf_bialign_breakpoint_t* const breakpoint) { + if ( score_0 + score_1 >= breakpoint->score) return; + // Parameters + wavefront_sequences_t* const sequences = &wf_aligner->sequences; + const int text_length = sequences->text_length; + const int pattern_length = sequences->pattern_length; + // Check wavefronts overlapping + const int lo_0 = mwf_0->lo; + const int hi_0 = mwf_0->hi; + const int lo_1 = WAVEFRONT_K_INVERSE(mwf_1->hi,pattern_length,text_length); + const int hi_1 = WAVEFRONT_K_INVERSE(mwf_1->lo,pattern_length,text_length); + if (hi_1 < lo_0 || hi_0 < lo_1) return; + // Compute overlapping interval + const int min_hi = MIN(hi_0,hi_1); + const int max_lo = MAX(lo_0,lo_1); + const int elems_per_register = 8; + const int num_diagonals = min_hi - max_lo + 1; + const int loop_peeling_iters = num_diagonals % elems_per_register; + int k_0; + for (k_0=max_lo;k_0 < max_lo+loop_peeling_iters; k_0++) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + const wf_offset_t moffset_0 = mwf_0->offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + return; + } + } + const __m256i tlens = _mm256_set1_epi32(text_length-1); //enable change >= to > + const __m256i rev = _mm256_set_epi32(0,1,2,3,4,5,6,7); + for (;k_0<=min_hi;k_0+=elems_per_register) { + const int k_1 = WAVEFRONT_K_INVERSE(k_0,pattern_length,text_length); + // Fetch offsets + __m256i moffsets_0 = _mm256_lddqu_si256((__m256i*)&mwf_0->offsets[k_0]); + __m256i moffsets_1 = _mm256_lddqu_si256((__m256i*)&mwf_1->offsets[k_1-elems_per_register+1]); + // doffsets_1 are in reverse order, so we need to reverse them + moffsets_1 = _mm256_permutevar8x32_epi32(moffsets_1, rev); + __m256i mh_0_1 = _mm256_add_epi32(moffsets_0, moffsets_1); + __m256i mask = _mm256_cmpgt_epi32(mh_0_1, tlens); + int bp_found_mask = _mm256_movemask_epi8(mask); + if (__builtin_expect(bp_found_mask == 0, 1)) continue; + // A breakpoint has been found! Check in which exact diagonal it is + // This can be done directly from the mask and vector registers, for now, + // it is implemented like the scalar implementation. This only happens + // when a BP is found, so it should not be a bottleneck. + int initial_k0 = k_0; + for (;k_0offsets[k_0]; + const wf_offset_t moffset_1 = mwf_1->offsets[k_1]; + const int mh_0 = WAVEFRONT_H(k_0,moffset_0); + const int mh_1 = WAVEFRONT_H(k_1,moffset_1); + // Check breakpoint m2m + if (mh_0 + mh_1 >= text_length) { + if (breakpoint_forward) { + breakpoint->score_forward = score_0; + breakpoint->score_reverse = score_1; + breakpoint->k_forward = k_0; + breakpoint->k_reverse = k_1; + breakpoint->offset_forward = moffset_0; + breakpoint->offset_reverse = moffset_1; + } else { + breakpoint->score_forward = score_1; + breakpoint->score_reverse = score_0; + breakpoint->k_forward = k_1; + breakpoint->k_reverse = k_0; + breakpoint->offset_forward = moffset_1; + breakpoint->offset_reverse = moffset_0; + } + breakpoint->score = score_0 + score_1; + breakpoint->component = affine2p_matrix_M; + return; + } + } + } +} +#endif // AVX2 +#endif // AVX512 +void wavefront_bialign_breakpoint_m2m( + wavefront_aligner_t* const wf_aligner, + const bool breakpoint_forward, + const int score_0, + const int score_1, + wavefront_t* const mwf_0, + wavefront_t* const mwf_1, + wf_bialign_breakpoint_t* const breakpoint) { + // Parameters + if (score_0 + score_1 >= breakpoint->score) return; +#if __AVX2__ +#if __AVX512CD__ && __AVX512VL__ + wavefront_bialign_breakpoint_m2m_avx512( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + mwf_0, + mwf_1, + breakpoint); +#else + wavefront_bialign_breakpoint_m2m_avx2( + wf_aligner, + breakpoint_forward, + score_0, + score_1, + mwf_0, + mwf_1, + breakpoint); +#endif // AVX2 +#else // Scalar implementation wavefront_sequences_t* const sequences = &wf_aligner->sequences; const int text_length = sequences->text_length; const int pattern_length = sequences->pattern_length; @@ -557,7 +869,7 @@ void wavefront_bialign_breakpoint_m2m( return; } } -#endif // AVX512 +#endif // AVX2 } /* * Bidirectional find overlaps