Skip to content

Commit

Permalink
Merge pull request #20 from chaithyagr/isign
Browse files Browse the repository at this point in the history
Autograd support added
  • Loading branch information
chaithyagr authored Jun 6, 2024
2 parents 0c34013 + 01b9cab commit 69264c9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
16 changes: 14 additions & 2 deletions CUDA/inc/gpuNUFFT_operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ class GpuNUFFTOperator
GpuNUFFTOperator(IndType kernelWidth, IndType sectorWidth, DType osf,
Dimensions imgDims, bool loadKernel = true,
OperatorType operatorType = DEFAULT,
bool matlabSharedMem = false)
bool matlabSharedMem = false, bool grad_mode = false)
: operatorType(operatorType), osf(osf), kernelWidth(kernelWidth),
sectorWidth(sectorWidth), imgDims(imgDims), gpuMemAllocated(false),
debugTiming(DEBUG), sens_d(NULL), crds_d(NULL), density_comp_d(NULL),
deapo_d(NULL), gdata_d(NULL), sector_centers_d(NULL), sectors_d(NULL),
data_indices_d(NULL), data_sorted_d(NULL), allocatedCoils(0),
matlabSharedMem(matlabSharedMem)
matlabSharedMem(matlabSharedMem), grad_mode(grad_mode)
{
if (loadKernel)
initKernel();
Expand Down Expand Up @@ -342,6 +342,14 @@ class GpuNUFFTOperator
GpuNUFFTOutput gpuNUFFTOut);

void clean_memory();

void setGradMode(bool grad_mode) {
this->grad_mode = grad_mode;
}

bool getGradMode() {
return this->grad_mode;
}
/** \brief Check if density compensation data is available. */
bool applyDensComp()
{
Expand Down Expand Up @@ -452,6 +460,10 @@ class GpuNUFFTOperator
*/
bool matlabSharedMem;

/** \brief Flag for changing the isign, mainly used for gradients
*/
bool grad_mode;

/** \brief Return Grid Width (ImageWidth * osf) */
IndType getGridWidth()
{
Expand Down
7 changes: 7 additions & 0 deletions CUDA/src/gpu/python/gpuNUFFT_operator_python_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ class GpuNUFFTPythonOperator
gpuNUFFTOp->clean_memory();
}

void toggle_grad_mode()
{
bool current_mode = gpuNUFFTOp->getGradMode();
gpuNUFFTOp->setGradMode(!current_mode);
}

void set_smaps(py::array_t<std::complex<DType>> sense_maps)
{
CAST_POINTER_VARNAME(sense_maps, sensArray);
Expand Down Expand Up @@ -436,6 +442,7 @@ PYBIND11_MODULE(gpuNUFFT, m) {
.def("clean_memory", &GpuNUFFTPythonOperator::clean_memory)
.def("estimate_density_comp", &GpuNUFFTPythonOperator::estimate_density_comp, py::arg("max_iter") = 10)
.def("set_smaps", &GpuNUFFTPythonOperator::set_smaps)
.def("toggle_grad_mode", &GpuNUFFTPythonOperator::toggle_grad_mode)
.def("get_spectral_radius", &GpuNUFFTPythonOperator::get_spectral_radius, py::arg("max_iter") = 20, py::arg("tolerance") = 1e-6);
}
#endif // GPUNUFFT_OPERATOR_PYTHONFACTORY_H_INCLUDED
8 changes: 4 additions & 4 deletions CUDA/src/gpuNUFFT_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ void gpuNUFFT::GpuNUFFTOperator::performForwardGpuNUFFT(
{
if ((err = pt2CufftExec(fft_plan, gdata_d + c * gi_host->gridDims_count,
gdata_d + c * gi_host->gridDims_count,
CUFFT_FORWARD)) != CUFFT_SUCCESS)
grad_mode?CUFFT_INVERSE:CUFFT_FORWARD)) != CUFFT_SUCCESS)
{
fprintf(stderr, "cufft has failed with err %i \n", err);
showMemoryInfo(true, stderr);
Expand All @@ -1037,7 +1037,7 @@ void gpuNUFFT::GpuNUFFTOperator::performForwardGpuNUFFT(
if (DEBUG && (cudaStreamSynchronize(new_stream) != cudaSuccess))
printf("error at thread synchronization 5: %s\n",
cudaGetErrorString(cudaGetLastError()));
performFFTShift(gdata_d, FORWARD, getGridDims(), gi_host);
performFFTShift(gdata_d, grad_mode?INVERSE:FORWARD, getGridDims(), gi_host);

if (DEBUG && (cudaStreamSynchronize(new_stream) != cudaSuccess))
printf("error at thread synchronization 6: %s\n",
Expand Down Expand Up @@ -1240,7 +1240,7 @@ void gpuNUFFT::GpuNUFFTOperator::performForwardGpuNUFFT(
{
if ((err = pt2CufftExec(fft_plan, gdata_d + c * gi_host->gridDims_count,
gdata_d + c * gi_host->gridDims_count,
CUFFT_FORWARD)) != CUFFT_SUCCESS)
grad_mode?CUFFT_INVERSE:CUFFT_FORWARD)) != CUFFT_SUCCESS)
{
fprintf(stderr, "cufft has failed with err %i \n", err);
showMemoryInfo(true, stderr);
Expand All @@ -1251,7 +1251,7 @@ void gpuNUFFT::GpuNUFFTOperator::performForwardGpuNUFFT(
if (DEBUG && (cudaStreamSynchronize(new_stream) != cudaSuccess))
printf("error at thread synchronization 5: %s\n",
cudaGetErrorString(cudaGetLastError()));
performFFTShift(gdata_d, FORWARD, getGridDims(), gi_host);
performFFTShift(gdata_d, grad_mode?INVERSE:FORWARD, getGridDims(), gi_host);

if (DEBUG && (cudaStreamSynchronize(new_stream) != cudaSuccess))
printf("error at thread synchronization 6: %s\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def build_extension(self, ext):

setup(
name="gpuNUFFT",
version="0.7.5",
version="0.8.0",
description="gpuNUFFT - An open source GPU Library for 3D Gridding and NUFFT",
ext_modules=[
CMakeExtension("gpuNUFFT", sourcedir=os.path.join("CUDA")),
Expand Down

0 comments on commit 69264c9

Please sign in to comment.