Skip to content

Commit

Permalink
Move FFTW-dependent functions to extension
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jul 2, 2024
1 parent 8f77455 commit d85ff99
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 131 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.9.17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -18,28 +17,31 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"

[compat]
AMDGPU = "0.9.4"
Adapt = "3.2, 4"
Atomix = "0.1"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
EnzymeCore = "0.5, 0.6, 0.7"
FFTW = "1.8.0"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
LinearAlgebra = "<0.0.1, 1"
Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.9"
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Documenter, NNlib

DocMeta.setdocmeta!(NNlib, :DocTestSetup,
:(using NNlib, UnicodePlots); recursive = true)
:(using FFTW, NNlib, UnicodePlots); recursive = true)

makedocs(modules = [NNlib],
sitename = "NNlib.jl",
Expand Down
4 changes: 4 additions & 0 deletions docs/src/audio.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Reference

!!! note
Spectral functions require importing `FFTW` package to enable them.

## Window functions

```@docs
Expand All @@ -26,6 +29,7 @@ spectrogram
Example:

```@example 1
using FFTW # <- required for STFT support.
using NNlib
using FileIO
using Makie, CairoMakie
Expand Down
9 changes: 9 additions & 0 deletions ext/NNlibFFTWExt/NNlibFFTWExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NNlibFFTWExt

using FFTW
using NNlib
using KernelAbstractions

include("stft.jl")

end
127 changes: 127 additions & 0 deletions ext/NNlibFFTWExt/stft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
function NNlib.stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
kab = get_backend(x)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as stft input `x` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp

Check warning on line 22 in ext/NNlibFFTWExt/stft.jl

View check run for this annotation

Codecov / codecov/patch

ext/NNlibFFTWExt/stft.jl#L19-L22

Added lines #L19 - L22 were not covered by tests
end

if center
pad_amount = n_fft ÷ 2
x = pad_reflect(x, pad_amount; dims=1)
end

n = size(x, 1)
(0 < n_fft n) || throw(ArgumentError(
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))

n_frames = 1 + (n - n_fft) ÷ hop_length

# time2col.
# Reshape `x` to (n_fft, n_frames, B) if needed.
# Each row in `n_frames` is shifted by `hop_length`.
if n_frames > 1
# TODO can be more efficient if we support something like torch.as_strided
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
use_window && (x = x .* window;)
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)

normalized && (y = y .* eltype(y)(n_fft^-0.5);)
return y
end

function NNlib.istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
kab = get_backend(y)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as istft input `y` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# TODO check `y` eltype is complex

n_frames = size(y, 2)

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp

Check warning on line 82 in ext/NNlibFFTWExt/stft.jl

View check run for this annotation

Codecov / codecov/patch

ext/NNlibFFTWExt/stft.jl#L79-L82

Added lines #L79 - L82 were not covered by tests
end

# Denormalize.
normalized && (y = y .* eltype(y)(n_fft^0.5);)

region = 1
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)

# De-apply window.
use_window && (x = x ./ window;)

# col2time.
expected_output_len = n_fft + hop_length * (n_frames - 1)

ids = Vector{Int}(undef, expected_output_len)
in_idx, out_idx = 0, 0
prev_e, v = 0, 0

for col in 0:(n_frames - 1)
for row in 1:n_fft
in_idx += 1
v = row + hop_length * col
v > prev_e || continue

out_idx += 1
ids[out_idx] = in_idx
end
prev_e = v
end

# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
right = if isnothing(original_length)
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
else
left + original_length - 1
end
x = x[left:right, nd...]
return x
end
1 change: 0 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ using Random
using Requires
using Statistics
using Statistics: mean
using FFTW

const libblas = Base.libblas_name

Expand Down
128 changes: 2 additions & 126 deletions src/audio/stft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,59 +168,7 @@ and ``m`` is the index of the sliding window.
Complex array of shape `(n_fft, n_frames, B)`,
where `B` is the optional batch dimension.
"""
function stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
kab = get_backend(x)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as stft input `x` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end

if center
pad_amount = n_fft ÷ 2
x = pad_reflect(x, pad_amount; dims=1)
end

n = size(x, 1)
(0 < n_fft n) || throw(ArgumentError(
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))

n_frames = 1 + (n - n_fft) ÷ hop_length

# time2col.
# Reshape `x` to (n_fft, n_frames, B) if needed.
# Each row in `n_frames` is shifted by `hop_length`.
if n_frames > 1
# TODO can be more efficient if we support something like torch.as_strided
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
use_window && (x = x .* window;)
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)

normalized && (y = y .* eltype(y)(n_fft^-0.5);)
return y
end
function stft end

"""
istft(y;
Expand Down Expand Up @@ -255,76 +203,4 @@ Return the least squares estimation of the original signal
of the input to `stft`. Helps restoring the exact `stft` input size.
Otherwise, the array might be a bit shorter.
"""
function istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
kab = get_backend(y)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as istft input `y` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# TODO check `y` eltype is complex

n_frames = size(y, 2)

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end

# Denormalize.
normalized && (y = y .* eltype(y)(n_fft^0.5);)

region = 1
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)

# De-apply window.
use_window && (x = x ./ window;)

# col2time.
expected_output_len = n_fft + hop_length * (n_frames - 1)

ids = Vector{Int}(undef, expected_output_len)
in_idx, out_idx = 0, 0
prev_e, v = 0, 0

for col in 0:(n_frames - 1)
for row in 1:n_fft
in_idx += 1
v = row + hop_length * col
v > prev_e || continue

out_idx += 1
ids[out_idx] = in_idx
end
prev_e = v
end

# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
right = if isnothing(original_length)
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
else
left + original_length - 1
end
x = x[left:right, nd...]
return x
end
function istft end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using Adapt
using ImageTransformations
using Interpolations: Constant
using KernelAbstractions
using FFTW
import ReverseDiff as RD # used in `pooling.jl`
import Pkg

Expand Down

0 comments on commit d85ff99

Please sign in to comment.