Skip to content

Commit

Permalink
In-place real-to-complex FFTs (#65)
Browse files Browse the repository at this point in the history
* allocate ManyPencilArrayInplaceRFFT for RFFT! plan

* definition of ManyPencilArrayInplaceRFFT for RFFT!

* inplace RFFT! mul! and ldiv!  operations

* define RFFT! and BRFFT! + wrappers to FFTW plans

* updated tests for rfft!

* include ManyPencilArrayInplaceRFFT

* changes to ManyPencilArrayRFFT!

* separated scaling and bacward transform

* fixed rfft! and brff! plan creation

* further rfft! tests

* further rfft! tests

* fixed docstring

* fixed output of _scale!

* fixed triangular dispatch for expand_dims

* Test: use rtol so that commented 256^3 tests pass

* Revert "Test: use rtol so that commented 256^3 tests pass"

This reverts commit 1f82324.

* Formatting

* Simplify dispatch in allocate_input

* Use PencilArrays.AbstractManyPencilArray

* Update version

* Don't test on Julia 1.8

* Update docs

---------

Co-authored-by: Juan Ignacio Polanco <[email protected]>
  • Loading branch information
fbignonnet and jipolanco authored May 25, 2023
1 parent 309eec5 commit a5546f3
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 35 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
experimental: [false]
version:
- '1.7'
- '1.8'
- '~1.9.0-0'
- '1.9'
os:
- ubuntu-latest
arch:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PencilFFTs"
uuid = "4a48f351-57a6-4416-9ec4-c37015456aae"
authors = ["Juan Ignacio Polanco <[email protected]>"]
version = "0.14.4"
version = "0.15.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -16,7 +16,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
AbstractFFTs = "1"
FFTW = "1.6"
MPI = "0.19, 0.20"
PencilArrays = "0.17"
PencilArrays = "0.18"
Reexport = "1"
TimerOutputs = "0.5"
julia = "1.7"
6 changes: 6 additions & 0 deletions docs/src/PencilFFTs.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ scale_factor(::PencilFFTPlan)
timer(::PencilFFTPlan)
is_inplace(::PencilFFTPlan)
```

## Internals

```@docs
ManyPencilArrayRFFT!
```
2 changes: 2 additions & 0 deletions docs/src/Transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ BFFT
BFFT!
RFFT
RFFT!
BRFFT
BRFFT!
R2R
R2R!
Expand Down
1 change: 1 addition & 0 deletions src/PencilFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const AbstractTransformList{N} = NTuple{N, AbstractTransform} where N

include("global_params.jl")
include("plans.jl")
include("multiarrays_r2c.jl")
include("allocate.jl")
include("operations.jl")

Expand Down
85 changes: 76 additions & 9 deletions src/Transforms/r2c.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Real-to-complex and complex-to-real transforms.
using FFTW: FFTW

"""
RFFT()
Expand All @@ -10,6 +11,13 @@ See also
"""
struct RFFT <: AbstractTransform end

"""
RFFT!()
In-place version of [`RFFT`](@ref).
"""
struct RFFT! <: AbstractTransform end

"""
BRFFT(d::Integer)
BRFFT((d1, d2, ..., dN))
Expand Down Expand Up @@ -40,23 +48,40 @@ struct BRFFT <: AbstractTransform
even_output :: Bool
end

_show_extra_info(io::IO, tr::BRFFT) = print(io, tr.even_output ? "{even}" : "{odd}")
"""
BRFFT!(d::Integer)
BRFFT!((d1, d2, ..., dN))
In-place version of [`BRFFT`](@ref).
"""
struct BRFFT! <: AbstractTransform
even_output :: Bool
end

const TransformR2C = Union{RFFT, RFFT!}
const TransformC2R = Union{BRFFT, BRFFT!}

_show_extra_info(io::IO, tr::TransformC2R) = print(io, tr.even_output ? "{even}" : "{odd}")

BRFFT(d::Integer) = BRFFT(iseven(d))
BRFFT(ts::Tuple) = BRFFT(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW)
BRFFT!(d::Integer) = BRFFT!(iseven(d))
BRFFT!(ts::Tuple) = BRFFT!(last(ts)) # c2r transform is applied along the **last** dimension (opposite of FFTW)

is_inplace(::Union{RFFT, BRFFT}) = false
is_inplace(::Union{RFFT!, BRFFT!}) = true

length_output(::RFFT, length_in::Integer) = div(length_in, 2) + 1
length_output(tr::BRFFT, length_in::Integer) = 2 * length_in - 1 - tr.even_output
length_output(::TransformR2C, length_in::Integer) = div(length_in, 2) + 1
length_output(tr::TransformC2R, length_in::Integer) = 2 * length_in - 1 - tr.even_output

eltype_output(::RFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_output(::BRFFT, ::Type{Complex{T}}) where {T <: FFTReal} = T
eltype_output(::TransformR2C, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_output(::TransformC2R, ::Type{Complex{T}}) where {T <: FFTReal} = T

eltype_input(::RFFT, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::BRFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}
eltype_input(::TransformR2C, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::TransformC2R, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::RFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_rfft(A, args...; kwargs...)
plan(::RFFT!, A::AbstractArray, args...; kwargs...) = plan_rfft!(A, args...; kwargs...)

# NOTE: unlike most FFTW plans, this function also requires the length `d` of
# the transform output along the first transformed dimension.
Expand All @@ -65,23 +90,65 @@ function plan(tr::BRFFT, A::AbstractArray, dims; kwargs...)
d = length_output(tr, Nin)
FFTW.plan_brfft(A, d, dims; kwargs...)
end
function plan(tr::BRFFT!, A::AbstractArray, dims; kwargs...)
Nin = size(A, first(dims)) # input length along first dimension
d = length_output(tr, Nin)
plan_brfft!(A, d, dims; kwargs...)
end

binv(::RFFT, d) = BRFFT(d)
binv(::BRFFT, d) = RFFT()
binv(::RFFT!, d) = BRFFT!(d)
binv(::BRFFT!, d) = RFFT!()

function scale_factor(tr::BRFFT, A::ComplexArray, dims)
function scale_factor(tr::TransformC2R, A::ComplexArray, dims)
prod(dims; init = one(Int)) do i
n = size(A, i)
i == last(dims) ? length_output(tr, n) : n
end
end

scale_factor(::RFFT, A::RealArray, dims) = _prod_dims(A, dims)
scale_factor(::TransformR2C, A::RealArray, dims) = _prod_dims(A, dims)

# r2c along the first dimension, then c2c for the other dimensions.
expand_dims(tr::RFFT, ::Val{N}) where {N} =
N === 0 ? () : (tr, expand_dims(FFT(), Val(N - 1))...)
expand_dims(tr::RFFT!, ::Val{N}) where {N} =
N === 0 ? () : (tr, expand_dims(FFT!(), Val(N - 1))...)

expand_dims(tr::BRFFT, ::Val{N}) where {N} = (BFFT(), expand_dims(tr, Val(N - 1))...)
expand_dims(tr::BRFFT!, ::Val{N}) where {N} = (BFFT!(), expand_dims(tr, Val(N - 1))...)
expand_dims(tr::BRFFT, ::Val{1}) = (tr, )
expand_dims(tr::BRFFT, ::Val{0}) = ()
expand_dims(tr::BRFFT!, ::Val{1}) = (tr, )
expand_dims(tr::BRFFT!, ::Val{0}) = ()

## FFTW wrappers for inplace RFFT plans

function plan_rfft!(X::StridedArray{T,N}, region;
flags::Integer=FFTW.ESTIMATE,
timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N}
sz = size(X) # physical input size (real)
osize = FFTW.rfft_output_size(sz, region) # output size (complex)
isize = ntuple(i -> i == first(region) ? 2osize[i] : osize[i], Val(N)) # padded input size (real)
if flags&FFTW.ESTIMATE != 0 # time measurement not required
X_padded = FFTW.FakeArray{T,N}(sz, FFTW.colmajorstrides(isize)) # fake allocation, only pointer, size and strides matter
Y = FFTW.FakeArray{Complex{T}}(osize)
else # need to allocate new array since size of X is too small...
data = Array{T}(undef, prod(isize))
X_padded = view(reshape(data, isize), Base.OneTo.(sz)...) # allocation
Y = reshape(reinterpret(Complex{T}, data), osize)
end
return FFTW.rFFTWPlan{T,FFTW.FORWARD,true,N}(X_padded, Y, region, flags, timelimit)
end

function plan_brfft!(X::StridedArray{Complex{T},N}, d, region;
flags::Integer=FFTW.ESTIMATE,
timelimit::Real=FFTW.NO_TIMELIMIT) where {T<:FFTW.fftwReal,N}
isize = size(X) # input size (complex)
osize = ntuple(i -> i == first(region) ? 2isize[i] : isize[i], Val(N)) # padded output size (real)
sz = FFTW.brfft_output_size(X, d, region) # physical output size (real)
Yflat = reinterpret(T, reshape(X, prod(isize)))
Y = view(reshape(Yflat, osize), Base.OneTo.(sz)...) # Y is padded
return FFTW.rFFTWPlan{Complex{T},FFTW.BACKWARD,true,N}(X, Y, region, flags, timelimit)
end
42 changes: 33 additions & 9 deletions src/allocate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ size `dims`, and a tuple of `N` `PencilArray`s.
!!! note "In-place plans"
If `p` is an in-place plan, a
If `p` is an in-place real-to-real or complex-to-complex plan, a
[`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray)
is allocated. This
type holds `PencilArray` wrappers for the input and output transforms (as
is allocated. If `p` is an in-place real-to-complex plan, a
[`ManyPencilArrayRFFT!`](@ref) is allocated.
These types hold `PencilArray` wrappers for the input and output transforms (as
well as for intermediate transforms) which share the same space in memory.
The input and output `PencilArray`s should be respectively accessed by
calling [`first(::ManyPencilArray)`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#Base.first-Tuple{ManyPencilArray}) and
Expand All @@ -39,17 +41,26 @@ size `dims`, and a tuple of `N` `PencilArray`s.
# p * v_in # not allowed!!
```
"""
function allocate_input end
function allocate_input(p::PencilFFTPlan)
inplace = is_inplace(p)
_allocate_input(Val(inplace), p)
end

# Out-of-place version
function allocate_input(p::PencilFFTPlan{T,N,false} where {T,N})
function _allocate_input(inplace::Val{false}, p::PencilFFTPlan)
T = eltype_input(p)
pen = pencil_input(p)
PencilArray{T}(undef, pen, p.extra_dims...)
end

# In-place version
function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N})
function _allocate_input(inplace::Val{true}, p::PencilFFTPlan)
(; transforms,) = p.global_params
_allocate_input(inplace, p, transforms...)
end

# In-place: generic case
function _allocate_input(inplace::Val{true}, p::PencilFFTPlan, transforms...)
pencils = map(pp -> pp.pencil_in, p.plans)

# Note that for each 1D plan, the input and output pencils are the same.
Expand All @@ -61,6 +72,16 @@ function allocate_input(p::PencilFFTPlan{T,N,true} where {T,N})
ManyPencilArray{T}(undef, pencils...; extra_dims=p.extra_dims)
end

# In-place: specific case of RFFT!
function _allocate_input(
inplace::Val{true}, p::PencilFFTPlan{T},
::Transforms.RFFT!, ::Vararg{Transforms.FFT!},
) where {T}
plans = p.plans
pencils = (first(plans).pencil_in, first(plans).pencil_out, map(pp -> pp.pencil_in, plans[2:end])...)
ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=p.extra_dims)
end

allocate_input(p::PencilFFTPlan, dims...) =
_allocate_many(allocate_input, p, dims...)

Expand All @@ -76,17 +97,20 @@ If `p` is an in-place plan, a [`ManyPencilArray`](https://jipolanco.github.io/Pe
See [`allocate_input`](@ref) for details.
"""
function allocate_output end
function allocate_output(p::PencilFFTPlan)
inplace = is_inplace(p)
_allocate_output(Val(inplace), p)
end

# Out-of-place version.
function allocate_output(p::PencilFFTPlan{T,N,false} where {T,N})
function _allocate_output(inplace::Val{false}, p::PencilFFTPlan)
T = eltype_output(p)
pen = pencil_output(p)
PencilArray{T}(undef, pen, p.extra_dims...)
end

# For in-place plans, the output and input are the same ManyPencilArray.
allocate_output(p::PencilFFTPlan{T,N,true} where {T,N}) = allocate_input(p)
_allocate_output(inplace::Val{true}, p::PencilFFTPlan) = _allocate_input(inplace, p)

allocate_output(p::PencilFFTPlan, dims...) =
_allocate_many(allocate_output, p, dims...)
Expand Down
79 changes: 79 additions & 0 deletions src/multiarrays_r2c.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# copied and modified from https://github.com/jipolanco/PencilArrays.jl/blob/master/src/multiarrays.jl
import PencilArrays: AbstractManyPencilArray, _make_arrays

"""
ManyPencilArrayRFFT!{T,N,M} <: AbstractManyPencilArray{N,M}
Container holding `M` different [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray) views to the same
underlying data buffer. All views share the same and dimensionality `N`.
The element type `T` of the first view is real, that of subsequent views is
`Complex{T}`.
This can be used to perform in-place real-to-complex plan, see also[`Transforms.RFFT!`](@ref).
It is used internally for such transforms by [`allocate_input`](@ref) and should not be constructed directly.
---
ManyPencilArrayRFFT!{T}(undef, pencils...; extra_dims=())
Create a `ManyPencilArrayRFFT!` container that can hold data of type `T` and `Complex{T}` associated
to all the given [`Pencil`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.Pencil)s.
The optional `extra_dims` argument is the same as for [`PencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.PencilArray).
See also [`ManyPencilArray`](https://jipolanco.github.io/PencilArrays.jl/dev/PencilArrays/#PencilArrays.ManyPencilArray)
"""
struct ManyPencilArrayRFFT!{
T, # element type of real array
N, # number of dimensions of each array (including extra_dims)
M, # number of arrays
Arrays <: Tuple{Vararg{PencilArray,M}},
DataVector <: AbstractVector{T},
DataVectorComplex <: AbstractVector{Complex{T}},
} <: AbstractManyPencilArray{N, M}
data :: DataVector
data_complex :: DataVectorComplex
arrays :: Arrays

function ManyPencilArrayRFFT!{T}(
init, real_pencil::Pencil{Np}, complex_pencils::Vararg{Pencil{Np}};
extra_dims::Dims=()
) where {Np,T<:FFTReal}
# real_pencil is a Pencil with dimensions `dims` of a real array with no padding and no permutation
# the padded dimensions are (2*(dims[1] ÷ 2 + 1), dims[2:end]...)
# first(complex_pencils) is a Pencil with dimensions of a complex array (dims[1] ÷ 2 + 1, dims[2:end]...) and no permutation
pencils = (real_pencil, complex_pencils...)
BufType = PencilArrays.typeof_array(real_pencil)
@assert all(p -> PencilArrays.typeof_array(p) === BufType, complex_pencils)
@assert size_global(real_pencil)[2:end] == size_global(first(complex_pencils))[2:end]
@assert first(size_global(real_pencil)) ÷ 2 + 1 == first(size_global(first(complex_pencils)))

data_length = max(2 .* length.(complex_pencils)...) * prod(extra_dims)
data_real = BufType{T}(init, data_length)

# we don't use data_complex = reinterpret(Complex{T}, data_real)
# since there is an issue with StridedView of ReinterpretArray, called by _permutedims in PencilArrays.Transpositions
ptr_complex = convert(Ptr{Complex{T}}, pointer(data_real))
data_complex = unsafe_wrap(BufType, ptr_complex, data_length ÷ 2)

array_real = _make_real_array(data_real, extra_dims, real_pencil)
arrays_complex = PencilArrays._make_arrays(data_complex, extra_dims, complex_pencils...)
arrays = (array_real, arrays_complex...)

N = Np + length(extra_dims)
M = length(pencils)
new{T, N, M, typeof(arrays), typeof(data_real), typeof(data_complex)}(data_real, data_complex, arrays)
end
end

function _make_real_array(data, extra_dims, p)
dims_space_local = size_local(p, MemoryOrder())
dims_padded_local = (2*(dims_space_local[1] ÷ 2 + 1), dims_space_local[2:end]...)
dims = (dims_padded_local..., extra_dims...)
axes_local = (Base.OneTo.(dims_space_local)..., Base.OneTo.(extra_dims)...)
n = prod(dims)
vec = view(data, Base.OneTo(n))
parent_arr = reshape(vec, dims)
arr = view(parent_arr, axes_local...)
PencilArray(p, arr)
end
Loading

0 comments on commit a5546f3

Please sign in to comment.