Skip to content

Commit 4f84363

Browse files
authored
Remove test for free memory on FFTs (#864)
1 parent 007fa55 commit 4f84363

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

include/matx/transforms/fft/fft_cuda.h

+6-15
Original file line numberDiff line numberDiff line change
@@ -214,21 +214,12 @@ template <typename OutTensorType, typename InTensorType> class matxCUDAFFTPlan_t
214214
: i.Size(RANK - 1);
215215

216216
if (i.IsContiguous() && o.IsContiguous()) {
217-
size_t freeMem, totalMem;
218-
[[maybe_unused]] auto err = cudaMemGetInfo(&freeMem, &totalMem);
219-
MATX_ASSERT_STR(err == cudaSuccess, matxCudaError, "Failed to get memInfo from device");
220-
// Use up to 30% of free memory to batch, assuming memory use matches batch size
221-
double max_for_fft_workspace = static_cast<double>(freeMem) * 0.3;
222-
223-
params.batch = 1;
224-
for (int dim = i.Rank() - 2; dim >= 0; dim--) {
225-
if (static_cast<double>(params.batch * i.Size(dim) * sizeof(typename InTensorType::value_type)) > max_for_fft_workspace) {
226-
break;
227-
}
228-
229-
params.batch_dims++;
230-
params.batch *= i.Size(dim);
231-
}
217+
// Previously we used cudaMemGetInfo to get free memory to determine batch size. This can be very slow,
218+
// and for small FFTs this call can create extra latency. For now we'll just assume the user knows what
219+
// they're doing and not try to batch FFTs that are too small
220+
const auto shape = i.Shape();
221+
params.batch = std::accumulate(std::begin(shape), std::end(shape) - 1, 1, std::multiplies<index_t>());
222+
params.batch_dims = i.Rank() - 1;
232223
}
233224
else {
234225
if (RANK == 1) {

0 commit comments

Comments
 (0)