Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Short-time Fourier transform and its inverse #587

Merged
merged 13 commits into from
Jul 4, 2024
8 changes: 0 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ jobs:
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: |
julia --color=yes --project=docs/ -e '
using NNlib
# using Pkg; Pkg.activate("docs")
using Documenter
using Documenter: doctest
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true)
doctest(NNlib)'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
32 changes: 6 additions & 26 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,31 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"

[compat]
AMDGPU = "0.8, 0.9"
AMDGPU = "0.9.4"
Adapt = "3.2, 4"
Atomix = "0.1"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
EnzymeCore = "0.5, 0.6, 0.7"
FFTW = "1.8.0"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
LinearAlgebra = "<0.0.1, 1"
Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.9"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"]
6 changes: 6 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
26 changes: 14 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using Documenter, NNlib

DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true)
DocMeta.setdocmeta!(NNlib, :DocTestSetup,
:(using FFTW, NNlib, UnicodePlots); recursive = true)

makedocs(modules = [NNlib],
sitename = "NNlib.jl",
doctest = false,
pages = ["Home" => "index.md",
"Reference" => "reference.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)
sitename = "NNlib.jl",
doctest = true,
pages = ["Home" => "index.md",
"Reference" => "reference.md",
"Audio" => "audio.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)

deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
target = "build",
Expand Down
Binary file added docs/src/assets/jfk.flac
Binary file not shown.
61 changes: 61 additions & 0 deletions docs/src/audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Reference

!!! note
Spectral functions require importing `FFTW` package to enable them.

## Window functions

```@docs
hann_window
hamming_window
```

## Spectral

```@docs
stft
istft
NNlib.power_to_db
NNlib.db_to_power
```

## Spectrogram

```@docs
melscale_filterbanks
spectrogram
```

Example:

```@example 1
using FFTW # <- required for STFT support.
using NNlib
using FileIO
using Makie, CairoMakie
CairoMakie.activate!()

waveform, sampling_rate = load("./assets/jfk.flac")
fig = lines(reshape(waveform, :))
save("waveform.png", fig)

# Spectrogram.

n_fft = 1024
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
save("spectrogram.png", fig)

# Mel-scale spectrogram.

n_freqs = n_fft ÷ 2 + 1
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
save("mel-spectrogram.png", fig)
nothing # hide
```

|Waveform|Spectrogram|Mel Spectrogram|
|:---:|:---:|:---:|
|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|
9 changes: 9 additions & 0 deletions ext/NNlibFFTWExt/NNlibFFTWExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NNlibFFTWExt

using FFTW
using NNlib
using KernelAbstractions

include("stft.jl")

end
127 changes: 127 additions & 0 deletions ext/NNlibFFTWExt/stft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
function NNlib.stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
kab = get_backend(x)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as stft input `x` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp

Check warning on line 22 in ext/NNlibFFTWExt/stft.jl

View check run for this annotation

Codecov / codecov/patch

ext/NNlibFFTWExt/stft.jl#L19-L22

Added lines #L19 - L22 were not covered by tests
end

if center
pad_amount = n_fft ÷ 2
x = pad_reflect(x, pad_amount; dims=1)
end

n = size(x, 1)
(0 < n_fft ≤ n) || throw(ArgumentError(
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))

n_frames = 1 + (n - n_fft) ÷ hop_length

# time2col.
# Reshape `x` to (n_fft, n_frames, B) if needed.
# Each row in `n_frames` is shifted by `hop_length`.
if n_frames > 1
# TODO can be more efficient if we support something like torch.as_strided
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
use_window && (x = x .* window;)
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)

normalized && (y = y .* eltype(y)(n_fft^-0.5);)
return y
end

function NNlib.istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
kab = get_backend(y)
use_window = !isnothing(window)

use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as istft input `y` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

# TODO check `y` eltype is complex

n_frames = size(y, 2)

# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp

Check warning on line 82 in ext/NNlibFFTWExt/stft.jl

View check run for this annotation

Codecov / codecov/patch

ext/NNlibFFTWExt/stft.jl#L79-L82

Added lines #L79 - L82 were not covered by tests
end

# Denormalize.
normalized && (y = y .* eltype(y)(n_fft^0.5);)

region = 1
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)

# De-apply window.
use_window && (x = x ./ window;)

# col2time.
expected_output_len = n_fft + hop_length * (n_frames - 1)

ids = Vector{Int}(undef, expected_output_len)
in_idx, out_idx = 0, 0
prev_e, v = 0, 0

for col in 0:(n_frames - 1)
for row in 1:n_fft
in_idx += 1
v = row + hop_length * col
v > prev_e || continue

out_idx += 1
ids[out_idx] = in_idx
end
prev_e = v
end

# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
right = if isnothing(original_length)
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
else
left + original_length - 1
end
x = x[left:right, nd...]
return x
end
5 changes: 5 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,9 @@ include("deprecations.jl")
include("rotation.jl")
export imrotate, ∇imrotate

include("audio/stft.jl")
include("audio/spectrogram.jl")
include("audio/mel.jl")
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks

end # module NNlib
Loading
Loading