Skip to content

Commit

Permalink
fft_params: allow splitting on fast or slow FFT dim in tests
Browse files Browse the repository at this point in the history
Previously, we would always split the slowest FFT dim.
  • Loading branch information
evetsso authored Jan 19, 2024
1 parent 4f11fb1 commit 77bdad5
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 77 deletions.
4 changes: 2 additions & 2 deletions clients/tests/gtest_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ TEST(manual, vs_fftw) // MANUAL TESTS HERE

if(manual_devices > 1)
{
params.distribute_input(manual_devices);
params.distribute_output(manual_devices);
params.distribute_input(manual_devices, fft_params::SplitType::SLOWEST);
params.distribute_output(manual_devices, fft_params::SplitType::SLOWEST);
}

// Run an individual test using the provided command-line parameters.
Expand Down
51 changes: 30 additions & 21 deletions clients/tests/multi_device_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ static const std::vector<std::vector<size_t>> multi_gpu_sizes = {
{256, 256, 256},
};

std::vector<fft_params> param_generator_multi_gpu()
std::vector<fft_params> param_generator_multi_gpu(const fft_params::SplitType input_split,
const fft_params::SplitType output_split,
size_t min_fft_rank = 1)
{
int deviceCount = 0;
(void)hipGetDeviceCount(&deviceCount);
Expand Down Expand Up @@ -59,31 +61,24 @@ std::vector<fft_params> param_generator_multi_gpu()

std::vector<fft_params> all_params;

auto distribute_params = [&all_params, deviceCount](const std::vector<fft_params>& params) {
auto distribute_params = [=, &all_params](const std::vector<fft_params>& params) {
for(auto& p : params)
{
// run tests for:
// - multi-device input, normal output
// - multi-device output, normal input
// - multi-device both
auto p_in = p;
p_in.distribute_input(deviceCount);
auto p_out = p;
p_out.distribute_output(deviceCount);
auto p_both = p;
p_both.distribute_input(deviceCount);
p_both.distribute_output(deviceCount);
if(p.length.size() < min_fft_rank)
continue;

auto p_dist = p;
p_dist.distribute_input(deviceCount, input_split);
p_dist.distribute_output(deviceCount, output_split);

// "placement" flag is meaningless if exactly one of
// input+output is a field. So just add those cases if
// the flag is "out-of-place", since "in-place" is
// exactly the same test case.
if(p.placement == fft_placement_notinplace)
{
all_params.emplace_back(std::move(p_in));
all_params.emplace_back(std::move(p_out));
}
all_params.emplace_back(std::move(p_both));
if(p_dist.placement == fft_placement_inplace
&& p_dist.ifields.empty() != p_dist.ofields.empty())
continue;
all_params.push_back(std::move(p_dist));
}
};

Expand All @@ -93,7 +88,21 @@ std::vector<fft_params> param_generator_multi_gpu()
return all_params;
}

INSTANTIATE_TEST_SUITE_P(multi_gpu,
// split both input and output on slowest FFT dim
INSTANTIATE_TEST_SUITE_P(multi_gpu_slowest_dim,
accuracy_test,
::testing::ValuesIn(param_generator_multi_gpu(
fft_params::SplitType::SLOWEST, fft_params::SplitType::SLOWEST)),
accuracy_test::TestName);

// split slowest FFT dim only on input, or only on output
INSTANTIATE_TEST_SUITE_P(multi_gpu_slowest_input_dim,
accuracy_test,
::testing::ValuesIn(param_generator_multi_gpu(
fft_params::SplitType::SLOWEST, fft_params::SplitType::NONE)),
accuracy_test::TestName);
INSTANTIATE_TEST_SUITE_P(multi_gpu_slowest_output_dim,
accuracy_test,
::testing::ValuesIn(param_generator_multi_gpu()),
::testing::ValuesIn(param_generator_multi_gpu(
fft_params::SplitType::NONE, fft_params::SplitType::SLOWEST)),
accuracy_test::TestName);
61 changes: 36 additions & 25 deletions shared/fft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1845,49 +1845,56 @@ class fft_params
// buffer where transform output needs to go for validation
virtual void multi_gpu_finalize(std::vector<gpubuf>& obuffer, std::vector<void*>& pobuffer) {}

enum class SplitType
{
// do not split this field into bricks
NONE,
// split field on fastest FFT dimension
FASTEST,
// split field on slowest FFT dimension
SLOWEST,
};

// create bricks in the specified field for the specified number
// of devices. The field is split along the highest FFT
// dimension, and the length only includes FFT lengths, not batch
// dimension.
// of devices. Field length includes batch dimension.
void distribute_field(int deviceCount,
std::vector<fft_field>& fields,
const std::vector<size_t>& field_length)
const std::vector<size_t>& field_length,
SplitType type)
{
size_t slowLen = field_length.front();
if(slowLen < static_cast<size_t>(deviceCount))
if(type == SplitType::NONE)
return;

// batch is the first index, slowest FFT length is index 1
size_t splitDimIdx = type == SplitType::SLOWEST ? 1 : field_length.size() - 1;

size_t splitLen = field_length[splitDimIdx];
if(splitLen < static_cast<size_t>(deviceCount))
throw std::runtime_error("too many devices to distribute length "
+ std::to_string(slowLen));
+ std::to_string(splitLen));

auto& field = fields.emplace_back();

for(int i = 0; i < deviceCount; ++i)
{
// start at origin
std::vector<size_t> field_lower(field_length.size());
std::vector<size_t> field_upper(field_length.size());
std::vector<size_t> field_upper = field_length;

// note: slowest FFT dim is index 0 in these coordinates
field_lower[0] = slowLen / deviceCount * i;
field_lower[splitDimIdx] = splitLen / deviceCount * i;

// last brick needs to include the whole slow len
// last brick needs to include the whole split len
if(i == deviceCount - 1)
{
field_upper[0] = slowLen;
field_upper[splitDimIdx] = splitLen;
}
else
{
field_upper[0] = std::min(slowLen, field_lower[0] + slowLen / deviceCount);
field_upper[splitDimIdx]
= std::min(splitLen, field_lower[splitDimIdx] + splitLen / deviceCount);
}

for(unsigned int upperDim = 1; upperDim < field_length.size(); ++upperDim)
{
field_upper[upperDim] = field_length[upperDim];
}

// field coordinates also need to include batch
field_lower.insert(field_lower.begin(), 0);
field_upper.insert(field_upper.begin(), nbatch);

// bricks have contiguous strides
size_t brick_dist = 1;
std::vector<size_t> brick_stride(field_lower.size());
Expand All @@ -1902,14 +1909,18 @@ class fft_params
}
}

void distribute_input(int deviceCount)
void distribute_input(int deviceCount, SplitType type)
{
distribute_field(deviceCount, ifields, length);
auto len = length;
len.insert(len.begin(), nbatch);
distribute_field(deviceCount, ifields, len, type);
}

void distribute_output(int deviceCount)
void distribute_output(int deviceCount, SplitType type)
{
distribute_field(deviceCount, ofields, olength());
auto len = olength();
len.insert(len.begin(), nbatch);
distribute_field(deviceCount, ofields, len, type);
}
};

Expand Down
92 changes: 63 additions & 29 deletions shared/rocfft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ class rocfft_params : public fft_params
// we'll be allocating new ones for each brick
pbuffer.clear();

auto length_with_batch = copy_input ? length : olength();
length_with_batch.insert(length_with_batch.begin(), nbatch);
const auto splitDimIdx = get_split_dimension(field, length_with_batch);
const size_t elem_size_bytes = var_size<size_t>(precision, array_type);

for(const auto& b : field.bricks)
{
// get brick's length - note that this includes batch
Expand All @@ -462,7 +467,6 @@ class rocfft_params : public fft_params
const auto brick_stride = b.stride;

const size_t brick_size_elems = product(brick_len.begin(), brick_len.end());
const size_t elem_size_bytes = var_size<size_t>(precision, array_type);
const size_t brick_size_bytes = brick_size_elems * elem_size_bytes;

// set device for the alloc, but we want to return to the
Expand All @@ -477,14 +481,13 @@ class rocfft_params : public fft_params

if(copy_input)
{
// For now, assume we're only splitting on highest FFT
// dimension, lower-dimensional FFT data is all
// contiguous, and batches are contiguous in each brick.
//
// That means we can express this as a 2D memcpy.
const size_t unbatched_elems_per_brick
= product(brick_len.begin() + 1, brick_len.end());
const size_t unbatched_elems_per_fft = product(length.begin(), length.end());
// get contiguous elems before and after the split
const auto brick_length_before_split
= product(brick_len.begin() + splitDimIdx, brick_len.end());
const auto fft_length_with_split
= product(length_with_batch.begin() + splitDimIdx, length_with_batch.end());
const auto length_after_split
= product(brick_len.begin(), brick_len.begin() + splitDimIdx);

// get this brick's starting offset in the field
const size_t brick_offset
Expand All @@ -494,11 +497,11 @@ class rocfft_params : public fft_params
// assuming interleaved data so ibuffer has only one
// gpubuf
if(hipMemcpy2D(pbuffer.back(),
unbatched_elems_per_brick * elem_size_bytes,
brick_length_before_split * elem_size_bytes,
ibuffer.front().data_offset(brick_offset),
unbatched_elems_per_fft * elem_size_bytes,
unbatched_elems_per_brick * elem_size_bytes,
brick_len.front(),
fft_length_with_split * elem_size_bytes,
brick_length_before_split * elem_size_bytes,
length_after_split,
hipMemcpyHostToDevice)
!= hipSuccess)
throw std::runtime_error("hipMemcpy failure");
Expand Down Expand Up @@ -534,41 +537,41 @@ class rocfft_params : public fft_params
if(ofields.empty())
return;

auto length_with_batch = olength();
length_with_batch.insert(length_with_batch.begin(), nbatch);
const auto splitDimIdx = get_split_dimension(ofields.front(), length_with_batch);
const size_t elem_size_bytes = var_size<size_t>(precision, otype);

for(size_t i = 0; i < ofields.front().bricks.size(); ++i)
{
const auto& b = ofields.front().bricks[i];
const auto& brick_ptr = pobuffer[i];

const auto brick_len = b.length();

const size_t elem_size_bytes = var_size<size_t>(precision, otype);
// get contiguous elems before and after the split
const auto brick_length_before_split
= product(brick_len.begin() + splitDimIdx, brick_len.end());
const auto fft_length_with_split
= product(length_with_batch.begin() + splitDimIdx, length_with_batch.end());
const auto length_after_split
= product(brick_len.begin(), brick_len.begin() + splitDimIdx);

// get this brick's starting offset in the field
const size_t brick_offset = b.lower_field_offset(ostride, odist) * elem_size_bytes;

// switch device to where we're copying from
rocfft_scoped_device dev(b.device);

// For now, assume we're only splitting on highest FFT
// dimension, lower-dimensional FFT data is all
// contiguous, and batches are contiguous in each brick.
//
// That means we can express this as a 2D memcpy.
const size_t unbatched_elems_per_brick
= product(brick_len.begin() + 1, brick_len.end());
const auto output_length = olength();
const size_t unbatched_elems_per_fft
= product(output_length.begin(), output_length.end());

// copy to original output buffer - note that
// we're assuming interleaved data so obuffer
// has only one gpubuf
if(hipMemcpy2D(obuffer.front().data_offset(brick_offset),
unbatched_elems_per_fft * elem_size_bytes,
fft_length_with_split * elem_size_bytes,
brick_ptr,
unbatched_elems_per_brick * elem_size_bytes,
unbatched_elems_per_brick * elem_size_bytes,
brick_len.front(),
brick_length_before_split * elem_size_bytes,
brick_length_before_split * elem_size_bytes,
length_after_split,
hipMemcpyDeviceToDevice)
!= hipSuccess)
throw std::runtime_error("hipMemcpy failure");
Expand All @@ -580,6 +583,37 @@ class rocfft_params : public fft_params
pobuffer.clear();
pobuffer.push_back(obuffer.front().data());
}

private:
// return the dimension index that a set of bricks is splitting up
static size_t get_split_dimension(const fft_field& f,
const std::vector<size_t>& length_with_batch)
{
size_t splitDim = std::numeric_limits<size_t>::max();
for(size_t dimIdx = 0; dimIdx < length_with_batch.size(); ++dimIdx)
{
// if bricks are all same length as this dim's actual length,
// they're not splitting on this dimension.
if(std::all_of(f.bricks.begin(), f.bricks.end(), [&](const fft_brick& b) {
return b.length()[dimIdx] == length_with_batch[dimIdx];
}))
continue;

// otherwise, the bricks are splitting this dimension
if(splitDim != std::numeric_limits<size_t>::max())
{
// we already found a dimension that was split
throw std::runtime_error("bricks split on dimensions " + std::to_string(splitDim)
+ " and " + std::to_string(dimIdx));
}
splitDim = dimIdx;
}
if(splitDim == std::numeric_limits<size_t>::max())
{
throw std::runtime_error("could not find a split dimension");
}
return splitDim;
}
};

#endif

0 comments on commit 77bdad5

Please sign in to comment.