diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 47c5a4ac..0a6470ea 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -49,11 +49,12 @@ class committed_descriptor_impl; template std::vector compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, Scalar* output, - const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue); + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, + Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, + const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, + IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, + const std::vector& dependencies, sycl::queue& queue); template sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct& kd_struct, diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 31962663..5c4574df 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -224,19 +224,23 @@ sycl::event transpose_level(const typename committed_descriptor_impl(ld_output), static_cast(16)); - std::size_t ld_input_rounded = detail::round_up_to_multiple(static_cast(ld_input), static_cast(16)); - PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", ld_output_rounded, ld_input_rounded, "local_size", 16, 16); + std::size_t ld_output_rounded = + detail::round_up_to_multiple(static_cast(ld_output), static_cast(16)); + std::size_t ld_input_rounded = + detail::round_up_to_multiple(static_cast(ld_input), static_cast(16)); + PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", ld_output_rounded, ld_input_rounded, + "local_size", 16, 16); cgh.parallel_for>( - sycl::nd_range<2>({ld_output_rounded, ld_input_rounded}, {16, 16}), [= + sycl::nd_range<2>({ld_output_rounded, ld_input_rounded}, {16, 16}), + [= #ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config + , + global_logging_config = detail::global_logging_config #endif ](sycl::nd_item<2> it, sycl::kernel_handler kh) { detail::global_data_struct global_data{ #ifdef PORTFFT_KERNEL_LOG - s, global_logging_config, + s, global_logging_config, #endif it}; global_data.log_message_global("entering transpose kernel - buffer impl"); @@ -246,14 +250,17 @@ sycl::event transpose_level(const typename committed_descriptor_impl(ld_output, ld_input, 16, offset_input + outer_batch_offset, - &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, global_data); + &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, + global_data); } else { detail::generic_transpose<1>(ld_output, ld_input, 16, offset_input + outer_batch_offset, - &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, global_data); + &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, + global_data); } } global_data.log_message_global("exiting transpose kernel - buffer impl"); @@ -374,10 +381,12 @@ std::vector compute_level( PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size", local_range); cgh.parallel_for>( - sycl::nd_range<1>(sycl::range<1>(static_cast(global_range)), sycl::range<1>(static_cast(local_range))), [= + sycl::nd_range<1>(sycl::range<1>(static_cast(global_range)), + sycl::range<1>(static_cast(local_range))), + [= #ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config + , + global_logging_config = detail::global_logging_config #endif ](sycl::nd_item<1> it, sycl::kernel_handler kh) PORTFFT_REQD_SUBGROUP_SIZE(SubgroupSize) { detail::global_data_struct global_data{ @@ -386,10 +395,11 @@ std::vector compute_level( #endif it}; dispatch_level( - &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag, subimpl_twiddles, - multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, - inner_batches, inclusive_scan, batch_size, global_data, kh); - }); + &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, + offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], + &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size, + global_data, kh); + }); })); } return events; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index fdf5a472..3c96a36f 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -345,13 +345,13 @@ struct committed_descriptor_impl::run_kernel_struct(factor_num) == dimension_data.num_factors - 1) { PORTFFT_LOG_TRACE("This is the last kernel"); - l2_events = - detail::compute_level( - current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), - desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, - factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), - static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); + l2_events = detail::compute_level( + current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), + desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, + factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, + static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), + static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); } else { l2_events = detail::compute_level(