-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Workgroup strided transforms #143
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: HJA Bird <[email protected]>
…fit in local memory
global_data.log_message_global(__func__, "storing data from local to global memory (with 2 transposes)"); | ||
if (storage == complex_storage::INTERLEAVED_COMPLEX) { | ||
std::array<IdxGlobal, 4> global_strides{2 * output_distance, 2 * factor_n * output_stride, 2 * output_stride, | ||
1}; | ||
std::array<Idx, 4> local_strides{2, 2 * max_num_batches_in_local_mem, | ||
2 * factor_m * max_num_batches_in_local_mem, 1}; | ||
std::array<Idx, 4> copy_lengths{num_batches_in_local_mem, factor_m, factor_n, 2}; | ||
|
||
detail::md_view global_output_view{output, global_strides, output_global_offset}; | ||
detail::md_view local_output_view{loc_view, local_strides}; | ||
|
||
copy_group<level::WORKGROUP>(global_data, local_output_view, global_output_view, copy_lengths); | ||
} else { // storage == complex_storage::SPLIT_COMPLEX | ||
std::array<IdxGlobal, 3> global_strides{output_distance, factor_n * output_stride, output_stride}; | ||
std::array<Idx, 3> local_strides{1, max_num_batches_in_local_mem, factor_m * max_num_batches_in_local_mem}; | ||
std::array<Idx, 3> copy_lengths{num_batches_in_local_mem, factor_m, factor_n}; | ||
|
||
detail::md_view global_output_real_view{output, global_strides, output_global_offset}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets move this logic out into a new file and lower the number of lines in workgroup_dispatcher
, as the created views are not used anywhere else. similar to what I have done here.
Similarly for line 206 onwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look at options for splitting the code into smaller chunks, but your example only really hides the easy bit. I would rather just inline the definition of the view objects.
I find statements like this
subgroup_impl_local2global_strided_copy<level::SUBGROUP, 3, 3, 3>(
const_cast<T*>(input), loc_view, {input_distance * 2, input_stride * 2, 1}, {fft_size * 2, 2, 1},
input_distance * 2 * (i - static_cast<IdxGlobal>(id_of_fft_in_sg)), local_offset,
{n_ffts_worked_on_by_sg, fft_size, 2}, global_data, detail::transfer_direction::GLOBAL_TO_LOCAL);
quite awkward to understand since there a lot of things going on in one big statement and if I wanted to understand it I would need to somehow maps things out from parameter to argument. The copies in workgroup dispatcher are currently verbose, but it's because they split the definition into individual chunks of information which I find very helpful. We need to prioritise readability over writability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to prioritise readability over writability.
Yet we cannot prioritize readability to a point where our kernel functions start to span over 600+ lines (subgroup_dispatcher, where the majority of the lines come from view creations and then copying it).
All the views are always only temporarily required and add a LOT of lines. hence I would say its best to move it to a different function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subgroup_dispatcher definitely needs some work but I would rather sort that in a different PR.
I would also like to see the *_impl functions become smaller. I'll have a look at the workgroup_impl and see what I can do.
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts. | ||
*/ | ||
template <typename Scalar> | ||
std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets move this to utils.hpp
, we have only device callable functions in the common folder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we have the factorization functions for workitem and subgroup in common/workitem.hpp
and common/subgroup.hpp
respectively, so I was following the example there.
factorize_sg
is not called from device anywhere, along with fits_in_sg
and fits_in_wi
, so I wouldn't say we only have device callable functions in the common folder.
If we do want to refactor to puts the factorization functions in a utility file, then we should group them and put them in a "factorization.hpp" or something like that. Generic util files are a bit of a code smell imo (though I am guilty of committing that sin).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd agree that a factorisation.hpp
would be better than having everything in utils.hpp
.
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts. | ||
*/ | ||
template <typename Scalar> | ||
std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd agree that a factorisation.hpp
would be better than having everything in utils.hpp
.
Add the option for strided transforms at the workgroup level + testing
Checklist
Tick if relevant: