Skip to content

Commit

Permalink
Infer dtype for finufft_makeplan (#45)
Browse files Browse the repository at this point in the history
* Infer dtype for finufft_makeplan

Fixes #42

* Drop dtype keyword

* Add deprecation warning

* Remove explicitly passed dtype from tests

* Version bump

* Update README.md

* Version bump

Co-authored-by: Ludvig af Klinteberg <[email protected]>
  • Loading branch information
jkrimmer and ludvigak authored Sep 1, 2022
1 parent 162014a commit 6be1108
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FINUFFT"
uuid = "d8beea63-0952-562e-9c6a-8e8ef7364055"
author = "Ludvig af Klinteberg <[email protected]>"
version = "3.0.2"
version = "3.1.0"

[deps]
finufft_jll = "c41cd5a2-72a3-5203-9076-a500b088fc82"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![codecov](https://codecov.io/gh/ludvigak/FINUFFT.jl/branch/master/graph/badge.svg?token=Tkx7kma18J)](https://codecov.io/gh/ludvigak/FINUFFT.jl)
[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://ludvigak.github.io/FINUFFT.jl/latest/)

This is a full-featured Julia interface to [FINUFFT](https://github.com/flatironinstitute/finufft), which is a lightweight and fast parallel nonuniform fast Fourier transform (NUFFT) library released by the Flatiron Institute. This interface stands at v3.0.2, and it uses FINUFFT version 2.1.0 (note that the interface version number is distinct from the version of the wrapped binary FINUFFT library).
This is a full-featured Julia interface to [FINUFFT](https://github.com/flatironinstitute/finufft), which is a lightweight and fast parallel nonuniform fast Fourier transform (NUFFT) library released by the Flatiron Institute. This interface stands at v3.x, and it uses FINUFFT version 2.1.0 (note that the interface version number is distinct from the version of the wrapped binary FINUFFT library).

## Installation

Expand Down Expand Up @@ -38,7 +38,7 @@ An auto-generated reference for all provided Julia functions is [here](https://l
* Function calls mimic the C/C++ interface, with the exception that you don't need to pass the dimensions of any arrays in the argument (they are inferred using `size()`).
* A vectorized call (performing multiple transforms, each with different coefficient vectors but the same set of nonuniform points) can now be performed using the same functions as the single-transform interface, detected from the size of the input arrays.
* Both 64-bit and 32-bit precision calls are now supported using a single
set of function names, switched by a `dtype` keyword argument for clarity.
set of function names. Which precision to use is inferred from the type of the input arrays, except for in the guru interface where the `dtype` argument is required for `finufft_makeplan`. (NOTE: The use of the `dtype` argument in the simple interface is deprecated as of v3.1.0)
* The functions named `nufftDdN` return the output array.
* In contrast, the functions named `nufftDdN!` take the output array as an argument. This needs to be preallocated with the correct size.
* Likewise, in the guru interface, `finufft_exec` returns the output array,
Expand Down
1 change: 1 addition & 0 deletions src/FINUFFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ end
function checkkwdtype(dtype::DataType; kwargs...)
for (key, value) in kwargs
if String(key)=="dtype"
@warn "Explicitly passing the dtype argument is discouraged and will be deprecated."
@assert value == dtype
end
end
Expand Down
36 changes: 18 additions & 18 deletions src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function nufft1d1(xj::Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, ms, ntrans)
checkkwdtype(T; kwargs...)
nufft1d1!(xj, cj, iflag, eps, fk; kwargs...,dtype=T)
nufft1d1!(xj, cj, iflag, eps, fk; kwargs...)
return fk
end

Expand Down Expand Up @@ -94,7 +94,7 @@ function nufft2d1(xj :: Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, ms, mt, ntrans)
checkkwdtype(T; kwargs...)
nufft2d1!(xj, yj, cj, iflag, eps, fk;kwargs...,dtype=T)
nufft2d1!(xj, yj, cj, iflag, eps, fk;kwargs...)
return fk
end

Expand Down Expand Up @@ -152,7 +152,7 @@ function nufft3d1(xj :: Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, ms, mt, mu, ntrans)
checkkwdtype(T; kwargs...)
nufft3d1!(xj, yj, zj, cj, iflag, eps, fk;kwargs...,dtype=T)
nufft3d1!(xj, yj, zj, cj, iflag, eps, fk;kwargs...)
return fk
end

Expand Down Expand Up @@ -195,7 +195,7 @@ function nufft1d2(xj :: Array{T},
(ms, ntrans) = get_nmodes_from_fk(1,fk)
cj = Array{Complex{T}}(undef, nj, ntrans)
checkkwdtype(T; kwargs...)
nufft1d2!(xj, cj, iflag, eps, fk;kwargs...,dtype=T)
nufft1d2!(xj, cj, iflag, eps, fk;kwargs...)
return cj
end

Expand Down Expand Up @@ -240,7 +240,7 @@ function nufft2d2(xj :: Array{T},
(ms, mt, ntrans) = get_nmodes_from_fk(2,fk)
cj = Array{Complex{T}}(undef, nj, ntrans)
checkkwdtype(T; kwargs...)
nufft2d2!(xj, yj, cj, iflag, eps, fk;kwargs...,dtype=T)
nufft2d2!(xj, yj, cj, iflag, eps, fk;kwargs...)
return cj
end

Expand Down Expand Up @@ -288,7 +288,7 @@ function nufft3d2(xj :: Array{T},
(ms, mt, mu, ntrans) = get_nmodes_from_fk(3,fk)
cj = Array{Complex{T}}(undef, nj, ntrans)
checkkwdtype(T; kwargs...)
nufft3d2!(xj, yj, zj, cj, iflag, eps, fk;kwargs...,dtype=T)
nufft3d2!(xj, yj, zj, cj, iflag, eps, fk;kwargs...)
return cj
end

Expand Down Expand Up @@ -334,7 +334,7 @@ function nufft1d3(xj :: Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, nk, ntrans)
checkkwdtype(T; kwargs...)
nufft1d3!(xj, cj, iflag, eps, sk, fk;kwargs...,dtype=T)
nufft1d3!(xj, cj, iflag, eps, sk, fk;kwargs...)
return fk
end

Expand Down Expand Up @@ -382,7 +382,7 @@ function nufft2d3(xj :: Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, nk, ntrans)
checkkwdtype(T; kwargs...)
nufft2d3!(xj, yj, cj, iflag, eps, sk, tk, fk;kwargs...,dtype=T)
nufft2d3!(xj, yj, cj, iflag, eps, sk, tk, fk;kwargs...)
return fk
end

Expand Down Expand Up @@ -435,7 +435,7 @@ function nufft3d3(xj :: Array{T},
ntrans = valid_ntr(xj,cj)
fk = Array{Complex{T}}(undef, nk, ntrans)
checkkwdtype(T; kwargs...)
nufft3d3!(xj, yj, zj, cj, iflag, eps, sk, tk, uk, fk;kwargs...,dtype=T)
nufft3d3!(xj, yj, zj, cj, iflag, eps, sk, tk, uk, fk;kwargs...)
return fk
end

Expand Down Expand Up @@ -466,7 +466,7 @@ function nufft1d1!(xj :: Array{T},
(ms, ntrans_fk) = get_nmodes_from_fk(1,fk)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(1,[ms;],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(1,[ms;],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -495,7 +495,7 @@ function nufft1d2!(xj :: Array{T},
(ms, ntrans) = get_nmodes_from_fk(1,fk)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(2,[ms;],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(2,[ms;],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj)
finufft_exec!(plan,fk,cj)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -526,7 +526,7 @@ function nufft1d3!(xj :: Array{T},
ntrans = valid_ntr(xj,cj)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(3,1,iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(3,1,iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,T[],T[],sk)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -561,7 +561,7 @@ function nufft2d1!(xj :: Array{T},
@assert ntrans==ntrans_fk

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(1,[ms;mt],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(1,[ms;mt],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -592,7 +592,7 @@ function nufft2d2!(xj :: Array{T},
(ms, mt, ntrans) = get_nmodes_from_fk(2,fk)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(2,[ms;mt],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(2,[ms;mt],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj)
finufft_exec!(plan,fk,cj)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -626,7 +626,7 @@ function nufft2d3!(xj :: Array{T},
ntrans = valid_ntr(xj,cj)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(3,2,iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(3,2,iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj,T[],sk,tk)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -662,7 +662,7 @@ function nufft3d1!(xj :: Array{T},
@assert ntrans == ntrans_fk

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(1,[ms;mt;mu],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(1,[ms;mt;mu],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj,zj)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -694,7 +694,7 @@ function nufft3d2!(xj :: Array{T},
(ms, mt, mu, ntrans) = get_nmodes_from_fk(3,fk)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(2,[ms;mt;mu],iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(2,[ms;mt;mu],iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj,zj)
finufft_exec!(plan,fk,cj)
ret = finufft_destroy!(plan)
Expand Down Expand Up @@ -732,7 +732,7 @@ function nufft3d3!(xj :: Array{T},
ntrans = valid_ntr(xj,cj)

checkkwdtype(T; kwargs...)
plan = finufft_makeplan(3,3,iflag,ntrans,eps;kwargs...)
plan = finufft_makeplan(3,3,iflag,ntrans,eps;dtype=T,kwargs...)
finufft_setpts!(plan,xj,yj,zj,sk,tk,uk)
finufft_exec!(plan,cj,fk)
ret = finufft_destroy!(plan)
Expand Down
38 changes: 19 additions & 19 deletions test/test_nufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ function test_nufft(tol::Real, dtype::DataType)
end
end
# Simple, writing into array, setting some non-default opts...
nufft1d1!(x, c, 1, tol, out, debug=1, spread_sort=0, dtype=T)
nufft1d1!(x, c, 1, tol, out, debug=1, spread_sort=0)
relerr_1d1 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_1d1 < errfac*tol

# Different caller which returns array
out2 = nufft1d1(x, c, 1, tol, ms, dtype=T)
out2 = nufft1d1(x, c, 1, tol, ms)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol

Expand All @@ -85,7 +85,7 @@ function test_nufft(tol::Real, dtype::DataType)
@test relerr_guru_many < errfac*tol

# simple vectorized ("many")
fstack = nufft1d1(x,cstack,+1,tol,ms,dtype=T)
fstack = nufft1d1(x,cstack,+1,tol,ms)
relerr_many = norm(vec(fstack)-vec(refstack), Inf) / norm(vec(refstack), Inf)
@test relerr_many < errfac*tol
end
Expand All @@ -99,10 +99,10 @@ function test_nufft(tol::Real, dtype::DataType)
ref[j] += F1D[ss] * exp(1im*k1[ss]*x[j])
end
end
nufft1d2!(x, out, 1, tol, F1D, dtype=T)
nufft1d2!(x, out, 1, tol, F1D)
relerr_1d2 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_1d2 < errfac*tol
out2 = nufft1d2(x, 1, tol, F1D, dtype=T)
out2 = nufft1d2(x, 1, tol, F1D)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -116,10 +116,10 @@ function test_nufft(tol::Real, dtype::DataType)
ref[k] += c[j] * exp(1im*s[k]*x[j])
end
end
nufft1d3!(x,c,1,tol,s,out, dtype=T)
nufft1d3!(x,c,1,tol,s,out)
relerr_1d3 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_1d3 < errfac*tol
out2 = nufft1d3(x,c,1,tol,s, dtype=T)
out2 = nufft1d3(x,c,1,tol,s)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -138,10 +138,10 @@ function test_nufft(tol::Real, dtype::DataType)
end
end
end
nufft2d1!(x, y, c, 1, tol, out, dtype=T)
nufft2d1!(x, y, c, 1, tol, out)
relerr_2d1 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_2d1 < errfac*tol
out2 = nufft2d1(x, y, c, 1, tol, ms, mt, dtype=T)
out2 = nufft2d1(x, y, c, 1, tol, ms, mt)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -157,10 +157,10 @@ function test_nufft(tol::Real, dtype::DataType)
end
end
end
nufft2d2!(x, y, out, 1, tol, F2D, dtype=T)
nufft2d2!(x, y, out, 1, tol, F2D)
relerr_2d2 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_2d2 < errfac*tol
out2 = nufft2d2(x, y, 1, tol, F2D, dtype=T)
out2 = nufft2d2(x, y, 1, tol, F2D)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -174,10 +174,10 @@ function test_nufft(tol::Real, dtype::DataType)
ref[k] += c[j] * exp(1im*(s[k]*x[j]+t[k]*y[j]))
end
end
nufft2d3!(x,y,c,1,tol,s,t,out, dtype=T)
nufft2d3!(x,y,c,1,tol,s,t,out)
relerr_2d3 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_2d3 < errfac*tol
out2 = nufft2d3(x,y,c,1,tol,s,t, dtype=T)
out2 = nufft2d3(x,y,c,1,tol,s,t)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -198,10 +198,10 @@ function test_nufft(tol::Real, dtype::DataType)
end
end
end
nufft3d1!(x, y, z, c, 1, tol, out, dtype=T)
nufft3d1!(x, y, z, c, 1, tol, out)
relerr_3d1 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_3d1 < errfac*tol
out2 = nufft3d1(x, y, z, c, 1, tol, ms, mt, mu, dtype=T)
out2 = nufft3d1(x, y, z, c, 1, tol, ms, mt, mu)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -219,10 +219,10 @@ function test_nufft(tol::Real, dtype::DataType)
end
end
end
nufft3d2!(x, y, z, out, 1, tol, F3D, dtype=T)
nufft3d2!(x, y, z, out, 1, tol, F3D)
relerr_3d2 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_3d2 < errfac*tol
out2 = nufft3d2(x, y, z, 1, tol, F3D, dtype=T)
out2 = nufft3d2(x, y, z, 1, tol, F3D)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand All @@ -236,10 +236,10 @@ function test_nufft(tol::Real, dtype::DataType)
ref[k] += c[j] * exp(1im*(s[k]*x[j]+t[k]*y[j]+u[k]*z[j]))
end
end
nufft3d3!(x,y,z,c,1,tol,s,t,u,out, dtype=T)
nufft3d3!(x,y,z,c,1,tol,s,t,u,out)
relerr_3d3 = norm(vec(out)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_3d3 < errfac*tol
out2 = nufft3d3(x,y,z,c,1,tol,s,t,u, dtype=T)
out2 = nufft3d3(x,y,z,c,1,tol,s,t,u)
reldiff = norm(vec(out)-vec(out2), Inf) / norm(vec(out), Inf)
@test reldiff < errdifffac*tol
end
Expand Down

2 comments on commit 6be1108

@ludvigak
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67565

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.1.0 -m "<description of version>" 6be1108cf4f60da81281429d2b5784e25b264189
git push origin v3.1.0

Please sign in to comment.