Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed Feb 5, 2024
1 parent f055d40 commit d5cafea
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
11 changes: 6 additions & 5 deletions src/portfft/committed_descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ class committed_descriptor_impl;
template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
typename TIn>
std::vector<sycl::event> compute_level(
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input, Scalar* output,
const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple,
IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset,
IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id,
Idx total_factors, complex_storage storage, const std::vector<sycl::event>& dependencies, sycl::queue& queue);
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input,
Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset,
IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms,
IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage,
const std::vector<sycl::event>& dependencies, sycl::queue& queue);

template <typename Scalar, domain Domain, typename TOut>
sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct,
Expand Down
46 changes: 28 additions & 18 deletions src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,23 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
#ifdef PORTFFT_KERNEL_LOG
sycl::stream s{1024 * 16, 1024, cgh};
#endif
std::size_t ld_output_rounded = detail::round_up_to_multiple(static_cast<std::size_t>(ld_output), static_cast<std::size_t>(16));
std::size_t ld_input_rounded = detail::round_up_to_multiple(static_cast<std::size_t>(ld_input), static_cast<std::size_t>(16));
PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", ld_output_rounded, ld_input_rounded, "local_size", 16, 16);
std::size_t ld_output_rounded =
detail::round_up_to_multiple(static_cast<std::size_t>(ld_output), static_cast<std::size_t>(16));
std::size_t ld_input_rounded =
detail::round_up_to_multiple(static_cast<std::size_t>(ld_input), static_cast<std::size_t>(16));
PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", ld_output_rounded, ld_input_rounded,
"local_size", 16, 16);
cgh.parallel_for<detail::transpose_kernel<Scalar, Mem>>(
sycl::nd_range<2>({ld_output_rounded, ld_input_rounded}, {16, 16}), [=
sycl::nd_range<2>({ld_output_rounded, ld_input_rounded}, {16, 16}),
[=
#ifdef PORTFFT_KERNEL_LOG
,
global_logging_config = detail::global_logging_config
,
global_logging_config = detail::global_logging_config
#endif
](sycl::nd_item<2> it, sycl::kernel_handler kh) {
detail::global_data_struct global_data{
#ifdef PORTFFT_KERNEL_LOG
s, global_logging_config,
s, global_logging_config,
#endif
it};
global_data.log_message_global("entering transpose kernel - buffer impl");
Expand All @@ -246,14 +250,17 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num);
for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) {
global_data.log_message_subgroup("iter_value: ", iter_value);
IdxGlobal outer_batch_offset = get_outer_batch_offset(factors_triple, inner_batches, inclusive_scan, num_factors,
level_num, iter_value, outer_batch_product, storage);
IdxGlobal outer_batch_offset =
get_outer_batch_offset(factors_triple, inner_batches, inclusive_scan, num_factors, level_num,
iter_value, outer_batch_product, storage);
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
detail::generic_transpose<2>(ld_output, ld_input, 16, offset_input + outer_batch_offset,
&out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, global_data);
&out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc,
global_data);
} else {
detail::generic_transpose<1>(ld_output, ld_input, 16, offset_input + outer_batch_offset,
&out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, global_data);
&out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc,
global_data);
}
}
global_data.log_message_global("exiting transpose kernel - buffer impl");
Expand Down Expand Up @@ -374,10 +381,12 @@ std::vector<sycl::event> compute_level(
PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size",
local_range);
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, LayoutIn, LayoutOut, SubgroupSize>>(
sycl::nd_range<1>(sycl::range<1>(static_cast<std::size_t>(global_range)), sycl::range<1>(static_cast<std::size_t>(local_range))), [=
sycl::nd_range<1>(sycl::range<1>(static_cast<std::size_t>(global_range)),
sycl::range<1>(static_cast<std::size_t>(local_range))),
[=
#ifdef PORTFFT_KERNEL_LOG
,
global_logging_config = detail::global_logging_config
,
global_logging_config = detail::global_logging_config
#endif
](sycl::nd_item<1> it, sycl::kernel_handler kh) PORTFFT_REQD_SUBGROUP_SIZE(SubgroupSize) {
detail::global_data_struct global_data{
Expand All @@ -386,10 +395,11 @@ std::vector<sycl::event> compute_level(
#endif
it};
dispatch_level<Scalar, LayoutIn, LayoutOut, SubgroupSize>(
&in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag, subimpl_twiddles,
multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple,
inner_batches, inclusive_scan, batch_size, global_data, kh);
});
&in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset,
offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size,
global_data, kh);
});
}));
}
return events;
Expand Down
14 changes: 7 additions & 7 deletions src/portfft/dispatcher/global_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<LayoutIn, La
PORTFFT_LOG_TRACE("Dispatching the kernel for factor", factor_num, "of global implementation");
if (static_cast<Idx>(factor_num) == dimension_data.num_factors - 1) {
PORTFFT_LOG_TRACE("This is the last kernel");
l2_events =
detail::compute_level<Scalar, Domain, detail::layout::PACKED, detail::layout::PACKED, SubgroupSize, const Scalar*>(
current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(),
desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr,
factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size,
static_cast<Idx>(max_batches_in_l2), static_cast<IdxGlobal>(num_batches), static_cast<IdxGlobal>(i),
static_cast<Idx>(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue);
l2_events = detail::compute_level<Scalar, Domain, detail::layout::PACKED, detail::layout::PACKED,
SubgroupSize, const Scalar*>(
current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(),
desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr,
factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size,
static_cast<Idx>(max_batches_in_l2), static_cast<IdxGlobal>(num_batches), static_cast<IdxGlobal>(i),
static_cast<Idx>(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue);
} else {
l2_events = detail::compute_level<Scalar, Domain, detail::layout::BATCH_INTERLEAVED,
detail::layout::BATCH_INTERLEAVED, SubgroupSize, const Scalar*>(
Expand Down

0 comments on commit d5cafea

Please sign in to comment.