From 823b84fb8140045bbf61a30f4b26d1b78b1f65e4 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 19 Mar 2024 16:34:20 +0000 Subject: [PATCH] prevent OOB read from global memory --- src/portfft/common/subgroup_bluestein.hpp | 8 ++--- src/portfft/common/subgroup_ct.hpp | 35 +++++++++++-------- .../dispatcher/subgroup_dispatcher.hpp | 4 +-- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/portfft/common/subgroup_bluestein.hpp b/src/portfft/common/subgroup_bluestein.hpp index 3466410b..9ea3eeec 100644 --- a/src/portfft/common/subgroup_bluestein.hpp +++ b/src/portfft/common/subgroup_bluestein.hpp @@ -149,7 +149,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); + factor_sg, factor_wi, wi_working, global_data); // TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as the // result of sg_dft is transposed. However, the ideal way to this is using shuffles. Implement a batched matrix @@ -190,7 +190,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved( detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); + wi_working, global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { global_data.log_message(__func__, "Applying complex conjugate on the output"); @@ -244,7 +244,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); + factor_sg, factor_wi, wi_working, global_data); if (wi_working) { global_data.log_message(__func__, "storing result of the forward phase back to local memory"); @@ -276,7 +276,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); + wi_working, global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { global_data.log_message(__func__, "Applying complex conjugate on the output"); detail::conjugate_inplace(priv, factor_wi); diff --git a/src/portfft/common/subgroup_ct.hpp b/src/portfft/common/subgroup_ct.hpp index 291b1f54..52855b04 100644 --- a/src/portfft/common/subgroup_ct.hpp +++ b/src/portfft/common/subgroup_ct.hpp @@ -338,6 +338,7 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles * @param id_of_wi_in_fft workitem id withing the fft * @param factor_sg Number of workitems participating for one transform * @param factor_wi Number of complex elements per workitem for each transform + * @param wi_working Whether or not the workitem participates in the data transfers * @param global_data global_data_struct associated with the kernel launch */ template @@ -348,7 +349,7 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, - detail::global_data_struct<1>& global_data) { + bool wi_working, detail::global_data_struct<1>& global_data) { using vec2_t = sycl::vec; vec2_t modifier_vec; if (conjugate_on_load == detail::complex_conjugate::APPLIED) { @@ -356,13 +357,15 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element detail::conjugate_inplace(priv, factor_wi); } if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying load modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); + if (wi_working) { + global_data.log_message(__func__, "Applying load modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } } } sg_dft(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); @@ -373,13 +376,15 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element } if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying store modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); + if (wi_working) { + global_data.log_message(__func__, "Applying store modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } } } diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 44e2c040..756380d9 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -241,7 +241,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag sg_cooley_tukey(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, modifier_offset, - id_of_wi_in_fft, factor_sg, factor_wi, global_data); + id_of_wi_in_fft, factor_sg, factor_wi, working_inner, global_data); } else { sg_bluestein_batch_interleaved( priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles, @@ -409,7 +409,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - id_of_wi_in_fft, factor_sg, factor_wi, global_data); + id_of_wi_in_fft, factor_sg, factor_wi, working, global_data); } else { Idx loc_offset_store_view; Idx loc_offset_load_view;